Added WIP stuff for the Astaroth DSL compiler rewrite. Once this branch is finished only a single source file will be needed (file ending .ac). This revision is needed to decouple absolutely all implementation-specific stuff (f.ex. AC_dsx) from the core library and make life easier for everyone. The plan is to provide a standard library header written in the DSL containing the derivative operations instead of hardcoding them in the CUDA implementation.

This commit is contained in:
jpekkila
2019-10-02 21:03:59 +03:00
parent 15cc71895d
commit cc3c2eb926
6 changed files with 1627 additions and 485 deletions

View File

@@ -22,6 +22,7 @@ L [a-zA-Z_]
"ScalarArray" { return SCALARARRAY; }
"Kernel" { return KERNEL; } /* Function specifiers */
"Device" { return DEVICE; }
"Preprocessed" { return PREPROCESSED; }
"const" { return CONSTANT; }

View File

@@ -20,7 +20,7 @@ int yyget_lineno();
%token VOID INT INT3 COMPLEX
%token IF ELSE FOR WHILE ELIF
%token LEQU LAND LOR LLEQU
%token KERNEL PREPROCESSED
%token KERNEL DEVICE PREPROCESSED
%token INPLACE_INC INPLACE_DEC
%%
@@ -66,6 +66,7 @@ compound_statement: '{' '}'
statement: selection_statement { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); }
| iteration_statement { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); }
| exec_statement ';' { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); $$->postfix = ';'; }
| compound_statement { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); }
;
selection_statement: IF expression else_selection_statement { $$ = astnode_create(NODE_UNKNOWN, $2, $3); $$->prefix = IF; }
@@ -115,8 +116,8 @@ return_statement: /* Empty */
* =============================================================================
*/
declaration_list: declaration { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); }
| declaration_list ',' declaration { $$ = astnode_create(NODE_UNKNOWN, $1, $3); $$->infix = ','; }
declaration_list: declaration { $$ = astnode_create(NODE_DECLARATION_LIST, $1, NULL); }
| declaration_list ',' declaration { $$ = astnode_create(NODE_DECLARATION_LIST, $1, $3); $$->infix = ','; }
;
declaration: type_declaration identifier { $$ = astnode_create(NODE_DECLARATION, $1, $2); } // Note: accepts only one type qualifier. Good or not?
@@ -127,8 +128,8 @@ array_declaration: identifier '[' ']'
| identifier '[' expression ']' { $$ = astnode_create(NODE_UNKNOWN, $1, $3); $$->infix = '['; $$->postfix = ']'; }
;
type_declaration: type_specifier { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); }
| type_qualifier type_specifier { $$ = astnode_create(NODE_UNKNOWN, $1, $2); }
type_declaration: type_specifier { $$ = astnode_create(NODE_TYPE_DECLARATION, $1, NULL); }
| type_qualifier type_specifier { $$ = astnode_create(NODE_TYPE_DECLARATION, $1, $2); }
;
/*
@@ -196,6 +197,7 @@ unary_operator: '-' /* C-style casts are disallowed, would otherwise be defined
;
type_qualifier: KERNEL { $$ = astnode_create(NODE_TYPE_QUALIFIER, NULL, NULL); $$->token = KERNEL; }
| DEVICE { $$ = astnode_create(NODE_TYPE_QUALIFIER, NULL, NULL); $$->token = DEVICE; }
| PREPROCESSED { $$ = astnode_create(NODE_TYPE_QUALIFIER, NULL, NULL); $$->token = PREPROCESSED; }
| CONSTANT { $$ = astnode_create(NODE_TYPE_QUALIFIER, NULL, NULL); $$->token = CONSTANT; }
| IN { $$ = astnode_create(NODE_TYPE_QUALIFIER, NULL, NULL); $$->token = IN; }

View File

@@ -21,6 +21,8 @@
FUNC(NODE_DEFINITION), \
FUNC(NODE_GLOBAL_DEFINITION), \
FUNC(NODE_DECLARATION), \
FUNC(NODE_DECLARATION_LIST), \
FUNC(NODE_TYPE_DECLARATION), \
FUNC(NODE_TYPE_QUALIFIER), \
FUNC(NODE_TYPE_SPECIFIER), \
FUNC(NODE_IDENTIFIER), \
@@ -32,34 +34,11 @@
FUNC(NODE_REAL_NUMBER)
// clang-format on
/*
// Recreating strdup is not needed when using the GNU compiler.
// Let's also just say that anything but the GNU
// compiler is NOT supported, since there are also
// some gcc-specific calls in the files generated
// by flex and being completely compiler-independent is
// not a priority right now
#ifndef strdup
static inline char*
strdup(const char* in)
{
const size_t len = strlen(in) + 1;
char* out = malloc(len);
if (out) {
memcpy(out, in, len);
return out;
} else {
return NULL;
}
}
#endif
*/
typedef enum { FOR_NODE_TYPES(GEN_ID), NUM_NODE_TYPES } NodeType;
typedef struct astnode_s {
int id;
struct astnode_s* parent;
struct astnode_s* lhs;
struct astnode_s* rhs;
NodeType type; // Type of the AST node
@@ -85,6 +64,12 @@ astnode_create(const NodeType type, ASTNode* lhs, ASTNode* rhs)
node->prefix = node->infix = node->postfix = 0;
if (lhs)
node->lhs->parent = node;
if (rhs)
node->rhs->parent = node;
return node;
}
@@ -106,19 +91,21 @@ astnode_destroy(ASTNode* node)
free(node);
}
static inline void
astnode_print(const ASTNode* node)
{
const char* node_type_names[] = {FOR_NODE_TYPES(GEN_STR)};
printf("%s (%p)\n", node_type_names[node->type], node);
printf("\tid: %d\n", node->id);
printf("\tparent: %p\n", node->parent);
printf("\tlhs: %p\n", node->lhs);
printf("\trhs: %p\n", node->rhs);
printf("\tbuffer: %s\n", node->buffer);
printf("\ttoken: %d\n", node->token);
printf("\tprefix: %d ('%c')\n", node->prefix, node->prefix);
printf("\tinfix: %d ('%c')\n", node->infix, node->infix);
printf("\tpostfix: %d ('%c')\n", node->postfix, node->postfix);
}
extern ASTNode* root;
/*
typedef enum {
SCOPE_BLOCK
} ScopeType;
typedef struct symbol_s {
int type_specifier;
char* identifier;
int scope;
struct symbol_s* next;
} Symbol;
extern ASTNode* symbol_table;
*/

View File

@@ -25,6 +25,7 @@
*
*/
#include <assert.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
@@ -35,9 +36,9 @@
ASTNode* root = NULL;
static const char inout_name_prefix[] = "handle_";
typedef enum { STENCIL_ASSEMBLY, STENCIL_PROCESS, STENCIL_HEADER } CompilationType;
static CompilationType compilation_type;
// Output files
static FILE* DSLHEADER = NULL;
static FILE* CUDAHEADER = NULL;
/*
* =============================================================================
@@ -64,15 +65,13 @@ static const char* translation_table[TRANSLATION_TABLE_SIZE] = {
[SCALARARRAY] = "const AcReal* __restrict__",
[COMPLEX] = "acComplex",
// Type qualifiers
[KERNEL] = "template <int step_number> static __global__",
//__launch_bounds__(RK_THREADBLOCK_SIZE,
// RK_LAUNCH_BOUND_MIN_BLOCKS),
[PREPROCESSED] = "static __device__ "
"__forceinline__",
[CONSTANT] = "const",
[IN] = "in",
[OUT] = "out",
[UNIFORM] = "uniform",
[KERNEL] = "template <int step_number> static __global__",
[DEVICE] = "static __device__",
[PREPROCESSED] = "static __device__ __forceinline__",
[CONSTANT] = "const",
[IN] = "in",
[OUT] = "out",
[UNIFORM] = "uniform",
// ETC
[INPLACE_INC] = "++",
[INPLACE_DEC] = "--",
@@ -121,7 +120,7 @@ typedef enum {
NUM_SYMBOLTYPES
} SymbolType;
#define MAX_ID_LEN (128)
#define MAX_ID_LEN (256)
typedef struct {
SymbolType type;
int type_qualifier;
@@ -129,135 +128,61 @@ typedef struct {
char identifier[MAX_ID_LEN];
} Symbol;
#define SYMBOL_TABLE_SIZE (4096)
#define SYMBOL_TABLE_SIZE (65536)
static Symbol symbol_table[SYMBOL_TABLE_SIZE] = {};
static int num_symbols = 0;
static int
#define MAX_NESTS (32)
static size_t num_symbols[MAX_NESTS] = {};
static size_t current_nest = 0;
static Symbol*
symboltable_lookup(const char* identifier)
{
if (!identifier)
return -1;
return NULL;
for (int i = 0; i < num_symbols; ++i)
for (size_t i = 0; i < num_symbols[current_nest]; ++i)
if (strcmp(identifier, symbol_table[i].identifier) == 0)
return i;
return &symbol_table[i];
return -1;
return NULL;
}
static void
add_symbol(const SymbolType type, const int tqualifier, const int tspecifier, const char* id)
{
assert(num_symbols < SYMBOL_TABLE_SIZE);
assert(num_symbols[current_nest] < SYMBOL_TABLE_SIZE);
symbol_table[num_symbols].type = type;
symbol_table[num_symbols].type_qualifier = tqualifier;
symbol_table[num_symbols].type_specifier = tspecifier;
strcpy(symbol_table[num_symbols].identifier, id);
symbol_table[num_symbols[current_nest]].type = type;
symbol_table[num_symbols[current_nest]].type_qualifier = tqualifier;
symbol_table[num_symbols[current_nest]].type_specifier = tspecifier;
strcpy(symbol_table[num_symbols[current_nest]].identifier, id);
++num_symbols;
++num_symbols[current_nest];
}
static void
rm_symbol(const int handle)
{
assert(handle >= 0 && handle < num_symbols);
assert(num_symbols > 0);
if (&symbol_table[handle] != &symbol_table[num_symbols - 1])
memcpy(&symbol_table[handle], &symbol_table[num_symbols - 1], sizeof(Symbol));
--num_symbols;
}
static void
print_symbol(const int handle)
print_symbol(const size_t handle)
{
assert(handle < SYMBOL_TABLE_SIZE);
const char* fields[] = {translate(symbol_table[handle].type_qualifier),
translate(symbol_table[handle].type_specifier),
symbol_table[handle].identifier};
const size_t num_fields = sizeof(fields) / sizeof(fields[0]);
const char* fields[] = {
translate(symbol_table[handle].type_qualifier),
translate(symbol_table[handle].type_specifier),
symbol_table[handle].identifier,
};
const size_t num_fields = sizeof(fields) / sizeof(fields[0]);
for (size_t i = 0; i < num_fields; ++i)
if (fields[i])
printf("%s ", fields[i]);
}
static void
translate_latest_symbol(void)
{
const int handle = num_symbols - 1;
assert(handle < SYMBOL_TABLE_SIZE);
Symbol* symbol = &symbol_table[handle];
// FUNCTION
if (symbol->type == SYMBOLTYPE_FUNCTION) {
// KERNEL FUNCTION
if (symbol->type_qualifier == KERNEL) {
printf("%s %s\n%s", translate(symbol->type_qualifier),
translate(symbol->type_specifier), symbol->identifier);
}
// PREPROCESSED FUNCTION
else if (symbol->type_qualifier == PREPROCESSED) {
printf("%s %s\npreprocessed_%s", translate(symbol->type_qualifier),
translate(symbol->type_specifier), symbol->identifier);
}
// OTHER FUNCTION
else {
const char* regular_function_decorator = "static __device__ "
"__forceinline__";
printf("%s %s %s\n%s", regular_function_decorator,
translate(symbol->type_qualifier) ? translate(symbol->type_qualifier) : "",
translate(symbol->type_specifier), symbol->identifier);
}
}
// FUNCTION PARAMETER
else if (symbol->type == SYMBOLTYPE_FUNCTION_PARAMETER) {
if (symbol->type_qualifier == IN || symbol->type_qualifier == OUT) {
if (compilation_type == STENCIL_ASSEMBLY)
printf("const __restrict__ %s* %s", translate(symbol->type_specifier),
symbol->identifier);
else if (compilation_type == STENCIL_PROCESS)
printf("const %sData& %s", translate(symbol->type_specifier), symbol->identifier);
else
printf("Invalid compilation type %d, IN and OUT qualifiers not supported\n",
compilation_type);
}
else {
print_symbol(handle);
}
}
// UNIFORM
else if (symbol->type_qualifier == UNIFORM) {
// if (compilation_type != STENCIL_HEADER) {
// printf("ERROR: %s can only be used in stencil headers\n", translation_table[UNIFORM]);
//}
/* Do nothing */
}
// IN / OUT
else if (symbol->type != SYMBOLTYPE_FUNCTION_PARAMETER &&
(symbol->type_qualifier == IN || symbol->type_qualifier == OUT)) {
printf("static __device__ const %s %s%s",
symbol->type_specifier == SCALARFIELD ? "int" : "int3", inout_name_prefix,
symbol_table[handle].identifier);
if (symbol->type_specifier == VECTOR)
printf(" = make_int3");
}
// OTHER
else {
print_symbol(handle);
}
}
static inline void
print_symbol_table(void)
{
for (int i = 0; i < num_symbols; ++i) {
printf("%d: ", i);
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
printf("%lu: ", i);
const char* fields[] = {translate(symbol_table[i].type_qualifier),
translate(symbol_table[i].type_specifier),
symbol_table[i].identifier};
@@ -279,377 +204,205 @@ print_symbol_table(void)
/*
* =============================================================================
* State
* Traversal state
* =============================================================================
*/
static bool inside_declaration = false;
static bool inside_function_declaration = false;
static bool inside_function_parameter_declaration = false;
static bool inside_kernel = false;
static bool inside_preprocessed = false;
static int scope_start = 0;
/*
* =============================================================================
* AST traversal
* =============================================================================
*/
static int compound_statement_nests = 0;
static void
translate_latest_symbol(void)
{
// TODO
}
static void
traverse(const ASTNode* node)
{
// Prefix logic %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if (node->type == NODE_FUNCTION_DECLARATION)
inside_function_declaration = true;
if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION)
inside_function_parameter_declaration = true;
if (node->type == NODE_DECLARATION)
inside_declaration = true;
// Prefix translation
if (translate(node->prefix))
fprintf(CUDAHEADER, "%s", translate(node->prefix));
if (!inside_declaration && translate(node->prefix))
printf("%s", translate(node->prefix));
// Prefix logic
if (node->type == NODE_COMPOUND_STATEMENT) {
assert(current_nest < MAX_NESTS);
if (node->type == NODE_COMPOUND_STATEMENT)
++compound_statement_nests;
// BOILERPLATE START////////////////////////////////////////////////////////
if (node->type == NODE_TYPE_QUALIFIER && node->token == KERNEL)
inside_kernel = true;
// Kernel parameter boilerplate
const char* kernel_parameter_boilerplate = "GEN_KERNEL_PARAM_BOILERPLATE";
if (inside_kernel && node->type == NODE_FUNCTION_PARAMETER_DECLARATION) {
printf("%s", kernel_parameter_boilerplate);
if (node->lhs != NULL) {
printf("Compilation error: function parameters for Kernel functions not allowed!\n");
exit(EXIT_FAILURE);
}
++current_nest;
num_symbols[current_nest] = num_symbols[current_nest - 1];
}
// Kernel builtin variables boilerplate (read input/output arrays and setup
// indices)
const char* kernel_builtin_variables_boilerplate = "GEN_KERNEL_BUILTIN_VARIABLES_"
"BOILERPLATE();";
if (inside_kernel && node->type == NODE_COMPOUND_STATEMENT && compound_statement_nests == 1) {
printf("%s ", kernel_builtin_variables_boilerplate);
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_qualifier == IN) {
printf("const %sData %s = READ(%s%s);\n", translate(symbol_table[i].type_specifier),
symbol_table[i].identifier, inout_name_prefix, symbol_table[i].identifier);
}
else if (symbol_table[i].type_qualifier == OUT) {
printf("%s %s = READ_OUT(%s%s);", translate(symbol_table[i].type_specifier),
symbol_table[i].identifier, inout_name_prefix, symbol_table[i].identifier);
// printf("%s %s = buffer.out[%s%s][IDX(vertexIdx.x, vertexIdx.y, vertexIdx.z)];\n",
// translate(symbol_table[i].type_specifier), symbol_table[i].identifier,
// inout_name_prefix, symbol_table[i].identifier);
}
}
}
// Preprocessed parameter boilerplate
if (node->type == NODE_TYPE_QUALIFIER && node->token == PREPROCESSED)
inside_preprocessed = true;
static const char preprocessed_parameter_boilerplate
[] = "const int3& vertexIdx, const int3& globalVertexIdx, ";
if (inside_preprocessed && node->type == NODE_FUNCTION_PARAMETER_DECLARATION)
printf("%s ", preprocessed_parameter_boilerplate);
// BOILERPLATE END////////////////////////////////////////////////////////
// Enter LHS
// Traverse LHS
if (node->lhs)
traverse(node->lhs);
// Infix logic %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if (!inside_declaration && translate(node->infix))
printf("%s ", translate(node->infix));
// Infix translation
if (translate(node->infix))
fprintf(CUDAHEADER, "%s", translate(node->infix));
if (node->type == NODE_FUNCTION_DECLARATION)
inside_function_declaration = false;
// Infix logic
// TODO
// If the node is a subscript expression and the expression list inside it is not empty
if (node->type == NODE_MULTIDIM_SUBSCRIPT_EXPRESSION && node->rhs)
printf("IDX(");
// Do a regular translation
if (!inside_declaration) {
const int handle = symboltable_lookup(node->buffer);
if (handle >= 0) { // The variable exists in the symbol table
const Symbol* symbol = &symbol_table[handle];
if (symbol->type_qualifier == UNIFORM) {
if (inside_kernel && symbol->type_specifier == SCALARARRAY) {
printf("buffer.profiles[%s] ", symbol->identifier);
}
else {
printf("DCONST(%s) ", symbol->identifier);
}
}
else {
// Do a regular translation
if (translate(node->token))
printf("%s ", translate(node->token));
if (node->buffer) {
if (node->type == NODE_REAL_NUMBER) {
printf("%s(%s) ", translate(SCALAR),
node->buffer); // Cast to correct precision
}
else {
printf("%s ", node->buffer);
}
}
}
}
else {
// Do a regular translation
if (translate(node->token))
printf("%s ", translate(node->token));
if (node->buffer) {
if (node->type == NODE_REAL_NUMBER) {
printf("%s(%s) ", translate(SCALAR), node->buffer); // Cast to correct precision
}
else {
printf("%s ", node->buffer);
}
}
}
}
if (node->type == NODE_FUNCTION_DECLARATION) {
scope_start = num_symbols;
}
// Enter RHS
// Traverse RHS
if (node->rhs)
traverse(node->rhs);
// Postfix logic %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
// If the node is a subscript expression and the expression list inside it is not empty
if (node->type == NODE_MULTIDIM_SUBSCRIPT_EXPRESSION && node->rhs)
printf(")"); // Closing bracket of IDX()
// Postfix translation
if (translate(node->postfix))
fprintf(CUDAHEADER, "%s", translate(node->postfix));
// Generate writeback boilerplate for OUT fields
if (inside_kernel && node->type == NODE_COMPOUND_STATEMENT && compound_statement_nests == 1) {
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_qualifier == OUT) {
printf("WRITE_OUT(%s%s, %s);\n", inout_name_prefix, symbol_table[i].identifier,
symbol_table[i].identifier);
// printf("buffer.out[%s%s][IDX(vertexIdx.x, vertexIdx.y, vertexIdx.z)] = %s;\n",
// inout_name_prefix, symbol_table[i].identifier, symbol_table[i].identifier);
}
// Translate existing symbols
const Symbol* symbol = symboltable_lookup(node->buffer);
if (symbol) {
// Uniforms
if (symbol->type_qualifier == UNIFORM) {
fprintf(CUDAHEADER, "DCONST(%s) ", symbol->identifier);
}
}
if (!inside_declaration && translate(node->postfix))
printf("%s", translate(node->postfix));
// Add new symbols to the symbol table
if (node->type == NODE_DECLARATION) {
inside_declaration = false;
int stype;
ASTNode* tmp = node->parent;
while (tmp->type == NODE_DECLARATION_LIST)
tmp = tmp->parent;
int tqual = 0;
int tspec = 0;
if (node->lhs && node->lhs->lhs) {
if (node->lhs->lhs->type == NODE_TYPE_QUALIFIER)
tqual = node->lhs->lhs->token;
else if (node->lhs->lhs->type == NODE_TYPE_SPECIFIER)
tspec = node->lhs->lhs->token;
}
if (node->lhs && node->lhs->rhs) {
if (node->lhs->rhs->type == NODE_TYPE_SPECIFIER)
tspec = node->lhs->rhs->token;
}
if (tmp->type == NODE_FUNCTION_DECLARATION)
stype = SYMBOLTYPE_FUNCTION;
else if (tmp->type == NODE_FUNCTION_PARAMETER_DECLARATION)
stype = SYMBOLTYPE_FUNCTION_PARAMETER;
else
stype = SYMBOLTYPE_OTHER;
// Determine symbol type
SymbolType symboltype = SYMBOLTYPE_OTHER;
if (inside_function_declaration)
symboltype = SYMBOLTYPE_FUNCTION;
else if (inside_function_parameter_declaration)
symboltype = SYMBOLTYPE_FUNCTION_PARAMETER;
const ASTNode* tdeclaration = node->lhs;
const int tqualifier = tdeclaration->rhs ? tdeclaration->lhs->token : 0;
const int tspecifier = tdeclaration->rhs ? tdeclaration->rhs->token
: tdeclaration->lhs->token;
// Determine identifier
if (node->rhs->type == NODE_IDENTIFIER) {
add_symbol(symboltype, tqual, tspec, node->rhs->buffer); // Ordinary
translate_latest_symbol();
const char* identifier = node->rhs->type == NODE_IDENTIFIER ? node->rhs->buffer
: node->rhs->lhs->buffer;
add_symbol(stype, tqualifier, tspecifier, identifier);
// Translate the new symbol
if (tqualifier == UNIFORM) {
// Do nothing
}
else {
add_symbol(symboltype, tqual, tspec,
node->rhs->lhs->buffer); // Array
translate_latest_symbol();
// Traverse the expression once again, this time with
// "inside_declaration" flag off
printf("%s ", translate(node->rhs->infix));
if (node->rhs->rhs)
traverse(node->rhs->rhs);
printf("%s ", translate(node->rhs->postfix));
else if (tqualifier == KERNEL) {
fprintf(CUDAHEADER, "%s %s\n%s", //
translate(tqualifier), translate(tspecifier), identifier);
}
else if (tqualifier == DEVICE) {
fprintf(CUDAHEADER, "%s %s\n%s", //
translate(tqualifier), translate(tspecifier), identifier);
}
else if (tqualifier == PREPROCESSED) {
fprintf(CUDAHEADER, "%s %s\npreprocessed_%s", //
translate(tqualifier), translate(tspecifier), identifier);
}
else if (stype == SYMBOLTYPE_FUNCTION_PARAMETER) {
tmp = tmp->parent;
assert(tmp->type = NODE_FUNCTION_DECLARATION);
const Symbol* parent_function = symboltable_lookup(tmp->lhs->rhs->buffer);
if (parent_function->type_qualifier == DEVICE)
fprintf(CUDAHEADER, "%s %s\ndeviceparam_%s", //
translate(tqualifier), translate(tspecifier), identifier);
else if (parent_function->type_qualifier == PREPROCESSED)
fprintf(CUDAHEADER, "%s %s\npreprocessedparam_%s", //
translate(tqualifier), translate(tspecifier), identifier);
else
fprintf(CUDAHEADER, "%s %s\notherparam_%s", //
translate(tqualifier), translate(tspecifier), identifier);
}
else { // Do a regular translation
// fprintf(CUDAHEADER, "%s %s %s", //
// translate(tqualifier), translate(tspecifier), identifier);
}
}
if (node->type == NODE_COMPOUND_STATEMENT)
--compound_statement_nests;
if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION)
inside_function_parameter_declaration = false;
if (node->type == NODE_FUNCTION_DEFINITION) {
while (num_symbols > scope_start)
rm_symbol(num_symbols - 1);
inside_kernel = false;
inside_preprocessed = false;
// Postfix logic
if (node->type == NODE_COMPOUND_STATEMENT) {
assert(current_nest > 0);
--current_nest;
printf("Dropped rest of the symbol table, from %lu to %lu\n", num_symbols[current_nest + 1],
num_symbols[current_nest]);
}
}
// TODO: these should use the generic type names SCALAR and VECTOR
static void
generate_preprocessed_structures(void)
{
// PREPROCESSED DATA STRUCT
printf("\n");
printf("typedef struct {\n");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_qualifier == PREPROCESSED)
printf("%s %s;\n", translate(symbol_table[i].type_specifier),
symbol_table[i].identifier);
}
printf("} %sData;\n", translate(SCALAR));
// FILLING THE DATA STRUCT
printf("static __device__ __forceinline__ AcRealData\
read_data(const int3& vertexIdx,\
const int3& globalVertexIdx,\
AcReal* __restrict__ buf[], const int handle)\
{\n\
%sData data;\n",
translate(SCALAR));
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_qualifier == PREPROCESSED)
printf("data.%s = preprocessed_%s(vertexIdx, globalVertexIdx, buf[handle]);\n",
symbol_table[i].identifier, symbol_table[i].identifier);
}
printf("return data;\n");
printf("}\n");
// FUNCTIONS FOR ACCESSING MEMBERS OF THE PREPROCESSED STRUCT
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_qualifier == PREPROCESSED)
printf("static __device__ __forceinline__ %s\
%s(const AcRealData& data)\
{\n\
return data.%s;\
}\n",
translate(symbol_table[i].type_specifier), symbol_table[i].identifier,
symbol_table[i].identifier);
}
// Syntactic sugar: generate also a Vector data struct
printf("\
typedef struct {\
AcRealData x;\
AcRealData y;\
AcRealData z;\
} AcReal3Data;\
\
static __device__ __forceinline__ AcReal3Data\
read_data(const int3& vertexIdx,\
const int3& globalVertexIdx,\
AcReal* __restrict__ buf[], const int3& handle)\
{\
AcReal3Data data;\
\
data.x = read_data(vertexIdx, globalVertexIdx, buf, handle.x);\
data.y = read_data(vertexIdx, globalVertexIdx, buf, handle.y);\
data.z = read_data(vertexIdx, globalVertexIdx, buf, handle.z);\
\
return data;\
}\
");
// TODO
}
static void
generate_header(void)
{
printf("\n#pragma once\n");
fprintf(DSLHEADER, "#pragma once\n");
// Int params
printf("#define AC_FOR_USER_INT_PARAM_TYPES(FUNC)");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_specifier == INT) {
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
fprintf(DSLHEADER, "#define AC_FOR_USER_INT_PARAM_TYPES(FUNC)");
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == INT && symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
}
}
printf("\n\n");
fprintf(DSLHEADER, "\n\n");
// Int3 params
printf("#define AC_FOR_USER_INT3_PARAM_TYPES(FUNC)");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_specifier == INT3) {
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
fprintf(DSLHEADER, "#define AC_FOR_USER_INT3_PARAM_TYPES(FUNC)");
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == INT3 && symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
}
}
printf("\n\n");
fprintf(DSLHEADER, "\n\n");
// Scalar params
printf("#define AC_FOR_USER_REAL_PARAM_TYPES(FUNC)");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_specifier == SCALAR) {
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
fprintf(DSLHEADER, "#define AC_FOR_USER_REAL_PARAM_TYPES(FUNC)");
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == SCALAR && symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
}
}
printf("\n\n");
fprintf(DSLHEADER, "\n\n");
// Vector params
printf("#define AC_FOR_USER_REAL3_PARAM_TYPES(FUNC)");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_specifier == VECTOR) {
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
fprintf(DSLHEADER, "#define AC_FOR_USER_REAL3_PARAM_TYPES(FUNC)");
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == VECTOR && symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
}
}
printf("\n\n");
fprintf(DSLHEADER, "\n\n");
// Scalar fields
printf("#define AC_FOR_VTXBUF_HANDLES(FUNC)");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_specifier == SCALARFIELD) {
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
fprintf(DSLHEADER, "#define AC_FOR_VTXBUF_HANDLES(FUNC)");
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == SCALARFIELD &&
symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
}
}
printf("\n\n");
fprintf(DSLHEADER, "\n\n");
// Scalar arrays
printf("#define AC_FOR_SCALARARRAY_HANDLES(FUNC)");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_specifier == SCALARARRAY) {
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
fprintf(DSLHEADER, "#define AC_FOR_SCALARARRAY_HANDLES(FUNC)");
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_specifier == SCALARARRAY &&
symbol_table[i].type_qualifier == UNIFORM) {
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
}
}
printf("\n\n");
/*
printf("\n");
printf("typedef struct {\n");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_qualifier == PREPROCESSED)
printf("%s %s;\n", translate(symbol_table[i].type_specifier),
symbol_table[i].identifier);
}
printf("} %sData;\n", translate(SCALAR));
*/
fprintf(DSLHEADER, "\n\n");
}
static void
generate_library_hooks(void)
{
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_qualifier == KERNEL) {
printf("GEN_DEVICE_FUNC_HOOK(%s)\n", symbol_table[i].identifier);
// printf("GEN_NODE_FUNC_HOOK(%s)\n", symbol_table[i].identifier);
for (int i = 0; i < num_symbols[current_nest]; ++i) {
if (symbol_table[i].type_qualifier == KERNEL && symbol_table[i].type_qualifier == UNIFORM) {
fprintf(CUDAHEADER, "GEN_DEVICE_FUNC_HOOK(%s)\n", symbol_table[i].identifier);
}
}
}
@@ -657,49 +410,29 @@ generate_library_hooks(void)
int
main(int argc, char** argv)
{
if (argc == 2) {
if (!strcmp(argv[1], "-sas"))
compilation_type = STENCIL_ASSEMBLY;
else if (!strcmp(argv[1], "-sps"))
compilation_type = STENCIL_PROCESS;
else if (!strcmp(argv[1], "-sdh"))
compilation_type = STENCIL_HEADER;
else {
printf("Unknown flag %s. Generating stencil assembly.\n", argv[1]);
return EXIT_FAILURE;
}
}
else {
printf("Usage: ./acc [flags]\n"
"Flags:\n"
"\t-sas - Generates code for the stencil assembly stage\n"
"\t-sps - Generates code for the stencil processing stage\n"
"\t-hh - Generates stencil definitions from a header file\n");
printf("\n");
return EXIT_FAILURE;
}
root = astnode_create(NODE_UNKNOWN, NULL, NULL);
const int retval = yyparse();
if (retval) {
printf("COMPILATION FAILED\n");
fprintf(stderr, "COMPILATION FAILED\n");
return EXIT_FAILURE;
}
// Traverse
traverse(root);
if (compilation_type == STENCIL_ASSEMBLY)
generate_preprocessed_structures();
else if (compilation_type == STENCIL_HEADER)
generate_header();
else if (compilation_type == STENCIL_PROCESS)
generate_library_hooks();
DSLHEADER = fopen("user_defines.h", "w+");
CUDAHEADER = fopen("user_kernels.h", "w+");
assert(DSLHEADER);
assert(CUDAHEADER);
// print_symbol_table();
traverse(root);
generate_header();
generate_library_hooks();
print_symbol_table();
// Cleanup
fclose(DSLHEADER);
fclose(CUDAHEADER);
astnode_destroy(root);
// printf("COMPILATION SUCCESS\n");
fprintf(stdout, "COMPILATION SUCCESS\n");
return EXIT_SUCCESS;
}

705
acc/src/code_generator0.c Normal file
View File

@@ -0,0 +1,705 @@
/*
Copyright (C) 2014-2019, Johannes Pekkilae, Miikka Vaeisalae.
This file is part of Astaroth.
Astaroth is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Astaroth is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with Astaroth. If not, see <http://www.gnu.org/licenses/>.
*/
/**
* @file
* \brief Brief info.
*
* Detailed info.
*
*/
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "acc.tab.h"
#include "ast.h"
ASTNode* root = NULL;
static const char inout_name_prefix[] = "handle_";
typedef enum { STENCIL_ASSEMBLY, STENCIL_PROCESS, STENCIL_HEADER } CompilationType;
static CompilationType compilation_type;
/*
* =============================================================================
* Translation
* =============================================================================
*/
#define TRANSLATION_TABLE_SIZE (1024)
static const char* translation_table[TRANSLATION_TABLE_SIZE] = {
[0] = NULL,
// Control flow
[IF] = "if",
[ELSE] = "else",
[ELIF] = "else if",
[WHILE] = "while",
[FOR] = "for",
// Type specifiers
[VOID] = "void",
[INT] = "int",
[INT3] = "int3",
[SCALAR] = "AcReal",
[VECTOR] = "AcReal3",
[MATRIX] = "AcMatrix",
[SCALARFIELD] = "AcReal",
[SCALARARRAY] = "const AcReal* __restrict__",
[COMPLEX] = "acComplex",
// Type qualifiers
[KERNEL] = "template <int step_number> static __global__",
//__launch_bounds__(RK_THREADBLOCK_SIZE,
// RK_LAUNCH_BOUND_MIN_BLOCKS),
[PREPROCESSED] = "static __device__ "
"__forceinline__",
[CONSTANT] = "const",
[IN] = "in",
[OUT] = "out",
[UNIFORM] = "uniform",
// ETC
[INPLACE_INC] = "++",
[INPLACE_DEC] = "--",
// Unary
[','] = ",",
[';'] = ";\n",
['('] = "(",
[')'] = ")",
['['] = "[",
[']'] = "]",
['{'] = "{\n",
['}'] = "}\n",
['='] = "=",
['+'] = "+",
['-'] = "-",
['/'] = "/",
['*'] = "*",
['<'] = "<",
['>'] = ">",
['!'] = "!",
['.'] = "."};
static const char*
translate(const int token)
{
assert(token >= 0);
assert(token < TRANSLATION_TABLE_SIZE);
if (token > 0) {
if (!translation_table[token])
printf("ERROR: unidentified token %d\n", token);
assert(translation_table[token]);
}
return translation_table[token];
}
/*
* =============================================================================
* Symbols
* =============================================================================
*/
typedef enum {
SYMBOLTYPE_FUNCTION,
SYMBOLTYPE_FUNCTION_PARAMETER,
SYMBOLTYPE_OTHER,
NUM_SYMBOLTYPES
} SymbolType;
#define MAX_ID_LEN (128)
typedef struct {
SymbolType type;
int type_qualifier;
int type_specifier;
char identifier[MAX_ID_LEN];
} Symbol;
#define SYMBOL_TABLE_SIZE (4096)
static Symbol symbol_table[SYMBOL_TABLE_SIZE] = {};
static int num_symbols = 0;
static int
symboltable_lookup(const char* identifier)
{
if (!identifier)
return -1;
for (int i = 0; i < num_symbols; ++i)
if (strcmp(identifier, symbol_table[i].identifier) == 0)
return i;
return -1;
}
static void
add_symbol(const SymbolType type, const int tqualifier, const int tspecifier, const char* id)
{
assert(num_symbols < SYMBOL_TABLE_SIZE);
symbol_table[num_symbols].type = type;
symbol_table[num_symbols].type_qualifier = tqualifier;
symbol_table[num_symbols].type_specifier = tspecifier;
strcpy(symbol_table[num_symbols].identifier, id);
++num_symbols;
}
static void
rm_symbol(const int handle)
{
assert(handle >= 0 && handle < num_symbols);
assert(num_symbols > 0);
if (&symbol_table[handle] != &symbol_table[num_symbols - 1])
memcpy(&symbol_table[handle], &symbol_table[num_symbols - 1], sizeof(Symbol));
--num_symbols;
}
static void
print_symbol(const int handle)
{
assert(handle < SYMBOL_TABLE_SIZE);
const char* fields[] = {translate(symbol_table[handle].type_qualifier),
translate(symbol_table[handle].type_specifier),
symbol_table[handle].identifier};
const size_t num_fields = sizeof(fields) / sizeof(fields[0]);
for (size_t i = 0; i < num_fields; ++i)
if (fields[i])
printf("%s ", fields[i]);
}
static void
translate_latest_symbol(void)
{
const int handle = num_symbols - 1;
assert(handle < SYMBOL_TABLE_SIZE);
Symbol* symbol = &symbol_table[handle];
// FUNCTION
if (symbol->type == SYMBOLTYPE_FUNCTION) {
// KERNEL FUNCTION
if (symbol->type_qualifier == KERNEL) {
printf("%s %s\n%s", translate(symbol->type_qualifier),
translate(symbol->type_specifier), symbol->identifier);
}
// PREPROCESSED FUNCTION
else if (symbol->type_qualifier == PREPROCESSED) {
printf("%s %s\npreprocessed_%s", translate(symbol->type_qualifier),
translate(symbol->type_specifier), symbol->identifier);
}
// OTHER FUNCTION
else {
const char* regular_function_decorator = "static __device__ "
"__forceinline__";
printf("%s %s %s\n%s", regular_function_decorator,
translate(symbol->type_qualifier) ? translate(symbol->type_qualifier) : "",
translate(symbol->type_specifier), symbol->identifier);
}
}
// FUNCTION PARAMETER
else if (symbol->type == SYMBOLTYPE_FUNCTION_PARAMETER) {
if (symbol->type_qualifier == IN || symbol->type_qualifier == OUT) {
if (compilation_type == STENCIL_ASSEMBLY)
printf("const __restrict__ %s* %s", translate(symbol->type_specifier),
symbol->identifier);
else if (compilation_type == STENCIL_PROCESS)
printf("const %sData& %s", translate(symbol->type_specifier), symbol->identifier);
else
printf("Invalid compilation type %d, IN and OUT qualifiers not supported\n",
compilation_type);
}
else {
print_symbol(handle);
}
}
// UNIFORM
else if (symbol->type_qualifier == UNIFORM) {
// if (compilation_type != STENCIL_HEADER) {
// printf("ERROR: %s can only be used in stencil headers\n", translation_table[UNIFORM]);
//}
/* Do nothing */
}
// IN / OUT
else if (symbol->type != SYMBOLTYPE_FUNCTION_PARAMETER &&
(symbol->type_qualifier == IN || symbol->type_qualifier == OUT)) {
printf("static __device__ const %s %s%s",
symbol->type_specifier == SCALARFIELD ? "int" : "int3", inout_name_prefix,
symbol_table[handle].identifier);
if (symbol->type_specifier == VECTOR)
printf(" = make_int3");
}
// OTHER
else {
print_symbol(handle);
}
}
static inline void
print_symbol_table(void)
{
for (int i = 0; i < num_symbols; ++i) {
printf("%d: ", i);
const char* fields[] = {translate(symbol_table[i].type_qualifier),
translate(symbol_table[i].type_specifier),
symbol_table[i].identifier};
const size_t num_fields = sizeof(fields) / sizeof(fields[0]);
for (size_t j = 0; j < num_fields; ++j)
if (fields[j])
printf("%s ", fields[j]);
if (symbol_table[i].type == SYMBOLTYPE_FUNCTION)
printf("(function)");
else if (symbol_table[i].type == SYMBOLTYPE_FUNCTION_PARAMETER)
printf("(function parameter)");
else
printf("(other)");
printf("\n");
}
}
/*
* =============================================================================
* State
* =============================================================================
*/
static bool inside_declaration = false;
static bool inside_function_declaration = false;
static bool inside_function_parameter_declaration = false;
static bool inside_kernel = false;
static bool inside_preprocessed = false;
static int scope_start = 0;
/*
* =============================================================================
* AST traversal
* =============================================================================
*/
static int compound_statement_nests = 0;
static void
traverse(const ASTNode* node)
{
// Prefix logic %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if (node->type == NODE_FUNCTION_DECLARATION)
inside_function_declaration = true;
if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION)
inside_function_parameter_declaration = true;
if (node->type == NODE_DECLARATION)
inside_declaration = true;
if (!inside_declaration && translate(node->prefix))
printf("%s", translate(node->prefix));
if (node->type == NODE_COMPOUND_STATEMENT)
++compound_statement_nests;
// BOILERPLATE START////////////////////////////////////////////////////////
if (node->type == NODE_TYPE_QUALIFIER && node->token == KERNEL)
inside_kernel = true;
// Kernel parameter boilerplate
const char* kernel_parameter_boilerplate = "GEN_KERNEL_PARAM_BOILERPLATE";
if (inside_kernel && node->type == NODE_FUNCTION_PARAMETER_DECLARATION) {
printf("%s", kernel_parameter_boilerplate);
if (node->lhs != NULL) {
printf("Compilation error: function parameters for Kernel functions not allowed!\n");
exit(EXIT_FAILURE);
}
}
// Kernel builtin variables boilerplate (read input/output arrays and setup
// indices)
const char* kernel_builtin_variables_boilerplate = "GEN_KERNEL_BUILTIN_VARIABLES_"
"BOILERPLATE();";
if (inside_kernel && node->type == NODE_COMPOUND_STATEMENT && compound_statement_nests == 1) {
printf("%s ", kernel_builtin_variables_boilerplate);
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_qualifier == IN) {
printf("const %sData %s = READ(%s%s);\n", translate(symbol_table[i].type_specifier),
symbol_table[i].identifier, inout_name_prefix, symbol_table[i].identifier);
}
else if (symbol_table[i].type_qualifier == OUT) {
printf("%s %s = READ_OUT(%s%s);", translate(symbol_table[i].type_specifier),
symbol_table[i].identifier, inout_name_prefix, symbol_table[i].identifier);
// printf("%s %s = buffer.out[%s%s][IDX(vertexIdx.x, vertexIdx.y, vertexIdx.z)];\n",
// translate(symbol_table[i].type_specifier), symbol_table[i].identifier,
// inout_name_prefix, symbol_table[i].identifier);
}
}
}
// Preprocessed parameter boilerplate
if (node->type == NODE_TYPE_QUALIFIER && node->token == PREPROCESSED)
inside_preprocessed = true;
static const char preprocessed_parameter_boilerplate
[] = "const int3& vertexIdx, const int3& globalVertexIdx, ";
if (inside_preprocessed && node->type == NODE_FUNCTION_PARAMETER_DECLARATION)
printf("%s ", preprocessed_parameter_boilerplate);
// BOILERPLATE END////////////////////////////////////////////////////////
// Enter LHS
if (node->lhs)
traverse(node->lhs);
// Infix logic %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if (!inside_declaration && translate(node->infix))
printf("%s ", translate(node->infix));
if (node->type == NODE_FUNCTION_DECLARATION)
inside_function_declaration = false;
// If the node is a subscript expression and the expression list inside it is not empty
if (node->type == NODE_MULTIDIM_SUBSCRIPT_EXPRESSION && node->rhs)
printf("IDX(");
// Do a regular translation
if (!inside_declaration) {
const int handle = symboltable_lookup(node->buffer);
if (handle >= 0) { // The variable exists in the symbol table
const Symbol* symbol = &symbol_table[handle];
if (symbol->type_qualifier == UNIFORM) {
if (inside_kernel && symbol->type_specifier == SCALARARRAY) {
printf("buffer.profiles[%s] ", symbol->identifier);
}
else {
printf("DCONST(%s) ", symbol->identifier);
}
}
else {
// Do a regular translation
if (translate(node->token))
printf("%s ", translate(node->token));
if (node->buffer) {
if (node->type == NODE_REAL_NUMBER) {
printf("%s(%s) ", translate(SCALAR),
node->buffer); // Cast to correct precision
}
else {
printf("%s ", node->buffer);
}
}
}
}
else {
// Do a regular translation
if (translate(node->token))
printf("%s ", translate(node->token));
if (node->buffer) {
if (node->type == NODE_REAL_NUMBER) {
printf("%s(%s) ", translate(SCALAR), node->buffer); // Cast to correct precision
}
else {
printf("%s ", node->buffer);
}
}
}
}
if (node->type == NODE_FUNCTION_DECLARATION) {
scope_start = num_symbols;
}
// Enter RHS
if (node->rhs)
traverse(node->rhs);
// Postfix logic %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
// If the node is a subscript expression and the expression list inside it is not empty
if (node->type == NODE_MULTIDIM_SUBSCRIPT_EXPRESSION && node->rhs)
printf(")"); // Closing bracket of IDX()
// Generate writeback boilerplate for OUT fields
if (inside_kernel && node->type == NODE_COMPOUND_STATEMENT && compound_statement_nests == 1) {
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_qualifier == OUT) {
printf("WRITE_OUT(%s%s, %s);\n", inout_name_prefix, symbol_table[i].identifier,
symbol_table[i].identifier);
// printf("buffer.out[%s%s][IDX(vertexIdx.x, vertexIdx.y, vertexIdx.z)] = %s;\n",
// inout_name_prefix, symbol_table[i].identifier, symbol_table[i].identifier);
}
}
}
if (!inside_declaration && translate(node->postfix))
printf("%s", translate(node->postfix));
if (node->type == NODE_DECLARATION) {
inside_declaration = false;
int tqual = 0;
int tspec = 0;
if (node->lhs && node->lhs->lhs) {
if (node->lhs->lhs->type == NODE_TYPE_QUALIFIER)
tqual = node->lhs->lhs->token;
else if (node->lhs->lhs->type == NODE_TYPE_SPECIFIER)
tspec = node->lhs->lhs->token;
}
if (node->lhs && node->lhs->rhs) {
if (node->lhs->rhs->type == NODE_TYPE_SPECIFIER)
tspec = node->lhs->rhs->token;
}
// Determine symbol type
SymbolType symboltype = SYMBOLTYPE_OTHER;
if (inside_function_declaration)
symboltype = SYMBOLTYPE_FUNCTION;
else if (inside_function_parameter_declaration)
symboltype = SYMBOLTYPE_FUNCTION_PARAMETER;
// Determine identifier
if (node->rhs->type == NODE_IDENTIFIER) {
add_symbol(symboltype, tqual, tspec, node->rhs->buffer); // Ordinary
translate_latest_symbol();
}
else {
add_symbol(symboltype, tqual, tspec,
node->rhs->lhs->buffer); // Array
translate_latest_symbol();
// Traverse the expression once again, this time with
// "inside_declaration" flag off
printf("%s ", translate(node->rhs->infix));
if (node->rhs->rhs)
traverse(node->rhs->rhs);
printf("%s ", translate(node->rhs->postfix));
}
}
if (node->type == NODE_COMPOUND_STATEMENT)
--compound_statement_nests;
if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION)
inside_function_parameter_declaration = false;
if (node->type == NODE_FUNCTION_DEFINITION) {
while (num_symbols > scope_start)
rm_symbol(num_symbols - 1);
inside_kernel = false;
inside_preprocessed = false;
}
}
// TODO: these should use the generic type names SCALAR and VECTOR
static void
generate_preprocessed_structures(void)
{
// PREPROCESSED DATA STRUCT
printf("\n");
printf("typedef struct {\n");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_qualifier == PREPROCESSED)
printf("%s %s;\n", translate(symbol_table[i].type_specifier),
symbol_table[i].identifier);
}
printf("} %sData;\n", translate(SCALAR));
// FILLING THE DATA STRUCT
printf("static __device__ __forceinline__ AcRealData\
read_data(const int3& vertexIdx,\
const int3& globalVertexIdx,\
AcReal* __restrict__ buf[], const int handle)\
{\n\
%sData data;\n",
translate(SCALAR));
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_qualifier == PREPROCESSED)
printf("data.%s = preprocessed_%s(vertexIdx, globalVertexIdx, buf[handle]);\n",
symbol_table[i].identifier, symbol_table[i].identifier);
}
printf("return data;\n");
printf("}\n");
// FUNCTIONS FOR ACCESSING MEMBERS OF THE PREPROCESSED STRUCT
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_qualifier == PREPROCESSED)
printf("static __device__ __forceinline__ %s\
%s(const AcRealData& data)\
{\n\
return data.%s;\
}\n",
translate(symbol_table[i].type_specifier), symbol_table[i].identifier,
symbol_table[i].identifier);
}
// Syntactic sugar: generate also a Vector data struct
printf("\
typedef struct {\
AcRealData x;\
AcRealData y;\
AcRealData z;\
} AcReal3Data;\
\
static __device__ __forceinline__ AcReal3Data\
read_data(const int3& vertexIdx,\
const int3& globalVertexIdx,\
AcReal* __restrict__ buf[], const int3& handle)\
{\
AcReal3Data data;\
\
data.x = read_data(vertexIdx, globalVertexIdx, buf, handle.x);\
data.y = read_data(vertexIdx, globalVertexIdx, buf, handle.y);\
data.z = read_data(vertexIdx, globalVertexIdx, buf, handle.z);\
\
return data;\
}\
");
}
static void
generate_header(void)
{
printf("\n#pragma once\n");
// Int params
printf("#define AC_FOR_USER_INT_PARAM_TYPES(FUNC)");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_specifier == INT) {
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
}
}
printf("\n\n");
// Int3 params
printf("#define AC_FOR_USER_INT3_PARAM_TYPES(FUNC)");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_specifier == INT3) {
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
}
}
printf("\n\n");
// Scalar params
printf("#define AC_FOR_USER_REAL_PARAM_TYPES(FUNC)");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_specifier == SCALAR) {
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
}
}
printf("\n\n");
// Vector params
printf("#define AC_FOR_USER_REAL3_PARAM_TYPES(FUNC)");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_specifier == VECTOR) {
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
}
}
printf("\n\n");
// Scalar fields
printf("#define AC_FOR_VTXBUF_HANDLES(FUNC)");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_specifier == SCALARFIELD) {
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
}
}
printf("\n\n");
// Scalar arrays
printf("#define AC_FOR_SCALARARRAY_HANDLES(FUNC)");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_specifier == SCALARARRAY) {
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
}
}
printf("\n\n");
/*
printf("\n");
printf("typedef struct {\n");
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_qualifier == PREPROCESSED)
printf("%s %s;\n", translate(symbol_table[i].type_specifier),
symbol_table[i].identifier);
}
printf("} %sData;\n", translate(SCALAR));
*/
}
static void
generate_library_hooks(void)
{
for (int i = 0; i < num_symbols; ++i) {
if (symbol_table[i].type_qualifier == KERNEL) {
printf("GEN_DEVICE_FUNC_HOOK(%s)\n", symbol_table[i].identifier);
// printf("GEN_NODE_FUNC_HOOK(%s)\n", symbol_table[i].identifier);
}
}
}
int
main(int argc, char** argv)
{
if (argc == 2) {
if (!strcmp(argv[1], "-sas"))
compilation_type = STENCIL_ASSEMBLY;
else if (!strcmp(argv[1], "-sps"))
compilation_type = STENCIL_PROCESS;
else if (!strcmp(argv[1], "-sdh"))
compilation_type = STENCIL_HEADER;
else {
printf("Unknown flag %s. Generating stencil assembly.\n", argv[1]);
return EXIT_FAILURE;
}
}
else {
printf("Usage: ./acc [flags]\n"
"Flags:\n"
"\t-sas - Generates code for the stencil assembly stage\n"
"\t-sps - Generates code for the stencil processing stage\n"
"\t-hh - Generates stencil definitions from a header file\n");
printf("\n");
return EXIT_FAILURE;
}
root = astnode_create(NODE_UNKNOWN, NULL, NULL);
const int retval = yyparse();
if (retval) {
printf("COMPILATION FAILED\n");
return EXIT_FAILURE;
}
// Traverse
traverse(root);
if (compilation_type == STENCIL_ASSEMBLY)
generate_preprocessed_structures();
else if (compilation_type == STENCIL_HEADER)
generate_header();
else if (compilation_type == STENCIL_PROCESS)
generate_library_hooks();
// print_symbol_table();
// Cleanup
astnode_destroy(root);
// printf("COMPILATION SUCCESS\n");
return EXIT_SUCCESS;
}