Added WIP stuff for the Astaroth DSL compiler rewrite. Once this branch is finished only a single source file will be needed (file ending .ac). This revision is needed to decouple absolutely all implementation-specific stuff (f.ex. AC_dsx) from the core library and make life easier for everyone. The plan is to provide a standard library header written in the DSL containing the derivative operations instead of hardcoding them in the CUDA implementation.
This commit is contained in:
714
acc/accrevision/stencil_kernel.ac
Normal file
714
acc/accrevision/stencil_kernel.ac
Normal file
@@ -0,0 +1,714 @@
|
|||||||
|
#define LDENSITY (1)
|
||||||
|
#define LHYDRO (1)
|
||||||
|
#define LMAGNETIC (1)
|
||||||
|
#define LENTROPY (1)
|
||||||
|
#define LTEMPERATURE (0)
|
||||||
|
#define LFORCING (1)
|
||||||
|
#define LUPWD (1)
|
||||||
|
#define LSINK (0)
|
||||||
|
|
||||||
|
#define AC_THERMAL_CONDUCTIVITY (AcReal(0.001)) // TODO: make an actual config parameter
|
||||||
|
|
||||||
|
// Int params
|
||||||
|
uniform int AC_max_steps;
|
||||||
|
uniform int AC_save_steps;
|
||||||
|
uniform int AC_bin_steps;
|
||||||
|
uniform int AC_bc_type;
|
||||||
|
uniform int AC_start_step;
|
||||||
|
|
||||||
|
// Real params
|
||||||
|
uniform Scalar AC_dt;
|
||||||
|
uniform Scalar AC_max_time;
|
||||||
|
// Spacing
|
||||||
|
uniform Scalar AC_dsx;
|
||||||
|
uniform Scalar AC_dsy;
|
||||||
|
uniform Scalar AC_dsz;
|
||||||
|
uniform Scalar AC_dsmin;
|
||||||
|
// physical grid
|
||||||
|
uniform Scalar AC_xlen;
|
||||||
|
uniform Scalar AC_ylen;
|
||||||
|
uniform Scalar AC_zlen;
|
||||||
|
uniform Scalar AC_xorig;
|
||||||
|
uniform Scalar AC_yorig;
|
||||||
|
uniform Scalar AC_zorig;
|
||||||
|
// Physical units
|
||||||
|
uniform Scalar AC_unit_density;
|
||||||
|
uniform Scalar AC_unit_velocity;
|
||||||
|
uniform Scalar AC_unit_length;
|
||||||
|
// properties of gravitating star
|
||||||
|
uniform Scalar AC_star_pos_x;
|
||||||
|
uniform Scalar AC_star_pos_y;
|
||||||
|
uniform Scalar AC_star_pos_z;
|
||||||
|
uniform Scalar AC_M_star;
|
||||||
|
// properties of sink particle
|
||||||
|
uniform Scalar AC_sink_pos_x;
|
||||||
|
uniform Scalar AC_sink_pos_y;
|
||||||
|
uniform Scalar AC_sink_pos_z;
|
||||||
|
uniform Scalar AC_M_sink;
|
||||||
|
uniform Scalar AC_M_sink_init;
|
||||||
|
uniform Scalar AC_M_sink_Msun;
|
||||||
|
uniform Scalar AC_soft;
|
||||||
|
uniform Scalar AC_accretion_range;
|
||||||
|
uniform Scalar AC_switch_accretion;
|
||||||
|
// Run params
|
||||||
|
uniform Scalar AC_cdt;
|
||||||
|
uniform Scalar AC_cdtv;
|
||||||
|
uniform Scalar AC_cdts;
|
||||||
|
uniform Scalar AC_nu_visc;
|
||||||
|
uniform Scalar AC_cs_sound;
|
||||||
|
uniform Scalar AC_eta;
|
||||||
|
uniform Scalar AC_mu0;
|
||||||
|
uniform Scalar AC_cp_sound;
|
||||||
|
uniform Scalar AC_gamma;
|
||||||
|
uniform Scalar AC_cv_sound;
|
||||||
|
uniform Scalar AC_lnT0;
|
||||||
|
uniform Scalar AC_lnrho0;
|
||||||
|
uniform Scalar AC_zeta;
|
||||||
|
uniform Scalar AC_trans;
|
||||||
|
// Other
|
||||||
|
uniform Scalar AC_bin_save_t;
|
||||||
|
// Initial condition params
|
||||||
|
uniform Scalar AC_ampl_lnrho;
|
||||||
|
uniform Scalar AC_ampl_uu;
|
||||||
|
uniform Scalar AC_angl_uu;
|
||||||
|
uniform Scalar AC_lnrho_edge;
|
||||||
|
uniform Scalar AC_lnrho_out;
|
||||||
|
// Forcing parameters. User configured.
|
||||||
|
uniform Scalar AC_forcing_magnitude;
|
||||||
|
uniform Scalar AC_relhel;
|
||||||
|
uniform Scalar AC_kmin;
|
||||||
|
uniform Scalar AC_kmax;
|
||||||
|
// Forcing parameters. Set by the generator.
|
||||||
|
uniform Scalar AC_forcing_phase;
|
||||||
|
uniform Scalar AC_k_forcex;
|
||||||
|
uniform Scalar AC_k_forcey;
|
||||||
|
uniform Scalar AC_k_forcez;
|
||||||
|
uniform Scalar AC_kaver;
|
||||||
|
uniform Scalar AC_ff_hel_rex;
|
||||||
|
uniform Scalar AC_ff_hel_rey;
|
||||||
|
uniform Scalar AC_ff_hel_rez;
|
||||||
|
uniform Scalar AC_ff_hel_imx;
|
||||||
|
uniform Scalar AC_ff_hel_imy;
|
||||||
|
uniform Scalar AC_ff_hel_imz;
|
||||||
|
// Additional helper params // (deduced from other params do not set these directly!)
|
||||||
|
uniform Scalar AC_G_const;
|
||||||
|
uniform Scalar AC_GM_star;
|
||||||
|
uniform Scalar AC_unit_mass;
|
||||||
|
uniform Scalar AC_sq2GM_star;
|
||||||
|
uniform Scalar AC_cs2_sound;
|
||||||
|
uniform Scalar AC_inv_dsx;
|
||||||
|
uniform Scalar AC_inv_dsy;
|
||||||
|
uniform Scalar AC_inv_dsz;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* =============================================================================
|
||||||
|
* User-defined vertex buffers
|
||||||
|
* =============================================================================
|
||||||
|
*/
|
||||||
|
#if LENTROPY
|
||||||
|
uniform ScalarField VTXBUF_LNRHO;
|
||||||
|
uniform ScalarField VTXBUF_UUX;
|
||||||
|
uniform ScalarField VTXBUF_UUY;
|
||||||
|
uniform ScalarField VTXBUF_UUZ;
|
||||||
|
uniform ScalarField VTXBUF_AX;
|
||||||
|
uniform ScalarField VTXBUF_AY;
|
||||||
|
uniform ScalarField VTXBUF_AZ;
|
||||||
|
uniform ScalarField VTXBUF_ENTROPY;
|
||||||
|
#elif LMAGNETIC
|
||||||
|
uniform ScalarField VTXBUF_LNRHO;
|
||||||
|
uniform ScalarField VTXBUF_UUX;
|
||||||
|
uniform ScalarField VTXBUF_UUY;
|
||||||
|
uniform ScalarField VTXBUF_UUZ;
|
||||||
|
uniform ScalarField VTXBUF_AX;
|
||||||
|
uniform ScalarField VTXBUF_AY;
|
||||||
|
uniform ScalarField VTXBUF_AZ;
|
||||||
|
#elif LHYDRO
|
||||||
|
uniform ScalarField VTXBUF_LNRHO;
|
||||||
|
uniform ScalarField VTXBUF_UUX;
|
||||||
|
uniform ScalarField VTXBUF_UUY;
|
||||||
|
uniform ScalarField VTXBUF_UUZ;
|
||||||
|
#else
|
||||||
|
uniform ScalarField VTXBUF_LNRHO;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if LSINK
|
||||||
|
uniform ScalarField VTXBUF_ACCRETION;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
Preprocessed Scalar
|
||||||
|
value(in ScalarField vertex)
|
||||||
|
{
|
||||||
|
return vertex[vertexIdx];
|
||||||
|
}
|
||||||
|
|
||||||
|
Preprocessed Vector
|
||||||
|
gradient(in ScalarField vertex)
|
||||||
|
{
|
||||||
|
return (Vector){derx(vertexIdx, vertex), dery(vertexIdx, vertex), derz(vertexIdx, vertex)};
|
||||||
|
}
|
||||||
|
|
||||||
|
#if LUPWD
|
||||||
|
|
||||||
|
Preprocessed Scalar
|
||||||
|
der6x_upwd(in ScalarField vertex)
|
||||||
|
{
|
||||||
|
Scalar inv_ds = AC_inv_dsx;
|
||||||
|
|
||||||
|
return (Scalar){Scalar(1.0 / 60.0) * inv_ds *
|
||||||
|
(-Scalar(20.0) * vertex[vertexIdx.x, vertexIdx.y, vertexIdx.z] +
|
||||||
|
Scalar(15.0) * (vertex[vertexIdx.x + 1, vertexIdx.y, vertexIdx.z] +
|
||||||
|
vertex[vertexIdx.x - 1, vertexIdx.y, vertexIdx.z]) -
|
||||||
|
Scalar(6.0) * (vertex[vertexIdx.x + 2, vertexIdx.y, vertexIdx.z] +
|
||||||
|
vertex[vertexIdx.x - 2, vertexIdx.y, vertexIdx.z]) +
|
||||||
|
vertex[vertexIdx.x + 3, vertexIdx.y, vertexIdx.z] +
|
||||||
|
vertex[vertexIdx.x - 3, vertexIdx.y, vertexIdx.z])};
|
||||||
|
}
|
||||||
|
|
||||||
|
Preprocessed Scalar
|
||||||
|
der6y_upwd(in ScalarField vertex)
|
||||||
|
{
|
||||||
|
Scalar inv_ds = AC_inv_dsy;
|
||||||
|
|
||||||
|
return (Scalar){Scalar(1.0 / 60.0) * inv_ds *
|
||||||
|
(-Scalar(20.0) * vertex[vertexIdx.x, vertexIdx.y, vertexIdx.z] +
|
||||||
|
Scalar(15.0) * (vertex[vertexIdx.x, vertexIdx.y + 1, vertexIdx.z] +
|
||||||
|
vertex[vertexIdx.x, vertexIdx.y - 1, vertexIdx.z]) -
|
||||||
|
Scalar(6.0) * (vertex[vertexIdx.x, vertexIdx.y + 2, vertexIdx.z] +
|
||||||
|
vertex[vertexIdx.x, vertexIdx.y - 2, vertexIdx.z]) +
|
||||||
|
vertex[vertexIdx.x, vertexIdx.y + 3, vertexIdx.z] +
|
||||||
|
vertex[vertexIdx.x, vertexIdx.y - 3, vertexIdx.z])};
|
||||||
|
}
|
||||||
|
|
||||||
|
Preprocessed Scalar
|
||||||
|
der6z_upwd(in ScalarField vertex)
|
||||||
|
{
|
||||||
|
Scalar inv_ds = AC_inv_dsz;
|
||||||
|
|
||||||
|
return (Scalar){Scalar(1.0 / 60.0) * inv_ds *
|
||||||
|
(-Scalar(20.0) * vertex[vertexIdx.x, vertexIdx.y, vertexIdx.z] +
|
||||||
|
Scalar(15.0) * (vertex[vertexIdx.x, vertexIdx.y, vertexIdx.z + 1] +
|
||||||
|
vertex[vertexIdx.x, vertexIdx.y, vertexIdx.z - 1]) -
|
||||||
|
Scalar(6.0) * (vertex[vertexIdx.x, vertexIdx.y, vertexIdx.z + 2] +
|
||||||
|
vertex[vertexIdx.x, vertexIdx.y, vertexIdx.z - 2]) +
|
||||||
|
vertex[vertexIdx.x, vertexIdx.y, vertexIdx.z + 3] +
|
||||||
|
vertex[vertexIdx.x, vertexIdx.y, vertexIdx.z - 3])};
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Preprocessed Matrix
|
||||||
|
hessian(in ScalarField vertex)
|
||||||
|
{
|
||||||
|
Matrix hessian;
|
||||||
|
|
||||||
|
hessian.row[0] = (Vector){derxx(vertexIdx, vertex), derxy(vertexIdx, vertex),
|
||||||
|
derxz(vertexIdx, vertex)};
|
||||||
|
hessian.row[1] = (Vector){hessian.row[0].y, deryy(vertexIdx, vertex), deryz(vertexIdx, vertex)};
|
||||||
|
hessian.row[2] = (Vector){hessian.row[0].z, hessian.row[1].z, derzz(vertexIdx, vertex)};
|
||||||
|
|
||||||
|
return hessian;
|
||||||
|
}
|
||||||
|
|
||||||
|
Device Vector
|
||||||
|
value(in VectorField uu)
|
||||||
|
{
|
||||||
|
return (Vector){value(uu.x), value(uu.y), value(uu.z)};
|
||||||
|
}
|
||||||
|
|
||||||
|
#if LUPWD
|
||||||
|
Device Scalar
|
||||||
|
upwd_der6(in VectorField uu, in ScalarField lnrho)
|
||||||
|
{
|
||||||
|
Scalar uux = fabs(value(uu).x);
|
||||||
|
Scalar uuy = fabs(value(uu).y);
|
||||||
|
Scalar uuz = fabs(value(uu).z);
|
||||||
|
return (Scalar){uux * der6x_upwd(lnrho) + uuy * der6y_upwd(lnrho) + uuz * der6z_upwd(lnrho)};
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Device Matrix
|
||||||
|
gradients(in VectorField uu)
|
||||||
|
{
|
||||||
|
return (Matrix){gradient(uu.x), gradient(uu.y), gradient(uu.z)};
|
||||||
|
}
|
||||||
|
|
||||||
|
#if LSINK
|
||||||
|
Device Vector
|
||||||
|
sink_gravity(int3 globalVertexIdx){
|
||||||
|
int accretion_switch = int(AC_switch_accretion);
|
||||||
|
if (accretion_switch == 1){
|
||||||
|
Vector force_gravity;
|
||||||
|
const Vector grid_pos = (Vector){(globalVertexIdx.x - DCONST(AC_nx_min)) * AC_dsx,
|
||||||
|
(globalVertexIdx.y - DCONST(AC_ny_min)) * AC_dsy,
|
||||||
|
(globalVertexIdx.z - DCONST(AC_nz_min)) * AC_dsz};
|
||||||
|
const Scalar sink_mass = AC_M_sink;
|
||||||
|
const Vector sink_pos = (Vector){AC_sink_pos_x,
|
||||||
|
AC_sink_pos_y,
|
||||||
|
AC_sink_pos_z};
|
||||||
|
const Scalar distance = length(grid_pos - sink_pos);
|
||||||
|
const Scalar soft = AC_soft;
|
||||||
|
//MV: The commit 083ff59 had AC_G_const defined wrong here in DSL making it exxessively strong.
|
||||||
|
//MV: Scalar gravity_magnitude = ... below is correct!
|
||||||
|
const Scalar gravity_magnitude = (AC_G_const * sink_mass) / pow(((distance * distance) + soft*soft), 1.5);
|
||||||
|
const Vector direction = (Vector){(sink_pos.x - grid_pos.x) / distance,
|
||||||
|
(sink_pos.y - grid_pos.y) / distance,
|
||||||
|
(sink_pos.z - grid_pos.z) / distance};
|
||||||
|
force_gravity = gravity_magnitude * direction;
|
||||||
|
return force_gravity;
|
||||||
|
} else {
|
||||||
|
return (Vector){0.0, 0.0, 0.0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
#if LSINK
|
||||||
|
// Give Truelove density
|
||||||
|
Device Scalar
|
||||||
|
truelove_density(in ScalarField lnrho){
|
||||||
|
const Scalar rho = exp(value(lnrho));
|
||||||
|
const Scalar Jeans_length_squared = (M_PI * AC_cs2_sound) / (AC_G_const * rho);
|
||||||
|
const Scalar TJ_rho = ((M_PI) * ((AC_dsx * AC_dsx) / Jeans_length_squared) * AC_cs2_sound) / (AC_G_const * AC_dsx * AC_dsx);
|
||||||
|
//TODO: AC_dsx will cancel out, deal with it later for optimization.
|
||||||
|
|
||||||
|
Scalar accretion_rho = TJ_rho;
|
||||||
|
|
||||||
|
return accretion_rho;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This controls accretion of density/mass to the sink particle.
|
||||||
|
Device Scalar
|
||||||
|
sink_accretion(int3 globalVertexIdx, in ScalarField lnrho, Scalar dt){
|
||||||
|
const Vector grid_pos = (Vector){(globalVertexIdx.x - DCONST(AC_nx_min)) * AC_dsx,
|
||||||
|
(globalVertexIdx.y - DCONST(AC_ny_min)) * AC_dsy,
|
||||||
|
(globalVertexIdx.z - DCONST(AC_nz_min)) * AC_dsz};
|
||||||
|
const Vector sink_pos = (Vector){AC_sink_pos_x,
|
||||||
|
AC_sink_pos_y,
|
||||||
|
AC_sink_pos_z};
|
||||||
|
const Scalar profile_range = AC_accretion_range;
|
||||||
|
const Scalar accretion_distance = length(grid_pos - sink_pos);
|
||||||
|
int accretion_switch = AC_switch_accretion;
|
||||||
|
Scalar accretion_density;
|
||||||
|
Scalar weight;
|
||||||
|
|
||||||
|
if (accretion_switch == 1){
|
||||||
|
if ((accretion_distance) <= profile_range){
|
||||||
|
//weight = Scalar(1.0);
|
||||||
|
//Hann window function
|
||||||
|
Scalar window_ratio = accretion_distance/profile_range;
|
||||||
|
weight = Scalar(0.5)*(Scalar(1.0) - cos(Scalar(2.0)*M_PI*window_ratio));
|
||||||
|
} else {
|
||||||
|
weight = Scalar(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Truelove criterion is used as a kind of arbitrary density floor.
|
||||||
|
const Scalar lnrho_min = log(truelove_density(lnrho));
|
||||||
|
Scalar rate;
|
||||||
|
if (value(lnrho) > lnrho_min) {
|
||||||
|
rate = (exp(value(lnrho)) - exp(lnrho_min)) / dt;
|
||||||
|
} else {
|
||||||
|
rate = Scalar(0.0);
|
||||||
|
}
|
||||||
|
accretion_density = weight * rate ;
|
||||||
|
} else {
|
||||||
|
accretion_density = Scalar(0.0);
|
||||||
|
}
|
||||||
|
return accretion_density;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This controls accretion of velocity to the sink particle.
|
||||||
|
Device Vector
|
||||||
|
sink_accretion_velocity(int3 globalVertexIdx, in VectorField uu, Scalar dt) {
|
||||||
|
const Vector grid_pos = (Vector){(globalVertexIdx.x - DCONST(AC_nx_min)) * AC_dsx,
|
||||||
|
(globalVertexIdx.y - DCONST(AC_ny_min)) * AC_dsy,
|
||||||
|
(globalVertexIdx.z - DCONST(AC_nz_min)) * AC_dsz};
|
||||||
|
const Vector sink_pos = (Vector){AC_sink_pos_x,
|
||||||
|
AC_sink_pos_y,
|
||||||
|
AC_sink_pos_z};
|
||||||
|
const Scalar profile_range = AC_accretion_range;
|
||||||
|
const Scalar accretion_distance = length(grid_pos - sink_pos);
|
||||||
|
int accretion_switch = AC_switch_accretion;
|
||||||
|
Vector accretion_velocity;
|
||||||
|
|
||||||
|
if (accretion_switch == 1){
|
||||||
|
Scalar weight;
|
||||||
|
// Step function weighting
|
||||||
|
// Arch of a cosine function?
|
||||||
|
// Cubic spline x^3 - x in range [-0.5 , 0.5]
|
||||||
|
if ((accretion_distance) <= profile_range){
|
||||||
|
//weight = Scalar(1.0);
|
||||||
|
//Hann window function
|
||||||
|
Scalar window_ratio = accretion_distance/profile_range;
|
||||||
|
weight = Scalar(0.5)*(Scalar(1.0) - cos(Scalar(2.0)*M_PI*window_ratio));
|
||||||
|
} else {
|
||||||
|
weight = Scalar(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Vector rate;
|
||||||
|
// MV: Could we use divergence here ephasize velocitie which are compressive and
|
||||||
|
// MV: not absorbins stuff that would not be accreted anyway?
|
||||||
|
if (length(value(uu)) > Scalar(0.0)) {
|
||||||
|
rate = (Scalar(1.0)/dt) * value(uu);
|
||||||
|
} else {
|
||||||
|
rate = (Vector){0.0, 0.0, 0.0};
|
||||||
|
}
|
||||||
|
accretion_velocity = weight * rate ;
|
||||||
|
} else {
|
||||||
|
accretion_velocity = (Vector){0.0, 0.0, 0.0};
|
||||||
|
}
|
||||||
|
return accretion_velocity;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
Device Scalar
|
||||||
|
continuity(int3 globalVertexIdx, in VectorField uu, in ScalarField lnrho, Scalar dt)
|
||||||
|
{
|
||||||
|
return -dot(value(uu), gradient(lnrho))
|
||||||
|
#if LUPWD
|
||||||
|
// This is a corrective hyperdiffusion term for upwinding.
|
||||||
|
+ upwd_der6(uu, lnrho)
|
||||||
|
#endif
|
||||||
|
#if LSINK
|
||||||
|
- sink_accretion(globalVertexIdx, lnrho, dt) / exp(value(lnrho))
|
||||||
|
#endif
|
||||||
|
- divergence(uu);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#if LENTROPY
|
||||||
|
Device Vector
|
||||||
|
momentum(int3 globalVertexIdx, in VectorField uu, in ScalarField lnrho, in ScalarField ss, in VectorField aa, Scalar dt)
|
||||||
|
{
|
||||||
|
const Matrix S = stress_tensor(uu);
|
||||||
|
const Scalar cs2 = AC_cs2_sound * exp(AC_gamma * value(ss) / AC_cp_sound +
|
||||||
|
(AC_gamma - 1) * (value(lnrho) - AC_lnrho0));
|
||||||
|
const Vector j = (Scalar(1.0) / AC_mu0) *
|
||||||
|
(gradient_of_divergence(aa) - laplace_vec(aa)); // Current density
|
||||||
|
const Vector B = curl(aa);
|
||||||
|
// TODO: DOES INTHERMAL VERSTION INCLUDE THE MAGNETIC FIELD?
|
||||||
|
const Scalar inv_rho = Scalar(1.0) / exp(value(lnrho));
|
||||||
|
|
||||||
|
// Regex replace CPU constants with get\(AC_([a-zA-Z_0-9]*)\)
|
||||||
|
// \1
|
||||||
|
const Vector mom = -mul(gradients(uu), value(uu)) -
|
||||||
|
cs2 * ((Scalar(1.0) / AC_cp_sound) * gradient(ss) + gradient(lnrho)) +
|
||||||
|
inv_rho * cross(j, B) +
|
||||||
|
AC_nu_visc *
|
||||||
|
(laplace_vec(uu) + Scalar(1.0 / 3.0) * gradient_of_divergence(uu) +
|
||||||
|
Scalar(2.0) * mul(S, gradient(lnrho))) +
|
||||||
|
AC_zeta * gradient_of_divergence(uu)
|
||||||
|
#if LSINK
|
||||||
|
//Gravity term
|
||||||
|
+ sink_gravity(globalVertexIdx)
|
||||||
|
//Corresponding loss of momentum
|
||||||
|
- //(Scalar(1.0) / Scalar( (AC_dsx*AC_dsy*AC_dsz) * exp(value(lnrho)))) * // Correction factor by unit mass
|
||||||
|
sink_accretion_velocity(globalVertexIdx, uu, dt) // As in Lee et al.(2014)
|
||||||
|
;
|
||||||
|
#else
|
||||||
|
;
|
||||||
|
#endif
|
||||||
|
return mom;
|
||||||
|
}
|
||||||
|
#elif LTEMPERATURE
|
||||||
|
Device Vector
|
||||||
|
momentum(int3 globalVertexIdx, in VectorField uu, in ScalarField lnrho, in ScalarField tt)
|
||||||
|
{
|
||||||
|
Vector mom;
|
||||||
|
|
||||||
|
const Matrix S = stress_tensor(uu);
|
||||||
|
|
||||||
|
const Vector pressure_term = (AC_cp_sound - AC_cv_sound) *
|
||||||
|
(gradient(tt) + value(tt) * gradient(lnrho));
|
||||||
|
|
||||||
|
mom = -mul(gradients(uu), value(uu)) - pressure_term +
|
||||||
|
AC_nu_visc * (laplace_vec(uu) + Scalar(1.0 / 3.0) * gradient_of_divergence(uu) +
|
||||||
|
Scalar(2.0) * mul(S, gradient(lnrho))) +
|
||||||
|
AC_zeta * gradient_of_divergence(uu)
|
||||||
|
#if LSINK
|
||||||
|
+ sink_gravity(globalVertexIdx);
|
||||||
|
#else
|
||||||
|
;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if LGRAVITY
|
||||||
|
mom = mom - (Vector){0, 0, -10.0};
|
||||||
|
#endif
|
||||||
|
return mom;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
Device Vector
|
||||||
|
momentum(int3 globalVertexIdx, in VectorField uu, in ScalarField lnrho, Scalar dt)
|
||||||
|
{
|
||||||
|
Vector mom;
|
||||||
|
|
||||||
|
const Matrix S = stress_tensor(uu);
|
||||||
|
|
||||||
|
// Isothermal: we have constant speed of sound
|
||||||
|
|
||||||
|
mom = -mul(gradients(uu), value(uu)) - AC_cs2_sound * gradient(lnrho) +
|
||||||
|
AC_nu_visc * (laplace_vec(uu) + Scalar(1.0 / 3.0) * gradient_of_divergence(uu) +
|
||||||
|
Scalar(2.0) * mul(S, gradient(lnrho))) +
|
||||||
|
AC_zeta * gradient_of_divergence(uu)
|
||||||
|
#if LSINK
|
||||||
|
+ sink_gravity(globalVertexIdx)
|
||||||
|
//Corresponding loss of momentum
|
||||||
|
- //(Scalar(1.0) / Scalar( (AC_dsx*AC_dsy*AC_dsz) * exp(value(lnrho)))) * // Correction factor by unit mass
|
||||||
|
sink_accretion_velocity(globalVertexIdx, uu, dt) // As in Lee et al.(2014)
|
||||||
|
;
|
||||||
|
#else
|
||||||
|
;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if LGRAVITY
|
||||||
|
mom = mom - (Vector){0, 0, -10.0};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return mom;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Device Vector
|
||||||
|
induction(in VectorField uu, in VectorField aa)
|
||||||
|
{
|
||||||
|
// Note: We do (-nabla^2 A + nabla(nabla dot A)) instead of (nabla x (nabla
|
||||||
|
// x A)) in order to avoid taking the first derivative twice (did the math,
|
||||||
|
// yes this actually works. See pg.28 in arXiv:astro-ph/0109497)
|
||||||
|
// u cross B - AC_eta * AC_mu0 * (AC_mu0^-1 * [- laplace A + grad div A ])
|
||||||
|
const Vector B = curl(aa);
|
||||||
|
const Vector grad_div = gradient_of_divergence(aa);
|
||||||
|
const Vector lap = laplace_vec(aa);
|
||||||
|
|
||||||
|
// Note, AC_mu0 is cancelled out
|
||||||
|
const Vector ind = cross(value(uu), B) - AC_eta * (grad_div - lap);
|
||||||
|
|
||||||
|
return ind;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if LENTROPY
|
||||||
|
Device Scalar
|
||||||
|
lnT(in ScalarField ss, in ScalarField lnrho)
|
||||||
|
{
|
||||||
|
const Scalar lnT = AC_lnT0 + AC_gamma * value(ss) / AC_cp_sound +
|
||||||
|
(AC_gamma - Scalar(1.0)) * (value(lnrho) - AC_lnrho0);
|
||||||
|
return lnT;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nabla dot (K nabla T) / (rho T)
|
||||||
|
Device Scalar
|
||||||
|
heat_conduction(in ScalarField ss, in ScalarField lnrho)
|
||||||
|
{
|
||||||
|
const Scalar inv_AC_cp_sound = AcReal(1.0) / AC_cp_sound;
|
||||||
|
|
||||||
|
const Vector grad_ln_chi = -gradient(lnrho);
|
||||||
|
|
||||||
|
const Scalar first_term = AC_gamma * inv_AC_cp_sound * laplace(ss) +
|
||||||
|
(AC_gamma - AcReal(1.0)) * laplace(lnrho);
|
||||||
|
const Vector second_term = AC_gamma * inv_AC_cp_sound * gradient(ss) +
|
||||||
|
(AC_gamma - AcReal(1.0)) * gradient(lnrho);
|
||||||
|
const Vector third_term = AC_gamma * (inv_AC_cp_sound * gradient(ss) + gradient(lnrho)) +
|
||||||
|
grad_ln_chi;
|
||||||
|
|
||||||
|
const Scalar chi = AC_THERMAL_CONDUCTIVITY / (exp(value(lnrho)) * AC_cp_sound);
|
||||||
|
return AC_cp_sound * chi * (first_term + dot(second_term, third_term));
|
||||||
|
}
|
||||||
|
|
||||||
|
Device Scalar
|
||||||
|
heating(const int i, const int j, const int k)
|
||||||
|
{
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Device Scalar
|
||||||
|
entropy(in ScalarField ss, in VectorField uu, in ScalarField lnrho, in VectorField aa)
|
||||||
|
{
|
||||||
|
const Matrix S = stress_tensor(uu);
|
||||||
|
const Scalar inv_pT = Scalar(1.0) / (exp(value(lnrho)) * exp(lnT(ss, lnrho)));
|
||||||
|
const Vector j = (Scalar(1.0) / AC_mu0) *
|
||||||
|
(gradient_of_divergence(aa) - laplace_vec(aa)); // Current density
|
||||||
|
const Scalar RHS = H_CONST - C_CONST + AC_eta * (AC_mu0)*dot(j, j) +
|
||||||
|
Scalar(2.0) * exp(value(lnrho)) * AC_nu_visc * contract(S) +
|
||||||
|
AC_zeta * exp(value(lnrho)) * divergence(uu) * divergence(uu);
|
||||||
|
|
||||||
|
return -dot(value(uu), gradient(ss)) + inv_pT * RHS + heat_conduction(ss, lnrho);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if LTEMPERATURE
|
||||||
|
Device Scalar
|
||||||
|
heat_transfer(in VectorField uu, in ScalarField lnrho, in ScalarField tt)
|
||||||
|
{
|
||||||
|
const Matrix S = stress_tensor(uu);
|
||||||
|
const Scalar heat_diffusivity_k = 0.0008; // 8e-4;
|
||||||
|
return -dot(value(uu), gradient(tt)) + heat_diffusivity_k * laplace(tt) +
|
||||||
|
heat_diffusivity_k * dot(gradient(lnrho), gradient(tt)) +
|
||||||
|
AC_nu_visc * contract(S) * (Scalar(1.0) / AC_cv_sound) -
|
||||||
|
(AC_gamma - 1) * value(tt) * divergence(uu);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if LFORCING
|
||||||
|
Device Vector
|
||||||
|
simple_vortex_forcing(Vector a, Vector b, Scalar magnitude){
|
||||||
|
int accretion_switch = AC_switch_accretion;
|
||||||
|
|
||||||
|
if (accretion_switch == 0){
|
||||||
|
return magnitude * cross(normalized(b - a), (Vector){ 0, 0, 1}); // Vortex
|
||||||
|
} else {
|
||||||
|
return (Vector){0,0,0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Device Vector
|
||||||
|
simple_outward_flow_forcing(Vector a, Vector b, Scalar magnitude){
|
||||||
|
int accretion_switch = AC_switch_accretion;
|
||||||
|
if (accretion_switch == 0){
|
||||||
|
return magnitude * (1 / length(b - a)) * normalized(b - a); // Outward flow
|
||||||
|
} else {
|
||||||
|
return (Vector){0,0,0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The Pencil Code forcing_hel_noshear(), manual Eq. 222, inspired forcing function with adjustable
|
||||||
|
// helicity
|
||||||
|
Device Vector
|
||||||
|
helical_forcing(Scalar magnitude, Vector k_force, Vector xx, Vector ff_re, Vector ff_im, Scalar phi)
|
||||||
|
{
|
||||||
|
// JP: This looks wrong:
|
||||||
|
// 1) Should it be AC_dsx * AC_nx instead of AC_dsx * AC_ny?
|
||||||
|
// 2) Should you also use globalGrid.n instead of the local n?
|
||||||
|
// MV: You are rigth. Made a quickfix. I did not see the error because multigpu is split
|
||||||
|
// in z direction not y direction.
|
||||||
|
// 3) Also final point: can we do this with vectors/quaternions instead?
|
||||||
|
// Tringonometric functions are much more expensive and inaccurate/
|
||||||
|
// MV: Good idea. No an immediate priority.
|
||||||
|
// Fun related article:
|
||||||
|
// https://randomascii.wordpress.com/2014/10/09/intel-underestimates-error-bounds-by-1-3-quintillion/
|
||||||
|
xx.x = xx.x * (2.0 * M_PI / (AC_dsx * globalGridN.x));
|
||||||
|
xx.y = xx.y * (2.0 * M_PI / (AC_dsy * globalGridN.y));
|
||||||
|
xx.z = xx.z * (2.0 * M_PI / (AC_dsz * globalGridN.z));
|
||||||
|
|
||||||
|
Scalar cos_phi = cos(phi);
|
||||||
|
Scalar sin_phi = sin(phi);
|
||||||
|
Scalar cos_k_dot_x = cos(dot(k_force, xx));
|
||||||
|
Scalar sin_k_dot_x = sin(dot(k_force, xx));
|
||||||
|
// Phase affect only the x-component
|
||||||
|
// Scalar real_comp = cos_k_dot_x;
|
||||||
|
// Scalar imag_comp = sin_k_dot_x;
|
||||||
|
Scalar real_comp_phase = cos_k_dot_x * cos_phi - sin_k_dot_x * sin_phi;
|
||||||
|
Scalar imag_comp_phase = cos_k_dot_x * sin_phi + sin_k_dot_x * cos_phi;
|
||||||
|
|
||||||
|
Vector force = (Vector){ff_re.x * real_comp_phase - ff_im.x * imag_comp_phase,
|
||||||
|
ff_re.y * real_comp_phase - ff_im.y * imag_comp_phase,
|
||||||
|
ff_re.z * real_comp_phase - ff_im.z * imag_comp_phase};
|
||||||
|
|
||||||
|
return force;
|
||||||
|
}
|
||||||
|
|
||||||
|
Device Vector
|
||||||
|
forcing(int3 globalVertexIdx, Scalar dt)
|
||||||
|
{
|
||||||
|
int accretion_switch = AC_switch_accretion;
|
||||||
|
if (accretion_switch == 0){
|
||||||
|
|
||||||
|
Vector a = Scalar(0.5) * (Vector){globalGridN.x * AC_dsx,
|
||||||
|
globalGridN.y * AC_dsy,
|
||||||
|
globalGridN.z * AC_dsz}; // source (origin)
|
||||||
|
Vector xx = (Vector){(globalVertexIdx.x - DCONST(AC_nx_min)) * AC_dsx,
|
||||||
|
(globalVertexIdx.y - DCONST(AC_ny_min)) * AC_dsy,
|
||||||
|
(globalVertexIdx.z - DCONST(AC_nz_min)) * AC_dsz}; // sink (current index)
|
||||||
|
const Scalar cs2 = AC_cs2_sound;
|
||||||
|
const Scalar cs = sqrt(cs2);
|
||||||
|
|
||||||
|
//Placeholders until determined properly
|
||||||
|
Scalar magnitude = AC_forcing_magnitude;
|
||||||
|
Scalar phase = AC_forcing_phase;
|
||||||
|
Vector k_force = (Vector){AC_k_forcex, AC_k_forcey, AC_k_forcez};
|
||||||
|
Vector ff_re = (Vector){AC_ff_hel_rex, AC_ff_hel_rey, AC_ff_hel_rez};
|
||||||
|
Vector ff_im = (Vector){AC_ff_hel_imx, AC_ff_hel_imy, AC_ff_hel_imz};
|
||||||
|
|
||||||
|
|
||||||
|
//Determine that forcing funtion type at this point.
|
||||||
|
//Vector force = simple_vortex_forcing(a, xx, magnitude);
|
||||||
|
//Vector force = simple_outward_flow_forcing(a, xx, magnitude);
|
||||||
|
Vector force = helical_forcing(magnitude, k_force, xx, ff_re,ff_im, phase);
|
||||||
|
|
||||||
|
//Scaling N = magnitude*cs*sqrt(k*cs/dt) * dt
|
||||||
|
const Scalar NN = cs*sqrt(AC_kaver*cs);
|
||||||
|
//MV: Like in the Pencil Code. I don't understandf the logic here.
|
||||||
|
force.x = sqrt(dt)*NN*force.x;
|
||||||
|
force.y = sqrt(dt)*NN*force.y;
|
||||||
|
force.z = sqrt(dt)*NN*force.z;
|
||||||
|
|
||||||
|
if (is_valid(force)) { return force; }
|
||||||
|
else { return (Vector){0, 0, 0}; }
|
||||||
|
} else {
|
||||||
|
return (Vector){0,0,0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif // LFORCING
|
||||||
|
|
||||||
|
// Declare input and output arrays using locations specified in the
|
||||||
|
// array enum in astaroth.h
|
||||||
|
in ScalarField lnrho(VTXBUF_LNRHO);
|
||||||
|
out ScalarField out_lnrho(VTXBUF_LNRHO);
|
||||||
|
|
||||||
|
in VectorField uu(VTXBUF_UUX, VTXBUF_UUY, VTXBUF_UUZ);
|
||||||
|
out VectorField out_uu(VTXBUF_UUX, VTXBUF_UUY, VTXBUF_UUZ);
|
||||||
|
|
||||||
|
#if LMAGNETIC
|
||||||
|
in VectorField aa(VTXBUF_AX, VTXBUF_AY, VTXBUF_AZ);
|
||||||
|
out VectorField out_aa(VTXBUF_AX, VTXBUF_AY, VTXBUF_AZ);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if LENTROPY
|
||||||
|
in ScalarField ss(VTXBUF_ENTROPY);
|
||||||
|
out ScalarField out_ss(VTXBUF_ENTROPY);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if LTEMPERATURE
|
||||||
|
in ScalarField tt(VTXBUF_TEMPERATURE);
|
||||||
|
out ScalarField out_tt(VTXBUF_TEMPERATURE);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if LSINK
|
||||||
|
in ScalarField accretion(VTXBUF_ACCRETION);
|
||||||
|
out ScalarField out_accretion(VTXBUF_ACCRETION);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Kernel void
|
||||||
|
solve()
|
||||||
|
{
|
||||||
|
Scalar dt = AC_dt;
|
||||||
|
out_lnrho = rk3(out_lnrho, lnrho, continuity(globalVertexIdx, uu, lnrho, dt), dt);
|
||||||
|
|
||||||
|
#if LMAGNETIC
|
||||||
|
out_aa = rk3(out_aa, aa, induction(uu, aa), dt);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if LENTROPY
|
||||||
|
out_uu = rk3(out_uu, uu, momentum(globalVertexIdx, uu, lnrho, ss, aa, dt), dt);
|
||||||
|
out_ss = rk3(out_ss, ss, entropy(ss, uu, lnrho, aa), dt);
|
||||||
|
#elif LTEMPERATURE
|
||||||
|
out_uu = rk3(out_uu, uu, momentum(globalVertexIdx, uu, lnrho, tt, dt), dt);
|
||||||
|
out_tt = rk3(out_tt, tt, heat_transfer(uu, lnrho, tt), dt);
|
||||||
|
#else
|
||||||
|
out_uu = rk3(out_uu, uu, momentum(globalVertexIdx, uu, lnrho, dt), dt);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if LFORCING
|
||||||
|
if (step_number == 2) {
|
||||||
|
out_uu = out_uu + forcing(globalVertexIdx, dt);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if LSINK
|
||||||
|
out_accretion = rk3(out_accretion, accretion, sink_accretion(globalVertexIdx, lnrho, dt), dt);// unit now is rho!
|
||||||
|
|
||||||
|
if (step_number == 2) {
|
||||||
|
out_accretion = out_accretion * AC_dsx * AC_dsy * AC_dsz;// unit is now mass!
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
@@ -22,6 +22,7 @@ L [a-zA-Z_]
|
|||||||
"ScalarArray" { return SCALARARRAY; }
|
"ScalarArray" { return SCALARARRAY; }
|
||||||
|
|
||||||
"Kernel" { return KERNEL; } /* Function specifiers */
|
"Kernel" { return KERNEL; } /* Function specifiers */
|
||||||
|
"Device" { return DEVICE; }
|
||||||
"Preprocessed" { return PREPROCESSED; }
|
"Preprocessed" { return PREPROCESSED; }
|
||||||
|
|
||||||
"const" { return CONSTANT; }
|
"const" { return CONSTANT; }
|
||||||
|
@@ -20,7 +20,7 @@ int yyget_lineno();
|
|||||||
%token VOID INT INT3 COMPLEX
|
%token VOID INT INT3 COMPLEX
|
||||||
%token IF ELSE FOR WHILE ELIF
|
%token IF ELSE FOR WHILE ELIF
|
||||||
%token LEQU LAND LOR LLEQU
|
%token LEQU LAND LOR LLEQU
|
||||||
%token KERNEL PREPROCESSED
|
%token KERNEL DEVICE PREPROCESSED
|
||||||
%token INPLACE_INC INPLACE_DEC
|
%token INPLACE_INC INPLACE_DEC
|
||||||
|
|
||||||
%%
|
%%
|
||||||
@@ -66,6 +66,7 @@ compound_statement: '{' '}'
|
|||||||
statement: selection_statement { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); }
|
statement: selection_statement { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); }
|
||||||
| iteration_statement { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); }
|
| iteration_statement { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); }
|
||||||
| exec_statement ';' { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); $$->postfix = ';'; }
|
| exec_statement ';' { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); $$->postfix = ';'; }
|
||||||
|
| compound_statement { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); }
|
||||||
;
|
;
|
||||||
|
|
||||||
selection_statement: IF expression else_selection_statement { $$ = astnode_create(NODE_UNKNOWN, $2, $3); $$->prefix = IF; }
|
selection_statement: IF expression else_selection_statement { $$ = astnode_create(NODE_UNKNOWN, $2, $3); $$->prefix = IF; }
|
||||||
@@ -115,8 +116,8 @@ return_statement: /* Empty */
|
|||||||
* =============================================================================
|
* =============================================================================
|
||||||
*/
|
*/
|
||||||
|
|
||||||
declaration_list: declaration { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); }
|
declaration_list: declaration { $$ = astnode_create(NODE_DECLARATION_LIST, $1, NULL); }
|
||||||
| declaration_list ',' declaration { $$ = astnode_create(NODE_UNKNOWN, $1, $3); $$->infix = ','; }
|
| declaration_list ',' declaration { $$ = astnode_create(NODE_DECLARATION_LIST, $1, $3); $$->infix = ','; }
|
||||||
;
|
;
|
||||||
|
|
||||||
declaration: type_declaration identifier { $$ = astnode_create(NODE_DECLARATION, $1, $2); } // Note: accepts only one type qualifier. Good or not?
|
declaration: type_declaration identifier { $$ = astnode_create(NODE_DECLARATION, $1, $2); } // Note: accepts only one type qualifier. Good or not?
|
||||||
@@ -127,8 +128,8 @@ array_declaration: identifier '[' ']'
|
|||||||
| identifier '[' expression ']' { $$ = astnode_create(NODE_UNKNOWN, $1, $3); $$->infix = '['; $$->postfix = ']'; }
|
| identifier '[' expression ']' { $$ = astnode_create(NODE_UNKNOWN, $1, $3); $$->infix = '['; $$->postfix = ']'; }
|
||||||
;
|
;
|
||||||
|
|
||||||
type_declaration: type_specifier { $$ = astnode_create(NODE_UNKNOWN, $1, NULL); }
|
type_declaration: type_specifier { $$ = astnode_create(NODE_TYPE_DECLARATION, $1, NULL); }
|
||||||
| type_qualifier type_specifier { $$ = astnode_create(NODE_UNKNOWN, $1, $2); }
|
| type_qualifier type_specifier { $$ = astnode_create(NODE_TYPE_DECLARATION, $1, $2); }
|
||||||
;
|
;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -196,6 +197,7 @@ unary_operator: '-' /* C-style casts are disallowed, would otherwise be defined
|
|||||||
;
|
;
|
||||||
|
|
||||||
type_qualifier: KERNEL { $$ = astnode_create(NODE_TYPE_QUALIFIER, NULL, NULL); $$->token = KERNEL; }
|
type_qualifier: KERNEL { $$ = astnode_create(NODE_TYPE_QUALIFIER, NULL, NULL); $$->token = KERNEL; }
|
||||||
|
| DEVICE { $$ = astnode_create(NODE_TYPE_QUALIFIER, NULL, NULL); $$->token = DEVICE; }
|
||||||
| PREPROCESSED { $$ = astnode_create(NODE_TYPE_QUALIFIER, NULL, NULL); $$->token = PREPROCESSED; }
|
| PREPROCESSED { $$ = astnode_create(NODE_TYPE_QUALIFIER, NULL, NULL); $$->token = PREPROCESSED; }
|
||||||
| CONSTANT { $$ = astnode_create(NODE_TYPE_QUALIFIER, NULL, NULL); $$->token = CONSTANT; }
|
| CONSTANT { $$ = astnode_create(NODE_TYPE_QUALIFIER, NULL, NULL); $$->token = CONSTANT; }
|
||||||
| IN { $$ = astnode_create(NODE_TYPE_QUALIFIER, NULL, NULL); $$->token = IN; }
|
| IN { $$ = astnode_create(NODE_TYPE_QUALIFIER, NULL, NULL); $$->token = IN; }
|
||||||
|
@@ -21,6 +21,8 @@
|
|||||||
FUNC(NODE_DEFINITION), \
|
FUNC(NODE_DEFINITION), \
|
||||||
FUNC(NODE_GLOBAL_DEFINITION), \
|
FUNC(NODE_GLOBAL_DEFINITION), \
|
||||||
FUNC(NODE_DECLARATION), \
|
FUNC(NODE_DECLARATION), \
|
||||||
|
FUNC(NODE_DECLARATION_LIST), \
|
||||||
|
FUNC(NODE_TYPE_DECLARATION), \
|
||||||
FUNC(NODE_TYPE_QUALIFIER), \
|
FUNC(NODE_TYPE_QUALIFIER), \
|
||||||
FUNC(NODE_TYPE_SPECIFIER), \
|
FUNC(NODE_TYPE_SPECIFIER), \
|
||||||
FUNC(NODE_IDENTIFIER), \
|
FUNC(NODE_IDENTIFIER), \
|
||||||
@@ -32,34 +34,11 @@
|
|||||||
FUNC(NODE_REAL_NUMBER)
|
FUNC(NODE_REAL_NUMBER)
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
/*
|
|
||||||
// Recreating strdup is not needed when using the GNU compiler.
|
|
||||||
// Let's also just say that anything but the GNU
|
|
||||||
// compiler is NOT supported, since there are also
|
|
||||||
// some gcc-specific calls in the files generated
|
|
||||||
// by flex and being completely compiler-independent is
|
|
||||||
// not a priority right now
|
|
||||||
#ifndef strdup
|
|
||||||
static inline char*
|
|
||||||
strdup(const char* in)
|
|
||||||
{
|
|
||||||
const size_t len = strlen(in) + 1;
|
|
||||||
char* out = malloc(len);
|
|
||||||
|
|
||||||
if (out) {
|
|
||||||
memcpy(out, in, len);
|
|
||||||
return out;
|
|
||||||
} else {
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
*/
|
|
||||||
|
|
||||||
typedef enum { FOR_NODE_TYPES(GEN_ID), NUM_NODE_TYPES } NodeType;
|
typedef enum { FOR_NODE_TYPES(GEN_ID), NUM_NODE_TYPES } NodeType;
|
||||||
|
|
||||||
typedef struct astnode_s {
|
typedef struct astnode_s {
|
||||||
int id;
|
int id;
|
||||||
|
struct astnode_s* parent;
|
||||||
struct astnode_s* lhs;
|
struct astnode_s* lhs;
|
||||||
struct astnode_s* rhs;
|
struct astnode_s* rhs;
|
||||||
NodeType type; // Type of the AST node
|
NodeType type; // Type of the AST node
|
||||||
@@ -85,6 +64,12 @@ astnode_create(const NodeType type, ASTNode* lhs, ASTNode* rhs)
|
|||||||
|
|
||||||
node->prefix = node->infix = node->postfix = 0;
|
node->prefix = node->infix = node->postfix = 0;
|
||||||
|
|
||||||
|
if (lhs)
|
||||||
|
node->lhs->parent = node;
|
||||||
|
|
||||||
|
if (rhs)
|
||||||
|
node->rhs->parent = node;
|
||||||
|
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,19 +91,21 @@ astnode_destroy(ASTNode* node)
|
|||||||
free(node);
|
free(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline void
|
||||||
|
astnode_print(const ASTNode* node)
|
||||||
|
{
|
||||||
|
const char* node_type_names[] = {FOR_NODE_TYPES(GEN_STR)};
|
||||||
|
|
||||||
|
printf("%s (%p)\n", node_type_names[node->type], node);
|
||||||
|
printf("\tid: %d\n", node->id);
|
||||||
|
printf("\tparent: %p\n", node->parent);
|
||||||
|
printf("\tlhs: %p\n", node->lhs);
|
||||||
|
printf("\trhs: %p\n", node->rhs);
|
||||||
|
printf("\tbuffer: %s\n", node->buffer);
|
||||||
|
printf("\ttoken: %d\n", node->token);
|
||||||
|
printf("\tprefix: %d ('%c')\n", node->prefix, node->prefix);
|
||||||
|
printf("\tinfix: %d ('%c')\n", node->infix, node->infix);
|
||||||
|
printf("\tpostfix: %d ('%c')\n", node->postfix, node->postfix);
|
||||||
|
}
|
||||||
|
|
||||||
extern ASTNode* root;
|
extern ASTNode* root;
|
||||||
|
|
||||||
/*
|
|
||||||
typedef enum {
|
|
||||||
SCOPE_BLOCK
|
|
||||||
} ScopeType;
|
|
||||||
|
|
||||||
typedef struct symbol_s {
|
|
||||||
int type_specifier;
|
|
||||||
char* identifier;
|
|
||||||
int scope;
|
|
||||||
struct symbol_s* next;
|
|
||||||
} Symbol;
|
|
||||||
|
|
||||||
extern ASTNode* symbol_table;
|
|
||||||
*/
|
|
||||||
|
@@ -25,6 +25,7 @@
|
|||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
@@ -35,9 +36,9 @@
|
|||||||
|
|
||||||
ASTNode* root = NULL;
|
ASTNode* root = NULL;
|
||||||
|
|
||||||
static const char inout_name_prefix[] = "handle_";
|
// Output files
|
||||||
typedef enum { STENCIL_ASSEMBLY, STENCIL_PROCESS, STENCIL_HEADER } CompilationType;
|
static FILE* DSLHEADER = NULL;
|
||||||
static CompilationType compilation_type;
|
static FILE* CUDAHEADER = NULL;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* =============================================================================
|
* =============================================================================
|
||||||
@@ -64,15 +65,13 @@ static const char* translation_table[TRANSLATION_TABLE_SIZE] = {
|
|||||||
[SCALARARRAY] = "const AcReal* __restrict__",
|
[SCALARARRAY] = "const AcReal* __restrict__",
|
||||||
[COMPLEX] = "acComplex",
|
[COMPLEX] = "acComplex",
|
||||||
// Type qualifiers
|
// Type qualifiers
|
||||||
[KERNEL] = "template <int step_number> static __global__",
|
[KERNEL] = "template <int step_number> static __global__",
|
||||||
//__launch_bounds__(RK_THREADBLOCK_SIZE,
|
[DEVICE] = "static __device__",
|
||||||
// RK_LAUNCH_BOUND_MIN_BLOCKS),
|
[PREPROCESSED] = "static __device__ __forceinline__",
|
||||||
[PREPROCESSED] = "static __device__ "
|
[CONSTANT] = "const",
|
||||||
"__forceinline__",
|
[IN] = "in",
|
||||||
[CONSTANT] = "const",
|
[OUT] = "out",
|
||||||
[IN] = "in",
|
[UNIFORM] = "uniform",
|
||||||
[OUT] = "out",
|
|
||||||
[UNIFORM] = "uniform",
|
|
||||||
// ETC
|
// ETC
|
||||||
[INPLACE_INC] = "++",
|
[INPLACE_INC] = "++",
|
||||||
[INPLACE_DEC] = "--",
|
[INPLACE_DEC] = "--",
|
||||||
@@ -121,7 +120,7 @@ typedef enum {
|
|||||||
NUM_SYMBOLTYPES
|
NUM_SYMBOLTYPES
|
||||||
} SymbolType;
|
} SymbolType;
|
||||||
|
|
||||||
#define MAX_ID_LEN (128)
|
#define MAX_ID_LEN (256)
|
||||||
typedef struct {
|
typedef struct {
|
||||||
SymbolType type;
|
SymbolType type;
|
||||||
int type_qualifier;
|
int type_qualifier;
|
||||||
@@ -129,135 +128,61 @@ typedef struct {
|
|||||||
char identifier[MAX_ID_LEN];
|
char identifier[MAX_ID_LEN];
|
||||||
} Symbol;
|
} Symbol;
|
||||||
|
|
||||||
#define SYMBOL_TABLE_SIZE (4096)
|
#define SYMBOL_TABLE_SIZE (65536)
|
||||||
static Symbol symbol_table[SYMBOL_TABLE_SIZE] = {};
|
static Symbol symbol_table[SYMBOL_TABLE_SIZE] = {};
|
||||||
static int num_symbols = 0;
|
|
||||||
|
|
||||||
static int
|
#define MAX_NESTS (32)
|
||||||
|
static size_t num_symbols[MAX_NESTS] = {};
|
||||||
|
static size_t current_nest = 0;
|
||||||
|
|
||||||
|
static Symbol*
|
||||||
symboltable_lookup(const char* identifier)
|
symboltable_lookup(const char* identifier)
|
||||||
{
|
{
|
||||||
if (!identifier)
|
if (!identifier)
|
||||||
return -1;
|
return NULL;
|
||||||
|
|
||||||
for (int i = 0; i < num_symbols; ++i)
|
for (size_t i = 0; i < num_symbols[current_nest]; ++i)
|
||||||
if (strcmp(identifier, symbol_table[i].identifier) == 0)
|
if (strcmp(identifier, symbol_table[i].identifier) == 0)
|
||||||
return i;
|
return &symbol_table[i];
|
||||||
|
|
||||||
return -1;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
add_symbol(const SymbolType type, const int tqualifier, const int tspecifier, const char* id)
|
add_symbol(const SymbolType type, const int tqualifier, const int tspecifier, const char* id)
|
||||||
{
|
{
|
||||||
assert(num_symbols < SYMBOL_TABLE_SIZE);
|
assert(num_symbols[current_nest] < SYMBOL_TABLE_SIZE);
|
||||||
|
|
||||||
symbol_table[num_symbols].type = type;
|
symbol_table[num_symbols[current_nest]].type = type;
|
||||||
symbol_table[num_symbols].type_qualifier = tqualifier;
|
symbol_table[num_symbols[current_nest]].type_qualifier = tqualifier;
|
||||||
symbol_table[num_symbols].type_specifier = tspecifier;
|
symbol_table[num_symbols[current_nest]].type_specifier = tspecifier;
|
||||||
strcpy(symbol_table[num_symbols].identifier, id);
|
strcpy(symbol_table[num_symbols[current_nest]].identifier, id);
|
||||||
|
|
||||||
++num_symbols;
|
++num_symbols[current_nest];
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
rm_symbol(const int handle)
|
print_symbol(const size_t handle)
|
||||||
{
|
|
||||||
assert(handle >= 0 && handle < num_symbols);
|
|
||||||
assert(num_symbols > 0);
|
|
||||||
|
|
||||||
if (&symbol_table[handle] != &symbol_table[num_symbols - 1])
|
|
||||||
memcpy(&symbol_table[handle], &symbol_table[num_symbols - 1], sizeof(Symbol));
|
|
||||||
--num_symbols;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void
|
|
||||||
print_symbol(const int handle)
|
|
||||||
{
|
{
|
||||||
assert(handle < SYMBOL_TABLE_SIZE);
|
assert(handle < SYMBOL_TABLE_SIZE);
|
||||||
|
|
||||||
const char* fields[] = {translate(symbol_table[handle].type_qualifier),
|
const char* fields[] = {
|
||||||
translate(symbol_table[handle].type_specifier),
|
translate(symbol_table[handle].type_qualifier),
|
||||||
symbol_table[handle].identifier};
|
translate(symbol_table[handle].type_specifier),
|
||||||
const size_t num_fields = sizeof(fields) / sizeof(fields[0]);
|
symbol_table[handle].identifier,
|
||||||
|
};
|
||||||
|
|
||||||
|
const size_t num_fields = sizeof(fields) / sizeof(fields[0]);
|
||||||
for (size_t i = 0; i < num_fields; ++i)
|
for (size_t i = 0; i < num_fields; ++i)
|
||||||
if (fields[i])
|
if (fields[i])
|
||||||
printf("%s ", fields[i]);
|
printf("%s ", fields[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
|
||||||
translate_latest_symbol(void)
|
|
||||||
{
|
|
||||||
const int handle = num_symbols - 1;
|
|
||||||
assert(handle < SYMBOL_TABLE_SIZE);
|
|
||||||
|
|
||||||
Symbol* symbol = &symbol_table[handle];
|
|
||||||
|
|
||||||
// FUNCTION
|
|
||||||
if (symbol->type == SYMBOLTYPE_FUNCTION) {
|
|
||||||
// KERNEL FUNCTION
|
|
||||||
if (symbol->type_qualifier == KERNEL) {
|
|
||||||
printf("%s %s\n%s", translate(symbol->type_qualifier),
|
|
||||||
translate(symbol->type_specifier), symbol->identifier);
|
|
||||||
}
|
|
||||||
// PREPROCESSED FUNCTION
|
|
||||||
else if (symbol->type_qualifier == PREPROCESSED) {
|
|
||||||
printf("%s %s\npreprocessed_%s", translate(symbol->type_qualifier),
|
|
||||||
translate(symbol->type_specifier), symbol->identifier);
|
|
||||||
}
|
|
||||||
// OTHER FUNCTION
|
|
||||||
else {
|
|
||||||
const char* regular_function_decorator = "static __device__ "
|
|
||||||
"__forceinline__";
|
|
||||||
printf("%s %s %s\n%s", regular_function_decorator,
|
|
||||||
translate(symbol->type_qualifier) ? translate(symbol->type_qualifier) : "",
|
|
||||||
translate(symbol->type_specifier), symbol->identifier);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// FUNCTION PARAMETER
|
|
||||||
else if (symbol->type == SYMBOLTYPE_FUNCTION_PARAMETER) {
|
|
||||||
if (symbol->type_qualifier == IN || symbol->type_qualifier == OUT) {
|
|
||||||
if (compilation_type == STENCIL_ASSEMBLY)
|
|
||||||
printf("const __restrict__ %s* %s", translate(symbol->type_specifier),
|
|
||||||
symbol->identifier);
|
|
||||||
else if (compilation_type == STENCIL_PROCESS)
|
|
||||||
printf("const %sData& %s", translate(symbol->type_specifier), symbol->identifier);
|
|
||||||
else
|
|
||||||
printf("Invalid compilation type %d, IN and OUT qualifiers not supported\n",
|
|
||||||
compilation_type);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
print_symbol(handle);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// UNIFORM
|
|
||||||
else if (symbol->type_qualifier == UNIFORM) {
|
|
||||||
// if (compilation_type != STENCIL_HEADER) {
|
|
||||||
// printf("ERROR: %s can only be used in stencil headers\n", translation_table[UNIFORM]);
|
|
||||||
//}
|
|
||||||
/* Do nothing */
|
|
||||||
}
|
|
||||||
// IN / OUT
|
|
||||||
else if (symbol->type != SYMBOLTYPE_FUNCTION_PARAMETER &&
|
|
||||||
(symbol->type_qualifier == IN || symbol->type_qualifier == OUT)) {
|
|
||||||
|
|
||||||
printf("static __device__ const %s %s%s",
|
|
||||||
symbol->type_specifier == SCALARFIELD ? "int" : "int3", inout_name_prefix,
|
|
||||||
symbol_table[handle].identifier);
|
|
||||||
if (symbol->type_specifier == VECTOR)
|
|
||||||
printf(" = make_int3");
|
|
||||||
}
|
|
||||||
// OTHER
|
|
||||||
else {
|
|
||||||
print_symbol(handle);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline void
|
static inline void
|
||||||
print_symbol_table(void)
|
print_symbol_table(void)
|
||||||
{
|
{
|
||||||
for (int i = 0; i < num_symbols; ++i) {
|
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||||
printf("%d: ", i);
|
printf("%lu: ", i);
|
||||||
const char* fields[] = {translate(symbol_table[i].type_qualifier),
|
const char* fields[] = {translate(symbol_table[i].type_qualifier),
|
||||||
translate(symbol_table[i].type_specifier),
|
translate(symbol_table[i].type_specifier),
|
||||||
symbol_table[i].identifier};
|
symbol_table[i].identifier};
|
||||||
@@ -279,377 +204,205 @@ print_symbol_table(void)
|
|||||||
|
|
||||||
/*
|
/*
|
||||||
* =============================================================================
|
* =============================================================================
|
||||||
* State
|
* Traversal state
|
||||||
* =============================================================================
|
* =============================================================================
|
||||||
*/
|
*/
|
||||||
static bool inside_declaration = false;
|
|
||||||
static bool inside_function_declaration = false;
|
|
||||||
static bool inside_function_parameter_declaration = false;
|
|
||||||
|
|
||||||
static bool inside_kernel = false;
|
|
||||||
static bool inside_preprocessed = false;
|
|
||||||
|
|
||||||
static int scope_start = 0;
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* =============================================================================
|
* =============================================================================
|
||||||
* AST traversal
|
* AST traversal
|
||||||
* =============================================================================
|
* =============================================================================
|
||||||
*/
|
*/
|
||||||
|
static void
|
||||||
static int compound_statement_nests = 0;
|
translate_latest_symbol(void)
|
||||||
|
{
|
||||||
|
// TODO
|
||||||
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
traverse(const ASTNode* node)
|
traverse(const ASTNode* node)
|
||||||
{
|
{
|
||||||
// Prefix logic %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
// Prefix translation
|
||||||
if (node->type == NODE_FUNCTION_DECLARATION)
|
if (translate(node->prefix))
|
||||||
inside_function_declaration = true;
|
fprintf(CUDAHEADER, "%s", translate(node->prefix));
|
||||||
if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION)
|
|
||||||
inside_function_parameter_declaration = true;
|
|
||||||
if (node->type == NODE_DECLARATION)
|
|
||||||
inside_declaration = true;
|
|
||||||
|
|
||||||
if (!inside_declaration && translate(node->prefix))
|
// Prefix logic
|
||||||
printf("%s", translate(node->prefix));
|
if (node->type == NODE_COMPOUND_STATEMENT) {
|
||||||
|
assert(current_nest < MAX_NESTS);
|
||||||
|
|
||||||
if (node->type == NODE_COMPOUND_STATEMENT)
|
++current_nest;
|
||||||
++compound_statement_nests;
|
num_symbols[current_nest] = num_symbols[current_nest - 1];
|
||||||
|
|
||||||
// BOILERPLATE START////////////////////////////////////////////////////////
|
|
||||||
if (node->type == NODE_TYPE_QUALIFIER && node->token == KERNEL)
|
|
||||||
inside_kernel = true;
|
|
||||||
|
|
||||||
// Kernel parameter boilerplate
|
|
||||||
const char* kernel_parameter_boilerplate = "GEN_KERNEL_PARAM_BOILERPLATE";
|
|
||||||
if (inside_kernel && node->type == NODE_FUNCTION_PARAMETER_DECLARATION) {
|
|
||||||
printf("%s", kernel_parameter_boilerplate);
|
|
||||||
|
|
||||||
if (node->lhs != NULL) {
|
|
||||||
printf("Compilation error: function parameters for Kernel functions not allowed!\n");
|
|
||||||
exit(EXIT_FAILURE);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Kernel builtin variables boilerplate (read input/output arrays and setup
|
// Traverse LHS
|
||||||
// indices)
|
|
||||||
const char* kernel_builtin_variables_boilerplate = "GEN_KERNEL_BUILTIN_VARIABLES_"
|
|
||||||
"BOILERPLATE();";
|
|
||||||
if (inside_kernel && node->type == NODE_COMPOUND_STATEMENT && compound_statement_nests == 1) {
|
|
||||||
printf("%s ", kernel_builtin_variables_boilerplate);
|
|
||||||
|
|
||||||
for (int i = 0; i < num_symbols; ++i) {
|
|
||||||
if (symbol_table[i].type_qualifier == IN) {
|
|
||||||
printf("const %sData %s = READ(%s%s);\n", translate(symbol_table[i].type_specifier),
|
|
||||||
symbol_table[i].identifier, inout_name_prefix, symbol_table[i].identifier);
|
|
||||||
}
|
|
||||||
else if (symbol_table[i].type_qualifier == OUT) {
|
|
||||||
printf("%s %s = READ_OUT(%s%s);", translate(symbol_table[i].type_specifier),
|
|
||||||
symbol_table[i].identifier, inout_name_prefix, symbol_table[i].identifier);
|
|
||||||
// printf("%s %s = buffer.out[%s%s][IDX(vertexIdx.x, vertexIdx.y, vertexIdx.z)];\n",
|
|
||||||
// translate(symbol_table[i].type_specifier), symbol_table[i].identifier,
|
|
||||||
// inout_name_prefix, symbol_table[i].identifier);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Preprocessed parameter boilerplate
|
|
||||||
if (node->type == NODE_TYPE_QUALIFIER && node->token == PREPROCESSED)
|
|
||||||
inside_preprocessed = true;
|
|
||||||
static const char preprocessed_parameter_boilerplate
|
|
||||||
[] = "const int3& vertexIdx, const int3& globalVertexIdx, ";
|
|
||||||
if (inside_preprocessed && node->type == NODE_FUNCTION_PARAMETER_DECLARATION)
|
|
||||||
printf("%s ", preprocessed_parameter_boilerplate);
|
|
||||||
// BOILERPLATE END////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
// Enter LHS
|
|
||||||
if (node->lhs)
|
if (node->lhs)
|
||||||
traverse(node->lhs);
|
traverse(node->lhs);
|
||||||
|
|
||||||
// Infix logic %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
// Infix translation
|
||||||
if (!inside_declaration && translate(node->infix))
|
if (translate(node->infix))
|
||||||
printf("%s ", translate(node->infix));
|
fprintf(CUDAHEADER, "%s", translate(node->infix));
|
||||||
|
|
||||||
if (node->type == NODE_FUNCTION_DECLARATION)
|
// Infix logic
|
||||||
inside_function_declaration = false;
|
// TODO
|
||||||
|
|
||||||
// If the node is a subscript expression and the expression list inside it is not empty
|
// Traverse RHS
|
||||||
if (node->type == NODE_MULTIDIM_SUBSCRIPT_EXPRESSION && node->rhs)
|
|
||||||
printf("IDX(");
|
|
||||||
|
|
||||||
// Do a regular translation
|
|
||||||
if (!inside_declaration) {
|
|
||||||
const int handle = symboltable_lookup(node->buffer);
|
|
||||||
if (handle >= 0) { // The variable exists in the symbol table
|
|
||||||
const Symbol* symbol = &symbol_table[handle];
|
|
||||||
|
|
||||||
if (symbol->type_qualifier == UNIFORM) {
|
|
||||||
if (inside_kernel && symbol->type_specifier == SCALARARRAY) {
|
|
||||||
printf("buffer.profiles[%s] ", symbol->identifier);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
printf("DCONST(%s) ", symbol->identifier);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// Do a regular translation
|
|
||||||
if (translate(node->token))
|
|
||||||
printf("%s ", translate(node->token));
|
|
||||||
if (node->buffer) {
|
|
||||||
if (node->type == NODE_REAL_NUMBER) {
|
|
||||||
printf("%s(%s) ", translate(SCALAR),
|
|
||||||
node->buffer); // Cast to correct precision
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
printf("%s ", node->buffer);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// Do a regular translation
|
|
||||||
if (translate(node->token))
|
|
||||||
printf("%s ", translate(node->token));
|
|
||||||
if (node->buffer) {
|
|
||||||
if (node->type == NODE_REAL_NUMBER) {
|
|
||||||
printf("%s(%s) ", translate(SCALAR), node->buffer); // Cast to correct precision
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
printf("%s ", node->buffer);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node->type == NODE_FUNCTION_DECLARATION) {
|
|
||||||
scope_start = num_symbols;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enter RHS
|
|
||||||
if (node->rhs)
|
if (node->rhs)
|
||||||
traverse(node->rhs);
|
traverse(node->rhs);
|
||||||
|
|
||||||
// Postfix logic %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
// Postfix translation
|
||||||
// If the node is a subscript expression and the expression list inside it is not empty
|
if (translate(node->postfix))
|
||||||
if (node->type == NODE_MULTIDIM_SUBSCRIPT_EXPRESSION && node->rhs)
|
fprintf(CUDAHEADER, "%s", translate(node->postfix));
|
||||||
printf(")"); // Closing bracket of IDX()
|
|
||||||
|
|
||||||
// Generate writeback boilerplate for OUT fields
|
// Translate existing symbols
|
||||||
if (inside_kernel && node->type == NODE_COMPOUND_STATEMENT && compound_statement_nests == 1) {
|
const Symbol* symbol = symboltable_lookup(node->buffer);
|
||||||
for (int i = 0; i < num_symbols; ++i) {
|
if (symbol) {
|
||||||
if (symbol_table[i].type_qualifier == OUT) {
|
// Uniforms
|
||||||
printf("WRITE_OUT(%s%s, %s);\n", inout_name_prefix, symbol_table[i].identifier,
|
if (symbol->type_qualifier == UNIFORM) {
|
||||||
symbol_table[i].identifier);
|
fprintf(CUDAHEADER, "DCONST(%s) ", symbol->identifier);
|
||||||
// printf("buffer.out[%s%s][IDX(vertexIdx.x, vertexIdx.y, vertexIdx.z)] = %s;\n",
|
|
||||||
// inout_name_prefix, symbol_table[i].identifier, symbol_table[i].identifier);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!inside_declaration && translate(node->postfix))
|
// Add new symbols to the symbol table
|
||||||
printf("%s", translate(node->postfix));
|
|
||||||
|
|
||||||
if (node->type == NODE_DECLARATION) {
|
if (node->type == NODE_DECLARATION) {
|
||||||
inside_declaration = false;
|
int stype;
|
||||||
|
ASTNode* tmp = node->parent;
|
||||||
|
while (tmp->type == NODE_DECLARATION_LIST)
|
||||||
|
tmp = tmp->parent;
|
||||||
|
|
||||||
int tqual = 0;
|
if (tmp->type == NODE_FUNCTION_DECLARATION)
|
||||||
int tspec = 0;
|
stype = SYMBOLTYPE_FUNCTION;
|
||||||
if (node->lhs && node->lhs->lhs) {
|
else if (tmp->type == NODE_FUNCTION_PARAMETER_DECLARATION)
|
||||||
if (node->lhs->lhs->type == NODE_TYPE_QUALIFIER)
|
stype = SYMBOLTYPE_FUNCTION_PARAMETER;
|
||||||
tqual = node->lhs->lhs->token;
|
else
|
||||||
else if (node->lhs->lhs->type == NODE_TYPE_SPECIFIER)
|
stype = SYMBOLTYPE_OTHER;
|
||||||
tspec = node->lhs->lhs->token;
|
|
||||||
}
|
|
||||||
if (node->lhs && node->lhs->rhs) {
|
|
||||||
if (node->lhs->rhs->type == NODE_TYPE_SPECIFIER)
|
|
||||||
tspec = node->lhs->rhs->token;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine symbol type
|
const ASTNode* tdeclaration = node->lhs;
|
||||||
SymbolType symboltype = SYMBOLTYPE_OTHER;
|
const int tqualifier = tdeclaration->rhs ? tdeclaration->lhs->token : 0;
|
||||||
if (inside_function_declaration)
|
const int tspecifier = tdeclaration->rhs ? tdeclaration->rhs->token
|
||||||
symboltype = SYMBOLTYPE_FUNCTION;
|
: tdeclaration->lhs->token;
|
||||||
else if (inside_function_parameter_declaration)
|
|
||||||
symboltype = SYMBOLTYPE_FUNCTION_PARAMETER;
|
|
||||||
|
|
||||||
// Determine identifier
|
const char* identifier = node->rhs->type == NODE_IDENTIFIER ? node->rhs->buffer
|
||||||
if (node->rhs->type == NODE_IDENTIFIER) {
|
: node->rhs->lhs->buffer;
|
||||||
add_symbol(symboltype, tqual, tspec, node->rhs->buffer); // Ordinary
|
add_symbol(stype, tqualifier, tspecifier, identifier);
|
||||||
translate_latest_symbol();
|
|
||||||
|
// Translate the new symbol
|
||||||
|
if (tqualifier == UNIFORM) {
|
||||||
|
// Do nothing
|
||||||
}
|
}
|
||||||
else {
|
else if (tqualifier == KERNEL) {
|
||||||
add_symbol(symboltype, tqual, tspec,
|
fprintf(CUDAHEADER, "%s %s\n%s", //
|
||||||
node->rhs->lhs->buffer); // Array
|
translate(tqualifier), translate(tspecifier), identifier);
|
||||||
translate_latest_symbol();
|
}
|
||||||
// Traverse the expression once again, this time with
|
else if (tqualifier == DEVICE) {
|
||||||
// "inside_declaration" flag off
|
fprintf(CUDAHEADER, "%s %s\n%s", //
|
||||||
printf("%s ", translate(node->rhs->infix));
|
translate(tqualifier), translate(tspecifier), identifier);
|
||||||
if (node->rhs->rhs)
|
}
|
||||||
traverse(node->rhs->rhs);
|
else if (tqualifier == PREPROCESSED) {
|
||||||
printf("%s ", translate(node->rhs->postfix));
|
fprintf(CUDAHEADER, "%s %s\npreprocessed_%s", //
|
||||||
|
translate(tqualifier), translate(tspecifier), identifier);
|
||||||
|
}
|
||||||
|
else if (stype == SYMBOLTYPE_FUNCTION_PARAMETER) {
|
||||||
|
tmp = tmp->parent;
|
||||||
|
assert(tmp->type = NODE_FUNCTION_DECLARATION);
|
||||||
|
const Symbol* parent_function = symboltable_lookup(tmp->lhs->rhs->buffer);
|
||||||
|
if (parent_function->type_qualifier == DEVICE)
|
||||||
|
fprintf(CUDAHEADER, "%s %s\ndeviceparam_%s", //
|
||||||
|
translate(tqualifier), translate(tspecifier), identifier);
|
||||||
|
else if (parent_function->type_qualifier == PREPROCESSED)
|
||||||
|
fprintf(CUDAHEADER, "%s %s\npreprocessedparam_%s", //
|
||||||
|
translate(tqualifier), translate(tspecifier), identifier);
|
||||||
|
else
|
||||||
|
fprintf(CUDAHEADER, "%s %s\notherparam_%s", //
|
||||||
|
translate(tqualifier), translate(tspecifier), identifier);
|
||||||
|
}
|
||||||
|
else { // Do a regular translation
|
||||||
|
// fprintf(CUDAHEADER, "%s %s %s", //
|
||||||
|
// translate(tqualifier), translate(tspecifier), identifier);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node->type == NODE_COMPOUND_STATEMENT)
|
// Postfix logic
|
||||||
--compound_statement_nests;
|
if (node->type == NODE_COMPOUND_STATEMENT) {
|
||||||
|
assert(current_nest > 0);
|
||||||
if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION)
|
--current_nest;
|
||||||
inside_function_parameter_declaration = false;
|
printf("Dropped rest of the symbol table, from %lu to %lu\n", num_symbols[current_nest + 1],
|
||||||
|
num_symbols[current_nest]);
|
||||||
if (node->type == NODE_FUNCTION_DEFINITION) {
|
|
||||||
while (num_symbols > scope_start)
|
|
||||||
rm_symbol(num_symbols - 1);
|
|
||||||
|
|
||||||
inside_kernel = false;
|
|
||||||
inside_preprocessed = false;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: these should use the generic type names SCALAR and VECTOR
|
|
||||||
static void
|
static void
|
||||||
generate_preprocessed_structures(void)
|
generate_preprocessed_structures(void)
|
||||||
{
|
{
|
||||||
// PREPROCESSED DATA STRUCT
|
// TODO
|
||||||
printf("\n");
|
|
||||||
printf("typedef struct {\n");
|
|
||||||
for (int i = 0; i < num_symbols; ++i) {
|
|
||||||
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
|
||||||
printf("%s %s;\n", translate(symbol_table[i].type_specifier),
|
|
||||||
symbol_table[i].identifier);
|
|
||||||
}
|
|
||||||
printf("} %sData;\n", translate(SCALAR));
|
|
||||||
|
|
||||||
// FILLING THE DATA STRUCT
|
|
||||||
printf("static __device__ __forceinline__ AcRealData\
|
|
||||||
read_data(const int3& vertexIdx,\
|
|
||||||
const int3& globalVertexIdx,\
|
|
||||||
AcReal* __restrict__ buf[], const int handle)\
|
|
||||||
{\n\
|
|
||||||
%sData data;\n",
|
|
||||||
translate(SCALAR));
|
|
||||||
|
|
||||||
for (int i = 0; i < num_symbols; ++i) {
|
|
||||||
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
|
||||||
printf("data.%s = preprocessed_%s(vertexIdx, globalVertexIdx, buf[handle]);\n",
|
|
||||||
symbol_table[i].identifier, symbol_table[i].identifier);
|
|
||||||
}
|
|
||||||
printf("return data;\n");
|
|
||||||
printf("}\n");
|
|
||||||
|
|
||||||
// FUNCTIONS FOR ACCESSING MEMBERS OF THE PREPROCESSED STRUCT
|
|
||||||
for (int i = 0; i < num_symbols; ++i) {
|
|
||||||
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
|
||||||
printf("static __device__ __forceinline__ %s\
|
|
||||||
%s(const AcRealData& data)\
|
|
||||||
{\n\
|
|
||||||
return data.%s;\
|
|
||||||
}\n",
|
|
||||||
translate(symbol_table[i].type_specifier), symbol_table[i].identifier,
|
|
||||||
symbol_table[i].identifier);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Syntactic sugar: generate also a Vector data struct
|
|
||||||
printf("\
|
|
||||||
typedef struct {\
|
|
||||||
AcRealData x;\
|
|
||||||
AcRealData y;\
|
|
||||||
AcRealData z;\
|
|
||||||
} AcReal3Data;\
|
|
||||||
\
|
|
||||||
static __device__ __forceinline__ AcReal3Data\
|
|
||||||
read_data(const int3& vertexIdx,\
|
|
||||||
const int3& globalVertexIdx,\
|
|
||||||
AcReal* __restrict__ buf[], const int3& handle)\
|
|
||||||
{\
|
|
||||||
AcReal3Data data;\
|
|
||||||
\
|
|
||||||
data.x = read_data(vertexIdx, globalVertexIdx, buf, handle.x);\
|
|
||||||
data.y = read_data(vertexIdx, globalVertexIdx, buf, handle.y);\
|
|
||||||
data.z = read_data(vertexIdx, globalVertexIdx, buf, handle.z);\
|
|
||||||
\
|
|
||||||
return data;\
|
|
||||||
}\
|
|
||||||
");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
generate_header(void)
|
generate_header(void)
|
||||||
{
|
{
|
||||||
printf("\n#pragma once\n");
|
fprintf(DSLHEADER, "#pragma once\n");
|
||||||
|
|
||||||
// Int params
|
// Int params
|
||||||
printf("#define AC_FOR_USER_INT_PARAM_TYPES(FUNC)");
|
fprintf(DSLHEADER, "#define AC_FOR_USER_INT_PARAM_TYPES(FUNC)");
|
||||||
for (int i = 0; i < num_symbols; ++i) {
|
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||||
if (symbol_table[i].type_specifier == INT) {
|
if (symbol_table[i].type_specifier == INT && symbol_table[i].type_qualifier == UNIFORM) {
|
||||||
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("\n\n");
|
fprintf(DSLHEADER, "\n\n");
|
||||||
|
|
||||||
// Int3 params
|
// Int3 params
|
||||||
printf("#define AC_FOR_USER_INT3_PARAM_TYPES(FUNC)");
|
fprintf(DSLHEADER, "#define AC_FOR_USER_INT3_PARAM_TYPES(FUNC)");
|
||||||
for (int i = 0; i < num_symbols; ++i) {
|
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||||
if (symbol_table[i].type_specifier == INT3) {
|
if (symbol_table[i].type_specifier == INT3 && symbol_table[i].type_qualifier == UNIFORM) {
|
||||||
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("\n\n");
|
fprintf(DSLHEADER, "\n\n");
|
||||||
|
|
||||||
// Scalar params
|
// Scalar params
|
||||||
printf("#define AC_FOR_USER_REAL_PARAM_TYPES(FUNC)");
|
fprintf(DSLHEADER, "#define AC_FOR_USER_REAL_PARAM_TYPES(FUNC)");
|
||||||
for (int i = 0; i < num_symbols; ++i) {
|
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||||
if (symbol_table[i].type_specifier == SCALAR) {
|
if (symbol_table[i].type_specifier == SCALAR && symbol_table[i].type_qualifier == UNIFORM) {
|
||||||
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("\n\n");
|
fprintf(DSLHEADER, "\n\n");
|
||||||
|
|
||||||
// Vector params
|
// Vector params
|
||||||
printf("#define AC_FOR_USER_REAL3_PARAM_TYPES(FUNC)");
|
fprintf(DSLHEADER, "#define AC_FOR_USER_REAL3_PARAM_TYPES(FUNC)");
|
||||||
for (int i = 0; i < num_symbols; ++i) {
|
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||||
if (symbol_table[i].type_specifier == VECTOR) {
|
if (symbol_table[i].type_specifier == VECTOR && symbol_table[i].type_qualifier == UNIFORM) {
|
||||||
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("\n\n");
|
fprintf(DSLHEADER, "\n\n");
|
||||||
|
|
||||||
// Scalar fields
|
// Scalar fields
|
||||||
printf("#define AC_FOR_VTXBUF_HANDLES(FUNC)");
|
fprintf(DSLHEADER, "#define AC_FOR_VTXBUF_HANDLES(FUNC)");
|
||||||
for (int i = 0; i < num_symbols; ++i) {
|
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||||
if (symbol_table[i].type_specifier == SCALARFIELD) {
|
if (symbol_table[i].type_specifier == SCALARFIELD &&
|
||||||
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
symbol_table[i].type_qualifier == UNIFORM) {
|
||||||
|
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("\n\n");
|
fprintf(DSLHEADER, "\n\n");
|
||||||
|
|
||||||
// Scalar arrays
|
// Scalar arrays
|
||||||
printf("#define AC_FOR_SCALARARRAY_HANDLES(FUNC)");
|
fprintf(DSLHEADER, "#define AC_FOR_SCALARARRAY_HANDLES(FUNC)");
|
||||||
for (int i = 0; i < num_symbols; ++i) {
|
for (size_t i = 0; i < num_symbols[current_nest]; ++i) {
|
||||||
if (symbol_table[i].type_specifier == SCALARARRAY) {
|
if (symbol_table[i].type_specifier == SCALARARRAY &&
|
||||||
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
symbol_table[i].type_qualifier == UNIFORM) {
|
||||||
|
fprintf(DSLHEADER, "\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("\n\n");
|
fprintf(DSLHEADER, "\n\n");
|
||||||
|
|
||||||
/*
|
|
||||||
printf("\n");
|
|
||||||
printf("typedef struct {\n");
|
|
||||||
for (int i = 0; i < num_symbols; ++i) {
|
|
||||||
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
|
||||||
printf("%s %s;\n", translate(symbol_table[i].type_specifier),
|
|
||||||
symbol_table[i].identifier);
|
|
||||||
}
|
|
||||||
printf("} %sData;\n", translate(SCALAR));
|
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
generate_library_hooks(void)
|
generate_library_hooks(void)
|
||||||
{
|
{
|
||||||
for (int i = 0; i < num_symbols; ++i) {
|
for (int i = 0; i < num_symbols[current_nest]; ++i) {
|
||||||
if (symbol_table[i].type_qualifier == KERNEL) {
|
if (symbol_table[i].type_qualifier == KERNEL && symbol_table[i].type_qualifier == UNIFORM) {
|
||||||
printf("GEN_DEVICE_FUNC_HOOK(%s)\n", symbol_table[i].identifier);
|
fprintf(CUDAHEADER, "GEN_DEVICE_FUNC_HOOK(%s)\n", symbol_table[i].identifier);
|
||||||
// printf("GEN_NODE_FUNC_HOOK(%s)\n", symbol_table[i].identifier);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -657,49 +410,29 @@ generate_library_hooks(void)
|
|||||||
int
|
int
|
||||||
main(int argc, char** argv)
|
main(int argc, char** argv)
|
||||||
{
|
{
|
||||||
if (argc == 2) {
|
|
||||||
if (!strcmp(argv[1], "-sas"))
|
|
||||||
compilation_type = STENCIL_ASSEMBLY;
|
|
||||||
else if (!strcmp(argv[1], "-sps"))
|
|
||||||
compilation_type = STENCIL_PROCESS;
|
|
||||||
else if (!strcmp(argv[1], "-sdh"))
|
|
||||||
compilation_type = STENCIL_HEADER;
|
|
||||||
else {
|
|
||||||
printf("Unknown flag %s. Generating stencil assembly.\n", argv[1]);
|
|
||||||
return EXIT_FAILURE;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
printf("Usage: ./acc [flags]\n"
|
|
||||||
"Flags:\n"
|
|
||||||
"\t-sas - Generates code for the stencil assembly stage\n"
|
|
||||||
"\t-sps - Generates code for the stencil processing stage\n"
|
|
||||||
"\t-hh - Generates stencil definitions from a header file\n");
|
|
||||||
printf("\n");
|
|
||||||
return EXIT_FAILURE;
|
|
||||||
}
|
|
||||||
|
|
||||||
root = astnode_create(NODE_UNKNOWN, NULL, NULL);
|
root = astnode_create(NODE_UNKNOWN, NULL, NULL);
|
||||||
|
|
||||||
const int retval = yyparse();
|
const int retval = yyparse();
|
||||||
if (retval) {
|
if (retval) {
|
||||||
printf("COMPILATION FAILED\n");
|
fprintf(stderr, "COMPILATION FAILED\n");
|
||||||
return EXIT_FAILURE;
|
return EXIT_FAILURE;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Traverse
|
DSLHEADER = fopen("user_defines.h", "w+");
|
||||||
traverse(root);
|
CUDAHEADER = fopen("user_kernels.h", "w+");
|
||||||
if (compilation_type == STENCIL_ASSEMBLY)
|
assert(DSLHEADER);
|
||||||
generate_preprocessed_structures();
|
assert(CUDAHEADER);
|
||||||
else if (compilation_type == STENCIL_HEADER)
|
|
||||||
generate_header();
|
|
||||||
else if (compilation_type == STENCIL_PROCESS)
|
|
||||||
generate_library_hooks();
|
|
||||||
|
|
||||||
// print_symbol_table();
|
traverse(root);
|
||||||
|
generate_header();
|
||||||
|
generate_library_hooks();
|
||||||
|
|
||||||
|
print_symbol_table();
|
||||||
|
|
||||||
// Cleanup
|
// Cleanup
|
||||||
|
fclose(DSLHEADER);
|
||||||
|
fclose(CUDAHEADER);
|
||||||
astnode_destroy(root);
|
astnode_destroy(root);
|
||||||
// printf("COMPILATION SUCCESS\n");
|
fprintf(stdout, "COMPILATION SUCCESS\n");
|
||||||
return EXIT_SUCCESS;
|
return EXIT_SUCCESS;
|
||||||
}
|
}
|
||||||
|
705
acc/src/code_generator0.c
Normal file
705
acc/src/code_generator0.c
Normal file
@@ -0,0 +1,705 @@
|
|||||||
|
/*
|
||||||
|
Copyright (C) 2014-2019, Johannes Pekkilae, Miikka Vaeisalae.
|
||||||
|
|
||||||
|
This file is part of Astaroth.
|
||||||
|
|
||||||
|
Astaroth is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
Astaroth is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with Astaroth. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file
|
||||||
|
* \brief Brief info.
|
||||||
|
*
|
||||||
|
* Detailed info.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "acc.tab.h"
|
||||||
|
#include "ast.h"
|
||||||
|
|
||||||
|
ASTNode* root = NULL;
|
||||||
|
|
||||||
|
static const char inout_name_prefix[] = "handle_";
|
||||||
|
typedef enum { STENCIL_ASSEMBLY, STENCIL_PROCESS, STENCIL_HEADER } CompilationType;
|
||||||
|
static CompilationType compilation_type;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* =============================================================================
|
||||||
|
* Translation
|
||||||
|
* =============================================================================
|
||||||
|
*/
|
||||||
|
#define TRANSLATION_TABLE_SIZE (1024)
|
||||||
|
static const char* translation_table[TRANSLATION_TABLE_SIZE] = {
|
||||||
|
[0] = NULL,
|
||||||
|
// Control flow
|
||||||
|
[IF] = "if",
|
||||||
|
[ELSE] = "else",
|
||||||
|
[ELIF] = "else if",
|
||||||
|
[WHILE] = "while",
|
||||||
|
[FOR] = "for",
|
||||||
|
// Type specifiers
|
||||||
|
[VOID] = "void",
|
||||||
|
[INT] = "int",
|
||||||
|
[INT3] = "int3",
|
||||||
|
[SCALAR] = "AcReal",
|
||||||
|
[VECTOR] = "AcReal3",
|
||||||
|
[MATRIX] = "AcMatrix",
|
||||||
|
[SCALARFIELD] = "AcReal",
|
||||||
|
[SCALARARRAY] = "const AcReal* __restrict__",
|
||||||
|
[COMPLEX] = "acComplex",
|
||||||
|
// Type qualifiers
|
||||||
|
[KERNEL] = "template <int step_number> static __global__",
|
||||||
|
//__launch_bounds__(RK_THREADBLOCK_SIZE,
|
||||||
|
// RK_LAUNCH_BOUND_MIN_BLOCKS),
|
||||||
|
[PREPROCESSED] = "static __device__ "
|
||||||
|
"__forceinline__",
|
||||||
|
[CONSTANT] = "const",
|
||||||
|
[IN] = "in",
|
||||||
|
[OUT] = "out",
|
||||||
|
[UNIFORM] = "uniform",
|
||||||
|
// ETC
|
||||||
|
[INPLACE_INC] = "++",
|
||||||
|
[INPLACE_DEC] = "--",
|
||||||
|
// Unary
|
||||||
|
[','] = ",",
|
||||||
|
[';'] = ";\n",
|
||||||
|
['('] = "(",
|
||||||
|
[')'] = ")",
|
||||||
|
['['] = "[",
|
||||||
|
[']'] = "]",
|
||||||
|
['{'] = "{\n",
|
||||||
|
['}'] = "}\n",
|
||||||
|
['='] = "=",
|
||||||
|
['+'] = "+",
|
||||||
|
['-'] = "-",
|
||||||
|
['/'] = "/",
|
||||||
|
['*'] = "*",
|
||||||
|
['<'] = "<",
|
||||||
|
['>'] = ">",
|
||||||
|
['!'] = "!",
|
||||||
|
['.'] = "."};
|
||||||
|
|
||||||
|
static const char*
|
||||||
|
translate(const int token)
|
||||||
|
{
|
||||||
|
assert(token >= 0);
|
||||||
|
assert(token < TRANSLATION_TABLE_SIZE);
|
||||||
|
if (token > 0) {
|
||||||
|
if (!translation_table[token])
|
||||||
|
printf("ERROR: unidentified token %d\n", token);
|
||||||
|
assert(translation_table[token]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return translation_table[token];
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* =============================================================================
|
||||||
|
* Symbols
|
||||||
|
* =============================================================================
|
||||||
|
*/
|
||||||
|
typedef enum {
|
||||||
|
SYMBOLTYPE_FUNCTION,
|
||||||
|
SYMBOLTYPE_FUNCTION_PARAMETER,
|
||||||
|
SYMBOLTYPE_OTHER,
|
||||||
|
NUM_SYMBOLTYPES
|
||||||
|
} SymbolType;
|
||||||
|
|
||||||
|
#define MAX_ID_LEN (128)
|
||||||
|
typedef struct {
|
||||||
|
SymbolType type;
|
||||||
|
int type_qualifier;
|
||||||
|
int type_specifier;
|
||||||
|
char identifier[MAX_ID_LEN];
|
||||||
|
} Symbol;
|
||||||
|
|
||||||
|
#define SYMBOL_TABLE_SIZE (4096)
|
||||||
|
static Symbol symbol_table[SYMBOL_TABLE_SIZE] = {};
|
||||||
|
static int num_symbols = 0;
|
||||||
|
|
||||||
|
static int
|
||||||
|
symboltable_lookup(const char* identifier)
|
||||||
|
{
|
||||||
|
if (!identifier)
|
||||||
|
return -1;
|
||||||
|
|
||||||
|
for (int i = 0; i < num_symbols; ++i)
|
||||||
|
if (strcmp(identifier, symbol_table[i].identifier) == 0)
|
||||||
|
return i;
|
||||||
|
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
add_symbol(const SymbolType type, const int tqualifier, const int tspecifier, const char* id)
|
||||||
|
{
|
||||||
|
assert(num_symbols < SYMBOL_TABLE_SIZE);
|
||||||
|
|
||||||
|
symbol_table[num_symbols].type = type;
|
||||||
|
symbol_table[num_symbols].type_qualifier = tqualifier;
|
||||||
|
symbol_table[num_symbols].type_specifier = tspecifier;
|
||||||
|
strcpy(symbol_table[num_symbols].identifier, id);
|
||||||
|
|
||||||
|
++num_symbols;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
rm_symbol(const int handle)
|
||||||
|
{
|
||||||
|
assert(handle >= 0 && handle < num_symbols);
|
||||||
|
assert(num_symbols > 0);
|
||||||
|
|
||||||
|
if (&symbol_table[handle] != &symbol_table[num_symbols - 1])
|
||||||
|
memcpy(&symbol_table[handle], &symbol_table[num_symbols - 1], sizeof(Symbol));
|
||||||
|
--num_symbols;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
print_symbol(const int handle)
|
||||||
|
{
|
||||||
|
assert(handle < SYMBOL_TABLE_SIZE);
|
||||||
|
|
||||||
|
const char* fields[] = {translate(symbol_table[handle].type_qualifier),
|
||||||
|
translate(symbol_table[handle].type_specifier),
|
||||||
|
symbol_table[handle].identifier};
|
||||||
|
const size_t num_fields = sizeof(fields) / sizeof(fields[0]);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < num_fields; ++i)
|
||||||
|
if (fields[i])
|
||||||
|
printf("%s ", fields[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
translate_latest_symbol(void)
|
||||||
|
{
|
||||||
|
const int handle = num_symbols - 1;
|
||||||
|
assert(handle < SYMBOL_TABLE_SIZE);
|
||||||
|
|
||||||
|
Symbol* symbol = &symbol_table[handle];
|
||||||
|
|
||||||
|
// FUNCTION
|
||||||
|
if (symbol->type == SYMBOLTYPE_FUNCTION) {
|
||||||
|
// KERNEL FUNCTION
|
||||||
|
if (symbol->type_qualifier == KERNEL) {
|
||||||
|
printf("%s %s\n%s", translate(symbol->type_qualifier),
|
||||||
|
translate(symbol->type_specifier), symbol->identifier);
|
||||||
|
}
|
||||||
|
// PREPROCESSED FUNCTION
|
||||||
|
else if (symbol->type_qualifier == PREPROCESSED) {
|
||||||
|
printf("%s %s\npreprocessed_%s", translate(symbol->type_qualifier),
|
||||||
|
translate(symbol->type_specifier), symbol->identifier);
|
||||||
|
}
|
||||||
|
// OTHER FUNCTION
|
||||||
|
else {
|
||||||
|
const char* regular_function_decorator = "static __device__ "
|
||||||
|
"__forceinline__";
|
||||||
|
printf("%s %s %s\n%s", regular_function_decorator,
|
||||||
|
translate(symbol->type_qualifier) ? translate(symbol->type_qualifier) : "",
|
||||||
|
translate(symbol->type_specifier), symbol->identifier);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// FUNCTION PARAMETER
|
||||||
|
else if (symbol->type == SYMBOLTYPE_FUNCTION_PARAMETER) {
|
||||||
|
if (symbol->type_qualifier == IN || symbol->type_qualifier == OUT) {
|
||||||
|
if (compilation_type == STENCIL_ASSEMBLY)
|
||||||
|
printf("const __restrict__ %s* %s", translate(symbol->type_specifier),
|
||||||
|
symbol->identifier);
|
||||||
|
else if (compilation_type == STENCIL_PROCESS)
|
||||||
|
printf("const %sData& %s", translate(symbol->type_specifier), symbol->identifier);
|
||||||
|
else
|
||||||
|
printf("Invalid compilation type %d, IN and OUT qualifiers not supported\n",
|
||||||
|
compilation_type);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
print_symbol(handle);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// UNIFORM
|
||||||
|
else if (symbol->type_qualifier == UNIFORM) {
|
||||||
|
// if (compilation_type != STENCIL_HEADER) {
|
||||||
|
// printf("ERROR: %s can only be used in stencil headers\n", translation_table[UNIFORM]);
|
||||||
|
//}
|
||||||
|
/* Do nothing */
|
||||||
|
}
|
||||||
|
// IN / OUT
|
||||||
|
else if (symbol->type != SYMBOLTYPE_FUNCTION_PARAMETER &&
|
||||||
|
(symbol->type_qualifier == IN || symbol->type_qualifier == OUT)) {
|
||||||
|
|
||||||
|
printf("static __device__ const %s %s%s",
|
||||||
|
symbol->type_specifier == SCALARFIELD ? "int" : "int3", inout_name_prefix,
|
||||||
|
symbol_table[handle].identifier);
|
||||||
|
if (symbol->type_specifier == VECTOR)
|
||||||
|
printf(" = make_int3");
|
||||||
|
}
|
||||||
|
// OTHER
|
||||||
|
else {
|
||||||
|
print_symbol(handle);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void
|
||||||
|
print_symbol_table(void)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < num_symbols; ++i) {
|
||||||
|
printf("%d: ", i);
|
||||||
|
const char* fields[] = {translate(symbol_table[i].type_qualifier),
|
||||||
|
translate(symbol_table[i].type_specifier),
|
||||||
|
symbol_table[i].identifier};
|
||||||
|
|
||||||
|
const size_t num_fields = sizeof(fields) / sizeof(fields[0]);
|
||||||
|
for (size_t j = 0; j < num_fields; ++j)
|
||||||
|
if (fields[j])
|
||||||
|
printf("%s ", fields[j]);
|
||||||
|
|
||||||
|
if (symbol_table[i].type == SYMBOLTYPE_FUNCTION)
|
||||||
|
printf("(function)");
|
||||||
|
else if (symbol_table[i].type == SYMBOLTYPE_FUNCTION_PARAMETER)
|
||||||
|
printf("(function parameter)");
|
||||||
|
else
|
||||||
|
printf("(other)");
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* =============================================================================
|
||||||
|
* State
|
||||||
|
* =============================================================================
|
||||||
|
*/
|
||||||
|
static bool inside_declaration = false;
|
||||||
|
static bool inside_function_declaration = false;
|
||||||
|
static bool inside_function_parameter_declaration = false;
|
||||||
|
|
||||||
|
static bool inside_kernel = false;
|
||||||
|
static bool inside_preprocessed = false;
|
||||||
|
|
||||||
|
static int scope_start = 0;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* =============================================================================
|
||||||
|
* AST traversal
|
||||||
|
* =============================================================================
|
||||||
|
*/
|
||||||
|
|
||||||
|
static int compound_statement_nests = 0;
|
||||||
|
|
||||||
|
static void
|
||||||
|
traverse(const ASTNode* node)
|
||||||
|
{
|
||||||
|
// Prefix logic %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||||
|
if (node->type == NODE_FUNCTION_DECLARATION)
|
||||||
|
inside_function_declaration = true;
|
||||||
|
if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION)
|
||||||
|
inside_function_parameter_declaration = true;
|
||||||
|
if (node->type == NODE_DECLARATION)
|
||||||
|
inside_declaration = true;
|
||||||
|
|
||||||
|
if (!inside_declaration && translate(node->prefix))
|
||||||
|
printf("%s", translate(node->prefix));
|
||||||
|
|
||||||
|
if (node->type == NODE_COMPOUND_STATEMENT)
|
||||||
|
++compound_statement_nests;
|
||||||
|
|
||||||
|
// BOILERPLATE START////////////////////////////////////////////////////////
|
||||||
|
if (node->type == NODE_TYPE_QUALIFIER && node->token == KERNEL)
|
||||||
|
inside_kernel = true;
|
||||||
|
|
||||||
|
// Kernel parameter boilerplate
|
||||||
|
const char* kernel_parameter_boilerplate = "GEN_KERNEL_PARAM_BOILERPLATE";
|
||||||
|
if (inside_kernel && node->type == NODE_FUNCTION_PARAMETER_DECLARATION) {
|
||||||
|
printf("%s", kernel_parameter_boilerplate);
|
||||||
|
|
||||||
|
if (node->lhs != NULL) {
|
||||||
|
printf("Compilation error: function parameters for Kernel functions not allowed!\n");
|
||||||
|
exit(EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kernel builtin variables boilerplate (read input/output arrays and setup
|
||||||
|
// indices)
|
||||||
|
const char* kernel_builtin_variables_boilerplate = "GEN_KERNEL_BUILTIN_VARIABLES_"
|
||||||
|
"BOILERPLATE();";
|
||||||
|
if (inside_kernel && node->type == NODE_COMPOUND_STATEMENT && compound_statement_nests == 1) {
|
||||||
|
printf("%s ", kernel_builtin_variables_boilerplate);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_symbols; ++i) {
|
||||||
|
if (symbol_table[i].type_qualifier == IN) {
|
||||||
|
printf("const %sData %s = READ(%s%s);\n", translate(symbol_table[i].type_specifier),
|
||||||
|
symbol_table[i].identifier, inout_name_prefix, symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
else if (symbol_table[i].type_qualifier == OUT) {
|
||||||
|
printf("%s %s = READ_OUT(%s%s);", translate(symbol_table[i].type_specifier),
|
||||||
|
symbol_table[i].identifier, inout_name_prefix, symbol_table[i].identifier);
|
||||||
|
// printf("%s %s = buffer.out[%s%s][IDX(vertexIdx.x, vertexIdx.y, vertexIdx.z)];\n",
|
||||||
|
// translate(symbol_table[i].type_specifier), symbol_table[i].identifier,
|
||||||
|
// inout_name_prefix, symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preprocessed parameter boilerplate
|
||||||
|
if (node->type == NODE_TYPE_QUALIFIER && node->token == PREPROCESSED)
|
||||||
|
inside_preprocessed = true;
|
||||||
|
static const char preprocessed_parameter_boilerplate
|
||||||
|
[] = "const int3& vertexIdx, const int3& globalVertexIdx, ";
|
||||||
|
if (inside_preprocessed && node->type == NODE_FUNCTION_PARAMETER_DECLARATION)
|
||||||
|
printf("%s ", preprocessed_parameter_boilerplate);
|
||||||
|
// BOILERPLATE END////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Enter LHS
|
||||||
|
if (node->lhs)
|
||||||
|
traverse(node->lhs);
|
||||||
|
|
||||||
|
// Infix logic %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||||
|
if (!inside_declaration && translate(node->infix))
|
||||||
|
printf("%s ", translate(node->infix));
|
||||||
|
|
||||||
|
if (node->type == NODE_FUNCTION_DECLARATION)
|
||||||
|
inside_function_declaration = false;
|
||||||
|
|
||||||
|
// If the node is a subscript expression and the expression list inside it is not empty
|
||||||
|
if (node->type == NODE_MULTIDIM_SUBSCRIPT_EXPRESSION && node->rhs)
|
||||||
|
printf("IDX(");
|
||||||
|
|
||||||
|
// Do a regular translation
|
||||||
|
if (!inside_declaration) {
|
||||||
|
const int handle = symboltable_lookup(node->buffer);
|
||||||
|
if (handle >= 0) { // The variable exists in the symbol table
|
||||||
|
const Symbol* symbol = &symbol_table[handle];
|
||||||
|
|
||||||
|
if (symbol->type_qualifier == UNIFORM) {
|
||||||
|
if (inside_kernel && symbol->type_specifier == SCALARARRAY) {
|
||||||
|
printf("buffer.profiles[%s] ", symbol->identifier);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
printf("DCONST(%s) ", symbol->identifier);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Do a regular translation
|
||||||
|
if (translate(node->token))
|
||||||
|
printf("%s ", translate(node->token));
|
||||||
|
if (node->buffer) {
|
||||||
|
if (node->type == NODE_REAL_NUMBER) {
|
||||||
|
printf("%s(%s) ", translate(SCALAR),
|
||||||
|
node->buffer); // Cast to correct precision
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
printf("%s ", node->buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Do a regular translation
|
||||||
|
if (translate(node->token))
|
||||||
|
printf("%s ", translate(node->token));
|
||||||
|
if (node->buffer) {
|
||||||
|
if (node->type == NODE_REAL_NUMBER) {
|
||||||
|
printf("%s(%s) ", translate(SCALAR), node->buffer); // Cast to correct precision
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
printf("%s ", node->buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->type == NODE_FUNCTION_DECLARATION) {
|
||||||
|
scope_start = num_symbols;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enter RHS
|
||||||
|
if (node->rhs)
|
||||||
|
traverse(node->rhs);
|
||||||
|
|
||||||
|
// Postfix logic %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||||
|
// If the node is a subscript expression and the expression list inside it is not empty
|
||||||
|
if (node->type == NODE_MULTIDIM_SUBSCRIPT_EXPRESSION && node->rhs)
|
||||||
|
printf(")"); // Closing bracket of IDX()
|
||||||
|
|
||||||
|
// Generate writeback boilerplate for OUT fields
|
||||||
|
if (inside_kernel && node->type == NODE_COMPOUND_STATEMENT && compound_statement_nests == 1) {
|
||||||
|
for (int i = 0; i < num_symbols; ++i) {
|
||||||
|
if (symbol_table[i].type_qualifier == OUT) {
|
||||||
|
printf("WRITE_OUT(%s%s, %s);\n", inout_name_prefix, symbol_table[i].identifier,
|
||||||
|
symbol_table[i].identifier);
|
||||||
|
// printf("buffer.out[%s%s][IDX(vertexIdx.x, vertexIdx.y, vertexIdx.z)] = %s;\n",
|
||||||
|
// inout_name_prefix, symbol_table[i].identifier, symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!inside_declaration && translate(node->postfix))
|
||||||
|
printf("%s", translate(node->postfix));
|
||||||
|
|
||||||
|
if (node->type == NODE_DECLARATION) {
|
||||||
|
inside_declaration = false;
|
||||||
|
|
||||||
|
int tqual = 0;
|
||||||
|
int tspec = 0;
|
||||||
|
if (node->lhs && node->lhs->lhs) {
|
||||||
|
if (node->lhs->lhs->type == NODE_TYPE_QUALIFIER)
|
||||||
|
tqual = node->lhs->lhs->token;
|
||||||
|
else if (node->lhs->lhs->type == NODE_TYPE_SPECIFIER)
|
||||||
|
tspec = node->lhs->lhs->token;
|
||||||
|
}
|
||||||
|
if (node->lhs && node->lhs->rhs) {
|
||||||
|
if (node->lhs->rhs->type == NODE_TYPE_SPECIFIER)
|
||||||
|
tspec = node->lhs->rhs->token;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine symbol type
|
||||||
|
SymbolType symboltype = SYMBOLTYPE_OTHER;
|
||||||
|
if (inside_function_declaration)
|
||||||
|
symboltype = SYMBOLTYPE_FUNCTION;
|
||||||
|
else if (inside_function_parameter_declaration)
|
||||||
|
symboltype = SYMBOLTYPE_FUNCTION_PARAMETER;
|
||||||
|
|
||||||
|
// Determine identifier
|
||||||
|
if (node->rhs->type == NODE_IDENTIFIER) {
|
||||||
|
add_symbol(symboltype, tqual, tspec, node->rhs->buffer); // Ordinary
|
||||||
|
translate_latest_symbol();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
add_symbol(symboltype, tqual, tspec,
|
||||||
|
node->rhs->lhs->buffer); // Array
|
||||||
|
translate_latest_symbol();
|
||||||
|
// Traverse the expression once again, this time with
|
||||||
|
// "inside_declaration" flag off
|
||||||
|
printf("%s ", translate(node->rhs->infix));
|
||||||
|
if (node->rhs->rhs)
|
||||||
|
traverse(node->rhs->rhs);
|
||||||
|
printf("%s ", translate(node->rhs->postfix));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->type == NODE_COMPOUND_STATEMENT)
|
||||||
|
--compound_statement_nests;
|
||||||
|
|
||||||
|
if (node->type == NODE_FUNCTION_PARAMETER_DECLARATION)
|
||||||
|
inside_function_parameter_declaration = false;
|
||||||
|
|
||||||
|
if (node->type == NODE_FUNCTION_DEFINITION) {
|
||||||
|
while (num_symbols > scope_start)
|
||||||
|
rm_symbol(num_symbols - 1);
|
||||||
|
|
||||||
|
inside_kernel = false;
|
||||||
|
inside_preprocessed = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: these should use the generic type names SCALAR and VECTOR
|
||||||
|
static void
|
||||||
|
generate_preprocessed_structures(void)
|
||||||
|
{
|
||||||
|
// PREPROCESSED DATA STRUCT
|
||||||
|
printf("\n");
|
||||||
|
printf("typedef struct {\n");
|
||||||
|
for (int i = 0; i < num_symbols; ++i) {
|
||||||
|
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
||||||
|
printf("%s %s;\n", translate(symbol_table[i].type_specifier),
|
||||||
|
symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
printf("} %sData;\n", translate(SCALAR));
|
||||||
|
|
||||||
|
// FILLING THE DATA STRUCT
|
||||||
|
printf("static __device__ __forceinline__ AcRealData\
|
||||||
|
read_data(const int3& vertexIdx,\
|
||||||
|
const int3& globalVertexIdx,\
|
||||||
|
AcReal* __restrict__ buf[], const int handle)\
|
||||||
|
{\n\
|
||||||
|
%sData data;\n",
|
||||||
|
translate(SCALAR));
|
||||||
|
|
||||||
|
for (int i = 0; i < num_symbols; ++i) {
|
||||||
|
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
||||||
|
printf("data.%s = preprocessed_%s(vertexIdx, globalVertexIdx, buf[handle]);\n",
|
||||||
|
symbol_table[i].identifier, symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
printf("return data;\n");
|
||||||
|
printf("}\n");
|
||||||
|
|
||||||
|
// FUNCTIONS FOR ACCESSING MEMBERS OF THE PREPROCESSED STRUCT
|
||||||
|
for (int i = 0; i < num_symbols; ++i) {
|
||||||
|
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
||||||
|
printf("static __device__ __forceinline__ %s\
|
||||||
|
%s(const AcRealData& data)\
|
||||||
|
{\n\
|
||||||
|
return data.%s;\
|
||||||
|
}\n",
|
||||||
|
translate(symbol_table[i].type_specifier), symbol_table[i].identifier,
|
||||||
|
symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Syntactic sugar: generate also a Vector data struct
|
||||||
|
printf("\
|
||||||
|
typedef struct {\
|
||||||
|
AcRealData x;\
|
||||||
|
AcRealData y;\
|
||||||
|
AcRealData z;\
|
||||||
|
} AcReal3Data;\
|
||||||
|
\
|
||||||
|
static __device__ __forceinline__ AcReal3Data\
|
||||||
|
read_data(const int3& vertexIdx,\
|
||||||
|
const int3& globalVertexIdx,\
|
||||||
|
AcReal* __restrict__ buf[], const int3& handle)\
|
||||||
|
{\
|
||||||
|
AcReal3Data data;\
|
||||||
|
\
|
||||||
|
data.x = read_data(vertexIdx, globalVertexIdx, buf, handle.x);\
|
||||||
|
data.y = read_data(vertexIdx, globalVertexIdx, buf, handle.y);\
|
||||||
|
data.z = read_data(vertexIdx, globalVertexIdx, buf, handle.z);\
|
||||||
|
\
|
||||||
|
return data;\
|
||||||
|
}\
|
||||||
|
");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
generate_header(void)
|
||||||
|
{
|
||||||
|
printf("\n#pragma once\n");
|
||||||
|
|
||||||
|
// Int params
|
||||||
|
printf("#define AC_FOR_USER_INT_PARAM_TYPES(FUNC)");
|
||||||
|
for (int i = 0; i < num_symbols; ++i) {
|
||||||
|
if (symbol_table[i].type_specifier == INT) {
|
||||||
|
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printf("\n\n");
|
||||||
|
|
||||||
|
// Int3 params
|
||||||
|
printf("#define AC_FOR_USER_INT3_PARAM_TYPES(FUNC)");
|
||||||
|
for (int i = 0; i < num_symbols; ++i) {
|
||||||
|
if (symbol_table[i].type_specifier == INT3) {
|
||||||
|
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printf("\n\n");
|
||||||
|
|
||||||
|
// Scalar params
|
||||||
|
printf("#define AC_FOR_USER_REAL_PARAM_TYPES(FUNC)");
|
||||||
|
for (int i = 0; i < num_symbols; ++i) {
|
||||||
|
if (symbol_table[i].type_specifier == SCALAR) {
|
||||||
|
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printf("\n\n");
|
||||||
|
|
||||||
|
// Vector params
|
||||||
|
printf("#define AC_FOR_USER_REAL3_PARAM_TYPES(FUNC)");
|
||||||
|
for (int i = 0; i < num_symbols; ++i) {
|
||||||
|
if (symbol_table[i].type_specifier == VECTOR) {
|
||||||
|
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printf("\n\n");
|
||||||
|
|
||||||
|
// Scalar fields
|
||||||
|
printf("#define AC_FOR_VTXBUF_HANDLES(FUNC)");
|
||||||
|
for (int i = 0; i < num_symbols; ++i) {
|
||||||
|
if (symbol_table[i].type_specifier == SCALARFIELD) {
|
||||||
|
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printf("\n\n");
|
||||||
|
|
||||||
|
// Scalar arrays
|
||||||
|
printf("#define AC_FOR_SCALARARRAY_HANDLES(FUNC)");
|
||||||
|
for (int i = 0; i < num_symbols; ++i) {
|
||||||
|
if (symbol_table[i].type_specifier == SCALARARRAY) {
|
||||||
|
printf("\\\nFUNC(%s),", symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printf("\n\n");
|
||||||
|
|
||||||
|
/*
|
||||||
|
printf("\n");
|
||||||
|
printf("typedef struct {\n");
|
||||||
|
for (int i = 0; i < num_symbols; ++i) {
|
||||||
|
if (symbol_table[i].type_qualifier == PREPROCESSED)
|
||||||
|
printf("%s %s;\n", translate(symbol_table[i].type_specifier),
|
||||||
|
symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
printf("} %sData;\n", translate(SCALAR));
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
generate_library_hooks(void)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < num_symbols; ++i) {
|
||||||
|
if (symbol_table[i].type_qualifier == KERNEL) {
|
||||||
|
printf("GEN_DEVICE_FUNC_HOOK(%s)\n", symbol_table[i].identifier);
|
||||||
|
// printf("GEN_NODE_FUNC_HOOK(%s)\n", symbol_table[i].identifier);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
main(int argc, char** argv)
|
||||||
|
{
|
||||||
|
if (argc == 2) {
|
||||||
|
if (!strcmp(argv[1], "-sas"))
|
||||||
|
compilation_type = STENCIL_ASSEMBLY;
|
||||||
|
else if (!strcmp(argv[1], "-sps"))
|
||||||
|
compilation_type = STENCIL_PROCESS;
|
||||||
|
else if (!strcmp(argv[1], "-sdh"))
|
||||||
|
compilation_type = STENCIL_HEADER;
|
||||||
|
else {
|
||||||
|
printf("Unknown flag %s. Generating stencil assembly.\n", argv[1]);
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
printf("Usage: ./acc [flags]\n"
|
||||||
|
"Flags:\n"
|
||||||
|
"\t-sas - Generates code for the stencil assembly stage\n"
|
||||||
|
"\t-sps - Generates code for the stencil processing stage\n"
|
||||||
|
"\t-hh - Generates stencil definitions from a header file\n");
|
||||||
|
printf("\n");
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
|
||||||
|
root = astnode_create(NODE_UNKNOWN, NULL, NULL);
|
||||||
|
|
||||||
|
const int retval = yyparse();
|
||||||
|
if (retval) {
|
||||||
|
printf("COMPILATION FAILED\n");
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Traverse
|
||||||
|
traverse(root);
|
||||||
|
if (compilation_type == STENCIL_ASSEMBLY)
|
||||||
|
generate_preprocessed_structures();
|
||||||
|
else if (compilation_type == STENCIL_HEADER)
|
||||||
|
generate_header();
|
||||||
|
else if (compilation_type == STENCIL_PROCESS)
|
||||||
|
generate_library_hooks();
|
||||||
|
|
||||||
|
// print_symbol_table();
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
astnode_destroy(root);
|
||||||
|
// printf("COMPILATION SUCCESS\n");
|
||||||
|
return EXIT_SUCCESS;
|
||||||
|
}
|
Reference in New Issue
Block a user