xref: /petsc/src/ksp/pc/impls/vpbjacobi/kokkos/vpbjacobi_kok.kokkos.cxx (revision 9a56b47496cad4948bae971196c35464b04e1dcc)
1f1be3500SJunchao Zhang #include <petscvec_kokkos.hpp>
2f1be3500SJunchao Zhang #include <../src/vec/vec/impls/seq/kokkos/veckokkosimpl.hpp>
3f1be3500SJunchao Zhang #include <petscdevice.h>
4f1be3500SJunchao Zhang #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h>
5f1be3500SJunchao Zhang 
6f1be3500SJunchao Zhang /* A class that manages helper arrays assisting parallel PCApply() with Kokkos */
7f1be3500SJunchao Zhang struct PC_VPBJacobi_Kokkos {
8f1be3500SJunchao Zhang   /* Cache the old sizes to check if we need realloc */
9f1be3500SJunchao Zhang   PetscInt  n; /* number of rows of the local matrix */
10f1be3500SJunchao Zhang   PetscInt  nblocks; /* number of point blocks */
11f1be3500SJunchao Zhang   PetscInt  nsize; /* sum of sizes of the point blocks */
12f1be3500SJunchao Zhang 
13f1be3500SJunchao Zhang   /* Helper arrays that are pre-computed on host and then copied to device.
14f1be3500SJunchao Zhang     bs:     [nblocks+1], "csr" version of bsizes[]
15f1be3500SJunchao Zhang     bs2:    [nblocks+1], "csr" version of squares of bsizes[]
16f1be3500SJunchao Zhang     matIdx: [n], row i of the local matrix belongs to the matIdx_d[i] block
17f1be3500SJunchao Zhang   */
18f1be3500SJunchao Zhang   PetscIntKokkosDualView     bs_dual, bs2_dual, matIdx_dual;
19f1be3500SJunchao Zhang   PetscScalarKokkosDualView  diag_dual;
20f1be3500SJunchao Zhang 
21f1be3500SJunchao Zhang   PC_VPBJacobi_Kokkos(PetscInt n,PetscInt nblocks,PetscInt nsize,const PetscInt *bsizes,MatScalar *diag_ptr_h)
22f1be3500SJunchao Zhang     : n(n),nblocks(nblocks),nsize(nsize),
23f1be3500SJunchao Zhang       bs_dual("bs_dual",nblocks+1),bs2_dual("bs2_dual",nblocks+1),matIdx_dual("matIdx_dual",n)
24f1be3500SJunchao Zhang   {
25f1be3500SJunchao Zhang     PetscScalarKokkosViewHost    diag_h(diag_ptr_h,nsize);
26f1be3500SJunchao Zhang 
27f1be3500SJunchao Zhang     auto diag_d = Kokkos::create_mirror_view(DefaultMemorySpace(),diag_h);
28f1be3500SJunchao Zhang     diag_dual = PetscScalarKokkosDualView(diag_d,diag_h);
29f1be3500SJunchao Zhang     PetscCallVoid(UpdateOffsetsOnDevice(bsizes,diag_ptr_h));
30f1be3500SJunchao Zhang   }
31f1be3500SJunchao Zhang 
32f1be3500SJunchao Zhang   PetscErrorCode UpdateOffsetsOnDevice(const PetscInt *bsizes,MatScalar *diag_ptr_h)
33f1be3500SJunchao Zhang   {
34f1be3500SJunchao Zhang     PetscFunctionBegin;
35f1be3500SJunchao Zhang     PetscCheck(diag_dual.view_host().data() == diag_ptr_h, PETSC_COMM_SELF,PETSC_ERR_PLIB,"Host pointer has changed since last call");
36f1be3500SJunchao Zhang     PetscCall(ComputeOffsetsOnHost(bsizes));
37f1be3500SJunchao Zhang 
38f1be3500SJunchao Zhang     PetscCallCXX(bs_dual.modify_host());
39f1be3500SJunchao Zhang     PetscCallCXX(bs2_dual.modify_host());
40f1be3500SJunchao Zhang     PetscCallCXX(matIdx_dual.modify_host());
41f1be3500SJunchao Zhang     PetscCallCXX(diag_dual.modify_host());
42f1be3500SJunchao Zhang 
43f1be3500SJunchao Zhang     PetscCallCXX(bs_dual.sync_device());
44f1be3500SJunchao Zhang     PetscCallCXX(bs2_dual.sync_device());
45f1be3500SJunchao Zhang     PetscCallCXX(matIdx_dual.sync_device());
46f1be3500SJunchao Zhang     PetscCallCXX(diag_dual.sync_device());
47*9a56b474SJunchao Zhang     PetscCall(PetscLogCpuToGpu(sizeof(PetscInt)*(2*nblocks+2+n) + sizeof(MatScalar)*nsize));
48f1be3500SJunchao Zhang     PetscFunctionReturn(0);
49f1be3500SJunchao Zhang   }
50f1be3500SJunchao Zhang 
51f1be3500SJunchao Zhang private:
52f1be3500SJunchao Zhang   PetscErrorCode ComputeOffsetsOnHost(const PetscInt *bsizes)
53f1be3500SJunchao Zhang   {
54f1be3500SJunchao Zhang     PetscInt *bs_h     = bs_dual.view_host().data();
55f1be3500SJunchao Zhang     PetscInt *bs2_h    = bs2_dual.view_host().data();
56f1be3500SJunchao Zhang     PetscInt *matIdx_h = matIdx_dual.view_host().data();
57f1be3500SJunchao Zhang 
58f1be3500SJunchao Zhang     PetscFunctionBegin;
59f1be3500SJunchao Zhang     bs_h[0] = bs2_h[0] = 0;
60f1be3500SJunchao Zhang     for (PetscInt i=0; i<nblocks; i++) {
61f1be3500SJunchao Zhang       bs_h[i+1]  = bs_h[i]  + bsizes[i];
62f1be3500SJunchao Zhang       bs2_h[i+1] = bs2_h[i] + bsizes[i]*bsizes[i];
63f1be3500SJunchao Zhang       for (PetscInt j=0; j<bsizes[i]; j++) matIdx_h[bs_h[i]+j] = i;
64f1be3500SJunchao Zhang     }
65f1be3500SJunchao Zhang     PetscFunctionReturn(0);
66f1be3500SJunchao Zhang   }
67f1be3500SJunchao Zhang };
68f1be3500SJunchao Zhang 
69f1be3500SJunchao Zhang static PetscErrorCode PCApply_VPBJacobi_Kokkos(PC pc,Vec x,Vec y)
70f1be3500SJunchao Zhang {
71f1be3500SJunchao Zhang   PC_VPBJacobi               *jac = (PC_VPBJacobi*)pc->data;
72f1be3500SJunchao Zhang   PC_VPBJacobi_Kokkos        *pckok = static_cast<PC_VPBJacobi_Kokkos*>(jac->spptr);
73f1be3500SJunchao Zhang   ConstPetscScalarKokkosView xv;
74f1be3500SJunchao Zhang   PetscScalarKokkosView      yv;
75f1be3500SJunchao Zhang   PetscScalarKokkosView      diag = pckok->diag_dual.view_device();
76f1be3500SJunchao Zhang   PetscIntKokkosView         bs = pckok->bs_dual.view_device();
77f1be3500SJunchao Zhang   PetscIntKokkosView         bs2 = pckok->bs2_dual.view_device();
78f1be3500SJunchao Zhang   PetscIntKokkosView         matIdx = pckok->matIdx_dual.view_device();
79f1be3500SJunchao Zhang 
80f1be3500SJunchao Zhang   PetscFunctionBegin;
81*9a56b474SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
82f1be3500SJunchao Zhang   VecErrorIfNotKokkos(x);
83f1be3500SJunchao Zhang   VecErrorIfNotKokkos(y);
84f1be3500SJunchao Zhang   PetscCall(VecGetKokkosView(x,&xv));
85f1be3500SJunchao Zhang   PetscCall(VecGetKokkosViewWrite(y,&yv));
86f1be3500SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for("PCApply_VPBJacobi_Kokkos",pckok->n,KOKKOS_LAMBDA(PetscInt row) {
87f1be3500SJunchao Zhang     const PetscScalar *Ap,*xp;
88f1be3500SJunchao Zhang     PetscScalar       *yp;
89f1be3500SJunchao Zhang     PetscInt          i,j,k,m;
90f1be3500SJunchao Zhang 
91f1be3500SJunchao Zhang     k  = matIdx(row);        /* k-th block/matrix */
92f1be3500SJunchao Zhang     m  = bs(k+1) - bs(k);    /* block size of the k-th block */
93f1be3500SJunchao Zhang     i  = row - bs(k);        /* i-th row of the block */
94f1be3500SJunchao Zhang     Ap = &diag(bs2(k) + i);  /* Ap points to the first entry of i-th row */
95f1be3500SJunchao Zhang     xp = &xv(bs(k));
96f1be3500SJunchao Zhang     yp = &yv(bs(k));
97f1be3500SJunchao Zhang 
98f1be3500SJunchao Zhang     yp[i] = 0.0;
99f1be3500SJunchao Zhang     for (j=0; j<m; j++) {yp[i] += Ap[0]*xp[j]; Ap += m;}
100f1be3500SJunchao Zhang   }));
101f1be3500SJunchao Zhang   PetscCall(VecRestoreKokkosView(x,&xv));
102f1be3500SJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(y,&yv));
103*9a56b474SJunchao Zhang   PetscCall(PetscLogGpuFlops(pckok->nsize*2)); /* FMA on entries in all blocks */
104*9a56b474SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
105f1be3500SJunchao Zhang   PetscFunctionReturn(0);
106f1be3500SJunchao Zhang }
107f1be3500SJunchao Zhang 
108f1be3500SJunchao Zhang static PetscErrorCode PCDestroy_VPBJacobi_Kokkos(PC pc)
109f1be3500SJunchao Zhang {
110f1be3500SJunchao Zhang   PC_VPBJacobi       *jac = (PC_VPBJacobi*)pc->data;
111f1be3500SJunchao Zhang 
112f1be3500SJunchao Zhang   PetscFunctionBegin;
113f1be3500SJunchao Zhang   PetscCallCXX(delete static_cast<PC_VPBJacobi_Kokkos*>(jac->spptr));
114f1be3500SJunchao Zhang   PetscCall(PCDestroy_VPBJacobi(pc));
115f1be3500SJunchao Zhang   PetscFunctionReturn(0);
116f1be3500SJunchao Zhang }
117f1be3500SJunchao Zhang 
118f1be3500SJunchao Zhang PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Kokkos(PC pc)
119f1be3500SJunchao Zhang {
120f1be3500SJunchao Zhang   PC_VPBJacobi        *jac = (PC_VPBJacobi*)pc->data;
121f1be3500SJunchao Zhang   PC_VPBJacobi_Kokkos *pckok = static_cast<PC_VPBJacobi_Kokkos*>(jac->spptr);
122f1be3500SJunchao Zhang   PetscInt            i,n,nblocks,nsize = 0;
123f1be3500SJunchao Zhang   const PetscInt      *bsizes;
124f1be3500SJunchao Zhang 
125f1be3500SJunchao Zhang   PetscFunctionBegin;
126f1be3500SJunchao Zhang   PetscCall(PCSetUp_VPBJacobi_Host(pc)); /* Compute the inverse on host now. Might worth doing it on device directly */
127f1be3500SJunchao Zhang   PetscCall(MatGetVariableBlockSizes(pc->pmat,&nblocks,&bsizes));
128f1be3500SJunchao Zhang   for (i=0; i<nblocks; i++) nsize += bsizes[i]*bsizes[i];
129f1be3500SJunchao Zhang   PetscCall(MatGetLocalSize(pc->pmat,&n,NULL));
130f1be3500SJunchao Zhang 
131f1be3500SJunchao Zhang   /* If one calls MatSetVariableBlockSizes() multiple times and sizes have been changed (is it allowed?), we delete the old and rebuild anyway */
132f1be3500SJunchao Zhang   if (pckok && (pckok->n != n || pckok->nblocks != nblocks || pckok->nsize != nsize)) {
133f1be3500SJunchao Zhang     PetscCallCXX(delete pckok);
134f1be3500SJunchao Zhang     pckok = nullptr;
135f1be3500SJunchao Zhang   }
136f1be3500SJunchao Zhang 
137f1be3500SJunchao Zhang   if (!pckok) { /* allocate the struct along with the helper arrays from the scatch */
138f1be3500SJunchao Zhang     PetscCallCXX(jac->spptr = new PC_VPBJacobi_Kokkos(n,nblocks,nsize,bsizes,jac->diag));
139f1be3500SJunchao Zhang   } else { /* update the value only */
140f1be3500SJunchao Zhang     PetscCall(pckok->UpdateOffsetsOnDevice(bsizes,jac->diag));
141f1be3500SJunchao Zhang   }
142f1be3500SJunchao Zhang 
143f1be3500SJunchao Zhang   pc->ops->apply   = PCApply_VPBJacobi_Kokkos;
144f1be3500SJunchao Zhang   pc->ops->destroy = PCDestroy_VPBJacobi_Kokkos;
145f1be3500SJunchao Zhang   PetscFunctionReturn(0);
146f1be3500SJunchao Zhang }
147