xref: /petsc/src/sys/classes/random/impls/curand/curand2.cu (revision 98921bda46e76d7aaed9e0138c5ff9d0ce93f355)
1d6cc7855SJacob Faibussowitsch #include <petsc/private/randomimpl.h>
2808ba619SStefano Zampini #include <thrust/transform.h>
3808ba619SStefano Zampini #include <thrust/device_ptr.h>
4808ba619SStefano Zampini #include <thrust/iterator/counting_iterator.h>
5808ba619SStefano Zampini 
6808ba619SStefano Zampini #if defined(PETSC_USE_COMPLEX)
7808ba619SStefano Zampini struct complexscalelw : public thrust::unary_function<thrust::tuple<PetscReal, size_t>,PetscReal>
8808ba619SStefano Zampini {
9808ba619SStefano Zampini   PetscReal rl,rw;
10808ba619SStefano Zampini   PetscReal il,iw;
11808ba619SStefano Zampini 
12808ba619SStefano Zampini   complexscalelw(PetscScalar low, PetscScalar width) {
13808ba619SStefano Zampini     rl = PetscRealPart(low);
14808ba619SStefano Zampini     il = PetscImaginaryPart(low);
15808ba619SStefano Zampini     rw = PetscRealPart(width);
16808ba619SStefano Zampini     iw = PetscImaginaryPart(width);
17808ba619SStefano Zampini   }
18808ba619SStefano Zampini 
19808ba619SStefano Zampini   __host__ __device__
20808ba619SStefano Zampini   PetscReal operator()(thrust::tuple<PetscReal, size_t> x) {
21808ba619SStefano Zampini     return x.get<1>()%2 ? x.get<0>()*iw + il : x.get<0>()*rw + rl;
22808ba619SStefano Zampini   }
23808ba619SStefano Zampini };
24808ba619SStefano Zampini #endif
25808ba619SStefano Zampini 
26808ba619SStefano Zampini struct realscalelw : public thrust::unary_function<PetscReal,PetscReal>
27808ba619SStefano Zampini {
28808ba619SStefano Zampini   PetscReal l,w;
29808ba619SStefano Zampini 
30808ba619SStefano Zampini   realscalelw(PetscReal low, PetscReal width) : l(low), w(width) {}
31808ba619SStefano Zampini 
32808ba619SStefano Zampini   __host__ __device__
33808ba619SStefano Zampini   PetscReal operator()(PetscReal x) {
34808ba619SStefano Zampini     return x*w + l;
35808ba619SStefano Zampini   }
36808ba619SStefano Zampini };
37808ba619SStefano Zampini 
38808ba619SStefano Zampini PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom r, size_t n, PetscReal *val, PetscBool isneg)
39808ba619SStefano Zampini {
40808ba619SStefano Zampini   PetscFunctionBegin;
41808ba619SStefano Zampini   if (!r->iset) PetscFunctionReturn(0);
42808ba619SStefano Zampini   if (isneg) { /* complex case, need to scale differently */
43808ba619SStefano Zampini #if defined(PETSC_USE_COMPLEX)
44808ba619SStefano Zampini     thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val);
45808ba619SStefano Zampini     auto zibit = thrust::make_zip_iterator(thrust::make_tuple(pval,thrust::counting_iterator<size_t>(0)));
46808ba619SStefano Zampini     thrust::transform(zibit,zibit+n,pval,complexscalelw(r->low,r->width));
47808ba619SStefano Zampini #else
48*98921bdaSJacob Faibussowitsch     SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Negative array size %" PetscInt_FMT,(PetscInt)n);
49808ba619SStefano Zampini #endif
50808ba619SStefano Zampini   } else {
51808ba619SStefano Zampini     PetscReal rl = PetscRealPart(r->low);
52808ba619SStefano Zampini     PetscReal rw = PetscRealPart(r->width);
53808ba619SStefano Zampini     thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val);
54808ba619SStefano Zampini     thrust::transform(pval,pval+n,pval,realscalelw(rl,rw));
55808ba619SStefano Zampini   }
56808ba619SStefano Zampini   PetscFunctionReturn(0);
57808ba619SStefano Zampini }
58