xref: /petsc/src/sys/classes/random/impls/curand/curand2.cu (revision 808ba619248ec38c62492271953950a08e6d0b8c)
1*808ba619SStefano Zampini #include <../src/sys/classes/random/randomimpl.h>
2*808ba619SStefano Zampini #include <thrust/transform.h>
3*808ba619SStefano Zampini #include <thrust/device_ptr.h>
4*808ba619SStefano Zampini #include <thrust/iterator/counting_iterator.h>
5*808ba619SStefano Zampini 
6*808ba619SStefano Zampini #if defined(PETSC_USE_COMPLEX)
7*808ba619SStefano Zampini struct complexscalelw : public thrust::unary_function<thrust::tuple<PetscReal, size_t>,PetscReal>
8*808ba619SStefano Zampini {
9*808ba619SStefano Zampini   PetscReal rl,rw;
10*808ba619SStefano Zampini   PetscReal il,iw;
11*808ba619SStefano Zampini 
12*808ba619SStefano Zampini   complexscalelw(PetscScalar low, PetscScalar width) {
13*808ba619SStefano Zampini     rl = PetscRealPart(low);
14*808ba619SStefano Zampini     il = PetscImaginaryPart(low);
15*808ba619SStefano Zampini     rw = PetscRealPart(width);
16*808ba619SStefano Zampini     iw = PetscImaginaryPart(width);
17*808ba619SStefano Zampini   }
18*808ba619SStefano Zampini 
19*808ba619SStefano Zampini   __host__ __device__
20*808ba619SStefano Zampini   PetscReal operator()(thrust::tuple<PetscReal, size_t> x) {
21*808ba619SStefano Zampini     return x.get<1>()%2 ? x.get<0>()*iw + il : x.get<0>()*rw + rl;
22*808ba619SStefano Zampini   }
23*808ba619SStefano Zampini };
24*808ba619SStefano Zampini #endif
25*808ba619SStefano Zampini 
26*808ba619SStefano Zampini struct realscalelw : public thrust::unary_function<PetscReal,PetscReal>
27*808ba619SStefano Zampini {
28*808ba619SStefano Zampini   PetscReal l,w;
29*808ba619SStefano Zampini 
30*808ba619SStefano Zampini   realscalelw(PetscReal low, PetscReal width) : l(low), w(width) {}
31*808ba619SStefano Zampini 
32*808ba619SStefano Zampini   __host__ __device__
33*808ba619SStefano Zampini   PetscReal operator()(PetscReal x) {
34*808ba619SStefano Zampini     return x*w + l;
35*808ba619SStefano Zampini   }
36*808ba619SStefano Zampini };
37*808ba619SStefano Zampini 
38*808ba619SStefano Zampini PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom r, size_t n, PetscReal *val, PetscBool isneg)
39*808ba619SStefano Zampini {
40*808ba619SStefano Zampini   PetscFunctionBegin;
41*808ba619SStefano Zampini   if (!r->iset) PetscFunctionReturn(0);
42*808ba619SStefano Zampini   if (isneg) { /* complex case, need to scale differently */
43*808ba619SStefano Zampini #if defined(PETSC_USE_COMPLEX)
44*808ba619SStefano Zampini     thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val);
45*808ba619SStefano Zampini     auto zibit = thrust::make_zip_iterator(thrust::make_tuple(pval,thrust::counting_iterator<size_t>(0)));
46*808ba619SStefano Zampini     thrust::transform(zibit,zibit+n,pval,complexscalelw(r->low,r->width));
47*808ba619SStefano Zampini #else
48*808ba619SStefano Zampini     SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Negative array size %D",(PetscInt)n);
49*808ba619SStefano Zampini #endif
50*808ba619SStefano Zampini   } else {
51*808ba619SStefano Zampini     PetscReal rl = PetscRealPart(r->low);
52*808ba619SStefano Zampini     PetscReal rw = PetscRealPart(r->width);
53*808ba619SStefano Zampini     thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val);
54*808ba619SStefano Zampini     thrust::transform(pval,pval+n,pval,realscalelw(rl,rw));
55*808ba619SStefano Zampini   }
56*808ba619SStefano Zampini   PetscFunctionReturn(0);
57*808ba619SStefano Zampini }
58