diff --git a/include/astaroth_device.h b/include/astaroth_device.h index 442b1b1..95cbb83 100644 --- a/include/astaroth_device.h +++ b/include/astaroth_device.h @@ -49,6 +49,10 @@ AcResult acDeviceSwapBuffers(const Device device); AcResult acDeviceLoadConstant(const Device device, const Stream stream, const AcRealParam param, const AcReal value); +/** */ +AcResult acDeviceLoadMeshInfo(const Device device, const Stream stream, + const AcMeshInfo device_config); + /** */ AcResult acDeviceLoadVertexBufferWithOffset(const Device device, const Stream stream, const AcMesh host_mesh, diff --git a/src/core/device.cu b/src/core/device.cu index a19dee0..0482f93 100644 --- a/src/core/device.cu +++ b/src/core/device.cu @@ -149,8 +149,7 @@ acDeviceCreate(const int id, const AcMeshInfo device_config, Device* device_hand #endif // Device constants - ERRCHK_CUDA_ALWAYS(cudaMemcpyToSymbol(d_mesh_info, &device_config, sizeof(device_config), 0, - cudaMemcpyHostToDevice)); + acDeviceLoadMeshInfo(device, STREAM_DEFAULT, device_config); printf("Created device %d (%p)\n", device->id, device); *device_handle = device; @@ -367,6 +366,15 @@ acDeviceLoadConstant(const Device device, const Stream stream, const AcRealParam return AC_SUCCESS; } +AcResult +acDeviceLoadMeshInfo(const Device device, const Stream stream, const AcMeshInfo device_config) +{ + cudaSetDevice(device->id); + ERRCHK_CUDA_ALWAYS(cudaMemcpyToSymbolAsync(d_mesh_info, &device_config, sizeof(device_config), + 0, cudaMemcpyHostToDevice, device->streams[stream])); + return AC_SUCCESS; +} + AcResult acDeviceLoadVertexBufferWithOffset(const Device device, const Stream stream, const AcMesh host_mesh, const VertexBufferHandle vtxbuf_handle, const int3 src,