Added WIP stuff for the Astaroth DSL compiler rewrite. Once this branch is finished only a single source file will be needed (file ending .ac). This revision is needed to decouple absolutely all implementation-specific stuff (f.ex. AC_dsx) from the core library and make life easier for everyone. The plan is to provide a standard library header written in the DSL containing the derivative operations instead of hardcoding them in the CUDA implementation.
This commit is contained in:
@@ -25,6 +25,7 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
@@ -35,9 +36,9 @@
|
||||
|
||||
ASTNode* root = NULL;
|
||||
|
||||
static const char inout_name_prefix[] = "handle_";
|
||||
typedef enum { STENCIL_ASSEMBLY, STENCIL_PROCESS, STENCIL_HEADER } CompilationType;
|
||||
static CompilationType compilation_type;
|
||||
// Output files
|
||||
static FILE* DSLHEADER = NULL;
|
||||
static FILE* CUDAHEADER = NULL;
|
||||
|
||||
/*
|
||||
* =============================================================================
|
||||
@@ -64,15 +65,13 @@ static const char* translation_table[TRANSLATION_TABLE_SIZE] = {
|
||||
[SCALARARRAY] = "const AcReal* __restrict__",
|
||||
[COMPLEX] = "acComplex",
|
||||
// Type qualifiers
|
||||
[KERNEL] = "template <int step_number> static __global__",
|
||||
//__launch_bounds__(RK_THREADBLOCK_SIZE,
|
||||
// RK_LAUNCH_BOUND_MIN_BLOCKS),
|
||||
[PREPROCESSED] = "static __device__ "
|
||||
"__forceinline__",
|
||||
[CONSTANT] = "const",
|
||||
[IN] = "in",
|
||||
[OUT] = "out",
|
||||
[UNIFORM] = "uniform",
|
||||
[KERNEL] = "template <int step_number> static __global__",
|
||||
[DEVICE] = "static __device__",
|
||||
[PREPROCESSED] = "static __device__ __forceinline__",
|
||||
[CONSTANT] = "const",
|
||||
[IN] = "in",
|
||||
[OUT] = "out",
|
||||
[UNIFORM] = "uniform",
|
||||
// ETC
|
||||
[INPLACE_INC] = "++",
|
||||
[INPLACE_DEC] = "--",
|
||||
@@ -121,7 +120,7 @@ typedef enum {
|
||||
NUM_SYMBOLTYPES
|
||||
} SymbolType;
|
||||
|
||||
#define MAX_ID_LEN (128)
|
||||
#define MAX_ID_LEN (256)
|
||||
typedef struct {
|
||||
SymbolType type;
|
||||
int type_qualifier;
|
||||
@@ -129,135 +128,61 @@ typedef struct {
|
||||
char identifier[MAX_ID_LEN];
|
||||
} Symbol;
|
||||
|
||||
#define SYMBOL_TABLE_SIZE (4096)
|
||||
#define SYMBOL_TABLE_SIZE (65536)
|
||||
static Symbol symbol_table[SYMBOL_TABLE_SIZE] = {};
|
||||
static int num_symbols = 0;
|
||||
|
||||
static int
|
||||
#define MAX_NESTS (32)
|
||||
static size_t num_symbols[MAX_NESTS] = {};
|
||||
static size_t current_nest = 0;
|
||||
|
||||
static Symbol*
|
||||
symboltable_lookup(const char* identifier)
|
||||
{
|
||||
if (!identifier)
|
||||
return -1;
|
||||
return NULL;
|
||||
|
||||
for (int i = 0; i < num_symbols; ++i)
|
||||
for (size_t i = 0; i < num_symbols[current_nest]; ++i)
|
||||
if (strcmp(identifier, symbol_table[i].identifier) == 0)
|
||||
return i;
|
||||
return &symbol_table[i];
|
||||
|
||||
return -1;
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void
|
||||
add_symbol(const SymbolType type, const int tqualifier, const int tspecifier, const char* id)
|
||||
{
|
||||
assert(num_symbols < SYMBOL_TABLE_SIZE);
|
||||
assert(num_symbols[current_nest] < SYMBOL_TABLE_SIZE);
|
||||
|
||||
symbol_table[num_symbols].type = type;
|
||||
symbol_table[num_symbols].type_qualifier = tqualifier;
|
||||
symbol_table[num_symbols].type_specifier = tspecifier;
|
||||
strcpy(symbol_table[num_symbols].identifier, id);
|
||||
symbol_table[num_symbols[current_nest]].type = type;
|
||||
symbol_table[num_symbols[current_nest]].type_qualifier = tqualifier;
|
||||
symbol_table[num_symbols[current_nest]].type_specifier = tspecifier;
|
||||
strcpy(symbol_table[num_symbols[current_nest]].identifier, id);
|
||||
|
||||
++num_symbols;
|
||||
++num_symbols[current_nest];
|
||||
}
|
||||
|
||||
static void
|
||||
rm_symbol(const int handle)
|
||||
{
|
||||
assert(handle >= 0 && handle < num_symbols);
|
||||
assert(num_symbols > 0);
|
||||
|
||||
if (&symbol_table[handle] != &symbol_table[num_symbols - 1])
|
||||
memcpy(&symbol_table[handle], &symbol_table[num_symbols - 1], sizeof(Symbol));
|
||||
--num_symbols;
|
||||
}
|
||||
|
||||
static void
|
||||
print_symbol(const int handle)
|
||||
print_symbol(const size_t handle)
|
||||
{
|
||||
assert(handle < SYMBOL_TABLE_SIZE);
|
||||
|
||||
const char* fields[] = {translate(symbol_table[handle].type_qualifier),
|
||||
translate(symbol_table[handle].type_specifier),
|
||||
symbol_table[handle].identifier};
|
||||
const size_t num_fields = sizeof(fields) / sizeof(fields[0]);
|
||||
const char* fields[] = {
|
||||
translate(symbol_table[handle].type_qualifier),
|
||||
translate(symbol_table[handle].type_specifier),
|
||||
symbol_table[handle].identifier,
|
||||
};
|
||||
|
||||
const size_t num_fields = sizeof(fields) / sizeof(fields[0]);
|
||||
for (size_t i = 0; i < num_fields; ++i)
|
||||
if (fields[i])
|
||||
printf("%s ", fields[i]);
|
||||
}
|
||||
|
||||
static void
|
||||
translate_latest_symbol(void)
|
||||
{
|
||||
const int handle = num_symbols - 1;
|
||||
assert(handle < SYMBOL_TABLE_SIZE);
|
||||
|
||||
Symbol* symbol = &symbol_table[handle];
|
||||
|
||||
// FUNCTION
|
||||
if (symbol->type == SYMBOLTYPE_FUNCTION) {
|
||||
// KERNEL FUNCTION
|
||||
if (symbol->type_qualifier == KERNEL) {
|
||||
printf("%s %s\n%s", translate(symbol->type_qualifier),
|
||||
translate(symbol->type_specifier), symbol->identifier);
|
||||
}
|
||||
// PREPROCESSED FUNCTION
|
||||
else if (symbol->type_qualifier == PREPROCESSED) {
|
||||
printf("%s %s\npreprocessed_%s", translate(symbol->type_qualifier),
|
||||
translate(symbol->type_specifier), symbol->identifier);
|
||||
}
|
||||
// OTHER FUNCTION
|
||||
else {
|
||||
const char* regular_function_decorator = "static __device__ "
|
||||
"__forceinline__";
|
||||
printf("%s %s %s\n%s", regular_function_decorator,
|
||||
translate(symbol->type_qualifier) ? translate(symbol->type_qualifier) : "",
|
||||
translate(symbol->type_specifier), symbol->identifier);
|
||||
}
|
||||
}
|
||||
// FUNCTION PARAMETER
|
||||
else if (symbol->type == SYMBOLTYPE_FUNCTION_PARAMETER) {
|
||||
if (symbol->type_qualifier == IN || symbol->type_qualifier == OUT) {
|
||||
if (compilation_type == STENCIL_ASSEMBLY)
|
||||
printf("const __restrict__ %s* %s", translate(symbol->type_specifier),
|
||||
symbol->identifier);
|
||||
else if (compilation_type == STENCIL_PROCESS)
|
||||
printf("const %sData& %s", translate(symbol->type_specifier), symbol->identifier);
|
||||
else
|
||||
printf("Invalid compilation type %d, IN and OUT qualifiers not supported\n",
|
||||
compilation_type);
|
||||
}
|
||||
else {
|
||||
print_symbol(handle);
|
||||
}
|
||||
}
|
||||
// UNIFORM
|
||||
else if (symbol->type_qualifier == UNIFORM) {
|
||||
// if (compilation_type != STENCIL_HEADER) {
|
||||
// printf("ERROR: %s can only be used in stencil headers\n", translation_table[UNIFORM]);
|
||||
//}
|
||||
/* Do nothing */
|
||||
}
|
||||
// IN / OUT
|
||||
else if (symbol->type != SYMBOLTYPE_FUNCTION_PARAMETER &&
|
||||
(symbol->type_qualifier == IN || symbol->type_qualifier == OUT)) {
|
||||
|
||||
printf("static __device__ const %s %s%s",
|
||||
symbol->type_specifier == SCALARFIELD ? "int" : "int3", inout_name_prefix,
|
||||
symbol_table[handle].identifier);
|
||||
if (symbol->type_specifier == VECTOR)
|
||||
printf(" = make_int3");
|
||||
}
|
||||
// OTHER
|
||||
else {
|
||||
print_symbol(handle);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void
|
||||
print_symbol_table(void)
|
||||
{
|
||||
for (int i = 0; i < num_symbols; ++i) {
|
||||
printf("%d: ", i);
|
||||
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||
printf("%lu: ", i);
|
||||
const char* fields[] = {translate(symbol_table[i].type_qualifier),
|
||||
translate(symbol_table[i].type_specifier),
|
||||
symbol_table[i].identifier};
|
||||
@@ -279,377 +204,205 @@ print_symbol_table(void)
|
||||
|
||||
/*
|
||||
* =============================================================================
|
||||
* State
|
||||
* Traversal state
|
||||
* =============================================================================
|
||||
*/
|
||||
static bool inside_declaration = false;
|
||||
static bool inside_function_declaration = false;
|
||||
static bool inside_function_parameter_declaration = false;
|
||||
|
||||
static bool inside_kernel = false;
|
||||
static bool inside_preprocessed = false;
|
||||
|
||||
static int scope_start = 0;
|
||||
|
||||
/*
|
||||
* =============================================================================
|
||||
* AST traversal
|
||||
* =============================================================================
|
||||
*/
|
||||
|
||||
static int compound_statement_nests = 0;
|
||||
static void
|
||||
translate_latest_symbol(void)
|
||||
{
|
||||
// TODO
|
||||
}
|
||||
|
||||
static void
|
||||
traverse(const ASTNode* node)
|
||||
{
|
||||
// Prefix logic %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
if (node->type == NODE_FUNCTION_DECLARATION)
|
||||
inside_function_declaration = true;
|
||||
if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION)
|
||||
inside_function_parameter_declaration = true;
|
||||
if (node->type == NODE_DECLARATION)
|
||||
inside_declaration = true;
|
||||
// Prefix translation
|
||||
if (translate(node->prefix))
|
||||
fprintf(CUDAHEADER, "%s", translate(node->prefix));
|
||||
|
||||
if (!inside_declaration && translate(node->prefix))
|
||||
printf("%s", translate(node->prefix));
|
||||
// Prefix logic
|
||||
if (node->type == NODE_COMPOUND_STATEMENT) {
|
||||
assert(current_nest < MAX_NESTS);
|
||||
|
||||
if (node->type == NODE_COMPOUND_STATEMENT)
|
||||
++compound_statement_nests;
|
||||
|
||||
// BOILERPLATE START////////////////////////////////////////////////////////
|
||||
if (node->type == NODE_TYPE_QUALIFIER && node->token == KERNEL)
|
||||
inside_kernel = true;
|
||||
|
||||
// Kernel parameter boilerplate
|
||||
const char* kernel_parameter_boilerplate = "GEN_KERNEL_PARAM_BOILERPLATE";
|
||||
if (inside_kernel && node->type == NODE_FUNCTION_PARAMETER_DECLARATION) {
|
||||
printf("%s", kernel_parameter_boilerplate);
|
||||
|
||||
if (node->lhs != NULL) {
|
||||
printf("Compilation error: function parameters for Kernel functions not allowed!\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
++current_nest;
|
||||
num_symbols[current_nest] = num_symbols[current_nest - 1];
|
||||
}
|
||||
|
||||
// Kernel builtin variables boilerplate (read input/output arrays and setup
|
||||
// indices)
|
||||
const char* kernel_builtin_variables_boilerplate = "GEN_KERNEL_BUILTIN_VARIABLES_"
|
||||
"BOILERPLATE();";
|
||||
if (inside_kernel && node->type == NODE_COMPOUND_STATEMENT && compound_statement_nests == 1) {
|
||||
printf("%s ", kernel_builtin_variables_boilerplate);
|
||||
|
||||
for (int i = 0; i < num_symbols; ++i) {
|
||||
if (symbol_table[i].type_qualifier == IN) {
|
||||
printf("const %sData %s = READ(%s%s);\n", translate(symbol_table[i].type_specifier),
|
||||
symbol_table[i].identifier, inout_name_prefix, symbol_table[i].identifier);
|
||||
}
|
||||
else if (symbol_table[i].type_qualifier == OUT) {
|
||||
printf("%s %s = READ_OUT(%s%s);", translate(symbol_table[i].type_specifier),
|
||||
symbol_table[i].identifier, inout_name_prefix, symbol_table[i].identifier);
|
||||
// printf("%s %s = buffer.out[%s%s][IDX(vertexIdx.x, vertexIdx.y, vertexIdx.z)];\n",
|
||||
// translate(symbol_table[i].type_specifier), symbol_table[i].identifier,
|
||||
// inout_name_prefix, symbol_table[i].identifier);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Preprocessed parameter boilerplate
|
||||
if (node->type == NODE_TYPE_QUALIFIER && node->token == PREPROCESSED)
|
||||
inside_preprocessed = true;
|
||||
static const char preprocessed_parameter_boilerplate
|
||||
[] = "const int3& vertexIdx, const int3& globalVertexIdx, ";
|
||||
if (inside_preprocessed && node->type == NODE_FUNCTION_PARAMETER_DECLARATION)
|
||||
printf("%s ", preprocessed_parameter_boilerplate);
|
||||
// BOILERPLATE END////////////////////////////////////////////////////////
|
||||
|
||||
// Enter LHS
|
||||
// Traverse LHS
|
||||
if (node->lhs)
|
||||
traverse(node->lhs);
|
||||
|
||||
// Infix logic %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
if (!inside_declaration && translate(node->infix))
|
||||
printf("%s ", translate(node->infix));
|
||||
// Infix translation
|
||||
if (translate(node->infix))
|
||||
fprintf(CUDAHEADER, "%s", translate(node->infix));
|
||||
|
||||
if (node->type == NODE_FUNCTION_DECLARATION)
|
||||
inside_function_declaration = false;
|
||||
// Infix logic
|
||||
// TODO
|
||||
|
||||
// If the node is a subscript expression and the expression list inside it is not empty
|
||||
if (node->type == NODE_MULTIDIM_SUBSCRIPT_EXPRESSION && node->rhs)
|
||||
printf("IDX(");
|
||||
|
||||
// Do a regular translation
|
||||
if (!inside_declaration) {
|
||||
const int handle = symboltable_lookup(node->buffer);
|
||||
if (handle >= 0) { // The variable exists in the symbol table
|
||||
const Symbol* symbol = &symbol_table[handle];
|
||||
|
||||
if (symbol->type_qualifier == UNIFORM) {
|
||||
if (inside_kernel && symbol->type_specifier == SCALARARRAY) {
|
||||
printf("buffer.profiles[%s] ", symbol->identifier);
|
||||
}
|
||||
else {
|
||||
printf("DCONST(%s) ", symbol->identifier);
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Do a regular translation
|
||||
if (translate(node->token))
|
||||
printf("%s ", translate(node->token));
|
||||
if (node->buffer) {
|
||||
if (node->type == NODE_REAL_NUMBER) {
|
||||
printf("%s(%s) ", translate(SCALAR),
|
||||
node->buffer); // Cast to correct precision
|
||||
}
|
||||
else {
|
||||
printf("%s ", node->buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Do a regular translation
|
||||
if (translate(node->token))
|
||||
printf("%s ", translate(node->token));
|
||||
if (node->buffer) {
|
||||
if (node->type == NODE_REAL_NUMBER) {
|
||||
printf("%s(%s) ", translate(SCALAR), node->buffer); // Cast to correct precision
|
||||
}
|
||||
else {
|
||||
printf("%s ", node->buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (node->type == NODE_FUNCTION_DECLARATION) {
|
||||
scope_start = num_symbols;
|
||||
}
|
||||
|
||||
// Enter RHS
|
||||
// Traverse RHS
|
||||
if (node->rhs)
|
||||
traverse(node->rhs);
|
||||
|
||||
// Postfix logic %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||
// If the node is a subscript expression and the expression list inside it is not empty
|
||||
if (node->type == NODE_MULTIDIM_SUBSCRIPT_EXPRESSION && node->rhs)
|
||||
printf(")"); // Closing bracket of IDX()
|
||||
// Postfix translation
|
||||
if (translate(node->postfix))
|
||||
fprintf(CUDAHEADER, "%s", translate(node->postfix));
|
||||
|
||||
// Generate writeback boilerplate for OUT fields
|
||||
if (inside_kernel && node->type == NODE_COMPOUND_STATEMENT && compound_statement_nests == 1) {
|
||||
for (int i = 0; i < num_symbols; ++i) {
|
||||
if (symbol_table[i].type_qualifier == OUT) {
|
||||
printf("WRITE_OUT(%s%s, %s);\n", inout_name_prefix, symbol_table[i].identifier,
|
||||
symbol_table[i].identifier);
|
||||
// printf("buffer.out[%s%s][IDX(vertexIdx.x, vertexIdx.y, vertexIdx.z)] = %s;\n",
|
||||
// inout_name_prefix, symbol_table[i].identifier, symbol_table[i].identifier);
|
||||
}
|
||||
// Translate existing symbols
|
||||
const Symbol* symbol = symboltable_lookup(node->buffer);
|
||||
if (symbol) {
|
||||
// Uniforms
|
||||
if (symbol->type_qualifier == UNIFORM) {
|
||||
fprintf(CUDAHEADER, "DCONST(%s) ", symbol->identifier);
|
||||
}
|
||||
}
|
||||
|
||||
if (!inside_declaration && translate(node->postfix))
|
||||
printf("%s", translate(node->postfix));
|
||||
|
||||
// Add new symbols to the symbol table
|
||||
if (node->type == NODE_DECLARATION) {
|
||||
inside_declaration = false;
|
||||
int stype;
|
||||
ASTNode* tmp = node->parent;
|
||||
while (tmp->type == NODE_DECLARATION_LIST)
|
||||
tmp = tmp->parent;
|
||||
|
||||
int tqual = 0;
|
||||
int tspec = 0;
|
||||
if (node->lhs && node->lhs->lhs) {
|
||||
if (node->lhs->lhs->type == NODE_TYPE_QUALIFIER)
|
||||
tqual = node->lhs->lhs->token;
|
||||
else if (node->lhs->lhs->type == NODE_TYPE_SPECIFIER)
|
||||
tspec = node->lhs->lhs->token;
|
||||
}
|
||||
if (node->lhs && node->lhs->rhs) {
|
||||
if (node->lhs->rhs->type == NODE_TYPE_SPECIFIER)
|
||||
tspec = node->lhs->rhs->token;
|
||||
}
|
||||
if (tmp->type == NODE_FUNCTION_DECLARATION)
|
||||
stype = SYMBOLTYPE_FUNCTION;
|
||||
else if (tmp->type == NODE_FUNCTION_PARAMETER_DECLARATION)
|
||||
stype = SYMBOLTYPE_FUNCTION_PARAMETER;
|
||||
else
|
||||
stype = SYMBOLTYPE_OTHER;
|
||||
|
||||
// Determine symbol type
|
||||
SymbolType symboltype = SYMBOLTYPE_OTHER;
|
||||
if (inside_function_declaration)
|
||||
symboltype = SYMBOLTYPE_FUNCTION;
|
||||
else if (inside_function_parameter_declaration)
|
||||
symboltype = SYMBOLTYPE_FUNCTION_PARAMETER;
|
||||
const ASTNode* tdeclaration = node->lhs;
|
||||
const int tqualifier = tdeclaration->rhs ? tdeclaration->lhs->token : 0;
|
||||
const int tspecifier = tdeclaration->rhs ? tdeclaration->rhs->token
|
||||
: tdeclaration->lhs->token;
|
||||
|
||||
// Determine identifier
|
||||
if (node->rhs->type == NODE_IDENTIFIER) {
|
||||
add_symbol(symboltype, tqual, tspec, node->rhs->buffer); // Ordinary
|
||||
translate_latest_symbol();
|
||||
const char* identifier = node->rhs->type == NODE_IDENTIFIER ? node->rhs->buffer
|
||||
: node->rhs->lhs->buffer;
|
||||
add_symbol(stype, tqualifier, tspecifier, identifier);
|
||||
|
||||
// Translate the new symbol
|
||||
if (tqualifier == UNIFORM) {
|
||||
// Do nothing
|
||||
}
|
||||
else {
|
||||
add_symbol(symboltype, tqual, tspec,
|
||||
node->rhs->lhs->buffer); // Array
|
||||
translate_latest_symbol();
|
||||
// Traverse the expression once again, this time with
|
||||
// "inside_declaration" flag off
|
||||
printf("%s ", translate(node->rhs->infix));
|
||||
if (node->rhs->rhs)
|
||||
traverse(node->rhs->rhs);
|
||||
printf("%s ", translate(node->rhs->postfix));
|
||||
else if (tqualifier == KERNEL) {
|
||||
fprintf(CUDAHEADER, "%s %s\n%s", //
|
||||
translate(tqualifier), translate(tspecifier), identifier);
|
||||
}
|
||||
else if (tqualifier == DEVICE) {
|
||||
fprintf(CUDAHEADER, "%s %s\n%s", //
|
||||
translate(tqualifier), translate(tspecifier), identifier);
|
||||
}
|
||||
else if (tqualifier == PREPROCESSED) {
|
||||
fprintf(CUDAHEADER, "%s %s\npreprocessed_%s", //
|
||||
translate(tqualifier), translate(tspecifier), identifier);
|
||||
}
|
||||
else if (stype == SYMBOLTYPE_FUNCTION_PARAMETER) {
|
||||
tmp = tmp->parent;
|
||||
assert(tmp->type = NODE_FUNCTION_DECLARATION);
|
||||
const Symbol* parent_function = symboltable_lookup(tmp->lhs->rhs->buffer);
|
||||
if (parent_function->type_qualifier == DEVICE)
|
||||
fprintf(CUDAHEADER, "%s %s\ndeviceparam_%s", //
|
||||
translate(tqualifier), translate(tspecifier), identifier);
|
||||
else if (parent_function->type_qualifier == PREPROCESSED)
|
||||
fprintf(CUDAHEADER, "%s %s\npreprocessedparam_%s", //
|
||||
translate(tqualifier), translate(tspecifier), identifier);
|
||||
else
|
||||
fprintf(CUDAHEADER, "%s %s\notherparam_%s", //
|
||||
translate(tqualifier), translate(tspecifier), identifier);
|
||||
}
|
||||
else { // Do a regular translation
|
||||
// fprintf(CUDAHEADER, "%s %s %s", //
|
||||
// translate(tqualifier), translate(tspecifier), identifier);
|
||||
}
|
||||
}
|
||||
|
||||
if (node->type == NODE_COMPOUND_STATEMENT)
|
||||
--compound_statement_nests;
|
||||
|
||||
if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION)
|
||||
inside_function_parameter_declaration = false;
|
||||
|
||||
if (node->type == NODE_FUNCTION_DEFINITION) {
|
||||
while (num_symbols > scope_start)
|
||||
rm_symbol(num_symbols - 1);
|
||||
|
||||
inside_kernel = false;
|
||||
inside_preprocessed = false;
|
||||
// Postfix logic
|
||||
if (node->type == NODE_COMPOUND_STATEMENT) {
|
||||
assert(current_nest > 0);
|
||||
--current_nest;
|
||||
printf("Dropped rest of the symbol table, from %lu to %lu\n", num_symbols[current_nest + 1],
|
||||
num_symbols[current_nest]);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: these should use the generic type names SCALAR and VECTOR
|
||||
static void
|
||||
generate_preprocessed_structures(void)
|
||||
{
|
||||
// PREPROCESSED DATA STRUCT
|
||||
printf("\n");
|
||||
printf("typedef struct {\n");
|
||||
for (int i = 0; i < num_symbols; ++i) {
|
||||
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
||||
printf("%s %s;\n", translate(symbol_table[i].type_specifier),
|
||||
symbol_table[i].identifier);
|
||||
}
|
||||
printf("} %sData;\n", translate(SCALAR));
|
||||
|
||||
// FILLING THE DATA STRUCT
|
||||
printf("static __device__ __forceinline__ AcRealData\
|
||||
read_data(const int3& vertexIdx,\
|
||||
const int3& globalVertexIdx,\
|
||||
AcReal* __restrict__ buf[], const int handle)\
|
||||
{\n\
|
||||
%sData data;\n",
|
||||
translate(SCALAR));
|
||||
|
||||
for (int i = 0; i < num_symbols; ++i) {
|
||||
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
||||
printf("data.%s = preprocessed_%s(vertexIdx, globalVertexIdx, buf[handle]);\n",
|
||||
symbol_table[i].identifier, symbol_table[i].identifier);
|
||||
}
|
||||
printf("return data;\n");
|
||||
printf("}\n");
|
||||
|
||||
// FUNCTIONS FOR ACCESSING MEMBERS OF THE PREPROCESSED STRUCT
|
||||
for (int i = 0; i < num_symbols; ++i) {
|
||||
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
||||
printf("static __device__ __forceinline__ %s\
|
||||
%s(const AcRealData& data)\
|
||||
{\n\
|
||||
return data.%s;\
|
||||
}\n",
|
||||
translate(symbol_table[i].type_specifier), symbol_table[i].identifier,
|
||||
symbol_table[i].identifier);
|
||||
}
|
||||
|
||||
// Syntactic sugar: generate also a Vector data struct
|
||||
printf("\
|
||||
typedef struct {\
|
||||
AcRealData x;\
|
||||
AcRealData y;\
|
||||
AcRealData z;\
|
||||
} AcReal3Data;\
|
||||
\
|
||||
static __device__ __forceinline__ AcReal3Data\
|
||||
read_data(const int3& vertexIdx,\
|
||||
const int3& globalVertexIdx,\
|
||||
AcReal* __restrict__ buf[], const int3& handle)\
|
||||
{\
|
||||
AcReal3Data data;\
|
||||
\
|
||||
data.x = read_data(vertexIdx, globalVertexIdx, buf, handle.x);\
|
||||
data.y = read_data(vertexIdx, globalVertexIdx, buf, handle.y);\
|
||||
data.z = read_data(vertexIdx, globalVertexIdx, buf, handle.z);\
|
||||
\
|
||||
return data;\
|
||||
}\
|
||||
");
|
||||
// TODO
|
||||
}
|
||||
|
||||
static void
|
||||
generate_header(void)
|
||||
{
|
||||
printf("\n#pragma once\n");
|
||||
fprintf(DSLHEADER, "#pragma once\n");
|
||||
|
||||
// Int params
|
||||
printf("#define AC_FOR_USER_INT_PARAM_TYPES(FUNC)");
|
||||
for (int i = 0; i < num_symbols; ++i) {
|
||||
if (symbol_table[i].type_specifier == INT) {
|
||||
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
fprintf(DSLHEADER, "#define AC_FOR_USER_INT_PARAM_TYPES(FUNC)");
|
||||
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||
if (symbol_table[i].type_specifier == INT && symbol_table[i].type_qualifier == UNIFORM) {
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
}
|
||||
}
|
||||
printf("\n\n");
|
||||
fprintf(DSLHEADER, "\n\n");
|
||||
|
||||
// Int3 params
|
||||
printf("#define AC_FOR_USER_INT3_PARAM_TYPES(FUNC)");
|
||||
for (int i = 0; i < num_symbols; ++i) {
|
||||
if (symbol_table[i].type_specifier == INT3) {
|
||||
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
fprintf(DSLHEADER, "#define AC_FOR_USER_INT3_PARAM_TYPES(FUNC)");
|
||||
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||
if (symbol_table[i].type_specifier == INT3 && symbol_table[i].type_qualifier == UNIFORM) {
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
}
|
||||
}
|
||||
printf("\n\n");
|
||||
fprintf(DSLHEADER, "\n\n");
|
||||
|
||||
// Scalar params
|
||||
printf("#define AC_FOR_USER_REAL_PARAM_TYPES(FUNC)");
|
||||
for (int i = 0; i < num_symbols; ++i) {
|
||||
if (symbol_table[i].type_specifier == SCALAR) {
|
||||
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
fprintf(DSLHEADER, "#define AC_FOR_USER_REAL_PARAM_TYPES(FUNC)");
|
||||
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||
if (symbol_table[i].type_specifier == SCALAR && symbol_table[i].type_qualifier == UNIFORM) {
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
}
|
||||
}
|
||||
printf("\n\n");
|
||||
fprintf(DSLHEADER, "\n\n");
|
||||
|
||||
// Vector params
|
||||
printf("#define AC_FOR_USER_REAL3_PARAM_TYPES(FUNC)");
|
||||
for (int i = 0; i < num_symbols; ++i) {
|
||||
if (symbol_table[i].type_specifier == VECTOR) {
|
||||
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
fprintf(DSLHEADER, "#define AC_FOR_USER_REAL3_PARAM_TYPES(FUNC)");
|
||||
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||
if (symbol_table[i].type_specifier == VECTOR && symbol_table[i].type_qualifier == UNIFORM) {
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
}
|
||||
}
|
||||
printf("\n\n");
|
||||
fprintf(DSLHEADER, "\n\n");
|
||||
|
||||
// Scalar fields
|
||||
printf("#define AC_FOR_VTXBUF_HANDLES(FUNC)");
|
||||
for (int i = 0; i < num_symbols; ++i) {
|
||||
if (symbol_table[i].type_specifier == SCALARFIELD) {
|
||||
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
fprintf(DSLHEADER, "#define AC_FOR_VTXBUF_HANDLES(FUNC)");
|
||||
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||
if (symbol_table[i].type_specifier == SCALARFIELD &&
|
||||
symbol_table[i].type_qualifier == UNIFORM) {
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
}
|
||||
}
|
||||
printf("\n\n");
|
||||
fprintf(DSLHEADER, "\n\n");
|
||||
|
||||
// Scalar arrays
|
||||
printf("#define AC_FOR_SCALARARRAY_HANDLES(FUNC)");
|
||||
for (int i = 0; i < num_symbols; ++i) {
|
||||
if (symbol_table[i].type_specifier == SCALARARRAY) {
|
||||
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
fprintf(DSLHEADER, "#define AC_FOR_SCALARARRAY_HANDLES(FUNC)");
|
||||
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||
if (symbol_table[i].type_specifier == SCALARARRAY &&
|
||||
symbol_table[i].type_qualifier == UNIFORM) {
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
}
|
||||
}
|
||||
printf("\n\n");
|
||||
|
||||
/*
|
||||
printf("\n");
|
||||
printf("typedef struct {\n");
|
||||
for (int i = 0; i < num_symbols; ++i) {
|
||||
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
||||
printf("%s %s;\n", translate(symbol_table[i].type_specifier),
|
||||
symbol_table[i].identifier);
|
||||
}
|
||||
printf("} %sData;\n", translate(SCALAR));
|
||||
*/
|
||||
fprintf(DSLHEADER, "\n\n");
|
||||
}
|
||||
|
||||
static void
|
||||
generate_library_hooks(void)
|
||||
{
|
||||
for (int i = 0; i < num_symbols; ++i) {
|
||||
if (symbol_table[i].type_qualifier == KERNEL) {
|
||||
printf("GEN_DEVICE_FUNC_HOOK(%s)\n", symbol_table[i].identifier);
|
||||
// printf("GEN_NODE_FUNC_HOOK(%s)\n", symbol_table[i].identifier);
|
||||
for (int i = 0; i < num_symbols[current_nest]; ++i) {
|
||||
if (symbol_table[i].type_qualifier == KERNEL && symbol_table[i].type_qualifier == UNIFORM) {
|
||||
fprintf(CUDAHEADER, "GEN_DEVICE_FUNC_HOOK(%s)\n", symbol_table[i].identifier);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -657,49 +410,29 @@ generate_library_hooks(void)
|
||||
int
|
||||
main(int argc, char** argv)
|
||||
{
|
||||
if (argc == 2) {
|
||||
if (!strcmp(argv[1], "-sas"))
|
||||
compilation_type = STENCIL_ASSEMBLY;
|
||||
else if (!strcmp(argv[1], "-sps"))
|
||||
compilation_type = STENCIL_PROCESS;
|
||||
else if (!strcmp(argv[1], "-sdh"))
|
||||
compilation_type = STENCIL_HEADER;
|
||||
else {
|
||||
printf("Unknown flag %s. Generating stencil assembly.\n", argv[1]);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
else {
|
||||
printf("Usage: ./acc [flags]\n"
|
||||
"Flags:\n"
|
||||
"\t-sas - Generates code for the stencil assembly stage\n"
|
||||
"\t-sps - Generates code for the stencil processing stage\n"
|
||||
"\t-hh - Generates stencil definitions from a header file\n");
|
||||
printf("\n");
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
root = astnode_create(NODE_UNKNOWN, NULL, NULL);
|
||||
|
||||
const int retval = yyparse();
|
||||
if (retval) {
|
||||
printf("COMPILATION FAILED\n");
|
||||
fprintf(stderr, "COMPILATION FAILED\n");
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
// Traverse
|
||||
traverse(root);
|
||||
if (compilation_type == STENCIL_ASSEMBLY)
|
||||
generate_preprocessed_structures();
|
||||
else if (compilation_type == STENCIL_HEADER)
|
||||
generate_header();
|
||||
else if (compilation_type == STENCIL_PROCESS)
|
||||
generate_library_hooks();
|
||||
DSLHEADER = fopen("user_defines.h", "w+");
|
||||
CUDAHEADER = fopen("user_kernels.h", "w+");
|
||||
assert(DSLHEADER);
|
||||
assert(CUDAHEADER);
|
||||
|
||||
// print_symbol_table();
|
||||
traverse(root);
|
||||
generate_header();
|
||||
generate_library_hooks();
|
||||
|
||||
print_symbol_table();
|
||||
|
||||
// Cleanup
|
||||
fclose(DSLHEADER);
|
||||
fclose(CUDAHEADER);
|
||||
astnode_destroy(root);
|
||||
// printf("COMPILATION SUCCESS\n");
|
||||
fprintf(stdout, "COMPILATION SUCCESS\n");
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
Reference in New Issue
Block a user