diff --git a/acc/accrevision/stencil_kernel.ac b/acc/accrevision/stencil_kernel.ac index 21dc17f..abe1053 100644 --- a/acc/accrevision/stencil_kernel.ac +++ b/acc/accrevision/stencil_kernel.ac @@ -1,3 +1,5 @@ +#include + #define LDENSITY (1) #define LHYDRO (1) #define LMAGNETIC (1) @@ -8,6 +10,8 @@ #define LSINK (0) #define AC_THERMAL_CONDUCTIVITY (AcReal(0.001)) // TODO: make an actual config parameter +#define H_CONST (0) // TODO: make an actual config parameter +#define C_CONST (0) // TODO: make an actual config parameter // Int params uniform int AC_max_steps; @@ -20,9 +24,6 @@ uniform int AC_start_step; uniform Scalar AC_dt; uniform Scalar AC_max_time; // Spacing -uniform Scalar AC_dsx; -uniform Scalar AC_dsy; -uniform Scalar AC_dsz; uniform Scalar AC_dsmin; // physical grid uniform Scalar AC_xlen; @@ -96,9 +97,6 @@ uniform Scalar AC_GM_star; uniform Scalar AC_unit_mass; uniform Scalar AC_sq2GM_star; uniform Scalar AC_cs2_sound; -uniform Scalar AC_inv_dsx; -uniform Scalar AC_inv_dsy; -uniform Scalar AC_inv_dsz; /* * ============================================================================= @@ -135,19 +133,6 @@ uniform ScalarField VTXBUF_LNRHO; uniform ScalarField VTXBUF_ACCRETION; #endif - -Preprocessed Scalar -value(in ScalarField vertex) -{ - return vertex[vertexIdx]; -} - -Preprocessed Vector -gradient(in ScalarField vertex) -{ - return (Vector){derx(vertexIdx, vertex), dery(vertexIdx, vertex), derz(vertexIdx, vertex)}; -} - #if LUPWD Preprocessed Scalar @@ -197,24 +182,6 @@ der6z_upwd(in ScalarField vertex) #endif -Preprocessed Matrix -hessian(in ScalarField vertex) -{ - Matrix hessian; - - hessian.row[0] = (Vector){derxx(vertexIdx, vertex), derxy(vertexIdx, vertex), - derxz(vertexIdx, vertex)}; - hessian.row[1] = (Vector){hessian.row[0].y, deryy(vertexIdx, vertex), deryz(vertexIdx, vertex)}; - hessian.row[2] = (Vector){hessian.row[0].z, hessian.row[1].z, derzz(vertexIdx, vertex)}; - - return hessian; -} - -Device Vector -value(in VectorField uu) -{ - return (Vector){value(uu.x), value(uu.y), value(uu.z)}; -} #if LUPWD Device Scalar @@ -492,9 +459,8 @@ induction(in VectorField uu, in VectorField aa) Device Scalar lnT(in ScalarField ss, in ScalarField lnrho) { - const Scalar lnT = AC_lnT0 + AC_gamma * value(ss) / AC_cp_sound + + return AC_lnT0 + AC_gamma * value(ss) / AC_cp_sound + (AC_gamma - Scalar(1.0)) * (value(lnrho) - AC_lnrho0); - return lnT; } // Nabla dot (K nabla T) / (rho T) diff --git a/acc/src/acc.y b/acc/src/acc.y index bbf9139..94abf7c 100644 --- a/acc/src/acc.y +++ b/acc/src/acc.y @@ -80,8 +80,8 @@ else_selection_statement: compound_statement elif_selection_statement: ELIF expression else_selection_statement { $$ = astnode_create(NODE_UNKNOWN, $2, $3); $$->prefix = ELIF; } ; -iteration_statement: WHILE expression compound_statement { $$ = astnode_create(NODE_UNKNOWN, $2, $3); $$->prefix = WHILE; } - | FOR for_expression compound_statement { $$ = astnode_create(NODE_UNKNOWN, $2, $3); $$->prefix = FOR; } +iteration_statement: WHILE expression compound_statement { $$ = astnode_create(NODE_ITERATION_STATEMENT, $2, $3); $$->prefix = WHILE; } + | FOR for_expression compound_statement { $$ = astnode_create(NODE_ITERATION_STATEMENT, $2, $3); $$->prefix = FOR; } ; for_expression: '(' for_init_param for_other_params ')' { $$ = astnode_create(NODE_UNKNOWN, $2, $3); $$->prefix = '('; $$->postfix = ')'; } @@ -124,8 +124,8 @@ declaration: type_declaration identifier | type_declaration array_declaration { $$ = astnode_create(NODE_DECLARATION, $1, $2); } ; -array_declaration: identifier '[' ']' { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); $$->infix = '['; $$->postfix = ']'; } - | identifier '[' expression ']' { $$ = astnode_create(NODE_UNKNOWN, $1, $3); $$->infix = '['; $$->postfix = ']'; } +array_declaration: identifier '[' ']' { $$ = astnode_create(NODE_ARRAY_DECLARATION, $1, NULL); $$->infix = '['; $$->postfix = ']'; } + | identifier '[' expression ']' { $$ = astnode_create(NODE_ARRAY_DECLARATION, $1, $3); $$->infix = '['; $$->postfix = ']'; } ; type_declaration: type_specifier { $$ = astnode_create(NODE_TYPE_DECLARATION, $1, NULL); } diff --git a/acc/src/ast.h b/acc/src/ast.h index 9311a24..ebd53c6 100644 --- a/acc/src/ast.h +++ b/acc/src/ast.h @@ -20,8 +20,10 @@ FUNC(NODE_UNKNOWN), \ FUNC(NODE_DEFINITION), \ FUNC(NODE_GLOBAL_DEFINITION), \ + FUNC(NODE_ITERATION_STATEMENT), \ FUNC(NODE_DECLARATION), \ FUNC(NODE_DECLARATION_LIST), \ + FUNC(NODE_ARRAY_DECLARATION), \ FUNC(NODE_TYPE_DECLARATION), \ FUNC(NODE_TYPE_QUALIFIER), \ FUNC(NODE_TYPE_SPECIFIER), \ diff --git a/acc/src/code_generator.c b/acc/src/code_generator.c index ba1484e..2b8d12a 100644 --- a/acc/src/code_generator.c +++ b/acc/src/code_generator.c @@ -153,6 +153,13 @@ add_symbol(const SymbolType type, const int tqualifier, const int tspecifier, co { assert(num_symbols[current_nest] < SYMBOL_TABLE_SIZE); + if (symboltable_lookup(id) && type != SYMBOLTYPE_FUNCTION) { + fprintf(stderr, + "Syntax error. Symbol '%s' is ambiguous, declared multiple times in the same scope" + " (shadowing).\n", + id); + } + symbol_table[num_symbols[current_nest]].type = type; symbol_table[num_symbols[current_nest]].type_qualifier = tqualifier; symbol_table[num_symbols[current_nest]].type_specifier = tspecifier; @@ -222,73 +229,107 @@ print_symbol_table(void) * Traversal state * ============================================================================= */ +static bool inside_declaration = false; /* * ============================================================================= * AST traversal * ============================================================================= */ +/* +static bool +introspect(const ASTNode* node, const NodeType type) +{ + assert(node); + + ASTNode* parent = node->parent; + while (parent) { + if (parent->type == type) + return true; + else + parent = parent->parent; + } + return false; +} +*/ static void traverse(const ASTNode* node) { // Prefix translation - if (translate(node->prefix)) + if (!inside_declaration && translate(node->prefix)) fprintf(CUDAHEADER, "%s", translate(node->prefix)); // Prefix logic if (node->type == NODE_COMPOUND_STATEMENT) { + // if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION || + // node->type == NODE_ITERATION_STATEMENT) { assert(current_nest < MAX_NESTS); ++current_nest; num_symbols[current_nest] = num_symbols[current_nest - 1]; } + if (node->type == NODE_DECLARATION) + inside_declaration = true; + + if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION) { + // Boilerplates + const ASTNode* typedecl = node->parent->lhs->lhs; + const ASTNode* typequal = typedecl->lhs; + printf("typedecl %d\n", typedecl->type); + assert(typedecl->type == NODE_TYPE_DECLARATION); + if (typequal->type == NODE_TYPE_QUALIFIER) { + if (typequal->token == KERNEL) { + fprintf(CUDAHEADER, "GEN_KERNEL_PARAM_BOILERPLATE"); + if (node->lhs != NULL) { + fprintf( + stderr, + "Syntax error: function parameters for Kernel functions not allowed!\n"); + } + } + else if (typequal->token == PREPROCESSED) { + fprintf(CUDAHEADER, "GEN_PREPROCESSED_PARAM_BOILERPLATE, "); + } + } + } + + if (node->type == NODE_COMPOUND_STATEMENT) { + if (node->parent->type == NODE_FUNCTION_DEFINITION) { + const Symbol* symbol = symboltable_lookup(node->parent->lhs->lhs->rhs->buffer); + if (symbol && symbol->type_qualifier == KERNEL) { + fprintf(CUDAHEADER, "GEN_KERNEL_BUILTIN_VARIABLES_BOILERPLATE();"); + for (int i = 0; i < num_symbols[current_nest]; ++i) { + if (symbol_table[i].type_qualifier == IN) { + fprintf(CUDAHEADER, "const %sData %s = READ(handle_%s);\n", + translate(symbol_table[i].type_specifier), + symbol_table[i].identifier, symbol_table[i].identifier); + } + else if (symbol_table[i].type_qualifier == OUT) { + fprintf(CUDAHEADER, "%s %s = READ_OUT(handle_%s);", + translate(symbol_table[i].type_specifier), + symbol_table[i].identifier, symbol_table[i].identifier); + } + } + } + } + } // Traverse LHS if (node->lhs) traverse(node->lhs); // Infix translation - if (translate(node->infix)) + if (!inside_declaration && translate(node->infix)) fprintf(CUDAHEADER, "%s", translate(node->infix)); // Infix logic - // TODO + // If the node is a subscript expression and the expression list inside it is not empty + if (node->type == NODE_MULTIDIM_SUBSCRIPT_EXPRESSION && node->rhs) + fprintf(CUDAHEADER, "IDX("); // Traverse RHS if (node->rhs) traverse(node->rhs); - // Postfix translation - if (translate(node->postfix)) - fprintf(CUDAHEADER, "%s", translate(node->postfix)); - - // Translate existing symbols - const Symbol* symbol = symboltable_lookup(node->buffer); - if (symbol) { - // Uniforms - if (symbol->type_qualifier == UNIFORM) { - fprintf(CUDAHEADER, "DCONST(%s) ", symbol->identifier); - } - else { - // print_symbol2(symbol); - } - } - else { - /* - // Translate literals - if (translate(node->token)) - printf("%s ", translate(node->token)); - if (node->buffer) { - if (node->type == NODE_REAL_NUMBER) { - printf("%s(%s) ", translate(SCALAR), node->buffer); // Cast to correct precision - } - else { - printf("%s ", node->buffer); - } - } - */ - } - // Add new symbols to the symbol table if (node->type == NODE_DECLARATION) { int stype; @@ -311,6 +352,7 @@ traverse(const ASTNode* node) const char* identifier = node->rhs->type == NODE_IDENTIFIER ? node->rhs->buffer : node->rhs->lhs->buffer; add_symbol(stype, tqualifier, tspecifier, identifier); + printf("Added %s\n", identifier); // Translate the new symbol if (tqualifier == UNIFORM) { @@ -328,22 +370,41 @@ traverse(const ASTNode* node) fprintf(CUDAHEADER, "%s %s\npreprocessed_%s", // translate(tqualifier), translate(tspecifier), identifier); } + else if (stype == SYMBOLTYPE_FUNCTION) { + // Stencil assembly stage device function + fprintf(CUDAHEADER, "%s %s\n%s", // + translate(DEVICE), translate(tspecifier), identifier); + } else if (stype == SYMBOLTYPE_FUNCTION_PARAMETER) { tmp = tmp->parent; assert(tmp->type = NODE_FUNCTION_DECLARATION); + const Symbol* parent_function = symboltable_lookup(tmp->lhs->rhs->buffer); assert(parent_function); if (tqualifier == IN || tqualifier == OUT) { - if (parent_function->type_qualifier == 0 || - parent_function->type_qualifier == PREPROCESSED) { - fprintf(CUDAHEADER, "const __restrict__ %s* %s", // - translate(tspecifier), identifier); - } - else { + if (tmp->lhs->lhs->lhs->token == DEVICE) { fprintf(CUDAHEADER, "const %sData& %s", // translate(tspecifier), identifier); } + else { + fprintf(CUDAHEADER, "const __restrict__ %s* %s", // + translate(tspecifier), identifier); + } + + /* + if (parent_function->type_qualifier == 0 || + parent_function->type_qualifier == PREPROCESSED) { + fprintf(CUDAHEADER, "const __restrict__ %s* %s", // + translate(tspecifier), identifier); + } + else { + fprintf(CUDAHEADER, "const %sData& %s", // + translate(tspecifier), identifier); + }*/ + } + else { + print_symbol2(&symbol_table[num_symbols[current_nest] - 1]); } } else if (tqualifier == IN || tqualifier == OUT) { // Global in/out declarator @@ -356,22 +417,183 @@ traverse(const ASTNode* node) // Do a regular translation print_symbol2(&symbol_table[num_symbols[current_nest] - 1]); } + + if (node->rhs->type == NODE_ARRAY_DECLARATION) { + // Traverse the expression once again, this time with + // "inside_declaration" flag off + inside_declaration = false; + fprintf(CUDAHEADER, "%s ", translate(node->rhs->infix)); + if (node->rhs->rhs) + traverse(node->rhs->rhs); + fprintf(CUDAHEADER, "%s ", translate(node->rhs->postfix)); + } + } + else { + // Translate existing symbols + const Symbol* symbol = symboltable_lookup(node->buffer); + + if (symbol) { + // Uniforms + if (symbol->type_qualifier == UNIFORM) { + fprintf(CUDAHEADER, "DCONST(%s) ", symbol->identifier); + } + else if (node->parent->type != NODE_DECLARATION) { + // Regular translation + if (translate(node->token)) + fprintf(CUDAHEADER, "%s ", translate(node->token)); + if (node->buffer) + fprintf(CUDAHEADER, "%s ", node->buffer); + } + } + else if (!inside_declaration) { + // Literal translation + if (translate(node->token)) + fprintf(CUDAHEADER, "%s ", translate(node->token)); + if (node->buffer) { + if (node->type == NODE_REAL_NUMBER) { + fprintf(CUDAHEADER, "%s(%s) ", translate(SCALAR), + node->buffer); // Cast to correct precision + } + else { + fprintf(CUDAHEADER, "%s ", node->buffer); + } + } + } } // Postfix logic + // If the node is a subscript expression and the expression list inside it is not empty + if (node->type == NODE_MULTIDIM_SUBSCRIPT_EXPRESSION && node->rhs) + fprintf(CUDAHEADER, ")"); // Closing bracket of IDX() + if (node->type == NODE_COMPOUND_STATEMENT) { + // if (node->type == NODE_FUNCTION_DEFINITION || node->type == NODE_ITERATION_STATEMENT) { assert(current_nest > 0); --current_nest; + // Drop function parameters + while (symbol_table[num_symbols[current_nest] - 1].type == SYMBOLTYPE_FUNCTION_PARAMETER) + --num_symbols[current_nest]; + + // Drop temporaries declared with iteration statements + // TODO + printf("Dropped rest of the symbol table, from %lu to %lu\n", num_symbols[current_nest + 1], num_symbols[current_nest]); + + // Kernel writeback boilerplate + if (node->parent->type == NODE_FUNCTION_DEFINITION) { + const Symbol* symbol = symboltable_lookup(node->parent->lhs->lhs->rhs->buffer); + if (symbol && symbol->type_qualifier == KERNEL) { + for (int i = 0; i < num_symbols[current_nest]; ++i) { + if (symbol_table[i].type_qualifier == OUT) { + fprintf(CUDAHEADER, "WRITE_OUT(handle_%s, %s);\n", + symbol_table[i].identifier, symbol_table[i].identifier); + } + } + } + } } + if (node->type == NODE_DECLARATION) + inside_declaration = false; + + // Postfix translation + if (!inside_declaration && translate(node->postfix)) + fprintf(CUDAHEADER, "%s", translate(node->postfix)); +} + +static void +gen_preprocessed_forward_declarations(void) +{ } static void generate_preprocessed_structures(void) { - // TODO + // Data structure + fprintf(CUDAHEADER, "\n"); + + // Read data to the data struct + fprintf(CUDAHEADER, "static __device__ __forceinline__ AcRealData\ + read_data(const int3& vertexIdx,\ + const int3& globalVertexIdx,\ + AcReal* __restrict__ buf[], const int handle)\ + {\n\ + %sData data;\n", + translate(SCALAR)); + + for (size_t i = 0; i < num_symbols[current_nest]; ++i) { + if (symbol_table[i].type_qualifier == PREPROCESSED) + fprintf(CUDAHEADER, + "data.%s = preprocessed_%s(vertexIdx, globalVertexIdx, buf[handle]);\n", + symbol_table[i].identifier, symbol_table[i].identifier); + } + fprintf(CUDAHEADER, "return data;\n"); + fprintf(CUDAHEADER, "}\n"); + + // Functions for accessing the data struct members + for (size_t i = 0; i < num_symbols[current_nest]; ++i) { + if (symbol_table[i].type_qualifier == PREPROCESSED) + fprintf(CUDAHEADER, "static __device__ __forceinline__ %s\ + %s(const AcRealData& data)\ + {\n\ + return data.%s;\ + }\n", + translate(symbol_table[i].type_specifier), symbol_table[i].identifier, + symbol_table[i].identifier); + } + + // Syntactic sugar: Vector data struct + fprintf(CUDAHEADER, "static __device__ __forceinline__ AcReal3Data\ + read_data(const int3& vertexIdx,\ + const int3& globalVertexIdx,\ + AcReal* __restrict__ buf[], const int3& handle)\ + {\ + AcReal3Data data;\ + \ + data.x = read_data(vertexIdx, globalVertexIdx, buf, handle.x);\ + data.y = read_data(vertexIdx, globalVertexIdx, buf, handle.y);\ + data.z = read_data(vertexIdx, globalVertexIdx, buf, handle.z);\ + \ + return data;\ + }\ + "); + + const size_t max_buflen = 65536; + char buffer[max_buflen]; + rewind(CUDAHEADER); + const size_t buflen = fread(buffer, sizeof(char), max_buflen, CUDAHEADER); + fclose(CUDAHEADER); + CUDAHEADER = fopen("user_kernels.h", "w+"); + + fprintf(CUDAHEADER, "#pragma once\n"); + fprintf(CUDAHEADER, "typedef struct {\n"); + for (size_t i = 0; i < num_symbols[current_nest]; ++i) { + if (symbol_table[i].type_qualifier == PREPROCESSED) + fprintf(CUDAHEADER, "%s %s;\n", translate(symbol_table[i].type_specifier), + symbol_table[i].identifier); + } + fprintf(CUDAHEADER, "} %sData;\n", translate(SCALAR)); + fprintf(CUDAHEADER, "typedef struct {\ + AcRealData x;\ + AcRealData y;\ + AcRealData z;\ + } AcReal3Data;\n"); + fprintf(CUDAHEADER, "static __device__ AcRealData\ + read_data(const int3& vertexIdx,\ + const int3& globalVertexIdx,\ + AcReal* __restrict__ buf[], const int handle);\n"); + fprintf(CUDAHEADER, "static __device__ AcReal3Data\ + read_data(const int3& vertexIdx,\ + const int3& globalVertexIdx,\ + AcReal* __restrict__ buf[], const int3& handle);\n"); + for (size_t i = 0; i < num_symbols[current_nest]; ++i) { + if (symbol_table[i].type_qualifier == PREPROCESSED) + fprintf(CUDAHEADER, "static __device__ %s %s(const AcRealData& data);\n", + translate(symbol_table[i].type_specifier), symbol_table[i].identifier); + } + + fwrite(buffer, sizeof(char), buflen, CUDAHEADER); } static void @@ -440,7 +662,7 @@ static void generate_library_hooks(void) { for (int i = 0; i < num_symbols[current_nest]; ++i) { - if (symbol_table[i].type_qualifier == KERNEL && symbol_table[i].type_qualifier == UNIFORM) { + if (symbol_table[i].type_qualifier == KERNEL) { fprintf(CUDAHEADER, "GEN_DEVICE_FUNC_HOOK(%s)\n", symbol_table[i].identifier); } } @@ -464,6 +686,7 @@ main(int argc, char** argv) traverse(root); generate_header(); + generate_preprocessed_structures(); generate_library_hooks(); print_symbol_table(); diff --git a/acc/stdlib/stdderiv.h b/acc/stdlib/stdderiv.h index 8ddca8c..1fbefd9 100644 --- a/acc/stdlib/stdderiv.h +++ b/acc/stdlib/stdderiv.h @@ -2,6 +2,13 @@ #define STENCIL_ORDER (6) #endif +uniform Scalar AC_dsx; +uniform Scalar AC_dsy; +uniform Scalar AC_dsz; +uniform Scalar AC_inv_dsx; +uniform Scalar AC_inv_dsy; +uniform Scalar AC_inv_dsz; + Scalar first_derivative(Scalar pencil[], Scalar inv_ds) { @@ -212,6 +219,12 @@ value(in ScalarField vertex) return vertex[vertexIdx]; } +Device Vector +value(in VectorField uu) +{ + return (Vector){value(uu.x), value(uu.y), value(uu.z)}; +} + Preprocessed Vector gradient(in ScalarField vertex) { @@ -221,12 +234,106 @@ gradient(in ScalarField vertex) Preprocessed Matrix hessian(in ScalarField vertex) { - Matrix hessian; + Matrix mat; - hessian.row[0] = (Vector){derxx(vertexIdx, vertex), derxy(vertexIdx, vertex), - derxz(vertexIdx, vertex)}; - hessian.row[1] = (Vector){hessian.row[0].y, deryy(vertexIdx, vertex), deryz(vertexIdx, vertex)}; - hessian.row[2] = (Vector){hessian.row[0].z, hessian.row[1].z, derzz(vertexIdx, vertex)}; + mat.row[0] = (Vector){derxx(vertexIdx, vertex), derxy(vertexIdx, vertex), + derxz(vertexIdx, vertex)}; + mat.row[1] = (Vector){mat.row[0].y, deryy(vertexIdx, vertex), deryz(vertexIdx, vertex)}; + mat.row[2] = (Vector){mat.row[0].z, mat.row[1].z, derzz(vertexIdx, vertex)}; - return hessian; + return mat; +} + +/////////////////// NEW + +Device Scalar +laplace(in ScalarField data) +{ + return hessian(data).row[0].x + hessian(data).row[1].y + hessian(data).row[2].z; +} + +Device Scalar +divergence(in VectorField vec) +{ + return gradient(vec.x).x + gradient(vec.y).y + gradient(vec.z).z; +} + +Device Vector +laplace_vec(in VectorField vec) +{ + return (Vector){laplace(vec.x), laplace(vec.y), laplace(vec.z)}; +} + +Device Vector +curl(in VectorField vec) +{ + return (Vector){gradient(vec.z).y - gradient(vec.y).z, gradient(vec.x).z - gradient(vec.z).x, + gradient(vec.y).x - gradient(vec.x).y}; +} + +Device Vector +gradient_of_divergence(in VectorField vec) +{ + return (Vector){hessian(vec.x).row[0].x + hessian(vec.y).row[0].y + hessian(vec.z).row[0].z, + hessian(vec.x).row[1].x + hessian(vec.y).row[1].y + hessian(vec.z).row[1].z, + hessian(vec.x).row[2].x + hessian(vec.y).row[2].y + hessian(vec.z).row[2].z}; +} + +// Takes uu gradients and returns S +Device Matrix +stress_tensor(in VectorField vec) +{ + Matrix S; + + S.row[0].x = Scalar(2.0 / 3.0) * gradient(vec.x).x - + Scalar(1.0 / 3.0) * (gradient(vec.y).y + gradient(vec.z).z); + S.row[0].y = Scalar(1.0 / 2.0) * (gradient(vec.x).y + gradient(vec.y).x); + S.row[0].z = Scalar(1.0 / 2.0) * (gradient(vec.x).z + gradient(vec.z).x); + + S.row[1].y = Scalar(2.0 / 3.0) * gradient(vec.y).y - + Scalar(1.0 / 3.0) * (gradient(vec.x).x + gradient(vec.z).z); + + S.row[1].z = Scalar(1.0 / 2.0) * (gradient(vec.y).z + gradient(vec.z).y); + + S.row[2].z = Scalar(2.0 / 3.0) * gradient(vec.z).z - + Scalar(1.0 / 3.0) * (gradient(vec.x).x + gradient(vec.y).y); + + S.row[1].x = S.row[0].y; + S.row[2].x = S.row[0].z; + S.row[2].y = S.row[1].z; + + return S; +} + +Device Scalar +contract(const Matrix mat) +{ + Scalar res = 0; + + for (int i = 0; i < 3; ++i) { + res = res + dot(mat.row[i], mat.row[i]); + } + + return res; +} + +///////////////////// NEW NEW BLAS + +Device Scalar +length(const Vector vec) +{ + return sqrt(vec.x * vec.x + vec.y * vec.y + vec.z * vec.z); +} + +Device Scalar +reciprocal_len(const Vector vec) +{ + return rsqrt(vec.x * vec.x + vec.y * vec.y + vec.z * vec.z); +} + +Device Vector +normalized(const Vector vec) +{ + const Scalar inv_len = reciprocal_len(vec); + return inv_len * vec; }