diff --git a/include/astaroth.h b/include/astaroth.h index b2b3fa2..9ccdf85 100644 --- a/include/astaroth.h +++ b/include/astaroth.h @@ -99,6 +99,8 @@ AcResult acLoadWithOffset(const AcMesh host_mesh, const int3 src, const int num_ /** */ int acGetNumDevicesPerNode(void); +Node acGetNode(void); + #ifdef __cplusplus } // extern "C" #endif diff --git a/src/core/astaroth.cc b/src/core/astaroth.cc index 75cd57d..659a9e0 100644 --- a/src/core/astaroth.cc +++ b/src/core/astaroth.cc @@ -34,8 +34,10 @@ const char* scalararray_names[] = {AC_FOR_SCALARARRAY_HANDLES(AC_GEN_STR)}; const char* vtxbuf_names[] = {AC_FOR_VTXBUF_HANDLES(AC_GEN_STR)}; #undef AC_GEN_STR -static const int num_nodes = 1; -static Node nodes[num_nodes]; +static const int max_num_nodes = 1; +static Node nodes[max_num_nodes] = {0}; + +static int num_nodes = 0; void acPrintMeshInfo(const AcMeshInfo config) @@ -55,12 +57,14 @@ acPrintMeshInfo(const AcMeshInfo config) AcResult acInit(const AcMeshInfo mesh_info) { + num_nodes = 1; return acNodeCreate(0, mesh_info, &nodes[0]); } AcResult acQuit(void) { + num_nodes = 0; return acNodeDestroy(nodes[0]); } @@ -176,3 +180,10 @@ acGetNumDevicesPerNode(void) ERRCHK_CUDA_ALWAYS(cudaGetDeviceCount(&num_devices)); return num_devices; } + +Node +acGetNode(void) +{ + ERRCHK_ALWAYS(num_nodes > 0); + return nodes[0]; +}