xref: /petsc/src/sys/classes/random/impls/curand/curand.c (revision 5162e2cff6525a9b2e011550902b85eb10a0c994)
1d6cc7855SJacob Faibussowitsch #include <petsc/private/randomimpl.h>
2808ba619SStefano Zampini #include <curand.h>
3808ba619SStefano Zampini 
4808ba619SStefano Zampini #define CHKERRCURAND(stat) \
5808ba619SStefano Zampini do { \
6808ba619SStefano Zampini    if (PetscUnlikely(stat != CURAND_STATUS_SUCCESS)) { \
7808ba619SStefano Zampini      if (((stat == CURAND_STATUS_INITIALIZATION_FAILED) || (stat == CURAND_STATUS_ALLOCATION_FAILED)) && PetscCUDAInitialized) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_GPU_RESOURCE,"cuRAND error %d. Reports not initialized or alloc failed; this indicates the GPU has run out resources",(int)stat); \
8808ba619SStefano Zampini      else SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_GPU,"cuRand error %d",(int)stat); \
9808ba619SStefano Zampini    } \
10808ba619SStefano Zampini } while (0)
11808ba619SStefano Zampini 
12808ba619SStefano Zampini typedef struct {
13808ba619SStefano Zampini   curandGenerator_t gen;
14808ba619SStefano Zampini } PetscRandom_CURAND;
15808ba619SStefano Zampini 
16808ba619SStefano Zampini PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r)
17808ba619SStefano Zampini {
18808ba619SStefano Zampini   curandStatus_t     cerr;
19808ba619SStefano Zampini   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
20808ba619SStefano Zampini 
21808ba619SStefano Zampini   PetscFunctionBegin;
22808ba619SStefano Zampini   cerr = curandSetPseudoRandomGeneratorSeed(curand->gen,r->seed);CHKERRCURAND(cerr);
23808ba619SStefano Zampini   PetscFunctionReturn(0);
24808ba619SStefano Zampini }
25808ba619SStefano Zampini 
26808ba619SStefano Zampini PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom,size_t,PetscReal*,PetscBool);
27808ba619SStefano Zampini 
28808ba619SStefano Zampini PetscErrorCode  PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val)
29808ba619SStefano Zampini {
30808ba619SStefano Zampini   curandStatus_t     cerr;
31808ba619SStefano Zampini   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
32*5162e2cfSBarry Smith   size_t             nn = n < 0 ? (size_t)(-2*n) : n; /* handle complex case */
33808ba619SStefano Zampini 
34808ba619SStefano Zampini   PetscFunctionBegin;
35808ba619SStefano Zampini #if defined(PETSC_USE_REAL_SINGLE)
36808ba619SStefano Zampini   cerr = curandGenerateUniform(curand->gen,val,nn);CHKERRCURAND(cerr);
37808ba619SStefano Zampini #else
38808ba619SStefano Zampini   cerr = curandGenerateUniformDouble(curand->gen,val,nn);CHKERRCURAND(cerr);
39808ba619SStefano Zampini #endif
40808ba619SStefano Zampini   if (r->iset) {
41808ba619SStefano Zampini     PetscErrorCode ierr = PetscRandomCurandScale_Private(r,nn,val,(PetscBool)(n<0));CHKERRQ(ierr);
42808ba619SStefano Zampini   }
43808ba619SStefano Zampini   PetscFunctionReturn(0);
44808ba619SStefano Zampini }
45808ba619SStefano Zampini 
46808ba619SStefano Zampini PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val)
47808ba619SStefano Zampini {
48808ba619SStefano Zampini   PetscErrorCode ierr;
49808ba619SStefano Zampini 
50808ba619SStefano Zampini   PetscFunctionBegin;
51808ba619SStefano Zampini #if defined(PETSC_USE_COMPLEX)
52808ba619SStefano Zampini   /* pass negative size to flag complex scaling (if needed) */
53808ba619SStefano Zampini   ierr = PetscRandomGetValuesReal_CURAND(r,-n,(PetscReal*)val);CHKERRQ(ierr);
54808ba619SStefano Zampini #else
55808ba619SStefano Zampini   ierr = PetscRandomGetValuesReal_CURAND(r,n,val);CHKERRQ(ierr);
56808ba619SStefano Zampini #endif
57808ba619SStefano Zampini   PetscFunctionReturn(0);
58808ba619SStefano Zampini }
59808ba619SStefano Zampini 
60808ba619SStefano Zampini PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r)
61808ba619SStefano Zampini {
62808ba619SStefano Zampini   PetscErrorCode     ierr;
63808ba619SStefano Zampini   curandStatus_t     cerr;
64808ba619SStefano Zampini   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
65808ba619SStefano Zampini 
66808ba619SStefano Zampini   PetscFunctionBegin;
67808ba619SStefano Zampini   cerr = curandDestroyGenerator(curand->gen);CHKERRCURAND(cerr);
68808ba619SStefano Zampini   ierr = PetscFree(r->data);CHKERRQ(ierr);
69808ba619SStefano Zampini   PetscFunctionReturn(0);
70808ba619SStefano Zampini }
71808ba619SStefano Zampini 
72808ba619SStefano Zampini static struct _PetscRandomOps PetscRandomOps_Values = {
73808ba619SStefano Zampini   PetscRandomSeed_CURAND,
74808ba619SStefano Zampini   NULL,
75808ba619SStefano Zampini   NULL,
76808ba619SStefano Zampini   PetscRandomGetValues_CURAND,
77808ba619SStefano Zampini   PetscRandomGetValuesReal_CURAND,
78808ba619SStefano Zampini   PetscRandomDestroy_CURAND,
79808ba619SStefano Zampini   NULL
80808ba619SStefano Zampini };
81808ba619SStefano Zampini 
82808ba619SStefano Zampini /*MC
83808ba619SStefano Zampini    PETSCCURAND - access to the CUDA random number generator
84808ba619SStefano Zampini 
85808ba619SStefano Zampini   Level: beginner
86808ba619SStefano Zampini 
87808ba619SStefano Zampini .seealso: PetscRandomCreate(), PetscRandomSetType()
88808ba619SStefano Zampini M*/
89808ba619SStefano Zampini 
90808ba619SStefano Zampini PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r)
91808ba619SStefano Zampini {
92808ba619SStefano Zampini   PetscErrorCode     ierr;
93808ba619SStefano Zampini   curandStatus_t     cerr;
94808ba619SStefano Zampini   PetscRandom_CURAND *curand;
95808ba619SStefano Zampini 
96808ba619SStefano Zampini   PetscFunctionBegin;
97808ba619SStefano Zampini   ierr = PetscCUDAInitializeCheck();CHKERRQ(ierr);
98808ba619SStefano Zampini   ierr = PetscNewLog(r,&curand);CHKERRQ(ierr);
99808ba619SStefano Zampini   cerr = curandCreateGenerator(&curand->gen,CURAND_RNG_PSEUDO_DEFAULT);CHKERRCURAND(cerr);
100808ba619SStefano Zampini   /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */
101808ba619SStefano Zampini   cerr = curandSetGeneratorOrdering(curand->gen,CURAND_ORDERING_PSEUDO_SEEDED);CHKERRCURAND(cerr);
102808ba619SStefano Zampini   ierr = PetscMemcpy(r->ops,&PetscRandomOps_Values,sizeof(PetscRandomOps_Values));CHKERRQ(ierr);
103808ba619SStefano Zampini   ierr = PetscObjectChangeTypeName((PetscObject)r,PETSCCURAND);CHKERRQ(ierr);
104808ba619SStefano Zampini   r->data = curand;
105808ba619SStefano Zampini   r->seed = 1234ULL; /* taken from example */
106808ba619SStefano Zampini   ierr = PetscRandomSeed_CURAND(r);CHKERRQ(ierr);
107808ba619SStefano Zampini   PetscFunctionReturn(0);
108808ba619SStefano Zampini }
109