acc: Removed debug prints, old code. Also the scope of the declarations made inside a for statement is now properly tracked

This commit is contained in:
jpekkila
2019-10-08 00:20:57 +03:00
parent 08f155cbec
commit 44a86f5e80
3 changed files with 22 additions and 42 deletions

View File

@@ -84,7 +84,7 @@ iteration_statement: WHILE expression compound_statement
| FOR for_expression compound_statement { $$ = astnode_create(NODE_ITERATION_STATEMENT, $2, $3); $$->prefix = FOR; }
;
for_expression: '(' for_init_param for_other_params ')' { $$ = astnode_create(NODE_UNKNOWN, $2, $3); $$->prefix = '('; $$->postfix = ')'; }
for_expression: '(' for_init_param for_other_params ')' { $$ = astnode_create(NODE_FOR_EXPRESSION, $2, $3); $$->prefix = '('; $$->postfix = ')'; }
;
for_init_param: expression ';' { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); $$->postfix = ';'; }

View File

@@ -33,7 +33,8 @@
FUNC(NODE_COMPOUND_STATEMENT), \
FUNC(NODE_FUNCTION_PARAMETER_DECLARATION), \
FUNC(NODE_MULTIDIM_SUBSCRIPT_EXPRESSION), \
FUNC(NODE_REAL_NUMBER)
FUNC(NODE_REAL_NUMBER), \
FUNC(NODE_FOR_EXPRESSION)
// clang-format on
typedef enum { FOR_NODE_TYPES(GEN_ID), NUM_NODE_TYPES } NodeType;

View File

@@ -40,6 +40,9 @@ ASTNode* root = NULL;
static FILE* DSLHEADER = NULL;
static FILE* CUDAHEADER = NULL;
static const char* dslheader_filename = "user_defines.h";
static const char* cudaheader_filename = "user_kernels.h";
/*
* =============================================================================
* Translation
@@ -101,7 +104,7 @@ translate(const int token)
assert(token < TRANSLATION_TABLE_SIZE);
if (token > 0) {
if (!translation_table[token])
printf("ERROR: unidentified token %d\n", token);
fprintf(stderr, "Error: unidentified token %d\n", token);
assert(translation_table[token]);
}
@@ -161,6 +164,7 @@ add_symbol(const SymbolType type, const int tqualifier, const int tspecifier, co
"Syntax error. Symbol '%s' is ambiguous, declared multiple times in the same scope"
" (shadowing).\n",
id);
assert(0);
}
symbol_table[num_symbols[current_nest]].type = type;
@@ -234,28 +238,12 @@ print_symbol_table(void)
*/
static bool inside_declaration = false;
static bool inside_kernel = false;
/*
* =============================================================================
* AST traversal
* =============================================================================
*/
/*
static bool
introspect(const ASTNode* node, const NodeType type)
{
assert(node);
ASTNode* parent = node->parent;
while (parent) {
if (parent->type == type)
return true;
else
parent = parent->parent;
}
return false;
}
*/
static void
traverse(const ASTNode* node)
{
@@ -265,8 +253,6 @@ traverse(const ASTNode* node)
// Prefix logic
if (node->type == NODE_COMPOUND_STATEMENT) {
// if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION ||
// node->type == NODE_ITERATION_STATEMENT) {
assert(current_nest < MAX_NESTS);
++current_nest;
@@ -281,7 +267,6 @@ traverse(const ASTNode* node)
// Boilerplates
const ASTNode* typedecl = node->parent->lhs->lhs;
const ASTNode* typequal = typedecl->lhs;
printf("typedecl %d\n", typedecl->type);
assert(typedecl->type == NODE_TYPE_DECLARATION);
if (typequal->type == NODE_TYPE_QUALIFIER) {
if (typequal->token == KERNEL) {
@@ -347,6 +332,9 @@ traverse(const ASTNode* node)
stype = SYMBOLTYPE_FUNCTION;
else if (tmp->type == NODE_FUNCTION_PARAMETER_DECLARATION)
stype = SYMBOLTYPE_FUNCTION_PARAMETER;
else if (node->parent && node->parent->parent && node->parent->parent->parent &&
node->parent->parent->parent->type == NODE_FOR_EXPRESSION)
stype = SYMBOLTYPE_FUNCTION_PARAMETER;
else
stype = SYMBOLTYPE_OTHER;
@@ -358,7 +346,6 @@ traverse(const ASTNode* node)
const char* identifier = node->rhs->type == NODE_IDENTIFIER ? node->rhs->buffer
: node->rhs->lhs->buffer;
add_symbol(stype, tqualifier, tspecifier, identifier);
printf("Added %s\n", identifier);
// Translate the new symbol
if (tqualifier == UNIFORM) {
@@ -387,9 +374,7 @@ traverse(const ASTNode* node)
// TODO FIX not to use symboltable_lookup
const Symbol* parent_function = symboltable_lookup(tmp->lhs->rhs->buffer);
assert(parent_function);
if (tqualifier == IN || tqualifier == OUT) {
if (parent_function && (tqualifier == IN || tqualifier == OUT)) {
if (tmp->lhs->lhs->lhs->token == DEVICE) {
fprintf(CUDAHEADER, "const %sData& %s", //
translate(tspecifier), identifier);
@@ -431,9 +416,6 @@ traverse(const ASTNode* node)
if (symbol) {
// Uniforms
if (symbol->type_qualifier == UNIFORM) {
printf("INSIDE KERNEL %d %s, type spec %d vs %d, %d\n", inside_kernel,
symbol->identifier, symbol->type_specifier, SCALARARRAY,
symbol->type_specifier == SCALARARRAY);
if (inside_kernel && symbol->type_specifier == SCALARARRAY)
fprintf(CUDAHEADER, "buffer.profiles[%s] ", symbol->identifier);
else
@@ -473,16 +455,11 @@ traverse(const ASTNode* node)
assert(current_nest > 0);
--current_nest;
// Drop function parameters
// Drop function parameters (incl. those declared in for statements)
while (symbol_table[num_symbols[current_nest] - 1].type == SYMBOLTYPE_FUNCTION_PARAMETER)
--num_symbols[current_nest];
inside_kernel = false;
// Drop temporaries declared with iteration statements
// TODO
printf("Dropped rest of the symbol table, from %lu to %lu\n", num_symbols[current_nest + 1],
num_symbols[current_nest]);
// Kernel writeback boilerplate
if (node->parent->type == NODE_FUNCTION_DEFINITION) {
@@ -562,7 +539,7 @@ generate_preprocessed_structures(void)
rewind(CUDAHEADER);
const size_t buflen = fread(buffer, sizeof(char), max_buflen, CUDAHEADER);
fclose(CUDAHEADER);
CUDAHEADER = fopen("user_kernels.h", "w+");
CUDAHEADER = fopen(cudaheader_filename, "w+");
fprintf(CUDAHEADER, "#pragma once\n");
fprintf(CUDAHEADER, "typedef struct {\n");
@@ -673,12 +650,12 @@ main(int argc, char** argv)
const int retval = yyparse();
if (retval) {
fprintf(stderr, "COMPILATION FAILED\n");
fprintf(stderr, "Fatal error: DSL compilation failed\n");
return EXIT_FAILURE;
}
DSLHEADER = fopen("user_defines.h", "w+");
CUDAHEADER = fopen("user_kernels.h", "w+");
DSLHEADER = fopen(dslheader_filename, "w+");
CUDAHEADER = fopen(cudaheader_filename, "w+");
assert(DSLHEADER);
assert(CUDAHEADER);
@@ -708,12 +685,14 @@ main(int argc, char** argv)
generate_preprocessed_structures();
generate_library_hooks();
print_symbol_table();
// print_symbol_table();
// Cleanup
fclose(DSLHEADER);
fclose(CUDAHEADER);
astnode_destroy(root);
fprintf(stdout, "COMPILATION SUCCESS\n");
fprintf(stdout, "-- Generated %s\n", dslheader_filename);
fprintf(stdout, "-- Generated %s\n", cudaheader_filename);
return EXIT_SUCCESS;
}