Added a minimal Fortran interface to Astaroth

This commit is contained in:
jpekkila
2020-06-25 06:34:16 +03:00
parent 70ecacee7c
commit fbb8d7c7c6
4 changed files with 74 additions and 19 deletions

View File

@@ -55,7 +55,8 @@ message(STATUS "AC module dir: ${DSL_MODULE_DIR}")
file(GLOB DSL_SOURCES ${DSL_MODULE_DIR}/* file(GLOB DSL_SOURCES ${DSL_MODULE_DIR}/*
${CMAKE_SOURCE_DIR}/acc/stdlib/*) ${CMAKE_SOURCE_DIR}/acc/stdlib/*)
set(DSL_HEADERS "${PROJECT_BINARY_DIR}/user_kernels.h" set(DSL_HEADERS "${PROJECT_BINARY_DIR}/user_kernels.h"
"${PROJECT_BINARY_DIR}/user_defines.h") "${PROJECT_BINARY_DIR}/user_defines.h"
"${PROJECT_BINARY_DIR}/astaroth.f90")
add_custom_command ( add_custom_command (
COMMENT "Building ACC objects ${DSL_MODULE_DIR}" COMMENT "Building ACC objects ${DSL_MODULE_DIR}"
@@ -99,6 +100,7 @@ if (BUILD_SAMPLES)
add_subdirectory(samples/bwtest) add_subdirectory(samples/bwtest)
add_subdirectory(samples/genbenchmarkscripts) add_subdirectory(samples/genbenchmarkscripts)
add_subdirectory(samples/mpi_reduce_bench) add_subdirectory(samples/mpi_reduce_bench)
add_subdirectory(samples/fortrantest)
endif() endif()
if (BUILD_STANDALONE) if (BUILD_STANDALONE)

View File

@@ -39,9 +39,11 @@ ASTNode* root = NULL;
// Output files // Output files
static FILE* DSLHEADER = NULL; static FILE* DSLHEADER = NULL;
static FILE* CUDAHEADER = NULL; static FILE* CUDAHEADER = NULL;
static FILE* FHEADER = NULL;
static const char* dslheader_filename = "user_defines.h"; static const char* dslheader_filename = "user_defines.h";
static const char* cudaheader_filename = "user_kernels.h"; static const char* cudaheader_filename = "user_kernels.h";
static const char* fheader_filename = "astaroth.f90";
// Forward declaration of yyparse // Forward declaration of yyparse
int yyparse(void); int yyparse(void);
@@ -98,7 +100,8 @@ static const char* translation_table[TRANSLATION_TABLE_SIZE] = {
['<'] = "<", ['<'] = "<",
['>'] = ">", ['>'] = ">",
['!'] = "!", ['!'] = "!",
['.'] = "."}; ['.'] = ".",
};
static const char* static const char*
translate(const int token) translate(const int token)
@@ -261,9 +264,8 @@ traverse(const ASTNode* node)
if (typequal->token == KERNEL) { if (typequal->token == KERNEL) {
fprintf(CUDAHEADER, "GEN_KERNEL_PARAM_BOILERPLATE"); fprintf(CUDAHEADER, "GEN_KERNEL_PARAM_BOILERPLATE");
if (node->lhs != NULL) { if (node->lhs != NULL) {
fprintf( fprintf(stderr, "Syntax error: function parameters for Kernel functions not "
stderr, "allowed!\n");
"Syntax error: function parameters for Kernel functions not allowed!\n");
} }
} }
else if (typequal->token == PREPROCESSED) { else if (typequal->token == PREPROCESSED) {
@@ -597,8 +599,10 @@ generate_preprocessed_structures(void)
} }
static void static void
generate_header(void) generate_headers(void)
{ {
int enumcounter = 0;
fprintf(DSLHEADER, "#pragma once\n"); fprintf(DSLHEADER, "#pragma once\n");
// Int params // Int params
@@ -606,56 +610,103 @@ generate_header(void)
for (size_t i = 0; i < num_symbols[current_nest]; ++i) { for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == INT && symbol_table[i].type_qualifier == UNIFORM) { if (symbol_table[i].type_specifier == INT && symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier); fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier,
enumcounter);
++enumcounter;
} }
} }
fprintf(DSLHEADER, "\n\n"); fprintf(DSLHEADER, "\n\n");
fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_INT_PARAMS = %d\n\n", enumcounter);
// Int3 params // Int3 params
fprintf(DSLHEADER, "#define AC_FOR_USER_INT3_PARAM_TYPES(FUNC)"); fprintf(DSLHEADER, "#define AC_FOR_USER_INT3_PARAM_TYPES(FUNC)");
enumcounter = 0;
for (size_t i = 0; i < num_symbols[current_nest]; ++i) { for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == INT3 && symbol_table[i].type_qualifier == UNIFORM) { if (symbol_table[i].type_specifier == INT3 && symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier); fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier,
enumcounter);
++enumcounter;
} }
} }
fprintf(DSLHEADER, "\n\n"); fprintf(DSLHEADER, "\n\n");
fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_INT3_PARAMS = %d\n\n", enumcounter);
// Scalar params // Scalar params
fprintf(DSLHEADER, "#define AC_FOR_USER_REAL_PARAM_TYPES(FUNC)"); fprintf(DSLHEADER, "#define AC_FOR_USER_REAL_PARAM_TYPES(FUNC)");
enumcounter = 0;
for (size_t i = 0; i < num_symbols[current_nest]; ++i) { for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == SCALAR && symbol_table[i].type_qualifier == UNIFORM) { if (symbol_table[i].type_specifier == SCALAR && symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier); fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier,
enumcounter);
++enumcounter;
} }
} }
fprintf(DSLHEADER, "\n\n"); fprintf(DSLHEADER, "\n\n");
fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_REAL_PARAMS = %d\n\n", enumcounter);
// Vector params // Vector params
fprintf(DSLHEADER, "#define AC_FOR_USER_REAL3_PARAM_TYPES(FUNC)"); fprintf(DSLHEADER, "#define AC_FOR_USER_REAL3_PARAM_TYPES(FUNC)");
enumcounter = 0;
for (size_t i = 0; i < num_symbols[current_nest]; ++i) { for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == VECTOR && symbol_table[i].type_qualifier == UNIFORM) { if (symbol_table[i].type_specifier == VECTOR && symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier); fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier,
enumcounter);
++enumcounter;
} }
} }
fprintf(DSLHEADER, "\n\n"); fprintf(DSLHEADER, "\n\n");
fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_REAL3_PARAMS = %d\n\n", enumcounter);
// Scalar fields // Scalar fields
fprintf(DSLHEADER, "#define AC_FOR_VTXBUF_HANDLES(FUNC)"); fprintf(DSLHEADER, "#define AC_FOR_VTXBUF_HANDLES(FUNC)");
enumcounter = 0;
for (size_t i = 0; i < num_symbols[current_nest]; ++i) { for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == SCALARFIELD && if (symbol_table[i].type_specifier == SCALARFIELD &&
symbol_table[i].type_qualifier == UNIFORM) { symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier); fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier,
enumcounter);
++enumcounter;
} }
} }
fprintf(DSLHEADER, "\n\n"); fprintf(DSLHEADER, "\n\n");
fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_VTXBUF_HANDLES = %d\n\n", enumcounter);
// Scalar arrays // Scalar arrays
fprintf(DSLHEADER, "#define AC_FOR_SCALARARRAY_HANDLES(FUNC)"); fprintf(DSLHEADER, "#define AC_FOR_SCALARARRAY_HANDLES(FUNC)");
enumcounter = 0;
for (size_t i = 0; i < num_symbols[current_nest]; ++i) { for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == SCALARARRAY && if (symbol_table[i].type_specifier == SCALARARRAY &&
symbol_table[i].type_qualifier == UNIFORM) { symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier); fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier,
enumcounter);
++enumcounter;
} }
} }
fprintf(DSLHEADER, "\n\n"); fprintf(DSLHEADER, "\n\n");
fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_SCALARRAY_HANDLES = %d\n\n", enumcounter);
// Do Fortran-specific
const char* fortran_structs = R"(
integer, parameter :: precision = c_float ! TODO WARNING
type, bind(C) :: AcMeshInfo
integer(c_int), dimension(AC_NUM_INT_PARAMS) :: int_params
integer(c_int), dimension(AC_NUM_INT3_PARAMS, 3) :: int3_params
real(precision), dimension(AC_NUM_REAL_PARAMS) :: real_params
real(precision), dimension(AC_NUM_REAL3_PARAMS, 3) :: real3_params
end type AcMeshInfo
)";
fprintf(FHEADER, "%s\n", fortran_structs);
} }
static void static void
@@ -681,20 +732,21 @@ main(void)
DSLHEADER = fopen(dslheader_filename, "w+"); DSLHEADER = fopen(dslheader_filename, "w+");
CUDAHEADER = fopen(cudaheader_filename, "w+"); CUDAHEADER = fopen(cudaheader_filename, "w+");
FHEADER = fopen(fheader_filename, "w+");
assert(DSLHEADER); assert(DSLHEADER);
assert(CUDAHEADER); assert(CUDAHEADER);
assert(FHEADER);
// Add built-in param symbols // Add built-in param symbols
for (size_t i = 0; i < ARRAY_SIZE(builtin_int_params); ++i) { for (size_t i = 0; i < ARRAY_SIZE(builtin_int_params); ++i)
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, builtin_int_params[i]); add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, builtin_int_params[i]);
}
for (size_t i = 0; i < ARRAY_SIZE(builtin_int3_params); ++i) { for (size_t i = 0; i < ARRAY_SIZE(builtin_int3_params); ++i)
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT3, builtin_int3_params[i]); add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT3, builtin_int3_params[i]);
}
// Generate // Generate
traverse(root); traverse(root);
generate_header(); generate_headers();
generate_preprocessed_structures(); generate_preprocessed_structures();
generate_library_hooks(); generate_library_hooks();
@@ -703,9 +755,12 @@ main(void)
// Cleanup // Cleanup
fclose(DSLHEADER); fclose(DSLHEADER);
fclose(CUDAHEADER); fclose(CUDAHEADER);
fclose(FHEADER);
astnode_destroy(root); astnode_destroy(root);
fprintf(stdout, "-- Generated %s\n", dslheader_filename); fprintf(stdout, "-- Generated %s\n", dslheader_filename);
fprintf(stdout, "-- Generated %s\n", cudaheader_filename); fprintf(stdout, "-- Generated %s\n", cudaheader_filename);
fprintf(stdout, "-- Generated %s\n", fheader_filename);
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }

View File

@@ -322,13 +322,11 @@ AcResult acGridIntegrate(const Stream stream, const AcReal dt);
AcResult acGridPeriodicBoundconds(const Stream stream); AcResult acGridPeriodicBoundconds(const Stream stream);
/** TODO */ /** TODO */
AcResult AcResult acGridReduceScal(const Stream stream, const ReductionType rtype,
acGridReduceScal(const Stream stream, const ReductionType rtype,
const VertexBufferHandle vtxbuf_handle, AcReal* result); const VertexBufferHandle vtxbuf_handle, AcReal* result);
/** TODO */ /** TODO */
AcResult AcResult acGridReduceVec(const Stream stream, const ReductionType rtype,
acGridReduceVec(const Stream stream, const ReductionType rtype,
const VertexBufferHandle vtxbuf0, const VertexBufferHandle vtxbuf1, const VertexBufferHandle vtxbuf0, const VertexBufferHandle vtxbuf1,
const VertexBufferHandle vtxbuf2, AcReal* result); const VertexBufferHandle vtxbuf2, AcReal* result);
#endif // AC_MPI_ENABLED #endif // AC_MPI_ENABLED

View File

@@ -1,7 +1,7 @@
find_package(CUDAToolkit) find_package(CUDAToolkit)
## Astaroth Core ## Astaroth Core
add_library(astaroth_core STATIC device.cc node.cc astaroth.cc) add_library(astaroth_core STATIC device.cc node.cc astaroth.cc astaroth_fortran.cc)
target_link_libraries(astaroth_core astaroth_utils astaroth_kernels CUDA::cudart CUDA::cuda_driver) target_link_libraries(astaroth_core astaroth_utils astaroth_kernels CUDA::cudart CUDA::cuda_driver)
## Options ## Options