xref: /petsc/src/sys/classes/random/impls/curand/curand.c (revision 808ba619248ec38c62492271953950a08e6d0b8c)
1*808ba619SStefano Zampini #include <../src/sys/classes/random/randomimpl.h>
2*808ba619SStefano Zampini #include <curand.h>
3*808ba619SStefano Zampini 
4*808ba619SStefano Zampini #define CHKERRCURAND(stat) \
5*808ba619SStefano Zampini do { \
6*808ba619SStefano Zampini    if (PetscUnlikely(stat != CURAND_STATUS_SUCCESS)) { \
7*808ba619SStefano Zampini      if (((stat == CURAND_STATUS_INITIALIZATION_FAILED) || (stat == CURAND_STATUS_ALLOCATION_FAILED)) && PetscCUDAInitialized) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_GPU_RESOURCE,"cuRAND error %d. Reports not initialized or alloc failed; this indicates the GPU has run out resources",(int)stat); \
8*808ba619SStefano Zampini      else SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_GPU,"cuRand error %d",(int)stat); \
9*808ba619SStefano Zampini    } \
10*808ba619SStefano Zampini } while (0)
11*808ba619SStefano Zampini 
12*808ba619SStefano Zampini typedef struct {
13*808ba619SStefano Zampini   curandGenerator_t gen;
14*808ba619SStefano Zampini } PetscRandom_CURAND;
15*808ba619SStefano Zampini 
16*808ba619SStefano Zampini PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r)
17*808ba619SStefano Zampini {
18*808ba619SStefano Zampini   curandStatus_t     cerr;
19*808ba619SStefano Zampini   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
20*808ba619SStefano Zampini 
21*808ba619SStefano Zampini   PetscFunctionBegin;
22*808ba619SStefano Zampini   cerr = curandSetPseudoRandomGeneratorSeed(curand->gen,r->seed);CHKERRCURAND(cerr);
23*808ba619SStefano Zampini   PetscFunctionReturn(0);
24*808ba619SStefano Zampini }
25*808ba619SStefano Zampini 
26*808ba619SStefano Zampini PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom,size_t,PetscReal*,PetscBool);
27*808ba619SStefano Zampini 
28*808ba619SStefano Zampini PetscErrorCode  PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val)
29*808ba619SStefano Zampini {
30*808ba619SStefano Zampini   curandStatus_t     cerr;
31*808ba619SStefano Zampini   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
32*808ba619SStefano Zampini   size_t             nn = n < 0 ? -2*(size_t)n : n; /* handle complex case */
33*808ba619SStefano Zampini 
34*808ba619SStefano Zampini   PetscFunctionBegin;
35*808ba619SStefano Zampini #if defined(PETSC_USE_REAL_SINGLE)
36*808ba619SStefano Zampini   cerr = curandGenerateUniform(curand->gen,val,nn);CHKERRCURAND(cerr);
37*808ba619SStefano Zampini #else
38*808ba619SStefano Zampini   cerr = curandGenerateUniformDouble(curand->gen,val,nn);CHKERRCURAND(cerr);
39*808ba619SStefano Zampini #endif
40*808ba619SStefano Zampini   if (r->iset) {
41*808ba619SStefano Zampini     PetscErrorCode ierr = PetscRandomCurandScale_Private(r,nn,val,(PetscBool)(n<0));CHKERRQ(ierr);
42*808ba619SStefano Zampini   }
43*808ba619SStefano Zampini   PetscFunctionReturn(0);
44*808ba619SStefano Zampini }
45*808ba619SStefano Zampini 
46*808ba619SStefano Zampini PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val)
47*808ba619SStefano Zampini {
48*808ba619SStefano Zampini   PetscErrorCode ierr;
49*808ba619SStefano Zampini 
50*808ba619SStefano Zampini   PetscFunctionBegin;
51*808ba619SStefano Zampini #if defined(PETSC_USE_COMPLEX)
52*808ba619SStefano Zampini   /* pass negative size to flag complex scaling (if needed) */
53*808ba619SStefano Zampini   ierr = PetscRandomGetValuesReal_CURAND(r,-n,(PetscReal*)val);CHKERRQ(ierr);
54*808ba619SStefano Zampini #else
55*808ba619SStefano Zampini   ierr = PetscRandomGetValuesReal_CURAND(r,n,val);CHKERRQ(ierr);
56*808ba619SStefano Zampini #endif
57*808ba619SStefano Zampini   PetscFunctionReturn(0);
58*808ba619SStefano Zampini }
59*808ba619SStefano Zampini 
60*808ba619SStefano Zampini PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r)
61*808ba619SStefano Zampini {
62*808ba619SStefano Zampini   PetscErrorCode     ierr;
63*808ba619SStefano Zampini   curandStatus_t     cerr;
64*808ba619SStefano Zampini   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
65*808ba619SStefano Zampini 
66*808ba619SStefano Zampini   PetscFunctionBegin;
67*808ba619SStefano Zampini   cerr = curandDestroyGenerator(curand->gen);CHKERRCURAND(cerr);
68*808ba619SStefano Zampini   ierr = PetscFree(r->data);CHKERRQ(ierr);
69*808ba619SStefano Zampini   PetscFunctionReturn(0);
70*808ba619SStefano Zampini }
71*808ba619SStefano Zampini 
72*808ba619SStefano Zampini static struct _PetscRandomOps PetscRandomOps_Values = {
73*808ba619SStefano Zampini   PetscRandomSeed_CURAND,
74*808ba619SStefano Zampini   NULL,
75*808ba619SStefano Zampini   NULL,
76*808ba619SStefano Zampini   PetscRandomGetValues_CURAND,
77*808ba619SStefano Zampini   PetscRandomGetValuesReal_CURAND,
78*808ba619SStefano Zampini   PetscRandomDestroy_CURAND,
79*808ba619SStefano Zampini   NULL
80*808ba619SStefano Zampini };
81*808ba619SStefano Zampini 
82*808ba619SStefano Zampini /*MC
83*808ba619SStefano Zampini    PETSCCURAND - access to the CUDA random number generator
84*808ba619SStefano Zampini 
85*808ba619SStefano Zampini   Level: beginner
86*808ba619SStefano Zampini 
87*808ba619SStefano Zampini .seealso: PetscRandomCreate(), PetscRandomSetType()
88*808ba619SStefano Zampini M*/
89*808ba619SStefano Zampini 
90*808ba619SStefano Zampini PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r)
91*808ba619SStefano Zampini {
92*808ba619SStefano Zampini   PetscErrorCode     ierr;
93*808ba619SStefano Zampini   curandStatus_t     cerr;
94*808ba619SStefano Zampini   PetscRandom_CURAND *curand;
95*808ba619SStefano Zampini 
96*808ba619SStefano Zampini   PetscFunctionBegin;
97*808ba619SStefano Zampini   ierr = PetscCUDAInitializeCheck();CHKERRQ(ierr);
98*808ba619SStefano Zampini   ierr = PetscNewLog(r,&curand);CHKERRQ(ierr);
99*808ba619SStefano Zampini   cerr = curandCreateGenerator(&curand->gen,CURAND_RNG_PSEUDO_DEFAULT);CHKERRCURAND(cerr);
100*808ba619SStefano Zampini   /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */
101*808ba619SStefano Zampini   cerr = curandSetGeneratorOrdering(curand->gen,CURAND_ORDERING_PSEUDO_SEEDED);CHKERRCURAND(cerr);
102*808ba619SStefano Zampini   ierr = PetscMemcpy(r->ops,&PetscRandomOps_Values,sizeof(PetscRandomOps_Values));CHKERRQ(ierr);
103*808ba619SStefano Zampini   ierr = PetscObjectChangeTypeName((PetscObject)r,PETSCCURAND);CHKERRQ(ierr);
104*808ba619SStefano Zampini   r->data = curand;
105*808ba619SStefano Zampini   r->seed = 1234ULL; /* taken from example */
106*808ba619SStefano Zampini   ierr = PetscRandomSeed_CURAND(r);CHKERRQ(ierr);
107*808ba619SStefano Zampini   PetscFunctionReturn(0);
108*808ba619SStefano Zampini }
109