diff --git a/CMakeLists.txt b/CMakeLists.txt index b2722d1..5152379 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,7 +55,8 @@ message(STATUS "AC module dir: ${DSL_MODULE_DIR}") file(GLOB DSL_SOURCES ${DSL_MODULE_DIR}/* ${CMAKE_SOURCE_DIR}/acc/stdlib/*) set(DSL_HEADERS "${PROJECT_BINARY_DIR}/user_kernels.h" - "${PROJECT_BINARY_DIR}/user_defines.h") + "${PROJECT_BINARY_DIR}/user_defines.h" + "${PROJECT_BINARY_DIR}/astaroth.f90") add_custom_command ( COMMENT "Building ACC objects ${DSL_MODULE_DIR}" @@ -99,6 +100,7 @@ if (BUILD_SAMPLES) add_subdirectory(samples/bwtest) add_subdirectory(samples/genbenchmarkscripts) add_subdirectory(samples/mpi_reduce_bench) + add_subdirectory(samples/fortrantest) endif() if (BUILD_STANDALONE) diff --git a/acc/src/code_generator.c b/acc/src/code_generator.c index f69d186..8738590 100644 --- a/acc/src/code_generator.c +++ b/acc/src/code_generator.c @@ -39,9 +39,11 @@ ASTNode* root = NULL; // Output files static FILE* DSLHEADER = NULL; static FILE* CUDAHEADER = NULL; +static FILE* FHEADER = NULL; static const char* dslheader_filename = "user_defines.h"; static const char* cudaheader_filename = "user_kernels.h"; +static const char* fheader_filename = "astaroth.f90"; // Forward declaration of yyparse int yyparse(void); @@ -98,7 +100,8 @@ static const char* translation_table[TRANSLATION_TABLE_SIZE] = { ['<'] = "<", ['>'] = ">", ['!'] = "!", - ['.'] = "."}; + ['.'] = ".", +}; static const char* translate(const int token) @@ -261,9 +264,8 @@ traverse(const ASTNode* node) if (typequal->token == KERNEL) { fprintf(CUDAHEADER, "GEN_KERNEL_PARAM_BOILERPLATE"); if (node->lhs != NULL) { - fprintf( - stderr, - "Syntax error: function parameters for Kernel functions not allowed!\n"); + fprintf(stderr, "Syntax error: function parameters for Kernel functions not " + "allowed!\n"); } } else if (typequal->token == PREPROCESSED) { @@ -597,8 +599,10 @@ generate_preprocessed_structures(void) } static void -generate_header(void) +generate_headers(void) { + int enumcounter = 0; + fprintf(DSLHEADER, "#pragma once\n"); // Int params @@ -606,56 +610,103 @@ generate_header(void) for (size_t i = 0; i < num_symbols[current_nest]; ++i) { if (symbol_table[i].type_specifier == INT && symbol_table[i].type_qualifier == UNIFORM) { fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier); + fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier, + enumcounter); + ++enumcounter; } } fprintf(DSLHEADER, "\n\n"); + fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_INT_PARAMS = %d\n\n", enumcounter); // Int3 params fprintf(DSLHEADER, "#define AC_FOR_USER_INT3_PARAM_TYPES(FUNC)"); + enumcounter = 0; for (size_t i = 0; i < num_symbols[current_nest]; ++i) { if (symbol_table[i].type_specifier == INT3 && symbol_table[i].type_qualifier == UNIFORM) { fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier); + + fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier, + enumcounter); + ++enumcounter; } } fprintf(DSLHEADER, "\n\n"); + fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_INT3_PARAMS = %d\n\n", enumcounter); // Scalar params fprintf(DSLHEADER, "#define AC_FOR_USER_REAL_PARAM_TYPES(FUNC)"); + enumcounter = 0; for (size_t i = 0; i < num_symbols[current_nest]; ++i) { if (symbol_table[i].type_specifier == SCALAR && symbol_table[i].type_qualifier == UNIFORM) { fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier); + + fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier, + enumcounter); + ++enumcounter; } } fprintf(DSLHEADER, "\n\n"); + fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_REAL_PARAMS = %d\n\n", enumcounter); // Vector params fprintf(DSLHEADER, "#define AC_FOR_USER_REAL3_PARAM_TYPES(FUNC)"); + enumcounter = 0; for (size_t i = 0; i < num_symbols[current_nest]; ++i) { if (symbol_table[i].type_specifier == VECTOR && symbol_table[i].type_qualifier == UNIFORM) { fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier); + + fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier, + enumcounter); + ++enumcounter; } } fprintf(DSLHEADER, "\n\n"); + fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_REAL3_PARAMS = %d\n\n", enumcounter); // Scalar fields fprintf(DSLHEADER, "#define AC_FOR_VTXBUF_HANDLES(FUNC)"); + enumcounter = 0; for (size_t i = 0; i < num_symbols[current_nest]; ++i) { if (symbol_table[i].type_specifier == SCALARFIELD && symbol_table[i].type_qualifier == UNIFORM) { fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier); + + fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier, + enumcounter); + ++enumcounter; } } fprintf(DSLHEADER, "\n\n"); + fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_VTXBUF_HANDLES = %d\n\n", enumcounter); // Scalar arrays fprintf(DSLHEADER, "#define AC_FOR_SCALARARRAY_HANDLES(FUNC)"); + enumcounter = 0; for (size_t i = 0; i < num_symbols[current_nest]; ++i) { if (symbol_table[i].type_specifier == SCALARARRAY && symbol_table[i].type_qualifier == UNIFORM) { fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier); + + fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier, + enumcounter); + ++enumcounter; } } fprintf(DSLHEADER, "\n\n"); + fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_SCALARRAY_HANDLES = %d\n\n", enumcounter); + + // Do Fortran-specific + const char* fortran_structs = R"( + integer, parameter :: precision = c_float ! TODO WARNING + + type, bind(C) :: AcMeshInfo + integer(c_int), dimension(AC_NUM_INT_PARAMS) :: int_params + integer(c_int), dimension(AC_NUM_INT3_PARAMS, 3) :: int3_params + real(precision), dimension(AC_NUM_REAL_PARAMS) :: real_params + real(precision), dimension(AC_NUM_REAL3_PARAMS, 3) :: real3_params + end type AcMeshInfo + )"; + fprintf(FHEADER, "%s\n", fortran_structs); } static void @@ -681,20 +732,21 @@ main(void) DSLHEADER = fopen(dslheader_filename, "w+"); CUDAHEADER = fopen(cudaheader_filename, "w+"); + FHEADER = fopen(fheader_filename, "w+"); assert(DSLHEADER); assert(CUDAHEADER); + assert(FHEADER); // Add built-in param symbols - for (size_t i = 0; i < ARRAY_SIZE(builtin_int_params); ++i) { + for (size_t i = 0; i < ARRAY_SIZE(builtin_int_params); ++i) add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, builtin_int_params[i]); - } - for (size_t i = 0; i < ARRAY_SIZE(builtin_int3_params); ++i) { + + for (size_t i = 0; i < ARRAY_SIZE(builtin_int3_params); ++i) add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT3, builtin_int3_params[i]); - } // Generate traverse(root); - generate_header(); + generate_headers(); generate_preprocessed_structures(); generate_library_hooks(); @@ -703,9 +755,12 @@ main(void) // Cleanup fclose(DSLHEADER); fclose(CUDAHEADER); + fclose(FHEADER); astnode_destroy(root); fprintf(stdout, "-- Generated %s\n", dslheader_filename); fprintf(stdout, "-- Generated %s\n", cudaheader_filename); + fprintf(stdout, "-- Generated %s\n", fheader_filename); + return EXIT_SUCCESS; } diff --git a/include/astaroth.h b/include/astaroth.h index 47beb88..d72d598 100644 --- a/include/astaroth.h +++ b/include/astaroth.h @@ -322,15 +322,13 @@ AcResult acGridIntegrate(const Stream stream, const AcReal dt); AcResult acGridPeriodicBoundconds(const Stream stream); /** TODO */ -AcResult -acGridReduceScal(const Stream stream, const ReductionType rtype, - const VertexBufferHandle vtxbuf_handle, AcReal* result); +AcResult acGridReduceScal(const Stream stream, const ReductionType rtype, + const VertexBufferHandle vtxbuf_handle, AcReal* result); /** TODO */ -AcResult -acGridReduceVec(const Stream stream, const ReductionType rtype, - const VertexBufferHandle vtxbuf0, const VertexBufferHandle vtxbuf1, - const VertexBufferHandle vtxbuf2, AcReal* result); +AcResult acGridReduceVec(const Stream stream, const ReductionType rtype, + const VertexBufferHandle vtxbuf0, const VertexBufferHandle vtxbuf1, + const VertexBufferHandle vtxbuf2, AcReal* result); #endif // AC_MPI_ENABLED /* diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 81bcf14..1a70f93 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -1,7 +1,7 @@ find_package(CUDAToolkit) ## Astaroth Core -add_library(astaroth_core STATIC device.cc node.cc astaroth.cc) +add_library(astaroth_core STATIC device.cc node.cc astaroth.cc astaroth_fortran.cc) target_link_libraries(astaroth_core astaroth_utils astaroth_kernels CUDA::cudart CUDA::cuda_driver) ## Options