xref: /petsc/src/sys/classes/random/impls/curand/curand.c (revision eb58fe77b0e5d125a3d55c974ef2bc19b73e6b04)
1a4af0ceeSJacob Faibussowitsch #include <petsc/private/deviceimpl.h>
2d6cc7855SJacob Faibussowitsch #include <petsc/private/randomimpl.h>
30e6b6b59SJacob Faibussowitsch #include <petscdevice_cuda.h>
4808ba619SStefano Zampini #include <curand.h>
5808ba619SStefano Zampini 
6808ba619SStefano Zampini typedef struct {
7808ba619SStefano Zampini   curandGenerator_t gen;
8808ba619SStefano Zampini } PetscRandom_CURAND;
9808ba619SStefano Zampini 
1066976f2fSJacob Faibussowitsch static PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r)
11d71ae5a4SJacob Faibussowitsch {
12808ba619SStefano Zampini   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
13808ba619SStefano Zampini 
14808ba619SStefano Zampini   PetscFunctionBegin;
159566063dSJacob Faibussowitsch   PetscCallCURAND(curandSetPseudoRandomGeneratorSeed(curand->gen, r->seed));
163ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
17808ba619SStefano Zampini }
18808ba619SStefano Zampini 
19808ba619SStefano Zampini PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom, size_t, PetscReal *, PetscBool);
20808ba619SStefano Zampini 
2166976f2fSJacob Faibussowitsch static PetscErrorCode PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val)
22d71ae5a4SJacob Faibussowitsch {
23808ba619SStefano Zampini   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
24bcdedc73SJed Brown   size_t              nn     = n < 0 ? (size_t)(-2 * n) : (size_t)n; /* handle complex case */
25808ba619SStefano Zampini 
26808ba619SStefano Zampini   PetscFunctionBegin;
27808ba619SStefano Zampini #if defined(PETSC_USE_REAL_SINGLE)
289566063dSJacob Faibussowitsch   PetscCallCURAND(curandGenerateUniform(curand->gen, val, nn));
29808ba619SStefano Zampini #else
309566063dSJacob Faibussowitsch   PetscCallCURAND(curandGenerateUniformDouble(curand->gen, val, nn));
31808ba619SStefano Zampini #endif
3248a46eb9SPierre Jolivet   if (r->iset) PetscCall(PetscRandomCurandScale_Private(r, nn, val, (PetscBool)(n < 0)));
333ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
34808ba619SStefano Zampini }
35808ba619SStefano Zampini 
3666976f2fSJacob Faibussowitsch static PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val)
37d71ae5a4SJacob Faibussowitsch {
38808ba619SStefano Zampini   PetscFunctionBegin;
39808ba619SStefano Zampini #if defined(PETSC_USE_COMPLEX)
40808ba619SStefano Zampini   /* pass negative size to flag complex scaling (if needed) */
419566063dSJacob Faibussowitsch   PetscCall(PetscRandomGetValuesReal_CURAND(r, -n, (PetscReal *)val));
42808ba619SStefano Zampini #else
439566063dSJacob Faibussowitsch   PetscCall(PetscRandomGetValuesReal_CURAND(r, n, val));
44808ba619SStefano Zampini #endif
453ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
46808ba619SStefano Zampini }
47808ba619SStefano Zampini 
4866976f2fSJacob Faibussowitsch static PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r)
49d71ae5a4SJacob Faibussowitsch {
50808ba619SStefano Zampini   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
51808ba619SStefano Zampini 
52808ba619SStefano Zampini   PetscFunctionBegin;
539566063dSJacob Faibussowitsch   PetscCallCURAND(curandDestroyGenerator(curand->gen));
549566063dSJacob Faibussowitsch   PetscCall(PetscFree(r->data));
553ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
56808ba619SStefano Zampini }
57808ba619SStefano Zampini 
58808ba619SStefano Zampini static struct _PetscRandomOps PetscRandomOps_Values = {
59267267bdSJacob Faibussowitsch   PetscDesignatedInitializer(seed, PetscRandomSeed_CURAND),
60267267bdSJacob Faibussowitsch   PetscDesignatedInitializer(getvalue, NULL),
61267267bdSJacob Faibussowitsch   PetscDesignatedInitializer(getvaluereal, NULL),
62267267bdSJacob Faibussowitsch   PetscDesignatedInitializer(getvalues, PetscRandomGetValues_CURAND),
63267267bdSJacob Faibussowitsch   PetscDesignatedInitializer(getvaluesreal, PetscRandomGetValuesReal_CURAND),
64267267bdSJacob Faibussowitsch   PetscDesignatedInitializer(destroy, PetscRandomDestroy_CURAND),
65808ba619SStefano Zampini };
66808ba619SStefano Zampini 
67808ba619SStefano Zampini /*MC
68811af0c4SBarry Smith    PETSCCURAND - access to the CUDA random number generator from a `PetscRandom` object
69808ba619SStefano Zampini 
70808ba619SStefano Zampini   Level: beginner
71808ba619SStefano Zampini 
72c31d2375SBarry Smith   Note:
73c31d2375SBarry Smith   This random number generator is available when PETSc is configured with ``./configure --with-cuda=1``
74c31d2375SBarry Smith 
75811af0c4SBarry Smith .seealso: `PetscRandomCreate()`, `PetscRandomSetType()`, `PetscRandomType`
76808ba619SStefano Zampini M*/
77808ba619SStefano Zampini 
78d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r)
79d71ae5a4SJacob Faibussowitsch {
80808ba619SStefano Zampini   PetscRandom_CURAND *curand;
81*eb58fe77SHansol Suh   PetscDeviceContext  dctx;
82*eb58fe77SHansol Suh   cudaStream_t       *stream;
83808ba619SStefano Zampini 
84808ba619SStefano Zampini   PetscFunctionBegin;
859566063dSJacob Faibussowitsch   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
86*eb58fe77SHansol Suh   PetscCall(PetscDeviceContextGetCurrentContextAssertType_Internal(&dctx, PETSC_DEVICE_CUDA));
87*eb58fe77SHansol Suh   PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
884dfa11a4SJacob Faibussowitsch   PetscCall(PetscNew(&curand));
899566063dSJacob Faibussowitsch   PetscCallCURAND(curandCreateGenerator(&curand->gen, CURAND_RNG_PSEUDO_DEFAULT));
90*eb58fe77SHansol Suh   PetscCallCURAND(curandSetStream(curand->gen, *stream));
91808ba619SStefano Zampini   /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */
929566063dSJacob Faibussowitsch   PetscCallCURAND(curandSetGeneratorOrdering(curand->gen, CURAND_ORDERING_PSEUDO_SEEDED));
93aea10558SJacob Faibussowitsch   r->ops[0] = PetscRandomOps_Values;
949566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)r, PETSCCURAND));
95808ba619SStefano Zampini   r->data = curand;
969566063dSJacob Faibussowitsch   PetscCall(PetscRandomSeed_CURAND(r));
973ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
98808ba619SStefano Zampini }
99