Rewrote the Astaroth DSL compiler. More information and cleanup in the next commits.
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
#include <stdderiv.h>
|
||||||
|
|
||||||
#define LDENSITY (1)
|
#define LDENSITY (1)
|
||||||
#define LHYDRO (1)
|
#define LHYDRO (1)
|
||||||
#define LMAGNETIC (1)
|
#define LMAGNETIC (1)
|
||||||
@@ -8,6 +10,8 @@
|
|||||||
#define LSINK (0)
|
#define LSINK (0)
|
||||||
|
|
||||||
#define AC_THERMAL_CONDUCTIVITY (AcReal(0.001)) // TODO: make an actual config parameter
|
#define AC_THERMAL_CONDUCTIVITY (AcReal(0.001)) // TODO: make an actual config parameter
|
||||||
|
#define H_CONST (0) // TODO: make an actual config parameter
|
||||||
|
#define C_CONST (0) // TODO: make an actual config parameter
|
||||||
|
|
||||||
// Int params
|
// Int params
|
||||||
uniform int AC_max_steps;
|
uniform int AC_max_steps;
|
||||||
@@ -20,9 +24,6 @@ uniform int AC_start_step;
|
|||||||
uniform Scalar AC_dt;
|
uniform Scalar AC_dt;
|
||||||
uniform Scalar AC_max_time;
|
uniform Scalar AC_max_time;
|
||||||
// Spacing
|
// Spacing
|
||||||
uniform Scalar AC_dsx;
|
|
||||||
uniform Scalar AC_dsy;
|
|
||||||
uniform Scalar AC_dsz;
|
|
||||||
uniform Scalar AC_dsmin;
|
uniform Scalar AC_dsmin;
|
||||||
// physical grid
|
// physical grid
|
||||||
uniform Scalar AC_xlen;
|
uniform Scalar AC_xlen;
|
||||||
@@ -96,9 +97,6 @@ uniform Scalar AC_GM_star;
|
|||||||
uniform Scalar AC_unit_mass;
|
uniform Scalar AC_unit_mass;
|
||||||
uniform Scalar AC_sq2GM_star;
|
uniform Scalar AC_sq2GM_star;
|
||||||
uniform Scalar AC_cs2_sound;
|
uniform Scalar AC_cs2_sound;
|
||||||
uniform Scalar AC_inv_dsx;
|
|
||||||
uniform Scalar AC_inv_dsy;
|
|
||||||
uniform Scalar AC_inv_dsz;
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* =============================================================================
|
* =============================================================================
|
||||||
@@ -135,19 +133,6 @@ uniform ScalarField VTXBUF_LNRHO;
|
|||||||
uniform ScalarField VTXBUF_ACCRETION;
|
uniform ScalarField VTXBUF_ACCRETION;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
Preprocessed Scalar
|
|
||||||
value(in ScalarField vertex)
|
|
||||||
{
|
|
||||||
return vertex[vertexIdx];
|
|
||||||
}
|
|
||||||
|
|
||||||
Preprocessed Vector
|
|
||||||
gradient(in ScalarField vertex)
|
|
||||||
{
|
|
||||||
return (Vector){derx(vertexIdx, vertex), dery(vertexIdx, vertex), derz(vertexIdx, vertex)};
|
|
||||||
}
|
|
||||||
|
|
||||||
#if LUPWD
|
#if LUPWD
|
||||||
|
|
||||||
Preprocessed Scalar
|
Preprocessed Scalar
|
||||||
@@ -197,24 +182,6 @@ der6z_upwd(in ScalarField vertex)
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
Preprocessed Matrix
|
|
||||||
hessian(in ScalarField vertex)
|
|
||||||
{
|
|
||||||
Matrix hessian;
|
|
||||||
|
|
||||||
hessian.row[0] = (Vector){derxx(vertexIdx, vertex), derxy(vertexIdx, vertex),
|
|
||||||
derxz(vertexIdx, vertex)};
|
|
||||||
hessian.row[1] = (Vector){hessian.row[0].y, deryy(vertexIdx, vertex), deryz(vertexIdx, vertex)};
|
|
||||||
hessian.row[2] = (Vector){hessian.row[0].z, hessian.row[1].z, derzz(vertexIdx, vertex)};
|
|
||||||
|
|
||||||
return hessian;
|
|
||||||
}
|
|
||||||
|
|
||||||
Device Vector
|
|
||||||
value(in VectorField uu)
|
|
||||||
{
|
|
||||||
return (Vector){value(uu.x), value(uu.y), value(uu.z)};
|
|
||||||
}
|
|
||||||
|
|
||||||
#if LUPWD
|
#if LUPWD
|
||||||
Device Scalar
|
Device Scalar
|
||||||
@@ -492,9 +459,8 @@ induction(in VectorField uu, in VectorField aa)
|
|||||||
Device Scalar
|
Device Scalar
|
||||||
lnT(in ScalarField ss, in ScalarField lnrho)
|
lnT(in ScalarField ss, in ScalarField lnrho)
|
||||||
{
|
{
|
||||||
const Scalar lnT = AC_lnT0 + AC_gamma * value(ss) / AC_cp_sound +
|
return AC_lnT0 + AC_gamma * value(ss) / AC_cp_sound +
|
||||||
(AC_gamma - Scalar(1.0)) * (value(lnrho) - AC_lnrho0);
|
(AC_gamma - Scalar(1.0)) * (value(lnrho) - AC_lnrho0);
|
||||||
return lnT;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Nabla dot (K nabla T) / (rho T)
|
// Nabla dot (K nabla T) / (rho T)
|
||||||
|
@@ -80,8 +80,8 @@ else_selection_statement: compound_statement
|
|||||||
elif_selection_statement: ELIF expression else_selection_statement { $$ = astnode_create(NODE_UNKNOWN, $2, $3); $$->prefix = ELIF; }
|
elif_selection_statement: ELIF expression else_selection_statement { $$ = astnode_create(NODE_UNKNOWN, $2, $3); $$->prefix = ELIF; }
|
||||||
;
|
;
|
||||||
|
|
||||||
iteration_statement: WHILE expression compound_statement { $$ = astnode_create(NODE_UNKNOWN, $2, $3); $$->prefix = WHILE; }
|
iteration_statement: WHILE expression compound_statement { $$ = astnode_create(NODE_ITERATION_STATEMENT, $2, $3); $$->prefix = WHILE; }
|
||||||
| FOR for_expression compound_statement { $$ = astnode_create(NODE_UNKNOWN, $2, $3); $$->prefix = FOR; }
|
| FOR for_expression compound_statement { $$ = astnode_create(NODE_ITERATION_STATEMENT, $2, $3); $$->prefix = FOR; }
|
||||||
;
|
;
|
||||||
|
|
||||||
for_expression: '(' for_init_param for_other_params ')' { $$ = astnode_create(NODE_UNKNOWN, $2, $3); $$->prefix = '('; $$->postfix = ')'; }
|
for_expression: '(' for_init_param for_other_params ')' { $$ = astnode_create(NODE_UNKNOWN, $2, $3); $$->prefix = '('; $$->postfix = ')'; }
|
||||||
@@ -124,8 +124,8 @@ declaration: type_declaration identifier
|
|||||||
| type_declaration array_declaration { $$ = astnode_create(NODE_DECLARATION, $1, $2); }
|
| type_declaration array_declaration { $$ = astnode_create(NODE_DECLARATION, $1, $2); }
|
||||||
;
|
;
|
||||||
|
|
||||||
array_declaration: identifier '[' ']' { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); $$->infix = '['; $$->postfix = ']'; }
|
array_declaration: identifier '[' ']' { $$ = astnode_create(NODE_ARRAY_DECLARATION, $1, NULL); $$->infix = '['; $$->postfix = ']'; }
|
||||||
| identifier '[' expression ']' { $$ = astnode_create(NODE_UNKNOWN, $1, $3); $$->infix = '['; $$->postfix = ']'; }
|
| identifier '[' expression ']' { $$ = astnode_create(NODE_ARRAY_DECLARATION, $1, $3); $$->infix = '['; $$->postfix = ']'; }
|
||||||
;
|
;
|
||||||
|
|
||||||
type_declaration: type_specifier { $$ = astnode_create(NODE_TYPE_DECLARATION, $1, NULL); }
|
type_declaration: type_specifier { $$ = astnode_create(NODE_TYPE_DECLARATION, $1, NULL); }
|
||||||
|
@@ -20,8 +20,10 @@
|
|||||||
FUNC(NODE_UNKNOWN), \
|
FUNC(NODE_UNKNOWN), \
|
||||||
FUNC(NODE_DEFINITION), \
|
FUNC(NODE_DEFINITION), \
|
||||||
FUNC(NODE_GLOBAL_DEFINITION), \
|
FUNC(NODE_GLOBAL_DEFINITION), \
|
||||||
|
FUNC(NODE_ITERATION_STATEMENT), \
|
||||||
FUNC(NODE_DECLARATION), \
|
FUNC(NODE_DECLARATION), \
|
||||||
FUNC(NODE_DECLARATION_LIST), \
|
FUNC(NODE_DECLARATION_LIST), \
|
||||||
|
FUNC(NODE_ARRAY_DECLARATION), \
|
||||||
FUNC(NODE_TYPE_DECLARATION), \
|
FUNC(NODE_TYPE_DECLARATION), \
|
||||||
FUNC(NODE_TYPE_QUALIFIER), \
|
FUNC(NODE_TYPE_QUALIFIER), \
|
||||||
FUNC(NODE_TYPE_SPECIFIER), \
|
FUNC(NODE_TYPE_SPECIFIER), \
|
||||||
|
@@ -153,6 +153,13 @@ add_symbol(const SymbolType type, const int tqualifier, const int tspecifier, co
|
|||||||
{
|
{
|
||||||
assert(num_symbols[current_nest] < SYMBOL_TABLE_SIZE);
|
assert(num_symbols[current_nest] < SYMBOL_TABLE_SIZE);
|
||||||
|
|
||||||
|
if (symboltable_lookup(id) && type != SYMBOLTYPE_FUNCTION) {
|
||||||
|
fprintf(stderr,
|
||||||
|
"Syntax error. Symbol '%s' is ambiguous, declared multiple times in the same scope"
|
||||||
|
" (shadowing).\n",
|
||||||
|
id);
|
||||||
|
}
|
||||||
|
|
||||||
symbol_table[num_symbols[current_nest]].type = type;
|
symbol_table[num_symbols[current_nest]].type = type;
|
||||||
symbol_table[num_symbols[current_nest]].type_qualifier = tqualifier;
|
symbol_table[num_symbols[current_nest]].type_qualifier = tqualifier;
|
||||||
symbol_table[num_symbols[current_nest]].type_specifier = tspecifier;
|
symbol_table[num_symbols[current_nest]].type_specifier = tspecifier;
|
||||||
@@ -222,73 +229,107 @@ print_symbol_table(void)
|
|||||||
* Traversal state
|
* Traversal state
|
||||||
* =============================================================================
|
* =============================================================================
|
||||||
*/
|
*/
|
||||||
|
static bool inside_declaration = false;
|
||||||
/*
|
/*
|
||||||
* =============================================================================
|
* =============================================================================
|
||||||
* AST traversal
|
* AST traversal
|
||||||
* =============================================================================
|
* =============================================================================
|
||||||
*/
|
*/
|
||||||
|
/*
|
||||||
|
static bool
|
||||||
|
introspect(const ASTNode* node, const NodeType type)
|
||||||
|
{
|
||||||
|
assert(node);
|
||||||
|
|
||||||
|
ASTNode* parent = node->parent;
|
||||||
|
while (parent) {
|
||||||
|
if (parent->type == type)
|
||||||
|
return true;
|
||||||
|
else
|
||||||
|
parent = parent->parent;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
static void
|
static void
|
||||||
traverse(const ASTNode* node)
|
traverse(const ASTNode* node)
|
||||||
{
|
{
|
||||||
// Prefix translation
|
// Prefix translation
|
||||||
if (translate(node->prefix))
|
if (!inside_declaration && translate(node->prefix))
|
||||||
fprintf(CUDAHEADER, "%s", translate(node->prefix));
|
fprintf(CUDAHEADER, "%s", translate(node->prefix));
|
||||||
|
|
||||||
// Prefix logic
|
// Prefix logic
|
||||||
if (node->type == NODE_COMPOUND_STATEMENT) {
|
if (node->type == NODE_COMPOUND_STATEMENT) {
|
||||||
|
// if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION ||
|
||||||
|
// node->type == NODE_ITERATION_STATEMENT) {
|
||||||
assert(current_nest < MAX_NESTS);
|
assert(current_nest < MAX_NESTS);
|
||||||
|
|
||||||
++current_nest;
|
++current_nest;
|
||||||
num_symbols[current_nest] = num_symbols[current_nest - 1];
|
num_symbols[current_nest] = num_symbols[current_nest - 1];
|
||||||
}
|
}
|
||||||
|
if (node->type == NODE_DECLARATION)
|
||||||
|
inside_declaration = true;
|
||||||
|
|
||||||
|
if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION) {
|
||||||
|
// Boilerplates
|
||||||
|
const ASTNode* typedecl = node->parent->lhs->lhs;
|
||||||
|
const ASTNode* typequal = typedecl->lhs;
|
||||||
|
printf("typedecl %d\n", typedecl->type);
|
||||||
|
assert(typedecl->type == NODE_TYPE_DECLARATION);
|
||||||
|
if (typequal->type == NODE_TYPE_QUALIFIER) {
|
||||||
|
if (typequal->token == KERNEL) {
|
||||||
|
fprintf(CUDAHEADER, "GEN_KERNEL_PARAM_BOILERPLATE");
|
||||||
|
if (node->lhs != NULL) {
|
||||||
|
fprintf(
|
||||||
|
stderr,
|
||||||
|
"Syntax error: function parameters for Kernel functions not allowed!\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (typequal->token == PREPROCESSED) {
|
||||||
|
fprintf(CUDAHEADER, "GEN_PREPROCESSED_PARAM_BOILERPLATE, ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->type == NODE_COMPOUND_STATEMENT) {
|
||||||
|
if (node->parent->type == NODE_FUNCTION_DEFINITION) {
|
||||||
|
const Symbol* symbol = symboltable_lookup(node->parent->lhs->lhs->rhs->buffer);
|
||||||
|
if (symbol && symbol->type_qualifier == KERNEL) {
|
||||||
|
fprintf(CUDAHEADER, "GEN_KERNEL_BUILTIN_VARIABLES_BOILERPLATE();");
|
||||||
|
for (int i = 0; i < num_symbols[current_nest]; ++i) {
|
||||||
|
if (symbol_table[i].type_qualifier == IN) {
|
||||||
|
fprintf(CUDAHEADER, "const %sData %s = READ(handle_%s);\n",
|
||||||
|
translate(symbol_table[i].type_specifier),
|
||||||
|
symbol_table[i].identifier, symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
else if (symbol_table[i].type_qualifier == OUT) {
|
||||||
|
fprintf(CUDAHEADER, "%s %s = READ_OUT(handle_%s);",
|
||||||
|
translate(symbol_table[i].type_specifier),
|
||||||
|
symbol_table[i].identifier, symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Traverse LHS
|
// Traverse LHS
|
||||||
if (node->lhs)
|
if (node->lhs)
|
||||||
traverse(node->lhs);
|
traverse(node->lhs);
|
||||||
|
|
||||||
// Infix translation
|
// Infix translation
|
||||||
if (translate(node->infix))
|
if (!inside_declaration && translate(node->infix))
|
||||||
fprintf(CUDAHEADER, "%s", translate(node->infix));
|
fprintf(CUDAHEADER, "%s", translate(node->infix));
|
||||||
|
|
||||||
// Infix logic
|
// Infix logic
|
||||||
// TODO
|
// If the node is a subscript expression and the expression list inside it is not empty
|
||||||
|
if (node->type == NODE_MULTIDIM_SUBSCRIPT_EXPRESSION && node->rhs)
|
||||||
|
fprintf(CUDAHEADER, "IDX(");
|
||||||
|
|
||||||
// Traverse RHS
|
// Traverse RHS
|
||||||
if (node->rhs)
|
if (node->rhs)
|
||||||
traverse(node->rhs);
|
traverse(node->rhs);
|
||||||
|
|
||||||
// Postfix translation
|
|
||||||
if (translate(node->postfix))
|
|
||||||
fprintf(CUDAHEADER, "%s", translate(node->postfix));
|
|
||||||
|
|
||||||
// Translate existing symbols
|
|
||||||
const Symbol* symbol = symboltable_lookup(node->buffer);
|
|
||||||
if (symbol) {
|
|
||||||
// Uniforms
|
|
||||||
if (symbol->type_qualifier == UNIFORM) {
|
|
||||||
fprintf(CUDAHEADER, "DCONST(%s) ", symbol->identifier);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// print_symbol2(symbol);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
/*
|
|
||||||
// Translate literals
|
|
||||||
if (translate(node->token))
|
|
||||||
printf("%s ", translate(node->token));
|
|
||||||
if (node->buffer) {
|
|
||||||
if (node->type == NODE_REAL_NUMBER) {
|
|
||||||
printf("%s(%s) ", translate(SCALAR), node->buffer); // Cast to correct precision
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
printf("%s ", node->buffer);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add new symbols to the symbol table
|
// Add new symbols to the symbol table
|
||||||
if (node->type == NODE_DECLARATION) {
|
if (node->type == NODE_DECLARATION) {
|
||||||
int stype;
|
int stype;
|
||||||
@@ -311,6 +352,7 @@ traverse(const ASTNode* node)
|
|||||||
const char* identifier = node->rhs->type == NODE_IDENTIFIER ? node->rhs->buffer
|
const char* identifier = node->rhs->type == NODE_IDENTIFIER ? node->rhs->buffer
|
||||||
: node->rhs->lhs->buffer;
|
: node->rhs->lhs->buffer;
|
||||||
add_symbol(stype, tqualifier, tspecifier, identifier);
|
add_symbol(stype, tqualifier, tspecifier, identifier);
|
||||||
|
printf("Added %s\n", identifier);
|
||||||
|
|
||||||
// Translate the new symbol
|
// Translate the new symbol
|
||||||
if (tqualifier == UNIFORM) {
|
if (tqualifier == UNIFORM) {
|
||||||
@@ -328,22 +370,41 @@ traverse(const ASTNode* node)
|
|||||||
fprintf(CUDAHEADER, "%s %s\npreprocessed_%s", //
|
fprintf(CUDAHEADER, "%s %s\npreprocessed_%s", //
|
||||||
translate(tqualifier), translate(tspecifier), identifier);
|
translate(tqualifier), translate(tspecifier), identifier);
|
||||||
}
|
}
|
||||||
|
else if (stype == SYMBOLTYPE_FUNCTION) {
|
||||||
|
// Stencil assembly stage device function
|
||||||
|
fprintf(CUDAHEADER, "%s %s\n%s", //
|
||||||
|
translate(DEVICE), translate(tspecifier), identifier);
|
||||||
|
}
|
||||||
else if (stype == SYMBOLTYPE_FUNCTION_PARAMETER) {
|
else if (stype == SYMBOLTYPE_FUNCTION_PARAMETER) {
|
||||||
tmp = tmp->parent;
|
tmp = tmp->parent;
|
||||||
assert(tmp->type = NODE_FUNCTION_DECLARATION);
|
assert(tmp->type = NODE_FUNCTION_DECLARATION);
|
||||||
|
|
||||||
const Symbol* parent_function = symboltable_lookup(tmp->lhs->rhs->buffer);
|
const Symbol* parent_function = symboltable_lookup(tmp->lhs->rhs->buffer);
|
||||||
assert(parent_function);
|
assert(parent_function);
|
||||||
|
|
||||||
if (tqualifier == IN || tqualifier == OUT) {
|
if (tqualifier == IN || tqualifier == OUT) {
|
||||||
if (parent_function->type_qualifier == 0 ||
|
if (tmp->lhs->lhs->lhs->token == DEVICE) {
|
||||||
parent_function->type_qualifier == PREPROCESSED) {
|
|
||||||
fprintf(CUDAHEADER, "const __restrict__ %s* %s", //
|
|
||||||
translate(tspecifier), identifier);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
fprintf(CUDAHEADER, "const %sData& %s", //
|
fprintf(CUDAHEADER, "const %sData& %s", //
|
||||||
translate(tspecifier), identifier);
|
translate(tspecifier), identifier);
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
fprintf(CUDAHEADER, "const __restrict__ %s* %s", //
|
||||||
|
translate(tspecifier), identifier);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
if (parent_function->type_qualifier == 0 ||
|
||||||
|
parent_function->type_qualifier == PREPROCESSED) {
|
||||||
|
fprintf(CUDAHEADER, "const __restrict__ %s* %s", //
|
||||||
|
translate(tspecifier), identifier);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
fprintf(CUDAHEADER, "const %sData& %s", //
|
||||||
|
translate(tspecifier), identifier);
|
||||||
|
}*/
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
print_symbol2(&symbol_table[num_symbols[current_nest] - 1]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (tqualifier == IN || tqualifier == OUT) { // Global in/out declarator
|
else if (tqualifier == IN || tqualifier == OUT) { // Global in/out declarator
|
||||||
@@ -356,22 +417,183 @@ traverse(const ASTNode* node)
|
|||||||
// Do a regular translation
|
// Do a regular translation
|
||||||
print_symbol2(&symbol_table[num_symbols[current_nest] - 1]);
|
print_symbol2(&symbol_table[num_symbols[current_nest] - 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node->rhs->type == NODE_ARRAY_DECLARATION) {
|
||||||
|
// Traverse the expression once again, this time with
|
||||||
|
// "inside_declaration" flag off
|
||||||
|
inside_declaration = false;
|
||||||
|
fprintf(CUDAHEADER, "%s ", translate(node->rhs->infix));
|
||||||
|
if (node->rhs->rhs)
|
||||||
|
traverse(node->rhs->rhs);
|
||||||
|
fprintf(CUDAHEADER, "%s ", translate(node->rhs->postfix));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Translate existing symbols
|
||||||
|
const Symbol* symbol = symboltable_lookup(node->buffer);
|
||||||
|
|
||||||
|
if (symbol) {
|
||||||
|
// Uniforms
|
||||||
|
if (symbol->type_qualifier == UNIFORM) {
|
||||||
|
fprintf(CUDAHEADER, "DCONST(%s) ", symbol->identifier);
|
||||||
|
}
|
||||||
|
else if (node->parent->type != NODE_DECLARATION) {
|
||||||
|
// Regular translation
|
||||||
|
if (translate(node->token))
|
||||||
|
fprintf(CUDAHEADER, "%s ", translate(node->token));
|
||||||
|
if (node->buffer)
|
||||||
|
fprintf(CUDAHEADER, "%s ", node->buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (!inside_declaration) {
|
||||||
|
// Literal translation
|
||||||
|
if (translate(node->token))
|
||||||
|
fprintf(CUDAHEADER, "%s ", translate(node->token));
|
||||||
|
if (node->buffer) {
|
||||||
|
if (node->type == NODE_REAL_NUMBER) {
|
||||||
|
fprintf(CUDAHEADER, "%s(%s) ", translate(SCALAR),
|
||||||
|
node->buffer); // Cast to correct precision
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
fprintf(CUDAHEADER, "%s ", node->buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Postfix logic
|
// Postfix logic
|
||||||
|
// If the node is a subscript expression and the expression list inside it is not empty
|
||||||
|
if (node->type == NODE_MULTIDIM_SUBSCRIPT_EXPRESSION && node->rhs)
|
||||||
|
fprintf(CUDAHEADER, ")"); // Closing bracket of IDX()
|
||||||
|
|
||||||
if (node->type == NODE_COMPOUND_STATEMENT) {
|
if (node->type == NODE_COMPOUND_STATEMENT) {
|
||||||
|
// if (node->type == NODE_FUNCTION_DEFINITION || node->type == NODE_ITERATION_STATEMENT) {
|
||||||
assert(current_nest > 0);
|
assert(current_nest > 0);
|
||||||
--current_nest;
|
--current_nest;
|
||||||
|
|
||||||
|
// Drop function parameters
|
||||||
|
while (symbol_table[num_symbols[current_nest] - 1].type == SYMBOLTYPE_FUNCTION_PARAMETER)
|
||||||
|
--num_symbols[current_nest];
|
||||||
|
|
||||||
|
// Drop temporaries declared with iteration statements
|
||||||
|
// TODO
|
||||||
|
|
||||||
printf("Dropped rest of the symbol table, from %lu to %lu\n", num_symbols[current_nest + 1],
|
printf("Dropped rest of the symbol table, from %lu to %lu\n", num_symbols[current_nest + 1],
|
||||||
num_symbols[current_nest]);
|
num_symbols[current_nest]);
|
||||||
|
|
||||||
|
// Kernel writeback boilerplate
|
||||||
|
if (node->parent->type == NODE_FUNCTION_DEFINITION) {
|
||||||
|
const Symbol* symbol = symboltable_lookup(node->parent->lhs->lhs->rhs->buffer);
|
||||||
|
if (symbol && symbol->type_qualifier == KERNEL) {
|
||||||
|
for (int i = 0; i < num_symbols[current_nest]; ++i) {
|
||||||
|
if (symbol_table[i].type_qualifier == OUT) {
|
||||||
|
fprintf(CUDAHEADER, "WRITE_OUT(handle_%s, %s);\n",
|
||||||
|
symbol_table[i].identifier, symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if (node->type == NODE_DECLARATION)
|
||||||
|
inside_declaration = false;
|
||||||
|
|
||||||
|
// Postfix translation
|
||||||
|
if (!inside_declaration && translate(node->postfix))
|
||||||
|
fprintf(CUDAHEADER, "%s", translate(node->postfix));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
gen_preprocessed_forward_declarations(void)
|
||||||
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
generate_preprocessed_structures(void)
|
generate_preprocessed_structures(void)
|
||||||
{
|
{
|
||||||
// TODO
|
// Data structure
|
||||||
|
fprintf(CUDAHEADER, "\n");
|
||||||
|
|
||||||
|
// Read data to the data struct
|
||||||
|
fprintf(CUDAHEADER, "static __device__ __forceinline__ AcRealData\
|
||||||
|
read_data(const int3& vertexIdx,\
|
||||||
|
const int3& globalVertexIdx,\
|
||||||
|
AcReal* __restrict__ buf[], const int handle)\
|
||||||
|
{\n\
|
||||||
|
%sData data;\n",
|
||||||
|
translate(SCALAR));
|
||||||
|
|
||||||
|
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||||
|
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
||||||
|
fprintf(CUDAHEADER,
|
||||||
|
"data.%s = preprocessed_%s(vertexIdx, globalVertexIdx, buf[handle]);\n",
|
||||||
|
symbol_table[i].identifier, symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
fprintf(CUDAHEADER, "return data;\n");
|
||||||
|
fprintf(CUDAHEADER, "}\n");
|
||||||
|
|
||||||
|
// Functions for accessing the data struct members
|
||||||
|
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||||
|
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
||||||
|
fprintf(CUDAHEADER, "static __device__ __forceinline__ %s\
|
||||||
|
%s(const AcRealData& data)\
|
||||||
|
{\n\
|
||||||
|
return data.%s;\
|
||||||
|
}\n",
|
||||||
|
translate(symbol_table[i].type_specifier), symbol_table[i].identifier,
|
||||||
|
symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Syntactic sugar: Vector data struct
|
||||||
|
fprintf(CUDAHEADER, "static __device__ __forceinline__ AcReal3Data\
|
||||||
|
read_data(const int3& vertexIdx,\
|
||||||
|
const int3& globalVertexIdx,\
|
||||||
|
AcReal* __restrict__ buf[], const int3& handle)\
|
||||||
|
{\
|
||||||
|
AcReal3Data data;\
|
||||||
|
\
|
||||||
|
data.x = read_data(vertexIdx, globalVertexIdx, buf, handle.x);\
|
||||||
|
data.y = read_data(vertexIdx, globalVertexIdx, buf, handle.y);\
|
||||||
|
data.z = read_data(vertexIdx, globalVertexIdx, buf, handle.z);\
|
||||||
|
\
|
||||||
|
return data;\
|
||||||
|
}\
|
||||||
|
");
|
||||||
|
|
||||||
|
const size_t max_buflen = 65536;
|
||||||
|
char buffer[max_buflen];
|
||||||
|
rewind(CUDAHEADER);
|
||||||
|
const size_t buflen = fread(buffer, sizeof(char), max_buflen, CUDAHEADER);
|
||||||
|
fclose(CUDAHEADER);
|
||||||
|
CUDAHEADER = fopen("user_kernels.h", "w+");
|
||||||
|
|
||||||
|
fprintf(CUDAHEADER, "#pragma once\n");
|
||||||
|
fprintf(CUDAHEADER, "typedef struct {\n");
|
||||||
|
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||||
|
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
||||||
|
fprintf(CUDAHEADER, "%s %s;\n", translate(symbol_table[i].type_specifier),
|
||||||
|
symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
fprintf(CUDAHEADER, "} %sData;\n", translate(SCALAR));
|
||||||
|
fprintf(CUDAHEADER, "typedef struct {\
|
||||||
|
AcRealData x;\
|
||||||
|
AcRealData y;\
|
||||||
|
AcRealData z;\
|
||||||
|
} AcReal3Data;\n");
|
||||||
|
fprintf(CUDAHEADER, "static __device__ AcRealData\
|
||||||
|
read_data(const int3& vertexIdx,\
|
||||||
|
const int3& globalVertexIdx,\
|
||||||
|
AcReal* __restrict__ buf[], const int handle);\n");
|
||||||
|
fprintf(CUDAHEADER, "static __device__ AcReal3Data\
|
||||||
|
read_data(const int3& vertexIdx,\
|
||||||
|
const int3& globalVertexIdx,\
|
||||||
|
AcReal* __restrict__ buf[], const int3& handle);\n");
|
||||||
|
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||||
|
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
||||||
|
fprintf(CUDAHEADER, "static __device__ %s %s(const AcRealData& data);\n",
|
||||||
|
translate(symbol_table[i].type_specifier), symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
|
||||||
|
fwrite(buffer, sizeof(char), buflen, CUDAHEADER);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
@@ -440,7 +662,7 @@ static void
|
|||||||
generate_library_hooks(void)
|
generate_library_hooks(void)
|
||||||
{
|
{
|
||||||
for (int i = 0; i < num_symbols[current_nest]; ++i) {
|
for (int i = 0; i < num_symbols[current_nest]; ++i) {
|
||||||
if (symbol_table[i].type_qualifier == KERNEL && symbol_table[i].type_qualifier == UNIFORM) {
|
if (symbol_table[i].type_qualifier == KERNEL) {
|
||||||
fprintf(CUDAHEADER, "GEN_DEVICE_FUNC_HOOK(%s)\n", symbol_table[i].identifier);
|
fprintf(CUDAHEADER, "GEN_DEVICE_FUNC_HOOK(%s)\n", symbol_table[i].identifier);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -464,6 +686,7 @@ main(int argc, char** argv)
|
|||||||
|
|
||||||
traverse(root);
|
traverse(root);
|
||||||
generate_header();
|
generate_header();
|
||||||
|
generate_preprocessed_structures();
|
||||||
generate_library_hooks();
|
generate_library_hooks();
|
||||||
|
|
||||||
print_symbol_table();
|
print_symbol_table();
|
||||||
|
@@ -2,6 +2,13 @@
|
|||||||
#define STENCIL_ORDER (6)
|
#define STENCIL_ORDER (6)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
uniform Scalar AC_dsx;
|
||||||
|
uniform Scalar AC_dsy;
|
||||||
|
uniform Scalar AC_dsz;
|
||||||
|
uniform Scalar AC_inv_dsx;
|
||||||
|
uniform Scalar AC_inv_dsy;
|
||||||
|
uniform Scalar AC_inv_dsz;
|
||||||
|
|
||||||
Scalar
|
Scalar
|
||||||
first_derivative(Scalar pencil[], Scalar inv_ds)
|
first_derivative(Scalar pencil[], Scalar inv_ds)
|
||||||
{
|
{
|
||||||
@@ -212,6 +219,12 @@ value(in ScalarField vertex)
|
|||||||
return vertex[vertexIdx];
|
return vertex[vertexIdx];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Device Vector
|
||||||
|
value(in VectorField uu)
|
||||||
|
{
|
||||||
|
return (Vector){value(uu.x), value(uu.y), value(uu.z)};
|
||||||
|
}
|
||||||
|
|
||||||
Preprocessed Vector
|
Preprocessed Vector
|
||||||
gradient(in ScalarField vertex)
|
gradient(in ScalarField vertex)
|
||||||
{
|
{
|
||||||
@@ -221,12 +234,106 @@ gradient(in ScalarField vertex)
|
|||||||
Preprocessed Matrix
|
Preprocessed Matrix
|
||||||
hessian(in ScalarField vertex)
|
hessian(in ScalarField vertex)
|
||||||
{
|
{
|
||||||
Matrix hessian;
|
Matrix mat;
|
||||||
|
|
||||||
hessian.row[0] = (Vector){derxx(vertexIdx, vertex), derxy(vertexIdx, vertex),
|
mat.row[0] = (Vector){derxx(vertexIdx, vertex), derxy(vertexIdx, vertex),
|
||||||
derxz(vertexIdx, vertex)};
|
derxz(vertexIdx, vertex)};
|
||||||
hessian.row[1] = (Vector){hessian.row[0].y, deryy(vertexIdx, vertex), deryz(vertexIdx, vertex)};
|
mat.row[1] = (Vector){mat.row[0].y, deryy(vertexIdx, vertex), deryz(vertexIdx, vertex)};
|
||||||
hessian.row[2] = (Vector){hessian.row[0].z, hessian.row[1].z, derzz(vertexIdx, vertex)};
|
mat.row[2] = (Vector){mat.row[0].z, mat.row[1].z, derzz(vertexIdx, vertex)};
|
||||||
|
|
||||||
return hessian;
|
return mat;
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////// NEW
|
||||||
|
|
||||||
|
Device Scalar
|
||||||
|
laplace(in ScalarField data)
|
||||||
|
{
|
||||||
|
return hessian(data).row[0].x + hessian(data).row[1].y + hessian(data).row[2].z;
|
||||||
|
}
|
||||||
|
|
||||||
|
Device Scalar
|
||||||
|
divergence(in VectorField vec)
|
||||||
|
{
|
||||||
|
return gradient(vec.x).x + gradient(vec.y).y + gradient(vec.z).z;
|
||||||
|
}
|
||||||
|
|
||||||
|
Device Vector
|
||||||
|
laplace_vec(in VectorField vec)
|
||||||
|
{
|
||||||
|
return (Vector){laplace(vec.x), laplace(vec.y), laplace(vec.z)};
|
||||||
|
}
|
||||||
|
|
||||||
|
Device Vector
|
||||||
|
curl(in VectorField vec)
|
||||||
|
{
|
||||||
|
return (Vector){gradient(vec.z).y - gradient(vec.y).z, gradient(vec.x).z - gradient(vec.z).x,
|
||||||
|
gradient(vec.y).x - gradient(vec.x).y};
|
||||||
|
}
|
||||||
|
|
||||||
|
Device Vector
|
||||||
|
gradient_of_divergence(in VectorField vec)
|
||||||
|
{
|
||||||
|
return (Vector){hessian(vec.x).row[0].x + hessian(vec.y).row[0].y + hessian(vec.z).row[0].z,
|
||||||
|
hessian(vec.x).row[1].x + hessian(vec.y).row[1].y + hessian(vec.z).row[1].z,
|
||||||
|
hessian(vec.x).row[2].x + hessian(vec.y).row[2].y + hessian(vec.z).row[2].z};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Takes uu gradients and returns S
|
||||||
|
Device Matrix
|
||||||
|
stress_tensor(in VectorField vec)
|
||||||
|
{
|
||||||
|
Matrix S;
|
||||||
|
|
||||||
|
S.row[0].x = Scalar(2.0 / 3.0) * gradient(vec.x).x -
|
||||||
|
Scalar(1.0 / 3.0) * (gradient(vec.y).y + gradient(vec.z).z);
|
||||||
|
S.row[0].y = Scalar(1.0 / 2.0) * (gradient(vec.x).y + gradient(vec.y).x);
|
||||||
|
S.row[0].z = Scalar(1.0 / 2.0) * (gradient(vec.x).z + gradient(vec.z).x);
|
||||||
|
|
||||||
|
S.row[1].y = Scalar(2.0 / 3.0) * gradient(vec.y).y -
|
||||||
|
Scalar(1.0 / 3.0) * (gradient(vec.x).x + gradient(vec.z).z);
|
||||||
|
|
||||||
|
S.row[1].z = Scalar(1.0 / 2.0) * (gradient(vec.y).z + gradient(vec.z).y);
|
||||||
|
|
||||||
|
S.row[2].z = Scalar(2.0 / 3.0) * gradient(vec.z).z -
|
||||||
|
Scalar(1.0 / 3.0) * (gradient(vec.x).x + gradient(vec.y).y);
|
||||||
|
|
||||||
|
S.row[1].x = S.row[0].y;
|
||||||
|
S.row[2].x = S.row[0].z;
|
||||||
|
S.row[2].y = S.row[1].z;
|
||||||
|
|
||||||
|
return S;
|
||||||
|
}
|
||||||
|
|
||||||
|
Device Scalar
|
||||||
|
contract(const Matrix mat)
|
||||||
|
{
|
||||||
|
Scalar res = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
res = res + dot(mat.row[i], mat.row[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////// NEW NEW BLAS
|
||||||
|
|
||||||
|
Device Scalar
|
||||||
|
length(const Vector vec)
|
||||||
|
{
|
||||||
|
return sqrt(vec.x * vec.x + vec.y * vec.y + vec.z * vec.z);
|
||||||
|
}
|
||||||
|
|
||||||
|
Device Scalar
|
||||||
|
reciprocal_len(const Vector vec)
|
||||||
|
{
|
||||||
|
return rsqrt(vec.x * vec.x + vec.y * vec.y + vec.z * vec.z);
|
||||||
|
}
|
||||||
|
|
||||||
|
Device Vector
|
||||||
|
normalized(const Vector vec)
|
||||||
|
{
|
||||||
|
const Scalar inv_len = reciprocal_len(vec);
|
||||||
|
return inv_len * vec;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user