xref: /petsc/src/sys/classes/random/impls/curand/curand.c (revision a4af0ceea8a251db97ee0dc5c0d52d4adf50264a)
1*a4af0ceeSJacob Faibussowitsch #include <petsc/private/deviceimpl.h>
2d6cc7855SJacob Faibussowitsch #include <petsc/private/randomimpl.h>
3808ba619SStefano Zampini #include <curand.h>
4808ba619SStefano Zampini 
5808ba619SStefano Zampini typedef struct {
6808ba619SStefano Zampini   curandGenerator_t gen;
7808ba619SStefano Zampini } PetscRandom_CURAND;
8808ba619SStefano Zampini 
9808ba619SStefano Zampini PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r)
10808ba619SStefano Zampini {
11808ba619SStefano Zampini   curandStatus_t     cerr;
12808ba619SStefano Zampini   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
13808ba619SStefano Zampini 
14808ba619SStefano Zampini   PetscFunctionBegin;
15808ba619SStefano Zampini   cerr = curandSetPseudoRandomGeneratorSeed(curand->gen,r->seed);CHKERRCURAND(cerr);
16808ba619SStefano Zampini   PetscFunctionReturn(0);
17808ba619SStefano Zampini }
18808ba619SStefano Zampini 
19808ba619SStefano Zampini PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom,size_t,PetscReal*,PetscBool);
20808ba619SStefano Zampini 
21808ba619SStefano Zampini PetscErrorCode  PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val)
22808ba619SStefano Zampini {
23808ba619SStefano Zampini   curandStatus_t     cerr;
24808ba619SStefano Zampini   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
255162e2cfSBarry Smith   size_t             nn = n < 0 ? (size_t)(-2*n) : n; /* handle complex case */
26808ba619SStefano Zampini 
27808ba619SStefano Zampini   PetscFunctionBegin;
28808ba619SStefano Zampini #if defined(PETSC_USE_REAL_SINGLE)
29808ba619SStefano Zampini   cerr = curandGenerateUniform(curand->gen,val,nn);CHKERRCURAND(cerr);
30808ba619SStefano Zampini #else
31808ba619SStefano Zampini   cerr = curandGenerateUniformDouble(curand->gen,val,nn);CHKERRCURAND(cerr);
32808ba619SStefano Zampini #endif
33808ba619SStefano Zampini   if (r->iset) {
34808ba619SStefano Zampini     PetscErrorCode ierr = PetscRandomCurandScale_Private(r,nn,val,(PetscBool)(n<0));CHKERRQ(ierr);
35808ba619SStefano Zampini   }
36808ba619SStefano Zampini   PetscFunctionReturn(0);
37808ba619SStefano Zampini }
38808ba619SStefano Zampini 
39808ba619SStefano Zampini PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val)
40808ba619SStefano Zampini {
41808ba619SStefano Zampini   PetscErrorCode ierr;
42808ba619SStefano Zampini 
43808ba619SStefano Zampini   PetscFunctionBegin;
44808ba619SStefano Zampini #if defined(PETSC_USE_COMPLEX)
45808ba619SStefano Zampini   /* pass negative size to flag complex scaling (if needed) */
46808ba619SStefano Zampini   ierr = PetscRandomGetValuesReal_CURAND(r,-n,(PetscReal*)val);CHKERRQ(ierr);
47808ba619SStefano Zampini #else
48808ba619SStefano Zampini   ierr = PetscRandomGetValuesReal_CURAND(r,n,val);CHKERRQ(ierr);
49808ba619SStefano Zampini #endif
50808ba619SStefano Zampini   PetscFunctionReturn(0);
51808ba619SStefano Zampini }
52808ba619SStefano Zampini 
53808ba619SStefano Zampini PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r)
54808ba619SStefano Zampini {
55808ba619SStefano Zampini   PetscErrorCode     ierr;
56808ba619SStefano Zampini   curandStatus_t     cerr;
57808ba619SStefano Zampini   PetscRandom_CURAND *curand = (PetscRandom_CURAND*)r->data;
58808ba619SStefano Zampini 
59808ba619SStefano Zampini   PetscFunctionBegin;
60808ba619SStefano Zampini   cerr = curandDestroyGenerator(curand->gen);CHKERRCURAND(cerr);
61808ba619SStefano Zampini   ierr = PetscFree(r->data);CHKERRQ(ierr);
62808ba619SStefano Zampini   PetscFunctionReturn(0);
63808ba619SStefano Zampini }
64808ba619SStefano Zampini 
65808ba619SStefano Zampini static struct _PetscRandomOps PetscRandomOps_Values = {
66808ba619SStefano Zampini   PetscRandomSeed_CURAND,
67808ba619SStefano Zampini   NULL,
68808ba619SStefano Zampini   NULL,
69808ba619SStefano Zampini   PetscRandomGetValues_CURAND,
70808ba619SStefano Zampini   PetscRandomGetValuesReal_CURAND,
71808ba619SStefano Zampini   PetscRandomDestroy_CURAND,
72808ba619SStefano Zampini   NULL
73808ba619SStefano Zampini };
74808ba619SStefano Zampini 
75808ba619SStefano Zampini /*MC
76808ba619SStefano Zampini    PETSCCURAND - access to the CUDA random number generator
77808ba619SStefano Zampini 
78808ba619SStefano Zampini   Level: beginner
79808ba619SStefano Zampini 
80808ba619SStefano Zampini .seealso: PetscRandomCreate(), PetscRandomSetType()
81808ba619SStefano Zampini M*/
82808ba619SStefano Zampini 
83808ba619SStefano Zampini PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r)
84808ba619SStefano Zampini {
85808ba619SStefano Zampini   PetscErrorCode     ierr;
86808ba619SStefano Zampini   curandStatus_t     cerr;
87808ba619SStefano Zampini   PetscRandom_CURAND *curand;
88808ba619SStefano Zampini 
89808ba619SStefano Zampini   PetscFunctionBegin;
90*a4af0ceeSJacob Faibussowitsch   ierr = PetscDeviceInitialize(PETSC_DEVICE_CUDA);CHKERRQ(ierr);
91808ba619SStefano Zampini   ierr = PetscNewLog(r,&curand);CHKERRQ(ierr);
92808ba619SStefano Zampini   cerr = curandCreateGenerator(&curand->gen,CURAND_RNG_PSEUDO_DEFAULT);CHKERRCURAND(cerr);
93808ba619SStefano Zampini   /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */
94808ba619SStefano Zampini   cerr = curandSetGeneratorOrdering(curand->gen,CURAND_ORDERING_PSEUDO_SEEDED);CHKERRCURAND(cerr);
95808ba619SStefano Zampini   ierr = PetscMemcpy(r->ops,&PetscRandomOps_Values,sizeof(PetscRandomOps_Values));CHKERRQ(ierr);
96808ba619SStefano Zampini   ierr = PetscObjectChangeTypeName((PetscObject)r,PETSCCURAND);CHKERRQ(ierr);
97808ba619SStefano Zampini   r->data = curand;
98808ba619SStefano Zampini   r->seed = 1234ULL; /* taken from example */
99808ba619SStefano Zampini   ierr = PetscRandomSeed_CURAND(r);CHKERRQ(ierr);
100808ba619SStefano Zampini   PetscFunctionReturn(0);
101808ba619SStefano Zampini }
102