Some optimizations for DSL compilation. Also a new feature: Inplace addition and subtraction += and -= are now allowed

This commit is contained in:
jpekkila
2019-10-07 16:33:24 +03:00
parent f7c079be2a
commit 0e1d1b9fb4
5 changed files with 19 additions and 11 deletions

View File

@@ -51,6 +51,8 @@ L [a-zA-Z_]
"++" { return INPLACE_INC; } "++" { return INPLACE_INC; }
"--" { return INPLACE_DEC; } "--" { return INPLACE_DEC; }
"+=" { return INPLACE_ADD; }
"-=" { return INPLACE_SUB; }
[-+*/;=\[\]{}(),\.<>] { return yytext[0]; } /* Characters */ [-+*/;=\[\]{}(),\.<>] { return yytext[0]; } /* Characters */

View File

@@ -21,7 +21,7 @@ int yyget_lineno();
%token IF ELSE FOR WHILE ELIF %token IF ELSE FOR WHILE ELIF
%token LEQU LAND LOR LLEQU %token LEQU LAND LOR LLEQU
%token KERNEL DEVICE PREPROCESSED %token KERNEL DEVICE PREPROCESSED
%token INPLACE_INC INPLACE_DEC %token INPLACE_INC INPLACE_DEC INPLACE_ADD INPLACE_SUB
%% %%
@@ -188,6 +188,8 @@ binary_operator: '+'
| LAND { $$ = astnode_create(NODE_UNKNOWN, NULL, NULL); astnode_set_buffer(yytext, $$); } | LAND { $$ = astnode_create(NODE_UNKNOWN, NULL, NULL); astnode_set_buffer(yytext, $$); }
| LOR { $$ = astnode_create(NODE_UNKNOWN, NULL, NULL); astnode_set_buffer(yytext, $$); } | LOR { $$ = astnode_create(NODE_UNKNOWN, NULL, NULL); astnode_set_buffer(yytext, $$); }
| LLEQU { $$ = astnode_create(NODE_UNKNOWN, NULL, NULL); astnode_set_buffer(yytext, $$); } | LLEQU { $$ = astnode_create(NODE_UNKNOWN, NULL, NULL); astnode_set_buffer(yytext, $$); }
| INPLACE_ADD { $$ = astnode_create(NODE_UNKNOWN, NULL, NULL); astnode_set_buffer(yytext, $$); }
| INPLACE_SUB { $$ = astnode_create(NODE_UNKNOWN, NULL, NULL); astnode_set_buffer(yytext, $$); }
; ;
unary_operator: '-' /* C-style casts are disallowed, would otherwise be defined here */ { $$ = astnode_create(NODE_UNKNOWN, NULL, NULL); $$->infix = yytext[0]; } unary_operator: '-' /* C-style casts are disallowed, would otherwise be defined here */ { $$ = astnode_create(NODE_UNKNOWN, NULL, NULL); $$->infix = yytext[0]; }

View File

@@ -66,7 +66,7 @@ static const char* translation_table[TRANSLATION_TABLE_SIZE] = {
[COMPLEX] = "acComplex", [COMPLEX] = "acComplex",
// Type qualifiers // Type qualifiers
[KERNEL] = "template <int step_number> static __global__", [KERNEL] = "template <int step_number> static __global__",
[DEVICE] = "static __device__", [DEVICE] = "static __device__ __forceinline__",
[PREPROCESSED] = "static __device__ __forceinline__", [PREPROCESSED] = "static __device__ __forceinline__",
[CONSTANT] = "const", [CONSTANT] = "const",
[IN] = "in", [IN] = "in",
@@ -138,6 +138,9 @@ static size_t current_nest = 0;
static Symbol* static Symbol*
symboltable_lookup(const char* identifier) symboltable_lookup(const char* identifier)
{ {
// TODO assert tha symbol not function! cannot be since we allow overloads->conflicts if not
// explicit
if (!identifier) if (!identifier)
return NULL; return NULL;
@@ -379,6 +382,7 @@ traverse(const ASTNode* node)
tmp = tmp->parent; tmp = tmp->parent;
assert(tmp->type = NODE_FUNCTION_DECLARATION); assert(tmp->type = NODE_FUNCTION_DECLARATION);
// TODO FIX not to use symboltable_lookup
const Symbol* parent_function = symboltable_lookup(tmp->lhs->rhs->buffer); const Symbol* parent_function = symboltable_lookup(tmp->lhs->rhs->buffer);
assert(parent_function); assert(parent_function);

View File

@@ -27,7 +27,7 @@ first_derivative(Scalar pencil[], Scalar inv_ds)
Scalar res = 0; Scalar res = 0;
for (int i = 1; i <= MID; ++i) { for (int i = 1; i <= MID; ++i) {
res = res + coefficients[i] * (pencil[MID + i] - pencil[MID - i]); res += coefficients[i] * (pencil[MID + i] - pencil[MID - i]);
} }
return res * inv_ds; return res * inv_ds;
@@ -50,7 +50,7 @@ second_derivative(Scalar pencil[], Scalar inv_ds)
Scalar res = coefficients[0] * pencil[MID]; Scalar res = coefficients[0] * pencil[MID];
for (int i = 1; i <= MID; ++i) { for (int i = 1; i <= MID; ++i) {
res = res + coefficients[i] * (pencil[MID + i] + pencil[MID - i]); res += coefficients[i] * (pencil[MID + i] + pencil[MID - i]);
} }
return res * inv_ds * inv_ds; return res * inv_ds * inv_ds;
@@ -76,8 +76,8 @@ cross_derivative(Scalar pencil_a[], Scalar pencil_b[], Scalar inv_ds_a, Scalar i
Scalar res = 0.0; Scalar res = 0.0;
for (int i = 1; i <= MID; ++i) { for (int i = 1; i <= MID; ++i) {
res = res + coefficients[i] * (pencil_a[MID + i] + pencil_a[MID - i] - pencil_b[MID + i] - res += coefficients[i] *
pencil_b[MID - i]); (pencil_a[MID + i] + pencil_a[MID - i] - pencil_b[MID + i] - pencil_b[MID - i]);
} }
return res * inv_ds_a * inv_ds_b; return res * inv_ds_a * inv_ds_b;
} }
@@ -311,7 +311,7 @@ contract(const Matrix mat)
Scalar res = 0; Scalar res = 0;
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
res = res + dot(mat.row[i], mat.row[i]); res += dot(mat.row[i], mat.row[i]);
} }
return res; return res;

View File

@@ -29,7 +29,7 @@
#include <assert.h> #include <assert.h>
static __device__ __forceinline__ int static __device__ constexpr int
IDX(const int i) IDX(const int i)
{ {
return i; return i;
@@ -95,7 +95,7 @@ write(AcReal* __restrict__ out[], const int handle, const int idx, const AcReal
out[handle][idx] = value; out[handle][idx] = value;
} }
static __device__ void static __device__ __forceinline__ void
write(AcReal* __restrict__ out[], const int3 vec, const int idx, const AcReal3 value) write(AcReal* __restrict__ out[], const int3 vec, const int idx, const AcReal3 value)
{ {
write(out, vec.x, idx, value.x); write(out, vec.x, idx, value.x);
@@ -103,13 +103,13 @@ write(AcReal* __restrict__ out[], const int3 vec, const int idx, const AcReal3 v
write(out, vec.z, idx, value.z); write(out, vec.z, idx, value.z);
} }
static __device__ AcReal static __device__ __forceinline__ AcReal
read_out(const int idx, AcReal* __restrict__ field[], const int handle) read_out(const int idx, AcReal* __restrict__ field[], const int handle)
{ {
return field[handle][idx]; return field[handle][idx];
} }
static __device__ AcReal3 static __device__ __forceinline__ AcReal3
read_out(const int idx, AcReal* __restrict__ field[], const int3 handle) read_out(const int idx, AcReal* __restrict__ field[], const int3 handle)
{ {
return (AcReal3){read_out(idx, field, handle.x), read_out(idx, field, handle.y), return (AcReal3){read_out(idx, field, handle.x), read_out(idx, field, handle.y),