xref: /petsc/src/ksp/pc/impls/pbjacobi/cuda/pbjacobi_cuda.cu (revision 12facf1b2b728ba534ad2f7a1cbdf48236a5076e)
1*12facf1bSJunchao Zhang #include <petscdevice_cuda.h>
2*12facf1bSJunchao Zhang #include <petsc/private/petsclegacycupmblas.h>
3*12facf1bSJunchao Zhang #include <../src/ksp/pc/impls/pbjacobi/pbjacobi.h>
4*12facf1bSJunchao Zhang 
5*12facf1bSJunchao Zhang #if PETSC_PKG_CUDA_VERSION_LT(11, 7, 0)
6*12facf1bSJunchao Zhang __global__ static void MatMultBatched(PetscInt bs, PetscInt mbs, const PetscScalar *A, const PetscScalar *x, PetscScalar *y, PetscBool transpose)
7*12facf1bSJunchao Zhang {
8*12facf1bSJunchao Zhang   const PetscInt gridSize = gridDim.x * blockDim.x;
9*12facf1bSJunchao Zhang   PetscInt       row      = blockIdx.x * blockDim.x + threadIdx.x;
10*12facf1bSJunchao Zhang   const PetscInt bs2      = bs * bs;
11*12facf1bSJunchao Zhang 
12*12facf1bSJunchao Zhang   /* One row per thread. The blocks are stored in column-major order */
13*12facf1bSJunchao Zhang   for (; row < bs * mbs; row += gridSize) {
14*12facf1bSJunchao Zhang     const PetscScalar *Ap, *xp;
15*12facf1bSJunchao Zhang     PetscScalar       *yp;
16*12facf1bSJunchao Zhang     PetscInt           i, j, k;
17*12facf1bSJunchao Zhang 
18*12facf1bSJunchao Zhang     k  = row / bs;                               /* k-th block */
19*12facf1bSJunchao Zhang     i  = row % bs;                               /* this thread deals with i-th row of the block */
20*12facf1bSJunchao Zhang     Ap = &A[bs2 * k + i * (transpose ? bs : 1)]; /* Ap points to the first entry of i-th row */
21*12facf1bSJunchao Zhang     xp = &x[bs * k];
22*12facf1bSJunchao Zhang     yp = &y[bs * k];
23*12facf1bSJunchao Zhang     /* multiply i-th row (column) with x */
24*12facf1bSJunchao Zhang     yp[i] = 0.0;
25*12facf1bSJunchao Zhang     for (j = 0; j < bs; j++) {
26*12facf1bSJunchao Zhang       yp[i] += Ap[0] * xp[j];
27*12facf1bSJunchao Zhang       Ap += (transpose ? 1 : bs); /* block is in column major order */
28*12facf1bSJunchao Zhang     }
29*12facf1bSJunchao Zhang   }
30*12facf1bSJunchao Zhang }
31*12facf1bSJunchao Zhang #endif
32*12facf1bSJunchao Zhang 
33*12facf1bSJunchao Zhang static PetscErrorCode PCApplyOrTranspose_PBJacobi_CUDA(PC pc, Vec x, Vec y, cublasOperation_t op)
34*12facf1bSJunchao Zhang {
35*12facf1bSJunchao Zhang   const PetscScalar *xx;
36*12facf1bSJunchao Zhang   PetscScalar       *yy;
37*12facf1bSJunchao Zhang   cublasHandle_t     handle;
38*12facf1bSJunchao Zhang   PC_PBJacobi       *jac = (PC_PBJacobi *)pc->data;
39*12facf1bSJunchao Zhang   const PetscScalar *A   = (const PetscScalar *)jac->spptr;
40*12facf1bSJunchao Zhang   const PetscInt     bs = jac->bs, mbs = jac->mbs;
41*12facf1bSJunchao Zhang 
42*12facf1bSJunchao Zhang   PetscFunctionBegin;
43*12facf1bSJunchao Zhang   PetscCall(VecCUDAGetArrayRead(x, &xx));
44*12facf1bSJunchao Zhang   PetscCall(VecCUDAGetArrayWrite(y, &yy));
45*12facf1bSJunchao Zhang   PetscCall(PetscCUBLASGetHandle(&handle));
46*12facf1bSJunchao Zhang   PetscCallCUBLAS(cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST)); /* alpha, beta are on host */
47*12facf1bSJunchao Zhang 
48*12facf1bSJunchao Zhang #if PETSC_PKG_CUDA_VERSION_GE(11, 7, 0)
49*12facf1bSJunchao Zhang   /* y = alpha op(A) x + beta y */
50*12facf1bSJunchao Zhang   const PetscScalar alpha = 1.0, beta = 0.0;
51*12facf1bSJunchao Zhang   PetscCallCUBLAS(cublasXgemvStridedBatched(handle, op, bs, bs, &alpha, A, bs, bs * bs, xx, 1, bs, &beta, yy, 1, bs, mbs));
52*12facf1bSJunchao Zhang #else
53*12facf1bSJunchao Zhang   PetscInt gridSize = PetscMin((bs * mbs + 255) / 256, 2147483647); /* <= 2^31-1 */
54*12facf1bSJunchao Zhang   MatMultBatched<<<gridSize, 256>>>(bs, mbs, A, xx, yy, (op == CUBLAS_OP_T ? PETSC_TRUE : PETSC_FALSE));
55*12facf1bSJunchao Zhang   PetscCallCUDA(cudaGetLastError());
56*12facf1bSJunchao Zhang #endif
57*12facf1bSJunchao Zhang   PetscCall(VecCUDARestoreArrayRead(x, &xx));
58*12facf1bSJunchao Zhang   PetscCall(VecCUDARestoreArrayWrite(y, &yy));
59*12facf1bSJunchao Zhang   PetscCall(PetscLogGpuFlops(bs * bs * mbs * 2));
60*12facf1bSJunchao Zhang   PetscFunctionReturn(0);
61*12facf1bSJunchao Zhang }
62*12facf1bSJunchao Zhang 
63*12facf1bSJunchao Zhang static PetscErrorCode PCApply_PBJacobi_CUDA(PC pc, Vec x, Vec y)
64*12facf1bSJunchao Zhang {
65*12facf1bSJunchao Zhang   PetscFunctionBegin;
66*12facf1bSJunchao Zhang   PetscCall(PCApplyOrTranspose_PBJacobi_CUDA(pc, x, y, CUBLAS_OP_N)); // No transpose
67*12facf1bSJunchao Zhang   PetscFunctionReturn(0);
68*12facf1bSJunchao Zhang }
69*12facf1bSJunchao Zhang 
70*12facf1bSJunchao Zhang static PetscErrorCode PCApplyTranspose_PBJacobi_CUDA(PC pc, Vec x, Vec y)
71*12facf1bSJunchao Zhang {
72*12facf1bSJunchao Zhang   PetscFunctionBegin;
73*12facf1bSJunchao Zhang   PetscCall(PCApplyOrTranspose_PBJacobi_CUDA(pc, x, y, CUBLAS_OP_T)); // Transpose
74*12facf1bSJunchao Zhang   PetscFunctionReturn(0);
75*12facf1bSJunchao Zhang }
76*12facf1bSJunchao Zhang 
77*12facf1bSJunchao Zhang static PetscErrorCode PCDestroy_PBJacobi_CUDA(PC pc)
78*12facf1bSJunchao Zhang {
79*12facf1bSJunchao Zhang   PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
80*12facf1bSJunchao Zhang 
81*12facf1bSJunchao Zhang   PetscFunctionBegin;
82*12facf1bSJunchao Zhang   PetscCallCUDA(cudaFree(jac->spptr));
83*12facf1bSJunchao Zhang   PetscCall(PCDestroy_PBJacobi(pc));
84*12facf1bSJunchao Zhang   PetscFunctionReturn(0);
85*12facf1bSJunchao Zhang }
86*12facf1bSJunchao Zhang 
87*12facf1bSJunchao Zhang PETSC_INTERN PetscErrorCode PCSetUp_PBJacobi_CUDA(PC pc)
88*12facf1bSJunchao Zhang {
89*12facf1bSJunchao Zhang   PC_PBJacobi *jac = (PC_PBJacobi *)pc->data;
90*12facf1bSJunchao Zhang   size_t       size;
91*12facf1bSJunchao Zhang 
92*12facf1bSJunchao Zhang   PetscFunctionBegin;
93*12facf1bSJunchao Zhang   PetscCall(PCSetUp_PBJacobi_Host(pc)); /* Compute the inverse on host now. Might worth doing it on device directly */
94*12facf1bSJunchao Zhang   size = sizeof(PetscScalar) * jac->bs * jac->bs * jac->mbs;
95*12facf1bSJunchao Zhang 
96*12facf1bSJunchao Zhang   /* PBJacobi_CUDA is simple so that we use jac->spptr as if it is diag_d */
97*12facf1bSJunchao Zhang   if (!jac->spptr) PetscCallCUDAVoid(cudaMalloc(&jac->spptr, size));
98*12facf1bSJunchao Zhang   PetscCallCUDAVoid(cudaMemcpy(jac->spptr, jac->diag, size, cudaMemcpyHostToDevice));
99*12facf1bSJunchao Zhang   PetscCall(PetscLogCpuToGpu(size));
100*12facf1bSJunchao Zhang 
101*12facf1bSJunchao Zhang   pc->ops->apply          = PCApply_PBJacobi_CUDA;
102*12facf1bSJunchao Zhang   pc->ops->applytranspose = PCApplyTranspose_PBJacobi_CUDA;
103*12facf1bSJunchao Zhang   pc->ops->destroy        = PCDestroy_PBJacobi_CUDA;
104*12facf1bSJunchao Zhang   PetscFunctionReturn(0);
105*12facf1bSJunchao Zhang }
106