1d6cc7855SJacob Faibussowitsch #include <petsc/private/randomimpl.h> 2808ba619SStefano Zampini #include <thrust/transform.h> 3808ba619SStefano Zampini #include <thrust/device_ptr.h> 4808ba619SStefano Zampini #include <thrust/iterator/counting_iterator.h> 5808ba619SStefano Zampini 6808ba619SStefano Zampini #if defined(PETSC_USE_COMPLEX) 7*cc6e31f1SJunchao Zhang struct complexscalelw 8*cc6e31f1SJunchao Zhang #if PETSC_PKG_CUDA_VERSION_LT(12, 8, 0) 9*cc6e31f1SJunchao Zhang : 10*cc6e31f1SJunchao Zhang public thrust::unary_function<thrust::tuple<PetscReal, size_t>, PetscReal> 11*cc6e31f1SJunchao Zhang #endif 12*cc6e31f1SJunchao Zhang { 13808ba619SStefano Zampini PetscReal rl, rw; 14808ba619SStefano Zampini PetscReal il, iw; 15808ba619SStefano Zampini 16d71ae5a4SJacob Faibussowitsch complexscalelw(PetscScalar low, PetscScalar width) 17d71ae5a4SJacob Faibussowitsch { 18808ba619SStefano Zampini rl = PetscRealPart(low); 19808ba619SStefano Zampini il = PetscImaginaryPart(low); 20808ba619SStefano Zampini rw = PetscRealPart(width); 21808ba619SStefano Zampini iw = PetscImaginaryPart(width); 22808ba619SStefano Zampini } 23808ba619SStefano Zampini 248f2bb370SSebastian Grimberg __host__ __device__ PetscReal operator()(thrust::tuple<PetscReal, size_t> x) { return thrust::get<1>(x) % 2 ? thrust::get<0>(x) * iw + il : thrust::get<0>(x) * rw + rl; } 25808ba619SStefano Zampini }; 26808ba619SStefano Zampini #endif 27808ba619SStefano Zampini 28*cc6e31f1SJunchao Zhang struct realscalelw 29*cc6e31f1SJunchao Zhang #if PETSC_PKG_CUDA_VERSION_LT(12, 8, 0) // To suppress the warning "thrust::THRUST_200700_860_NS::unary_function is deprecated" 30*cc6e31f1SJunchao Zhang : 31*cc6e31f1SJunchao Zhang public thrust::unary_function<PetscReal, PetscReal> 32*cc6e31f1SJunchao Zhang #endif 33*cc6e31f1SJunchao Zhang { 34808ba619SStefano Zampini PetscReal l, w; 35808ba619SStefano Zampini 36808ba619SStefano Zampini realscalelw(PetscReal low, PetscReal width) : l(low), w(width) { } 37808ba619SStefano Zampini 389371c9d4SSatish Balay __host__ __device__ PetscReal operator()(PetscReal x) { return x * w + l; } 39808ba619SStefano Zampini }; 40808ba619SStefano Zampini 41d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom r, size_t n, PetscReal *val, PetscBool isneg) 42d71ae5a4SJacob Faibussowitsch { 43808ba619SStefano Zampini PetscFunctionBegin; 443ba16761SJacob Faibussowitsch if (!r->iset) PetscFunctionReturn(PETSC_SUCCESS); 45808ba619SStefano Zampini if (isneg) { /* complex case, need to scale differently */ 46808ba619SStefano Zampini #if defined(PETSC_USE_COMPLEX) 47808ba619SStefano Zampini thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val); 48808ba619SStefano Zampini auto zibit = thrust::make_zip_iterator(thrust::make_tuple(pval, thrust::counting_iterator<size_t>(0))); 49808ba619SStefano Zampini thrust::transform(zibit, zibit + n, pval, complexscalelw(r->low, r->width)); 50808ba619SStefano Zampini #else 5198921bdaSJacob Faibussowitsch SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Negative array size %" PetscInt_FMT, (PetscInt)n); 52808ba619SStefano Zampini #endif 53808ba619SStefano Zampini } else { 54808ba619SStefano Zampini PetscReal rl = PetscRealPart(r->low); 55808ba619SStefano Zampini PetscReal rw = PetscRealPart(r->width); 56808ba619SStefano Zampini thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val); 57808ba619SStefano Zampini thrust::transform(pval, pval + n, pval, realscalelw(rl, rw)); 58808ba619SStefano Zampini } 593ba16761SJacob Faibussowitsch PetscFunctionReturn(PETSC_SUCCESS); 60808ba619SStefano Zampini } 61