diff --git a/include/astaroth.h b/include/astaroth.h index c3320dc..c73879d 100644 --- a/include/astaroth.h +++ b/include/astaroth.h @@ -282,6 +282,13 @@ AcResult acGridQuit(void); /** */ AcResult acGridSynchronizeStream(const Stream stream); +/** */ +AcResult acGridLoadScalarUniform(const Stream stream, const AcRealParam param, const AcReal value); + +/** */ +AcResult acGridLoadVectorUniform(const Stream stream, const AcReal3Param param, + const AcReal3 value); + /** */ AcResult acGridLoadMesh(const Stream stream, const AcMesh host_mesh); diff --git a/src/core/device.cc b/src/core/device.cc index 2721c59..28344fb 100644 --- a/src/core/device.cc +++ b/src/core/device.cc @@ -1325,6 +1325,47 @@ acGridQuit(void) return AC_SUCCESS; } +AcResult +acGridLoadScalarUniform(const Stream stream, const AcRealParam param, const AcReal value) +{ + ERRCHK(grid.initialized); + acGridSynchronizeStream(stream); + +#if AC_DOUBLE_PRECISION == 1 + MPI_Datatype datatype = MPI_DOUBLE; +#else + MPI_Datatype datatype = MPI_FLOAT; +#endif + + const int root_proc = 0; + AcReal buffer = value; + MPI_Bcast(&buffer, 1, datatype, root_proc, MPI_COMM_WORLD); + + acDeviceLoadScalarUniform(grid.device, stream, param, buffer); + return AC_SUCCESS; +} + +/** */ +AcResult +acGridLoadVectorUniform(const Stream stream, const AcReal3Param param, const AcReal3 value) +{ + ERRCHK(grid.initialized); + acGridSynchronizeStream(stream); + +#if AC_DOUBLE_PRECISION == 1 + MPI_Datatype datatype = MPI_DOUBLE; +#else + MPI_Datatype datatype = MPI_FLOAT; +#endif + + const int root_proc = 0; + AcReal3 buffer = value; + MPI_Bcast(&buffer, 3, datatype, root_proc, MPI_COMM_WORLD); + + acDeviceLoadVectorUniform(grid.device, stream, param, buffer); + return AC_SUCCESS; +} + AcResult acGridLoadMesh(const Stream stream, const AcMesh host_mesh) {