xref: /petsc/src/sys/classes/random/impls/curand/curand2.cu (revision 8f2bb3700d11cbe40742ed1ec05e6dcf55cdea91)
1d6cc7855SJacob Faibussowitsch #include <petsc/private/randomimpl.h>
2808ba619SStefano Zampini #include <thrust/transform.h>
3808ba619SStefano Zampini #include <thrust/device_ptr.h>
4808ba619SStefano Zampini #include <thrust/iterator/counting_iterator.h>
5808ba619SStefano Zampini 
6808ba619SStefano Zampini #if defined(PETSC_USE_COMPLEX)
79371c9d4SSatish Balay struct complexscalelw : public thrust::unary_function<thrust::tuple<PetscReal, size_t>, PetscReal> {
8808ba619SStefano Zampini   PetscReal rl, rw;
9808ba619SStefano Zampini   PetscReal il, iw;
10808ba619SStefano Zampini 
11d71ae5a4SJacob Faibussowitsch   complexscalelw(PetscScalar low, PetscScalar width)
12d71ae5a4SJacob Faibussowitsch   {
13808ba619SStefano Zampini     rl = PetscRealPart(low);
14808ba619SStefano Zampini     il = PetscImaginaryPart(low);
15808ba619SStefano Zampini     rw = PetscRealPart(width);
16808ba619SStefano Zampini     iw = PetscImaginaryPart(width);
17808ba619SStefano Zampini   }
18808ba619SStefano Zampini 
19*8f2bb370SSebastian Grimberg   __host__ __device__ PetscReal operator()(thrust::tuple<PetscReal, size_t> x) { return thrust::get<1>(x) % 2 ? thrust::get<0>(x) * iw + il : thrust::get<0>(x) * rw + rl; }
20808ba619SStefano Zampini };
21808ba619SStefano Zampini #endif
22808ba619SStefano Zampini 
239371c9d4SSatish Balay struct realscalelw : public thrust::unary_function<PetscReal, PetscReal> {
24808ba619SStefano Zampini   PetscReal l, w;
25808ba619SStefano Zampini 
26808ba619SStefano Zampini   realscalelw(PetscReal low, PetscReal width) : l(low), w(width) { }
27808ba619SStefano Zampini 
289371c9d4SSatish Balay   __host__ __device__ PetscReal operator()(PetscReal x) { return x * w + l; }
29808ba619SStefano Zampini };
30808ba619SStefano Zampini 
31d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom r, size_t n, PetscReal *val, PetscBool isneg)
32d71ae5a4SJacob Faibussowitsch {
33808ba619SStefano Zampini   PetscFunctionBegin;
343ba16761SJacob Faibussowitsch   if (!r->iset) PetscFunctionReturn(PETSC_SUCCESS);
35808ba619SStefano Zampini   if (isneg) { /* complex case, need to scale differently */
36808ba619SStefano Zampini #if defined(PETSC_USE_COMPLEX)
37808ba619SStefano Zampini     thrust::device_ptr<PetscReal> pval  = thrust::device_pointer_cast(val);
38808ba619SStefano Zampini     auto                          zibit = thrust::make_zip_iterator(thrust::make_tuple(pval, thrust::counting_iterator<size_t>(0)));
39808ba619SStefano Zampini     thrust::transform(zibit, zibit + n, pval, complexscalelw(r->low, r->width));
40808ba619SStefano Zampini #else
4198921bdaSJacob Faibussowitsch     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Negative array size %" PetscInt_FMT, (PetscInt)n);
42808ba619SStefano Zampini #endif
43808ba619SStefano Zampini   } else {
44808ba619SStefano Zampini     PetscReal                     rl   = PetscRealPart(r->low);
45808ba619SStefano Zampini     PetscReal                     rw   = PetscRealPart(r->width);
46808ba619SStefano Zampini     thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val);
47808ba619SStefano Zampini     thrust::transform(pval, pval + n, pval, realscalelw(rl, rw));
48808ba619SStefano Zampini   }
493ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
50808ba619SStefano Zampini }
51