The generated fortran header is now consistent with fortran conventions. Also cleaned up the C version of the header.

This commit is contained in:
jpekkila
2020-06-29 01:06:30 +03:00
parent d0ca1f8195
commit 6cab3586cf
3 changed files with 65 additions and 46 deletions

View File

@@ -43,7 +43,7 @@ static FILE* FHEADER = NULL;
static const char* dslheader_filename = "user_defines.h"; static const char* dslheader_filename = "user_defines.h";
static const char* cudaheader_filename = "user_kernels.h"; static const char* cudaheader_filename = "user_kernels.h";
static const char* fheader_filename = "astaroth.f90"; static const char* fheader_filename = "astaroth_fortran.h";
// Forward declaration of yyparse // Forward declaration of yyparse
int yyparse(void); int yyparse(void);
@@ -601,12 +601,32 @@ generate_preprocessed_structures(void)
static void static void
generate_headers(void) generate_headers(void)
{ {
int enumcounter = 0; // Fortran interface
const char* fortran_interface = R"(
! -*-f90-*- (for emacs) vim:set filetype=fortran: (for vim)
! Utils (see astaroth_fortran.cc for definitions)
external acupdatebuiltinparams
external acgetdevicecount
! Device interface (see astaroth_fortran.cc for definitions)
external acdevicecreate, acdevicedestroy
external acdeviceprintinfo
external acdeviceloadmeshinfo
external acdeviceloadmesh, acdevicestoremesh
external acdeviceintegratesubstep
external acdeviceperiodicboundconds
external acdeviceswapbuffers
external acdevicereducescal, acdevicereducevec
external acdevicesynchronizestream
)";
fprintf(FHEADER, "%s\n", fortran_interface);
fprintf(DSLHEADER, "#pragma once\n"); fprintf(DSLHEADER, "#pragma once\n");
// Int params // Int params
fprintf(DSLHEADER, "#define AC_FOR_USER_INT_PARAM_TYPES(FUNC)"); fprintf(DSLHEADER, "#define AC_FOR_USER_INT_PARAM_TYPES(FUNC)");
int enumcounter = 0;
for (size_t i = 0; i < num_symbols[current_nest]; ++i) { for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == INT && symbol_table[i].type_qualifier == UNIFORM) { if (symbol_table[i].type_specifier == INT && symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier); fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
@@ -695,17 +715,6 @@ generate_headers(void)
fprintf(DSLHEADER, "\n\n"); fprintf(DSLHEADER, "\n\n");
fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_SCALARRAY_HANDLES = %d\n\n", enumcounter); fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_SCALARRAY_HANDLES = %d\n\n", enumcounter);
// Do Fortran-specific
const char* fortran_structs = R"(
type, bind(C) :: AcMeshInfo
integer(c_int), dimension(AC_NUM_INT_PARAMS) :: int_params
integer(c_int), dimension(AC_NUM_INT3_PARAMS, 3) :: int3_params
real, dimension(AC_NUM_REAL_PARAMS) :: real_params
real, dimension(AC_NUM_REAL3_PARAMS, 3) :: real3_params
end type AcMeshInfo
)";
fprintf(FHEADER, "%s\n", fortran_structs);
// Streams // Streams
const size_t nstreams = 20; const size_t nstreams = 20;
for (size_t i = 0; i < nstreams; ++i) { for (size_t i = 0; i < nstreams; ++i) {
@@ -719,7 +728,7 @@ generate_headers(void)
fprintf(FHEADER, "integer(c_int), parameter :: STREAM_DEFAULT = STREAM_0\n"); fprintf(FHEADER, "integer(c_int), parameter :: STREAM_DEFAULT = STREAM_0\n");
fprintf(FHEADER, "integer(c_int), parameter :: STREAM_ALL = NUM_STREAMS\n"); fprintf(FHEADER, "integer(c_int), parameter :: STREAM_ALL = NUM_STREAMS\n");
fprintf(DSLHEADER, "typedef int Stream;\n"); fprintf(DSLHEADER, "typedef int Stream;\n\n");
// Reduction types // Reduction types
size_t counter = 0; size_t counter = 0;
@@ -746,6 +755,17 @@ generate_headers(void)
fprintf(DSLHEADER, "typedef int ReductionType;\n"); fprintf(DSLHEADER, "typedef int ReductionType;\n");
fprintf(DSLHEADER, "#define NUM_REDUCTION_TYPES (%lu)\n", counter); fprintf(DSLHEADER, "#define NUM_REDUCTION_TYPES (%lu)\n", counter);
fprintf(FHEADER, "integer(c_int), parameter :: NUM_REDUCTION_TYPES = %lu\n", counter); fprintf(FHEADER, "integer(c_int), parameter :: NUM_REDUCTION_TYPES = %lu\n", counter);
// Fortran structs
const char* fortran_structs = R"(
type, bind(C) :: AcMeshInfo
integer(c_int), dimension(AC_NUM_INT_PARAMS) :: int_params
integer(c_int), dimension(AC_NUM_INT3_PARAMS, 3) :: int3_params
real, dimension(AC_NUM_REAL_PARAMS) :: real_params
real, dimension(AC_NUM_REAL3_PARAMS, 3) :: real3_params
end type AcMeshInfo
)";
fprintf(FHEADER, "%s\n", fortran_structs);
} }
static void static void

View File

@@ -4,38 +4,46 @@
#include "astaroth_utils.h" #include "astaroth_utils.h"
#include "errchk.h" #include "errchk.h"
/**
* Utils
*/
void
acupdatebuiltinparams_(AcMeshInfo* info)
{
acUpdateBuiltinParams(info);
}
void
acgetdevicecount_(int* count)
{
ERRCHK_CUDA_ALWAYS(cudaGetDeviceCount(count));
}
/**
* Device
*/
void void
acdevicecreate_(const int* id, const AcMeshInfo* info, Device* handle) acdevicecreate_(const int* id, const AcMeshInfo* info, Device* handle)
{ {
// TODO errorcheck
acDeviceCreate(*id, *info, handle); acDeviceCreate(*id, *info, handle);
} }
void void
acdevicedestroy_(Device* device) acdevicedestroy_(Device* device)
{ {
// TODO errorcheck
acDeviceDestroy(*device); acDeviceDestroy(*device);
} }
void void
acdeviceprintinfo_(const Device* device) acdeviceprintinfo_(const Device* device)
{ {
// TODO errorcheck
acDevicePrintInfo(*device); acDevicePrintInfo(*device);
} }
void void
acupdatebuiltinparams_(AcMeshInfo* info) acdeviceloadmeshinfo_(const Device* device, const AcMeshInfo* info)
{ {
// TODO errorcheck acDeviceLoadMeshInfo(*device, *info);
acUpdateBuiltinParams(info);
}
void
acdeviceswapbuffers_(const Device* device)
{
acDeviceSwapBuffers(*device);
} }
void void
@@ -81,10 +89,15 @@ void
acdeviceperiodicboundconds_(const Device* device, const Stream* stream, const int3* start, acdeviceperiodicboundconds_(const Device* device, const Stream* stream, const int3* start,
const int3* end) const int3* end)
{ {
acDevicePeriodicBoundconds(*device, *stream, *start, *end); acDevicePeriodicBoundconds(*device, *stream, *start, *end);
} }
void
acdeviceswapbuffers_(const Device* device)
{
acDeviceSwapBuffers(*device);
}
void void
acdevicereducescal_(const Device* device, const Stream* stream, const ReductionType* rtype, acdevicereducescal_(const Device* device, const Stream* stream, const ReductionType* rtype,
const VertexBufferHandle* vtxbuf_handle, AcReal* result) const VertexBufferHandle* vtxbuf_handle, AcReal* result)
@@ -105,15 +118,3 @@ acdevicesynchronizestream_(const Device* device, const Stream* stream)
{ {
acDeviceSynchronizeStream(*device, *stream); acDeviceSynchronizeStream(*device, *stream);
} }
void
acdeviceloadmeshinfo_(const Device* device, const AcMeshInfo* info)
{
acDeviceLoadMeshInfo(*device, *info);
}
void
acgetdevicecount_(int* count)
{
ERRCHK_CUDA_ALWAYS(cudaGetDeviceCount(count));
}

View File

@@ -10,6 +10,8 @@ extern "C" {
*/ */
void acupdatebuiltinparams_(AcMeshInfo* info); void acupdatebuiltinparams_(AcMeshInfo* info);
void acgetdevicecount_(int* count);
/** /**
* Device * Device
*/ */
@@ -19,9 +21,7 @@ void acdevicedestroy_(Device* device);
void acdeviceprintinfo_(const Device* device); void acdeviceprintinfo_(const Device* device);
void acupdatebuiltinparams_(AcMeshInfo* info); void acdeviceloadmeshinfo_(const Device* device, const AcMeshInfo* info);
void acdeviceswapbuffers_(const Device* device);
void acdeviceloadmesh_(const Device* device, const Stream* stream, const AcMeshInfo* info, void acdeviceloadmesh_(const Device* device, const Stream* stream, const AcMeshInfo* info,
const int* num_farrays, AcReal* farray); const int* num_farrays, AcReal* farray);
@@ -34,6 +34,8 @@ void acdeviceintegratesubstep_(const Device* device, const Stream* stream, const
void acdeviceperiodicboundconds_(const Device* device, const Stream* stream, const int3* start, void acdeviceperiodicboundconds_(const Device* device, const Stream* stream, const int3* start,
const int3* end); const int3* end);
void acdeviceswapbuffers_(const Device* device);
void acdevicereducescal_(const Device* device, const Stream* stream, const ReductionType* rtype, void acdevicereducescal_(const Device* device, const Stream* stream, const ReductionType* rtype,
const VertexBufferHandle* vtxbuf_handle, AcReal* result); const VertexBufferHandle* vtxbuf_handle, AcReal* result);
@@ -43,10 +45,6 @@ void acdevicereducevec_(const Device* device, const Stream* stream, const Reduct
void acdevicesynchronizestream_(const Device* device, const Stream* stream); void acdevicesynchronizestream_(const Device* device, const Stream* stream);
void acdeviceloadmeshinfo_(const Device* device, const AcMeshInfo* info);
void acgetdevicecount_(int* count);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif