acc: Removed debug prints, old code. Also the scope of the declarations made inside a for statement is now properly tracked

This commit is contained in:
jpekkila
2019-10-08 00:20:57 +03:00
parent 08f155cbec
commit 44a86f5e80
3 changed files with 22 additions and 42 deletions

View File

@@ -84,7 +84,7 @@ iteration_statement: WHILE expression compound_statement
| FOR for_expression compound_statement { $$ = astnode_create(NODE_ITERATION_STATEMENT, $2, $3); $$->prefix = FOR; } | FOR for_expression compound_statement { $$ = astnode_create(NODE_ITERATION_STATEMENT, $2, $3); $$->prefix = FOR; }
; ;
for_expression: '(' for_init_param for_other_params ')' { $$ = astnode_create(NODE_UNKNOWN, $2, $3); $$->prefix = '('; $$->postfix = ')'; } for_expression: '(' for_init_param for_other_params ')' { $$ = astnode_create(NODE_FOR_EXPRESSION, $2, $3); $$->prefix = '('; $$->postfix = ')'; }
; ;
for_init_param: expression ';' { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); $$->postfix = ';'; } for_init_param: expression ';' { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); $$->postfix = ';'; }

View File

@@ -33,7 +33,8 @@
FUNC(NODE_COMPOUND_STATEMENT), \ FUNC(NODE_COMPOUND_STATEMENT), \
FUNC(NODE_FUNCTION_PARAMETER_DECLARATION), \ FUNC(NODE_FUNCTION_PARAMETER_DECLARATION), \
FUNC(NODE_MULTIDIM_SUBSCRIPT_EXPRESSION), \ FUNC(NODE_MULTIDIM_SUBSCRIPT_EXPRESSION), \
FUNC(NODE_REAL_NUMBER) FUNC(NODE_REAL_NUMBER), \
FUNC(NODE_FOR_EXPRESSION)
// clang-format on // clang-format on
typedef enum { FOR_NODE_TYPES(GEN_ID), NUM_NODE_TYPES } NodeType; typedef enum { FOR_NODE_TYPES(GEN_ID), NUM_NODE_TYPES } NodeType;

View File

@@ -40,6 +40,9 @@ ASTNode* root = NULL;
static FILE* DSLHEADER = NULL; static FILE* DSLHEADER = NULL;
static FILE* CUDAHEADER = NULL; static FILE* CUDAHEADER = NULL;
static const char* dslheader_filename = "user_defines.h";
static const char* cudaheader_filename = "user_kernels.h";
/* /*
* ============================================================================= * =============================================================================
* Translation * Translation
@@ -101,7 +104,7 @@ translate(const int token)
assert(token < TRANSLATION_TABLE_SIZE); assert(token < TRANSLATION_TABLE_SIZE);
if (token > 0) { if (token > 0) {
if (!translation_table[token]) if (!translation_table[token])
printf("ERROR: unidentified token %d\n", token); fprintf(stderr, "Error: unidentified token %d\n", token);
assert(translation_table[token]); assert(translation_table[token]);
} }
@@ -161,6 +164,7 @@ add_symbol(const SymbolType type, const int tqualifier, const int tspecifier, co
"Syntax error. Symbol '%s' is ambiguous, declared multiple times in the same scope" "Syntax error. Symbol '%s' is ambiguous, declared multiple times in the same scope"
" (shadowing).\n", " (shadowing).\n",
id); id);
assert(0);
} }
symbol_table[num_symbols[current_nest]].type = type; symbol_table[num_symbols[current_nest]].type = type;
@@ -234,28 +238,12 @@ print_symbol_table(void)
*/ */
static bool inside_declaration = false; static bool inside_declaration = false;
static bool inside_kernel = false; static bool inside_kernel = false;
/* /*
* ============================================================================= * =============================================================================
* AST traversal * AST traversal
* ============================================================================= * =============================================================================
*/ */
/*
static bool
introspect(const ASTNode* node, const NodeType type)
{
assert(node);
ASTNode* parent = node->parent;
while (parent) {
if (parent->type == type)
return true;
else
parent = parent->parent;
}
return false;
}
*/
static void static void
traverse(const ASTNode* node) traverse(const ASTNode* node)
{ {
@@ -265,8 +253,6 @@ traverse(const ASTNode* node)
// Prefix logic // Prefix logic
if (node->type == NODE_COMPOUND_STATEMENT) { if (node->type == NODE_COMPOUND_STATEMENT) {
// if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION ||
// node->type == NODE_ITERATION_STATEMENT) {
assert(current_nest < MAX_NESTS); assert(current_nest < MAX_NESTS);
++current_nest; ++current_nest;
@@ -281,7 +267,6 @@ traverse(const ASTNode* node)
// Boilerplates // Boilerplates
const ASTNode* typedecl = node->parent->lhs->lhs; const ASTNode* typedecl = node->parent->lhs->lhs;
const ASTNode* typequal = typedecl->lhs; const ASTNode* typequal = typedecl->lhs;
printf("typedecl %d\n", typedecl->type);
assert(typedecl->type == NODE_TYPE_DECLARATION); assert(typedecl->type == NODE_TYPE_DECLARATION);
if (typequal->type == NODE_TYPE_QUALIFIER) { if (typequal->type == NODE_TYPE_QUALIFIER) {
if (typequal->token == KERNEL) { if (typequal->token == KERNEL) {
@@ -347,6 +332,9 @@ traverse(const ASTNode* node)
stype = SYMBOLTYPE_FUNCTION; stype = SYMBOLTYPE_FUNCTION;
else if (tmp->type == NODE_FUNCTION_PARAMETER_DECLARATION) else if (tmp->type == NODE_FUNCTION_PARAMETER_DECLARATION)
stype = SYMBOLTYPE_FUNCTION_PARAMETER; stype = SYMBOLTYPE_FUNCTION_PARAMETER;
else if (node->parent && node->parent->parent && node->parent->parent->parent &&
node->parent->parent->parent->type == NODE_FOR_EXPRESSION)
stype = SYMBOLTYPE_FUNCTION_PARAMETER;
else else
stype = SYMBOLTYPE_OTHER; stype = SYMBOLTYPE_OTHER;
@@ -358,7 +346,6 @@ traverse(const ASTNode* node)
const char* identifier = node->rhs->type == NODE_IDENTIFIER ? node->rhs->buffer const char* identifier = node->rhs->type == NODE_IDENTIFIER ? node->rhs->buffer
: node->rhs->lhs->buffer; : node->rhs->lhs->buffer;
add_symbol(stype, tqualifier, tspecifier, identifier); add_symbol(stype, tqualifier, tspecifier, identifier);
printf("Added %s\n", identifier);
// Translate the new symbol // Translate the new symbol
if (tqualifier == UNIFORM) { if (tqualifier == UNIFORM) {
@@ -387,9 +374,7 @@ traverse(const ASTNode* node)
// TODO FIX not to use symboltable_lookup // TODO FIX not to use symboltable_lookup
const Symbol* parent_function = symboltable_lookup(tmp->lhs->rhs->buffer); const Symbol* parent_function = symboltable_lookup(tmp->lhs->rhs->buffer);
assert(parent_function); if (parent_function && (tqualifier == IN || tqualifier == OUT)) {
if (tqualifier == IN || tqualifier == OUT) {
if (tmp->lhs->lhs->lhs->token == DEVICE) { if (tmp->lhs->lhs->lhs->token == DEVICE) {
fprintf(CUDAHEADER, "const %sData& %s", // fprintf(CUDAHEADER, "const %sData& %s", //
translate(tspecifier), identifier); translate(tspecifier), identifier);
@@ -431,9 +416,6 @@ traverse(const ASTNode* node)
if (symbol) { if (symbol) {
// Uniforms // Uniforms
if (symbol->type_qualifier == UNIFORM) { if (symbol->type_qualifier == UNIFORM) {
printf("INSIDE KERNEL %d %s, type spec %d vs %d, %d\n", inside_kernel,
symbol->identifier, symbol->type_specifier, SCALARARRAY,
symbol->type_specifier == SCALARARRAY);
if (inside_kernel && symbol->type_specifier == SCALARARRAY) if (inside_kernel && symbol->type_specifier == SCALARARRAY)
fprintf(CUDAHEADER, "buffer.profiles[%s] ", symbol->identifier); fprintf(CUDAHEADER, "buffer.profiles[%s] ", symbol->identifier);
else else
@@ -473,16 +455,11 @@ traverse(const ASTNode* node)
assert(current_nest > 0); assert(current_nest > 0);
--current_nest; --current_nest;
// Drop function parameters // Drop function parameters (incl. those declared in for statements)
while (symbol_table[num_symbols[current_nest] - 1].type == SYMBOLTYPE_FUNCTION_PARAMETER) while (symbol_table[num_symbols[current_nest] - 1].type == SYMBOLTYPE_FUNCTION_PARAMETER)
--num_symbols[current_nest]; --num_symbols[current_nest];
inside_kernel = false; inside_kernel = false;
// Drop temporaries declared with iteration statements
// TODO
printf("Dropped rest of the symbol table, from %lu to %lu\n", num_symbols[current_nest + 1],
num_symbols[current_nest]);
// Kernel writeback boilerplate // Kernel writeback boilerplate
if (node->parent->type == NODE_FUNCTION_DEFINITION) { if (node->parent->type == NODE_FUNCTION_DEFINITION) {
@@ -562,7 +539,7 @@ generate_preprocessed_structures(void)
rewind(CUDAHEADER); rewind(CUDAHEADER);
const size_t buflen = fread(buffer, sizeof(char), max_buflen, CUDAHEADER); const size_t buflen = fread(buffer, sizeof(char), max_buflen, CUDAHEADER);
fclose(CUDAHEADER); fclose(CUDAHEADER);
CUDAHEADER = fopen("user_kernels.h", "w+"); CUDAHEADER = fopen(cudaheader_filename, "w+");
fprintf(CUDAHEADER, "#pragma once\n"); fprintf(CUDAHEADER, "#pragma once\n");
fprintf(CUDAHEADER, "typedef struct {\n"); fprintf(CUDAHEADER, "typedef struct {\n");
@@ -673,12 +650,12 @@ main(int argc, char** argv)
const int retval = yyparse(); const int retval = yyparse();
if (retval) { if (retval) {
fprintf(stderr, "COMPILATION FAILED\n"); fprintf(stderr, "Fatal error: DSL compilation failed\n");
return EXIT_FAILURE; return EXIT_FAILURE;
} }
DSLHEADER = fopen("user_defines.h", "w+"); DSLHEADER = fopen(dslheader_filename, "w+");
CUDAHEADER = fopen("user_kernels.h", "w+"); CUDAHEADER = fopen(cudaheader_filename, "w+");
assert(DSLHEADER); assert(DSLHEADER);
assert(CUDAHEADER); assert(CUDAHEADER);
@@ -708,12 +685,14 @@ main(int argc, char** argv)
generate_preprocessed_structures(); generate_preprocessed_structures();
generate_library_hooks(); generate_library_hooks();
print_symbol_table(); // print_symbol_table();
// Cleanup // Cleanup
fclose(DSLHEADER); fclose(DSLHEADER);
fclose(CUDAHEADER); fclose(CUDAHEADER);
astnode_destroy(root); astnode_destroy(root);
fprintf(stdout, "COMPILATION SUCCESS\n");
fprintf(stdout, "-- Generated %s\n", dslheader_filename);
fprintf(stdout, "-- Generated %s\n", cudaheader_filename);
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }