Major improvement: uniforms can now be set to default values. The syntax is the same as for setting any other values, f.ex. 'uniform Scalar a = 1; uniform Scalar b = 0.5 * a;'. Undefined uniforms are still allowed, but in this case the user should load a proper value into it during runtime. Default uniform values can be overwritten by calling any of the uniform loader funcions (like acDeviceLoadScalarUniform). Improved also error checking. Now there are explicit warnings if the user tries to load an invalid value into a device constant.
This commit is contained in:
@@ -224,6 +224,7 @@ print_symbol_table(void)
|
||||
*/
|
||||
static bool inside_declaration = false;
|
||||
static bool inside_kernel = false;
|
||||
static bool inside_function = false;
|
||||
|
||||
/*
|
||||
* =============================================================================
|
||||
@@ -246,6 +247,8 @@ traverse(const ASTNode* node)
|
||||
}
|
||||
if (node->type == NODE_DECLARATION)
|
||||
inside_declaration = true;
|
||||
if (node->type == NODE_FUNCTION_DEFINITION)
|
||||
inside_function = true;
|
||||
if (node->token == KERNEL)
|
||||
inside_kernel = true;
|
||||
|
||||
@@ -335,7 +338,12 @@ traverse(const ASTNode* node)
|
||||
|
||||
// Translate the new symbol
|
||||
if (tqualifier == UNIFORM) {
|
||||
// Do nothing
|
||||
if (tspecifier == SCALAR || tspecifier == VECTOR || tspecifier == INT ||
|
||||
tspecifier == INT3) {
|
||||
fprintf(CUDAHEADER, "static %s %s_DEFAULT_VALUE", translate(tspecifier),
|
||||
identifier);
|
||||
}
|
||||
// else do nothing
|
||||
}
|
||||
else if (tqualifier == KERNEL) {
|
||||
fprintf(CUDAHEADER, "%s %s\n%s", //
|
||||
@@ -404,7 +412,13 @@ traverse(const ASTNode* node)
|
||||
if (symbol->type_qualifier == UNIFORM) {
|
||||
if (inside_kernel && symbol->type_specifier == SCALARARRAY)
|
||||
fprintf(CUDAHEADER, "buffer.profiles[%s] ", symbol->identifier);
|
||||
else
|
||||
else if (!inside_function &&
|
||||
((symbol->type_specifier == SCALAR) ||
|
||||
(symbol->type_specifier == VECTOR) || (symbol->type_specifier == INT) ||
|
||||
(symbol->type_specifier == INT3))) // Global scope and an uniform which
|
||||
// can be set to a default value
|
||||
fprintf(CUDAHEADER, "%s_DEFAULT_VALUE ", symbol->identifier);
|
||||
else // Use device constants inside device functions
|
||||
fprintf(CUDAHEADER, "DCONST(%s) ", symbol->identifier);
|
||||
}
|
||||
else if (node->parent->type != NODE_DECLARATION) {
|
||||
@@ -462,12 +476,27 @@ traverse(const ASTNode* node)
|
||||
}
|
||||
if (node->type == NODE_DECLARATION)
|
||||
inside_declaration = false;
|
||||
if (node->type == NODE_FUNCTION_DEFINITION)
|
||||
inside_function = false;
|
||||
|
||||
// Postfix translation
|
||||
if (!inside_declaration && translate(node->postfix))
|
||||
fprintf(CUDAHEADER, "%s", translate(node->postfix));
|
||||
}
|
||||
|
||||
#define ARRAY_SIZE(x) (sizeof(x) / sizeof(x[0]))
|
||||
|
||||
static const char* builtin_int_params[] = {
|
||||
"AC_nx", "AC_ny", "AC_nz", "AC_mx", "AC_my",
|
||||
"AC_mz", "AC_nx_min", "AC_ny_min", "AC_nz_min", "AC_nx_max",
|
||||
"AC_ny_max", "AC_nz_max", "AC_mxy", "AC_nxy", "AC_nxyz",
|
||||
};
|
||||
|
||||
static const char* builtin_int3_params[] = {
|
||||
"AC_global_grid_n",
|
||||
"AC_multigpu_offset",
|
||||
};
|
||||
|
||||
static void
|
||||
generate_preprocessed_structures(void)
|
||||
{
|
||||
@@ -528,6 +557,16 @@ generate_preprocessed_structures(void)
|
||||
CUDAHEADER = fopen(cudaheader_filename, "w+");
|
||||
|
||||
fprintf(CUDAHEADER, "#pragma once\n");
|
||||
|
||||
// Add built-in params (the best way would be to inject these to user src with AC syntax)
|
||||
for (size_t i = 0; i < ARRAY_SIZE(builtin_int_params); ++i) {
|
||||
fprintf(CUDAHEADER, "static const int %s_DEFAULT_VALUE = 0;\n", builtin_int_params[i]);
|
||||
}
|
||||
for (size_t i = 0; i < ARRAY_SIZE(builtin_int3_params); ++i) {
|
||||
fprintf(CUDAHEADER, "static const int3 %s_DEFAULT_VALUE = make_int3(0, 0, 0);\n",
|
||||
builtin_int3_params[i]);
|
||||
}
|
||||
|
||||
fprintf(CUDAHEADER, "typedef struct {\n");
|
||||
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
||||
@@ -566,7 +605,7 @@ generate_header(void)
|
||||
fprintf(DSLHEADER, "#define AC_FOR_USER_INT_PARAM_TYPES(FUNC)");
|
||||
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||
if (symbol_table[i].type_specifier == INT && symbol_table[i].type_qualifier == UNIFORM) {
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
|
||||
}
|
||||
}
|
||||
fprintf(DSLHEADER, "\n\n");
|
||||
@@ -575,7 +614,7 @@ generate_header(void)
|
||||
fprintf(DSLHEADER, "#define AC_FOR_USER_INT3_PARAM_TYPES(FUNC)");
|
||||
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||
if (symbol_table[i].type_specifier == INT3 && symbol_table[i].type_qualifier == UNIFORM) {
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
|
||||
}
|
||||
}
|
||||
fprintf(DSLHEADER, "\n\n");
|
||||
@@ -584,7 +623,7 @@ generate_header(void)
|
||||
fprintf(DSLHEADER, "#define AC_FOR_USER_REAL_PARAM_TYPES(FUNC)");
|
||||
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||
if (symbol_table[i].type_specifier == SCALAR && symbol_table[i].type_qualifier == UNIFORM) {
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
|
||||
}
|
||||
}
|
||||
fprintf(DSLHEADER, "\n\n");
|
||||
@@ -593,7 +632,7 @@ generate_header(void)
|
||||
fprintf(DSLHEADER, "#define AC_FOR_USER_REAL3_PARAM_TYPES(FUNC)");
|
||||
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||
if (symbol_table[i].type_specifier == VECTOR && symbol_table[i].type_qualifier == UNIFORM) {
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
|
||||
}
|
||||
}
|
||||
fprintf(DSLHEADER, "\n\n");
|
||||
@@ -603,7 +642,7 @@ generate_header(void)
|
||||
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||
if (symbol_table[i].type_specifier == SCALARFIELD &&
|
||||
symbol_table[i].type_qualifier == UNIFORM) {
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
|
||||
}
|
||||
}
|
||||
fprintf(DSLHEADER, "\n\n");
|
||||
@@ -613,7 +652,7 @@ generate_header(void)
|
||||
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||
if (symbol_table[i].type_specifier == SCALARARRAY &&
|
||||
symbol_table[i].type_qualifier == UNIFORM) {
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||
fprintf(DSLHEADER, "\\\nFUNC(%s)", symbol_table[i].identifier);
|
||||
}
|
||||
}
|
||||
fprintf(DSLHEADER, "\n\n");
|
||||
@@ -645,25 +684,13 @@ main(void)
|
||||
assert(DSLHEADER);
|
||||
assert(CUDAHEADER);
|
||||
|
||||
// Add built-in params
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, "AC_nx");
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, "AC_ny");
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, "AC_nz");
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, "AC_mx");
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, "AC_my");
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, "AC_mz");
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, "AC_nx_min");
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, "AC_ny_min");
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, "AC_nz_min");
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, "AC_nx_max");
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, "AC_ny_max");
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, "AC_nz_max");
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, "AC_mxy");
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, "AC_nxy");
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, "AC_nxyz");
|
||||
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT3, "AC_global_grid_n");
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT3, "AC_multigpu_offset");
|
||||
// Add built-in param symbols
|
||||
for (size_t i = 0; i < ARRAY_SIZE(builtin_int_params); ++i) {
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT, builtin_int_params[i]);
|
||||
}
|
||||
for (size_t i = 0; i < ARRAY_SIZE(builtin_int3_params); ++i) {
|
||||
add_symbol(SYMBOLTYPE_OTHER, UNIFORM, INT3, builtin_int3_params[i]);
|
||||
}
|
||||
|
||||
// Generate
|
||||
traverse(root);
|
||||
|
Reference in New Issue
Block a user