From 6cab3586cfd2c7585b2bb7647fed53e950a45b83 Mon Sep 17 00:00:00 2001 From: jpekkila Date: Mon, 29 Jun 2020 01:06:30 +0300 Subject: [PATCH] The generated fortran header is now consistent with fortran conventions. Also cleaned up the C version of the header. --- acc/src/code_generator.c | 48 +++++++++++++++------- src/core/astaroth_fortran.cc | 51 ++++++++++++------------ {include => src/core}/astaroth_fortran.h | 12 +++--- 3 files changed, 65 insertions(+), 46 deletions(-) rename {include => src/core}/astaroth_fortran.h (97%) diff --git a/acc/src/code_generator.c b/acc/src/code_generator.c index 11711c0..02aa6e0 100644 --- a/acc/src/code_generator.c +++ b/acc/src/code_generator.c @@ -43,7 +43,7 @@ static FILE* FHEADER = NULL; static const char* dslheader_filename = "user_defines.h"; static const char* cudaheader_filename = "user_kernels.h"; -static const char* fheader_filename = "astaroth.f90"; +static const char* fheader_filename = "astaroth_fortran.h"; // Forward declaration of yyparse int yyparse(void); @@ -601,12 +601,32 @@ generate_preprocessed_structures(void) static void generate_headers(void) { - int enumcounter = 0; + // Fortran interface + const char* fortran_interface = R"( +! -*-f90-*- (for emacs) vim:set filetype=fortran: (for vim) + +! Utils (see astaroth_fortran.cc for definitions) +external acupdatebuiltinparams +external acgetdevicecount + +! Device interface (see astaroth_fortran.cc for definitions) +external acdevicecreate, acdevicedestroy +external acdeviceprintinfo +external acdeviceloadmeshinfo +external acdeviceloadmesh, acdevicestoremesh +external acdeviceintegratesubstep +external acdeviceperiodicboundconds +external acdeviceswapbuffers +external acdevicereducescal, acdevicereducevec +external acdevicesynchronizestream + )"; + fprintf(FHEADER, "%s\n", fortran_interface); fprintf(DSLHEADER, "#pragma once\n"); // Int params fprintf(DSLHEADER, "#define AC_FOR_USER_INT_PARAM_TYPES(FUNC)"); + int enumcounter = 0; for (size_t i = 0; i < num_symbols[current_nest]; ++i) { if (symbol_table[i].type_specifier == INT && symbol_table[i].type_qualifier == UNIFORM) { fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier); @@ -695,17 +715,6 @@ generate_headers(void) fprintf(DSLHEADER, "\n\n"); fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_SCALARRAY_HANDLES = %d\n\n", enumcounter); - // Do Fortran-specific - const char* fortran_structs = R"( - type, bind(C) :: AcMeshInfo - integer(c_int), dimension(AC_NUM_INT_PARAMS) :: int_params - integer(c_int), dimension(AC_NUM_INT3_PARAMS, 3) :: int3_params - real, dimension(AC_NUM_REAL_PARAMS) :: real_params - real, dimension(AC_NUM_REAL3_PARAMS, 3) :: real3_params - end type AcMeshInfo - )"; - fprintf(FHEADER, "%s\n", fortran_structs); - // Streams const size_t nstreams = 20; for (size_t i = 0; i < nstreams; ++i) { @@ -719,7 +728,7 @@ generate_headers(void) fprintf(FHEADER, "integer(c_int), parameter :: STREAM_DEFAULT = STREAM_0\n"); fprintf(FHEADER, "integer(c_int), parameter :: STREAM_ALL = NUM_STREAMS\n"); - fprintf(DSLHEADER, "typedef int Stream;\n"); + fprintf(DSLHEADER, "typedef int Stream;\n\n"); // Reduction types size_t counter = 0; @@ -746,6 +755,17 @@ generate_headers(void) fprintf(DSLHEADER, "typedef int ReductionType;\n"); fprintf(DSLHEADER, "#define NUM_REDUCTION_TYPES (%lu)\n", counter); fprintf(FHEADER, "integer(c_int), parameter :: NUM_REDUCTION_TYPES = %lu\n", counter); + + // Fortran structs + const char* fortran_structs = R"( +type, bind(C) :: AcMeshInfo + integer(c_int), dimension(AC_NUM_INT_PARAMS) :: int_params + integer(c_int), dimension(AC_NUM_INT3_PARAMS, 3) :: int3_params + real, dimension(AC_NUM_REAL_PARAMS) :: real_params + real, dimension(AC_NUM_REAL3_PARAMS, 3) :: real3_params +end type AcMeshInfo + )"; + fprintf(FHEADER, "%s\n", fortran_structs); } static void diff --git a/src/core/astaroth_fortran.cc b/src/core/astaroth_fortran.cc index d974aa1..80ee033 100644 --- a/src/core/astaroth_fortran.cc +++ b/src/core/astaroth_fortran.cc @@ -4,38 +4,46 @@ #include "astaroth_utils.h" #include "errchk.h" +/** + * Utils + */ +void +acupdatebuiltinparams_(AcMeshInfo* info) +{ + acUpdateBuiltinParams(info); +} + +void +acgetdevicecount_(int* count) +{ + ERRCHK_CUDA_ALWAYS(cudaGetDeviceCount(count)); +} + +/** + * Device + */ void acdevicecreate_(const int* id, const AcMeshInfo* info, Device* handle) { - // TODO errorcheck acDeviceCreate(*id, *info, handle); } void acdevicedestroy_(Device* device) { - // TODO errorcheck acDeviceDestroy(*device); } void acdeviceprintinfo_(const Device* device) { - // TODO errorcheck acDevicePrintInfo(*device); } void -acupdatebuiltinparams_(AcMeshInfo* info) +acdeviceloadmeshinfo_(const Device* device, const AcMeshInfo* info) { - // TODO errorcheck - acUpdateBuiltinParams(info); -} - -void -acdeviceswapbuffers_(const Device* device) -{ - acDeviceSwapBuffers(*device); + acDeviceLoadMeshInfo(*device, *info); } void @@ -81,10 +89,15 @@ void acdeviceperiodicboundconds_(const Device* device, const Stream* stream, const int3* start, const int3* end) { - acDevicePeriodicBoundconds(*device, *stream, *start, *end); } +void +acdeviceswapbuffers_(const Device* device) +{ + acDeviceSwapBuffers(*device); +} + void acdevicereducescal_(const Device* device, const Stream* stream, const ReductionType* rtype, const VertexBufferHandle* vtxbuf_handle, AcReal* result) @@ -105,15 +118,3 @@ acdevicesynchronizestream_(const Device* device, const Stream* stream) { acDeviceSynchronizeStream(*device, *stream); } - -void -acdeviceloadmeshinfo_(const Device* device, const AcMeshInfo* info) -{ - acDeviceLoadMeshInfo(*device, *info); -} - -void -acgetdevicecount_(int* count) -{ - ERRCHK_CUDA_ALWAYS(cudaGetDeviceCount(count)); -} diff --git a/include/astaroth_fortran.h b/src/core/astaroth_fortran.h similarity index 97% rename from include/astaroth_fortran.h rename to src/core/astaroth_fortran.h index 9c02c1f..fa0ba3a 100644 --- a/include/astaroth_fortran.h +++ b/src/core/astaroth_fortran.h @@ -10,6 +10,8 @@ extern "C" { */ void acupdatebuiltinparams_(AcMeshInfo* info); +void acgetdevicecount_(int* count); + /** * Device */ @@ -19,9 +21,7 @@ void acdevicedestroy_(Device* device); void acdeviceprintinfo_(const Device* device); -void acupdatebuiltinparams_(AcMeshInfo* info); - -void acdeviceswapbuffers_(const Device* device); +void acdeviceloadmeshinfo_(const Device* device, const AcMeshInfo* info); void acdeviceloadmesh_(const Device* device, const Stream* stream, const AcMeshInfo* info, const int* num_farrays, AcReal* farray); @@ -34,6 +34,8 @@ void acdeviceintegratesubstep_(const Device* device, const Stream* stream, const void acdeviceperiodicboundconds_(const Device* device, const Stream* stream, const int3* start, const int3* end); +void acdeviceswapbuffers_(const Device* device); + void acdevicereducescal_(const Device* device, const Stream* stream, const ReductionType* rtype, const VertexBufferHandle* vtxbuf_handle, AcReal* result); @@ -43,10 +45,6 @@ void acdevicereducevec_(const Device* device, const Stream* stream, const Reduct void acdevicesynchronizestream_(const Device* device, const Stream* stream); -void acdeviceloadmeshinfo_(const Device* device, const AcMeshInfo* info); - -void acgetdevicecount_(int* count); - #ifdef __cplusplus } // extern "C" #endif