Added a minimal Fortran interface to Astaroth

This commit is contained in:
jpekkila
2020-06-25 06:34:16 +03:00
parent 70ecacee7c
commit fbb8d7c7c6
4 changed files with 74 additions and 19 deletions

View File

@@ -55,7 +55,8 @@ message(STATUS "AC module dir: ${DSL_MODULE_DIR}")
file(GLOB DSL_SOURCES ${DSL_MODULE_DIR}/*
${CMAKE_SOURCE_DIR}/acc/stdlib/*)
set(DSL_HEADERS "${PROJECT_BINARY_DIR}/user_kernels.h"
"${PROJECT_BINARY_DIR}/user_defines.h")
"${PROJECT_BINARY_DIR}/user_defines.h"
"${PROJECT_BINARY_DIR}/astaroth.f90")
add_custom_command (
COMMENT "Building ACC objects ${DSL_MODULE_DIR}"
@@ -99,6 +100,7 @@ if (BUILD_SAMPLES)
add_subdirectory(samples/bwtest)
add_subdirectory(samples/genbenchmarkscripts)
add_subdirectory(samples/mpi_reduce_bench)
add_subdirectory(samples/fortrantest)
endif()
if (BUILD_STANDALONE)

View File

@@ -39,9 +39,11 @@ ASTNode* root = NULL;
// Output files
static FILE* DSLHEADER = NULL;
static FILE* CUDAHEADER = NULL;
static FILE* FHEADER = NULL;
static const char* dslheader_filename = "user_defines.h";
static const char* cudaheader_filename = "user_kernels.h";
static const char* fheader_filename = "astaroth.f90";
// Forward declaration of yyparse
int yyparse(void);
@@ -98,7 +100,8 @@ static const char* translation_table[TRANSLATION_TABLE_SIZE] = {
['<'] = "<",
['>'] = ">",
['!'] = "!",
['.'] = "."};
['.'] = ".",
};
static const char*
translate(const int token)
@@ -261,9 +264,8 @@ traverse(const ASTNode* node)
if (typequal->token == KERNEL) {
fprintf(CUDAHEADER, "GEN_KERNEL_PARAM_BOILERPLATE");
if (node->lhs != NULL) {
fprintf(
stderr,
"Syntax error: function parameters for Kernel functions not allowed!\n");
fprintf(stderr, "Syntax error: function parameters for Kernel functions not "
"allowed!\n");
}
}
else if (typequal->token == PREPROCESSED) {
@@ -597,8 +599,10 @@ generate_preprocessed_structures(void)
}
static void
generate_header(void)
generate_headers(void)
{
int enumcounter = 0;
fprintf(DSLHEADER, "#pragma once\n");
// Int params
@@ -606,56 +610,103 @@ generate_header(void)
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == INT && symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier,
enumcounter);
++enumcounter;
}
}
fprintf(DSLHEADER, "\n\n");
fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_INT_PARAMS = %d\n\n", enumcounter);
// Int3 params
fprintf(DSLHEADER, "#define AC_FOR_USER_INT3_PARAM_TYPES(FUNC)");
enumcounter = 0;
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == INT3 && symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier,
enumcounter);
++enumcounter;
}
}
fprintf(DSLHEADER, "\n\n");
fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_INT3_PARAMS = %d\n\n", enumcounter);
// Scalar params
fprintf(DSLHEADER, "#define AC_FOR_USER_REAL_PARAM_TYPES(FUNC)");
enumcounter = 0;
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == SCALAR && symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier,
enumcounter);
++enumcounter;
}
}
fprintf(DSLHEADER, "\n\n");
fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_REAL_PARAMS = %d\n\n", enumcounter);
// Vector params
fprintf(DSLHEADER, "#define AC_FOR_USER_REAL3_PARAM_TYPES(FUNC)");
enumcounter = 0;
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == VECTOR && symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier,
enumcounter);
++enumcounter;
}
}
fprintf(DSLHEADER, "\n\n");
fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_REAL3_PARAMS = %d\n\n", enumcounter);
// Scalar fields
fprintf(DSLHEADER, "#define AC_FOR_VTXBUF_HANDLES(FUNC)");
enumcounter = 0;
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == SCALARFIELD &&
symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier,
enumcounter);
++enumcounter;
}
}
fprintf(DSLHEADER, "\n\n");
fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_VTXBUF_HANDLES = %d\n\n", enumcounter);
// Scalar arrays
fprintf(DSLHEADER, "#define AC_FOR_SCALARARRAY_HANDLES(FUNC)");
enumcounter = 0;
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == SCALARARRAY &&
symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
fprintf(FHEADER, "integer(c_int), parameter :: %s = %d\n", symbol_table[i].identifier,
enumcounter);
++enumcounter;
}
}
fprintf(DSLHEADER, "\n\n");
fprintf(FHEADER, "integer(c_int), parameter :: AC_NUM_SCALARRAY_HANDLES = %d\n\n", enumcounter);
// Do Fortran-specific
const char* fortran_structs = R"(
integer, parameter :: precision = c_float ! TODO WARNING
type, bind(C) :: AcMeshInfo
integer(c_int), dimension(AC_NUM_INT_PARAMS) :: int_params
integer(c_int), dimension(AC_NUM_INT3_PARAMS, 3) :: int3_params
real(precision), dimension(AC_NUM_REAL_PARAMS) :: real_params
real(precision), dimension(AC_NUM_REAL3_PARAMS, 3) :: real3_params
end type AcMeshInfo
)";
fprintf(FHEADER, "%s\n", fortran_structs);
}
static void
@@ -681,20 +732,21 @@ main(void)
DSLHEADER = fopen(dslheader_filename, "w+");
CUDAHEADER = fopen(cudaheader_filename, "w+");
FHEADER = fopen(fheader_filename, "w+");
assert(DSLHEADER);
assert(CUDAHEADER);
assert(FHEADER);
// Add built-in param symbols
for (size_t i = 0; i < ARRAY_SIZE(builtin_int_params); ++i) {
for (size_t i = 0; i < ARRAY_SIZE(builtin_int_params); ++i)
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, builtin_int_params[i]);
}
for (size_t i = 0; i < ARRAY_SIZE(builtin_int3_params); ++i) {
for (size_t i = 0; i < ARRAY_SIZE(builtin_int3_params); ++i)
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT3, builtin_int3_params[i]);
}
// Generate
traverse(root);
generate_header();
generate_headers();
generate_preprocessed_structures();
generate_library_hooks();
@@ -703,9 +755,12 @@ main(void)
// Cleanup
fclose(DSLHEADER);
fclose(CUDAHEADER);
fclose(FHEADER);
astnode_destroy(root);
fprintf(stdout, "-- Generated %s\n", dslheader_filename);
fprintf(stdout, "-- Generated %s\n", cudaheader_filename);
fprintf(stdout, "-- Generated %s\n", fheader_filename);
return EXIT_SUCCESS;
}

View File

@@ -322,15 +322,13 @@ AcResult acGridIntegrate(const Stream stream, const AcReal dt);
AcResult acGridPeriodicBoundconds(const Stream stream);
/** TODO */
AcResult
acGridReduceScal(const Stream stream, const ReductionType rtype,
const VertexBufferHandle vtxbuf_handle, AcReal* result);
AcResult acGridReduceScal(const Stream stream, const ReductionType rtype,
const VertexBufferHandle vtxbuf_handle, AcReal* result);
/** TODO */
AcResult
acGridReduceVec(const Stream stream, const ReductionType rtype,
const VertexBufferHandle vtxbuf0, const VertexBufferHandle vtxbuf1,
const VertexBufferHandle vtxbuf2, AcReal* result);
AcResult acGridReduceVec(const Stream stream, const ReductionType rtype,
const VertexBufferHandle vtxbuf0, const VertexBufferHandle vtxbuf1,
const VertexBufferHandle vtxbuf2, AcReal* result);
#endif // AC_MPI_ENABLED
/*

View File

@@ -1,7 +1,7 @@
find_package(CUDAToolkit)
## Astaroth Core
add_library(astaroth_core STATIC device.cc node.cc astaroth.cc)
add_library(astaroth_core STATIC device.cc node.cc astaroth.cc astaroth_fortran.cc)
target_link_libraries(astaroth_core astaroth_utils astaroth_kernels CUDA::cudart CUDA::cuda_driver)
## Options