xref: /petsc/src/ksp/pc/impls/bjacobi/bjkokkos/bjkokkos.kokkos.cxx (revision 63a3b9bc7a1f24f247904ccba9383635fe6abade)
1e607c864SMark Adams #include <petscvec_kokkos.hpp>
2e607c864SMark Adams #include <petsc/private/pcimpl.h>
3e607c864SMark Adams #include <petsc/private/kspimpl.h>
4e607c864SMark Adams #include <petscksp.h>            /*I "petscksp.h" I*/
5e607c864SMark Adams #include "petscsection.h"
6e607c864SMark Adams #include <petscdmcomposite.h>
7e607c864SMark Adams #include <Kokkos_Core.hpp>
8e607c864SMark Adams 
9e607c864SMark Adams typedef Kokkos::TeamPolicy<>::member_type team_member;
10e607c864SMark Adams 
11e607c864SMark Adams #include <../src/mat/impls/aij/seq/aij.h>
12e607c864SMark Adams #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
13e607c864SMark Adams 
14e607c864SMark Adams #define PCBJKOKKOS_SHARED_LEVEL 1
15e607c864SMark Adams #define PCBJKOKKOS_VEC_SIZE 16
16e607c864SMark Adams #define PCBJKOKKOS_TEAM_SIZE 16
17e607c864SMark Adams #define PCBJKOKKOS_VERBOSE_LEVEL 0
18e607c864SMark Adams 
1926f882f0SMark Adams typedef enum {BATCH_KSP_BICG_IDX,BATCH_KSP_TFQMR_IDX,BATCH_KSP_GMRES_IDX,NUM_BATCH_TYPES} KSPIndex;
20e607c864SMark Adams typedef struct {
21e607c864SMark Adams   Vec                                              vec_diag;
22e607c864SMark Adams   PetscInt                                         nBlocks; /* total number of blocks */
23e607c864SMark Adams   PetscInt                                         n; // cache host version of d_bid_eqOffset_k[nBlocks]
24e607c864SMark Adams   KSP                                              ksp; // Used just for options. Should have one for each block
25e607c864SMark Adams   Kokkos::View<PetscInt*, Kokkos::LayoutRight>     *d_bid_eqOffset_k;
26e607c864SMark Adams   Kokkos::View<PetscScalar*, Kokkos::LayoutRight>  *d_idiag_k;
27e607c864SMark Adams   Kokkos::View<PetscInt*>                          *d_isrow_k;
28e607c864SMark Adams   Kokkos::View<PetscInt*>                          *d_isicol_k;
29e607c864SMark Adams   KSPIndex                                         ksp_type_idx;
30e607c864SMark Adams   PetscInt                                         nwork;
31e607c864SMark Adams   PetscInt                                         const_block_size; // used to decide to use shared memory for work vectors
32e607c864SMark Adams   PetscInt                                         *dm_Nf;  // Number of fields in each DM
33e607c864SMark Adams   PetscInt                                         num_dms;
34e607c864SMark Adams   // diagnostics
35e607c864SMark Adams   PetscBool                                        reason;
36e607c864SMark Adams   PetscBool                                        monitor;
37e607c864SMark Adams   PetscInt                                         batch_target;
38e607c864SMark Adams } PC_PCBJKOKKOS;
39e607c864SMark Adams 
40e607c864SMark Adams static PetscErrorCode  PCBJKOKKOSCreateKSP_BJKOKKOS(PC pc)
41e607c864SMark Adams {
42e607c864SMark Adams   const char    *prefix;
43e607c864SMark Adams   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;
44e607c864SMark Adams   DM             dm;
45e607c864SMark Adams 
46e607c864SMark Adams   PetscFunctionBegin;
479566063dSJacob Faibussowitsch   PetscCall(KSPCreate(PetscObjectComm((PetscObject)pc),&jac->ksp));
489566063dSJacob Faibussowitsch   PetscCall(KSPSetErrorIfNotConverged(jac->ksp,pc->erroriffailure));
499566063dSJacob Faibussowitsch   PetscCall(PetscObjectIncrementTabLevel((PetscObject)jac->ksp,(PetscObject)pc,1));
509566063dSJacob Faibussowitsch   PetscCall(PCGetOptionsPrefix(pc,&prefix));
519566063dSJacob Faibussowitsch   PetscCall(KSPSetOptionsPrefix(jac->ksp,prefix));
529566063dSJacob Faibussowitsch   PetscCall(KSPAppendOptionsPrefix(jac->ksp,"pc_bjkokkos_"));
539566063dSJacob Faibussowitsch   PetscCall(PCGetDM(pc,&dm));
54e607c864SMark Adams   if (dm) {
559566063dSJacob Faibussowitsch     PetscCall(KSPSetDM(jac->ksp, dm));
569566063dSJacob Faibussowitsch     PetscCall(KSPSetDMActive(jac->ksp, PETSC_FALSE));
57e607c864SMark Adams   }
58e607c864SMark Adams   jac->reason       = PETSC_FALSE;
59e607c864SMark Adams   jac->monitor      = PETSC_FALSE;
60e607c864SMark Adams   jac->batch_target = 0;
61e607c864SMark Adams   PetscFunctionReturn(0);
62e607c864SMark Adams }
63e607c864SMark Adams 
64e607c864SMark Adams // y <-- Ax
65e607c864SMark Adams KOKKOS_INLINE_FUNCTION PetscErrorCode MatMult(const team_member team,  const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
66e607c864SMark Adams {
67e607c864SMark Adams   Kokkos::parallel_for(Kokkos::TeamThreadRange(team,start,end), [=] (const int rowb) {
68e607c864SMark Adams       int rowa = ic[rowb];
69e607c864SMark Adams       int n = glb_Aai[rowa+1] - glb_Aai[rowa];
70e607c864SMark Adams       const PetscInt    *aj  = glb_Aaj + glb_Aai[rowa];
71e607c864SMark Adams       const PetscScalar *aa  = glb_Aaa + glb_Aai[rowa];
72e607c864SMark Adams       PetscScalar sum;
73e607c864SMark Adams       Kokkos::parallel_reduce(Kokkos::ThreadVectorRange (team, n), [=] (const int i, PetscScalar& lsum) {
74e607c864SMark Adams           lsum += aa[i] * x_loc[r[aj[i]]-start];
75e607c864SMark Adams         }, sum);
76e607c864SMark Adams       Kokkos::single(Kokkos::PerThread (team),[=]() {y_loc[rowb-start] = sum;});
77e607c864SMark Adams     });
78e607c864SMark Adams   team.team_barrier();
79e607c864SMark Adams   return 0;
80e607c864SMark Adams }
81e607c864SMark Adams 
82e607c864SMark Adams // temp buffer per thread with reduction at end?
83e607c864SMark Adams KOKKOS_INLINE_FUNCTION PetscErrorCode MatMultTranspose(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
84e607c864SMark Adams {
85e607c864SMark Adams   Kokkos::parallel_for(Kokkos::TeamVectorRange(team,end-start), [=] (int i) { y_loc[i] = 0;});
86e607c864SMark Adams   team.team_barrier();
87e607c864SMark Adams   Kokkos::parallel_for(Kokkos::TeamThreadRange(team,start,end), [=] (const int rowb) {
88e607c864SMark Adams       int rowa = ic[rowb];
89e607c864SMark Adams       int n = glb_Aai[rowa+1] - glb_Aai[rowa];
90e607c864SMark Adams       const PetscInt    *aj  = glb_Aaj + glb_Aai[rowa];
91e607c864SMark Adams       const PetscScalar *aa  = glb_Aaa + glb_Aai[rowa];
92e607c864SMark Adams       const PetscScalar xx = x_loc[rowb-start]; // rowb = ic[rowa] = ic[r[rowb]]
93e607c864SMark Adams       Kokkos::parallel_for(Kokkos::ThreadVectorRange(team,n), [=] (const int &i) {
94e607c864SMark Adams           PetscScalar val = aa[i] * xx;
95e607c864SMark Adams           Kokkos::atomic_fetch_add(&y_loc[r[aj[i]]-start], val);
96e607c864SMark Adams         });
97e607c864SMark Adams     });
98e607c864SMark Adams   team.team_barrier();
99e607c864SMark Adams   return 0;
100e607c864SMark Adams }
101e607c864SMark Adams 
102e607c864SMark Adams typedef struct Batch_MetaData_TAG
103e607c864SMark Adams {
104e607c864SMark Adams   PetscInt           flops;
105e607c864SMark Adams   PetscInt           its;
106e607c864SMark Adams   KSPConvergedReason reason;
107e607c864SMark Adams }Batch_MetaData;
108e607c864SMark Adams 
109e607c864SMark Adams // Solve A(BB^-1)x = y with TFQMR. Right preconditioned to get un-preconditioned residual
110e607c864SMark Adams KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_TFQMR(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space, const PetscInt stride, PetscReal rtol, PetscReal atol, PetscReal dtol,PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
111e607c864SMark Adams {
112e607c864SMark Adams   using Kokkos::parallel_reduce;
113e607c864SMark Adams   using Kokkos::parallel_for;
114e607c864SMark Adams   int               Nblk = end-start, i,m;
115e607c864SMark Adams   PetscReal         dp,dpold,w,dpest,tau,psi,cm,r0;
116e607c864SMark Adams   PetscScalar       *ptr = work_space, rho,rhoold,a,s,b,eta,etaold,psiold,cf,dpi;
117e607c864SMark Adams   const PetscScalar *Diag = &glb_idiag[start];
118e607c864SMark Adams   PetscScalar       *XX = ptr; ptr += stride;
119e607c864SMark Adams   PetscScalar       *R = ptr; ptr += stride;
120e607c864SMark Adams   PetscScalar       *RP = ptr; ptr += stride;
121e607c864SMark Adams   PetscScalar       *V = ptr; ptr += stride;
122e607c864SMark Adams   PetscScalar       *T = ptr; ptr += stride;
123e607c864SMark Adams   PetscScalar       *Q = ptr; ptr += stride;
124e607c864SMark Adams   PetscScalar       *P = ptr; ptr += stride;
125e607c864SMark Adams   PetscScalar       *U = ptr; ptr += stride;
126e607c864SMark Adams   PetscScalar       *D = ptr; ptr += stride;
127e607c864SMark Adams   PetscScalar       *AUQ = V;
128e607c864SMark Adams 
129e607c864SMark Adams   // init: get b, zero x
130e607c864SMark Adams   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
131e607c864SMark Adams       int rowa = ic[rowb];
132e607c864SMark Adams       R[rowb-start] = glb_b[rowa];
133e607c864SMark Adams       XX[rowb-start] = 0;
134e607c864SMark Adams     });
135e607c864SMark Adams   team.team_barrier();
136e607c864SMark Adams   parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += R[idx]*PetscConj(R[idx]);}, dpi);
137e607c864SMark Adams   team.team_barrier();
138e607c864SMark Adams   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
139e607c864SMark Adams   // diagnostics
140a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
141e607c864SMark Adams   if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", 0, (double)dp);});
142a25f047fSMark Adams #endif
143e607c864SMark Adams   if (dp < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; return 0;}
144e607c864SMark Adams   if (0 == maxit) {metad->reason = KSP_DIVERGED_ITS; return 0;}
145e607c864SMark Adams 
146e607c864SMark Adams   /* Make the initial Rp = R */
147e607c864SMark Adams   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {RP[idx] = R[idx];});
148e607c864SMark Adams   team.team_barrier();
149e607c864SMark Adams   /* Set the initial conditions */
150e607c864SMark Adams   etaold = 0.0;
151e607c864SMark Adams   psiold = 0.0;
152e607c864SMark Adams   tau    = dp;
153e607c864SMark Adams   dpold  = dp;
154e607c864SMark Adams 
155e607c864SMark Adams   /* rhoold = (r,rp)     */
156e607c864SMark Adams   parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += R[idx]*PetscConj(RP[idx]);}, rhoold);
157e607c864SMark Adams   team.team_barrier();
158e607c864SMark Adams   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {U[idx] = R[idx]; P[idx] = R[idx]; T[idx] = Diag[idx]*P[idx]; D[idx] = 0;});
159e607c864SMark Adams   team.team_barrier();
160e607c864SMark Adams   MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,T,V);
161e607c864SMark Adams 
162e607c864SMark Adams   i=0;
163e607c864SMark Adams   do {
164e607c864SMark Adams     /* s <- (v,rp)          */
165e607c864SMark Adams     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += V[idx]*PetscConj(RP[idx]);}, s);
166e607c864SMark Adams     team.team_barrier();
167e607c864SMark Adams     a    = rhoold / s;                              /* a <- rho / s         */
168e607c864SMark Adams     /* q <- u - a v    VecWAXPY(w,alpha,x,y): w = alpha x + y.     */
169e607c864SMark Adams     /* t <- u + q           */
170e607c864SMark Adams     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Q[idx] = U[idx] - a*V[idx]; T[idx] = U[idx] + Q[idx];});
171e607c864SMark Adams     team.team_barrier();
172e607c864SMark Adams     // KSP_PCApplyBAorAB
173e607c864SMark Adams     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {T[idx] = Diag[idx]*T[idx]; });
174e607c864SMark Adams     team.team_barrier();
175e607c864SMark Adams     MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,T,AUQ);
176e607c864SMark Adams     /* r <- r - a K (u + q) */
177e607c864SMark Adams     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {R[idx] = R[idx] - a*AUQ[idx]; });
178e607c864SMark Adams     team.team_barrier();
179e607c864SMark Adams     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += R[idx]*PetscConj(R[idx]);}, dpi);
180e607c864SMark Adams     team.team_barrier();
181e607c864SMark Adams     dp = PetscSqrtReal(PetscRealPart(dpi));
182e607c864SMark Adams     for (m=0; m<2; m++) {
183e607c864SMark Adams       if (!m) w = PetscSqrtReal(dp*dpold);
184e607c864SMark Adams       else w = dp;
185e607c864SMark Adams       psi = w / tau;
186e607c864SMark Adams       cm  = 1.0 / PetscSqrtReal(1.0 + psi * psi);
187e607c864SMark Adams       tau = tau * psi * cm;
188e607c864SMark Adams       eta = cm * cm * a;
189e607c864SMark Adams       cf  = psiold * psiold * etaold / a;
190e607c864SMark Adams       if (!m) {
191e607c864SMark Adams         /* D = U + cf D */
192e607c864SMark Adams         parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {D[idx] = U[idx] + cf*D[idx]; });
193e607c864SMark Adams       } else {
194e607c864SMark Adams         /* D = Q + cf D */
195e607c864SMark Adams         parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {D[idx] = Q[idx] + cf*D[idx]; });
196e607c864SMark Adams       }
197e607c864SMark Adams       team.team_barrier();
198e607c864SMark Adams       parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {XX[idx] = XX[idx] + eta*D[idx]; });
199e607c864SMark Adams       team.team_barrier();
200e607c864SMark Adams       dpest = PetscSqrtReal(2*i + m + 2.0) * tau;
201a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
20226f882f0SMark Adams       if (monitor && m==1) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", i+1, (double)dpest);});
203a25f047fSMark Adams #endif
204e607c864SMark Adams       if (dpest < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; goto done;}
205e607c864SMark Adams       if (dpest/r0 < rtol) {metad->reason = KSP_CONVERGED_RTOL_NORMAL; goto done;}
206a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
207e607c864SMark Adams       if (dpest/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; Kokkos::single (Kokkos::PerTeam (team), [=] () {printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n",team.league_rank(),i,dpest,r0);}); goto done;}
208e607c864SMark Adams #else
209e607c864SMark Adams       if (dpest/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; goto done;}
210e607c864SMark Adams #endif
211e607c864SMark Adams       if (i+1 == maxit) {metad->reason = KSP_DIVERGED_ITS; goto done;}
212e607c864SMark Adams 
213e607c864SMark Adams       etaold = eta;
214e607c864SMark Adams       psiold = psi;
215e607c864SMark Adams     }
216e607c864SMark Adams 
217e607c864SMark Adams     /* rho <- (r,rp)       */
218e607c864SMark Adams     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += R[idx]*PetscConj(RP[idx]);}, rho);
219e607c864SMark Adams     team.team_barrier();
220e607c864SMark Adams     b    = rho / rhoold;                            /* b <- rho / rhoold   */
221e607c864SMark Adams     /* u <- r + b q        */
222e607c864SMark Adams     /* p <- u + b(q + b p) */
223e607c864SMark Adams     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {U[idx] = R[idx] + b*Q[idx]; Q[idx] = Q[idx] + b*P[idx]; P[idx] = U[idx] + b*Q[idx];});
224e607c864SMark Adams     /* v <- K p  */
225e607c864SMark Adams     team.team_barrier();
226e607c864SMark Adams     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {T[idx] = Diag[idx]*P[idx]; });
227e607c864SMark Adams     team.team_barrier();
228e607c864SMark Adams     MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,T,V);
229e607c864SMark Adams 
230e607c864SMark Adams     rhoold = rho;
231e607c864SMark Adams     dpold  = dp;
232e607c864SMark Adams 
233e607c864SMark Adams     i++;
234e607c864SMark Adams   } while (i<maxit);
235e607c864SMark Adams   done:
236e607c864SMark Adams   // KSPUnwindPreconditioner
237e607c864SMark Adams   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {XX[idx] = Diag[idx]*XX[idx]; });
238e607c864SMark Adams   team.team_barrier();
239e607c864SMark Adams   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
240e607c864SMark Adams       int rowa = ic[rowb];
241e607c864SMark Adams       glb_x[rowa] = XX[rowb-start];
242e607c864SMark Adams     });
243e607c864SMark Adams   metad->its = i+1;
244e607c864SMark Adams   if (1) {
245e607c864SMark Adams     int nnz;
246e607c864SMark Adams     parallel_reduce(Kokkos::TeamVectorRange (team, start, end), [=] (const int idx, int& lsum) {lsum += (glb_Aai[idx+1] - glb_Aai[idx]);}, nnz);
247e607c864SMark Adams     metad->flops = 2*(metad->its*(10*Nblk + 2*nnz) + 5*Nblk);
248e607c864SMark Adams   } else {
249e607c864SMark Adams     metad->flops = 2*(metad->its*(10*Nblk + 2*50*Nblk) + 5*Nblk); // guess
250e607c864SMark Adams   }
251e607c864SMark Adams   return 0;
252e607c864SMark Adams }
253e607c864SMark Adams 
254e607c864SMark Adams // Solve Ax = y with biCG
255e607c864SMark Adams KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_BICG(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space, const PetscInt stride, PetscReal rtol, PetscReal atol, PetscReal dtol,PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
256e607c864SMark Adams {
257e607c864SMark Adams   using Kokkos::parallel_reduce;
258e607c864SMark Adams   using Kokkos::parallel_for;
259e607c864SMark Adams   int               Nblk = end-start, i;
260e607c864SMark Adams   PetscReal         dp, r0;
261e607c864SMark Adams   PetscScalar       *ptr = work_space, dpi, a=1.0, beta, betaold=1.0, b, b2, ma, mac;
262e607c864SMark Adams   const PetscScalar *Di = &glb_idiag[start];
263e607c864SMark Adams   PetscScalar       *XX = ptr; ptr += stride;
264e607c864SMark Adams   PetscScalar       *Rl = ptr; ptr += stride;
265e607c864SMark Adams   PetscScalar       *Zl = ptr; ptr += stride;
266e607c864SMark Adams   PetscScalar       *Pl = ptr; ptr += stride;
267e607c864SMark Adams   PetscScalar       *Rr = ptr; ptr += stride;
268e607c864SMark Adams   PetscScalar       *Zr = ptr; ptr += stride;
269e607c864SMark Adams   PetscScalar       *Pr = ptr; ptr += stride;
270e607c864SMark Adams 
271e607c864SMark Adams   /*     r <- b (x is 0) */
272e607c864SMark Adams   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
273e607c864SMark Adams       int rowa = ic[rowb];
2749566063dSJacob Faibussowitsch       //PetscCall(VecCopy(Rr,Rl));
275e607c864SMark Adams       Rl[rowb-start] = Rr[rowb-start] = glb_b[rowa];
276e607c864SMark Adams       XX[rowb-start] = 0;
277e607c864SMark Adams     });
278e607c864SMark Adams   team.team_barrier();
279e607c864SMark Adams   /*     z <- Br         */
280e607c864SMark Adams   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Zr[idx] = Di[idx]*Rr[idx]; Zl[idx] = Di[idx]*Rl[idx]; });
281e607c864SMark Adams   team.team_barrier();
282e607c864SMark Adams   /*    dp <- r'*r       */
283e607c864SMark Adams   parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += Rr[idx]*PetscConj(Rr[idx]);}, dpi);
284e607c864SMark Adams   team.team_barrier();
285e607c864SMark Adams   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
286a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
287e607c864SMark Adams   if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", 0, (double)dp);});
288a25f047fSMark Adams #endif
289e607c864SMark Adams   if (dp < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; return 0;}
290e607c864SMark Adams   if (0 == maxit) {metad->reason = KSP_DIVERGED_ITS; return 0;}
291e607c864SMark Adams   i = 0;
292e607c864SMark Adams   do {
293e607c864SMark Adams     /*     beta <- r'z     */
294e607c864SMark Adams     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += Zr[idx]*PetscConj(Rl[idx]);}, beta);
295e607c864SMark Adams     team.team_barrier();
296e607c864SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL >= 6
297a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
298e607c864SMark Adams     Kokkos::single (Kokkos::PerTeam (team), [=] () {printf("%7d beta = Z.R = %22.14e \n",i,(double)beta);});
299e607c864SMark Adams #endif
300a25f047fSMark Adams #endif
301e607c864SMark Adams     if (!i) {
302e607c864SMark Adams       if (beta == 0.0) {
303e607c864SMark Adams         metad->reason = KSP_DIVERGED_BREAKDOWN_BICG;
304e607c864SMark Adams         goto done;
305e607c864SMark Adams       }
306e607c864SMark Adams       /*     p <- z          */
307e607c864SMark Adams       parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Pr[idx] = Zr[idx]; Pl[idx] = Zl[idx];});
308e607c864SMark Adams     } else {
309e607c864SMark Adams       b    = beta/betaold;
310e607c864SMark Adams       /*     p <- z + b* p   */
311e607c864SMark Adams       b2    = PetscConj(b);
312e607c864SMark Adams       parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Pr[idx] = b*Pr[idx] + Zr[idx]; Pl[idx] = b2*Pl[idx] + Zl[idx];});
313e607c864SMark Adams     }
314e607c864SMark Adams     team.team_barrier();
315e607c864SMark Adams     betaold = beta;
316e607c864SMark Adams     /*     z <- Kp         */
317e607c864SMark Adams     MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,Pr,Zr);
318e607c864SMark Adams     MatMultTranspose(team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,Pl,Zl);
319e607c864SMark Adams     /*     dpi <- z'p      */
320e607c864SMark Adams     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += Zr[idx]*PetscConj(Pl[idx]);}, dpi);
321e607c864SMark Adams     team.team_barrier();
322e607c864SMark Adams     //
323e607c864SMark Adams     a       = beta/dpi;                           /*     a = beta/p'z    */
324e607c864SMark Adams     ma      = -a;
325e607c864SMark Adams     mac      = PetscConj(ma);
326e607c864SMark Adams     /*     x <- x + ap     */
327e607c864SMark Adams     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {XX[idx] = XX[idx] + a*Pr[idx]; Rr[idx] = Rr[idx] + ma*Zr[idx]; Rl[idx] = Rl[idx] + mac*Zl[idx];});team.team_barrier();
328e607c864SMark Adams     team.team_barrier();
329e607c864SMark Adams     /*    dp <- r'*r       */
330e607c864SMark Adams     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum +=  Rr[idx]*PetscConj(Rr[idx]);}, dpi);
331e607c864SMark Adams     team.team_barrier();
332e607c864SMark Adams     dp = PetscSqrtReal(PetscRealPart(dpi));
333a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
334e607c864SMark Adams     if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", i+1, (double)dp);});
335a25f047fSMark Adams #endif
336e607c864SMark Adams     if (dp < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; goto done;}
337e607c864SMark Adams     if (dp/r0 < rtol) {metad->reason = KSP_CONVERGED_RTOL_NORMAL; goto done;}
338a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
339e607c864SMark Adams     if (dp/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; Kokkos::single (Kokkos::PerTeam (team), [=] () {printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n",team.league_rank(),i,dp,r0);}); goto done;}
340e607c864SMark Adams #else
341e607c864SMark Adams     if (dp/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; goto done;}
342e607c864SMark Adams #endif
343e607c864SMark Adams     if (i+1 == maxit) {metad->reason = KSP_DIVERGED_ITS; goto done;}
344e607c864SMark Adams     /* z <- Br  */
345e607c864SMark Adams     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Zr[idx] = Di[idx]*Rr[idx]; Zl[idx] = Di[idx]*Rl[idx];});
346e607c864SMark Adams     i++;
347e607c864SMark Adams     team.team_barrier();
348e607c864SMark Adams   } while (i<maxit);
349e607c864SMark Adams  done:
350e607c864SMark Adams   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
351e607c864SMark Adams       int rowa = ic[rowb];
352e607c864SMark Adams       glb_x[rowa] = XX[rowb-start];
353e607c864SMark Adams     });
354e607c864SMark Adams   metad->its = i+1;
355e607c864SMark Adams   if (1) {
356e607c864SMark Adams     int nnz;
357e607c864SMark Adams     parallel_reduce(Kokkos::TeamVectorRange (team, start, end), [=] (const int idx, int& lsum) {lsum += (glb_Aai[idx+1] - glb_Aai[idx]);}, nnz);
358e607c864SMark Adams     metad->flops = 2*(metad->its*(10*Nblk + 2*nnz) + 5*Nblk);
359e607c864SMark Adams   } else {
360e607c864SMark Adams     metad->flops = 2*(metad->its*(10*Nblk + 2*50*Nblk) + 5*Nblk); // guess
361e607c864SMark Adams   }
362e607c864SMark Adams   return 0;
363e607c864SMark Adams }
364e607c864SMark Adams 
365e607c864SMark Adams // KSP solver solve Ax = b; x is output, bin is input
366e607c864SMark Adams static PetscErrorCode PCApply_BJKOKKOS(PC pc,Vec bin,Vec xout)
367e607c864SMark Adams {
368e607c864SMark Adams   PC_PCBJKOKKOS    *jac = (PC_PCBJKOKKOS*)pc->data;
369e607c864SMark Adams   Mat               A   = pc->pmat;
370e607c864SMark Adams   Mat_SeqAIJKokkos *aijkok;
371e607c864SMark Adams 
372e607c864SMark Adams   PetscFunctionBegin;
3739effc8a0SJed Brown   PetscCheck(jac->vec_diag && A,PetscObjectComm((PetscObject)pc),PETSC_ERR_USER,"Not setup???? %p %p",jac->vec_diag,A);
3749effc8a0SJed Brown   aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
3759effc8a0SJed Brown   if (!aijkok) {
3769effc8a0SJed Brown     SETERRQ(PetscObjectComm((PetscObject)pc),PETSC_ERR_USER,"No aijkok");
3779effc8a0SJed Brown   } else {
378e607c864SMark Adams     using scr_mem_t  = Kokkos::DefaultExecutionSpace::scratch_memory_space;
379e607c864SMark Adams     using vect2D_scr_t = Kokkos::View<PetscScalar**, Kokkos::LayoutLeft, scr_mem_t>;
380e607c864SMark Adams     PetscInt          *d_bid_eqOffset, maxit = jac->ksp->max_it, scr_bytes_team, stride, global_buff_size;
381e607c864SMark Adams     const PetscInt    conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp==0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
382e607c864SMark Adams     const PetscInt    nwork = jac->nwork, nBlk = jac->nBlocks;
383e607c864SMark Adams     PetscScalar       *glb_xdata=NULL;
384e607c864SMark Adams     PetscReal         rtol = jac->ksp->rtol, atol = jac->ksp->abstol, dtol = jac->ksp->divtol;
385e607c864SMark Adams     const PetscScalar *glb_idiag =jac->d_idiag_k->data(), *glb_bdata=NULL;
386e607c864SMark Adams     const PetscInt    *glb_Aai = aijkok->i_device_data(), *glb_Aaj = aijkok->j_device_data();
387e607c864SMark Adams     const PetscScalar *glb_Aaa = aijkok->a_device_data();
388e607c864SMark Adams     Kokkos::View<Batch_MetaData*, Kokkos::DefaultExecutionSpace> d_metadata("solver meta data", nBlk);
389e607c864SMark Adams     PCFailedReason    pcreason;
390e607c864SMark Adams     KSPIndex          ksp_type_idx = jac->ksp_type_idx;
391e607c864SMark Adams     PetscMemType      mtype;
392e607c864SMark Adams     PetscContainer    container;
393e607c864SMark Adams     PetscInt          batch_sz;
394e607c864SMark Adams     VecScatter        plex_batch=NULL;
395e607c864SMark Adams     Vec               bvec;
396e607c864SMark Adams     PetscBool         monitor = jac->monitor; // captured
397e607c864SMark Adams     PetscInt          view_bid = jac->batch_target;
398e607c864SMark Adams     // get field major is to map plex IO to/from block/field major
3999566063dSJacob Faibussowitsch     PetscCall(PetscObjectQuery((PetscObject) A, "plex_batch_is", (PetscObject *) &container));
4009566063dSJacob Faibussowitsch     PetscCall(VecDuplicate(bin,&bvec));
401e607c864SMark Adams     if (container) {
4029566063dSJacob Faibussowitsch       PetscCall(PetscContainerGetPointer(container, (void **) &plex_batch));
4039566063dSJacob Faibussowitsch       PetscCall(VecScatterBegin(plex_batch,bin,bvec,INSERT_VALUES,SCATTER_FORWARD));
4049566063dSJacob Faibussowitsch       PetscCall(VecScatterEnd(plex_batch,bin,bvec,INSERT_VALUES,SCATTER_FORWARD));
405e607c864SMark Adams     } else {
4069566063dSJacob Faibussowitsch       PetscCall(VecCopy(bin, bvec));
407e607c864SMark Adams     }
408e607c864SMark Adams     // get x
4099566063dSJacob Faibussowitsch     PetscCall(VecGetArrayAndMemType(xout,&glb_xdata,&mtype));
410e607c864SMark Adams #if defined(PETSC_HAVE_CUDA)
411*63a3b9bcSJacob Faibussowitsch     PetscCheck(PetscMemTypeDevice(mtype),PetscObjectComm((PetscObject) pc),PETSC_ERR_ARG_WRONG,"No GPU data for x %" PetscInt_FMT " != %" PetscInt_FMT,mtype,PETSC_MEMTYPE_DEVICE);
412e607c864SMark Adams #endif
4139566063dSJacob Faibussowitsch     PetscCall(VecGetArrayReadAndMemType(bvec,&glb_bdata,&mtype));
414e607c864SMark Adams #if defined(PETSC_HAVE_CUDA)
4159effc8a0SJed Brown     PetscCheck(PetscMemTypeDevice(mtype),PetscObjectComm((PetscObject) pc),PETSC_ERR_ARG_WRONG,"No GPU data for b");
416e607c864SMark Adams #endif
417e607c864SMark Adams     // get batch size
4189566063dSJacob Faibussowitsch     PetscCall(PetscObjectQuery((PetscObject) A, "batch size", (PetscObject *) &container));
419e607c864SMark Adams     if (container) {
420e607c864SMark Adams       PetscInt *pNf=NULL;
4219566063dSJacob Faibussowitsch       PetscCall(PetscContainerGetPointer(container, (void **) &pNf));
422e607c864SMark Adams       batch_sz = *pNf;
423e607c864SMark Adams     } else batch_sz = 1;
424d618d24fSMark Adams     PetscCheck(nBlk%batch_sz == 0,PetscObjectComm((PetscObject) pc),PETSC_ERR_ARG_WRONG,"batch_sz = %" PetscInt_FMT ", nBlk = %" PetscInt_FMT,batch_sz,nBlk);
425e607c864SMark Adams     d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
426e607c864SMark Adams     // solve each block independently
427e607c864SMark Adams     if (jac->const_block_size) { // use shared memory for work vectors only if constant block size - todo: test efficiency loss
428e607c864SMark Adams       scr_bytes_team = jac->const_block_size*nwork*sizeof(PetscScalar);
429e607c864SMark Adams       stride = jac->const_block_size; // captured
430e607c864SMark Adams       global_buff_size = 0;
431e607c864SMark Adams     } else {
432e607c864SMark Adams       scr_bytes_team = 0;
433e607c864SMark Adams       stride = jac->n; // captured
434e607c864SMark Adams       global_buff_size = jac->n*nwork;
435e607c864SMark Adams     }
436e607c864SMark Adams     Kokkos::View<PetscScalar*, Kokkos::DefaultExecutionSpace> d_work_vecs_k("workvectors", global_buff_size); // global work vectors
437*63a3b9bcSJacob Faibussowitsch     PetscCall(PetscInfo(pc,"\tn = %" PetscInt_FMT ". %d shared mem words/team. %" PetscInt_FMT " global mem words, rtol=%e, num blocks %" PetscInt_FMT ", team_size=%" PetscInt_FMT ", %d vector threads\n",jac->n,(PetscInt)(scr_bytes_team/sizeof(PetscScalar)),global_buff_size,(double)rtol,nBlk,team_size, PCBJKOKKOS_VEC_SIZE));
438e607c864SMark Adams     PetscScalar  *d_work_vecs = scr_bytes_team ? NULL : d_work_vecs_k.data();
439e607c864SMark Adams     const PetscInt *d_isicol = jac->d_isicol_k->data(), *d_isrow = jac->d_isrow_k->data();
440e607c864SMark Adams     Kokkos::parallel_for("Solve", Kokkos::TeamPolicy<>(nBlk, team_size, PCBJKOKKOS_VEC_SIZE).set_scratch_size(PCBJKOKKOS_SHARED_LEVEL, Kokkos::PerTeam(scr_bytes_team)),
441e607c864SMark Adams         KOKKOS_LAMBDA (const team_member team) {
442e607c864SMark Adams         const int    blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID+1];
443e607c864SMark Adams         vect2D_scr_t work_vecs(team.team_scratch(PCBJKOKKOS_SHARED_LEVEL), scr_bytes_team ? (end-start) : 0, nwork);
444e607c864SMark Adams         PetscScalar *work_buff = (scr_bytes_team) ? work_vecs.data() : &d_work_vecs[start];
445e607c864SMark Adams         bool        print = monitor && (blkID==view_bid);
446e607c864SMark Adams         switch (ksp_type_idx) {
447e607c864SMark Adams         case BATCH_KSP_BICG_IDX:
448e607c864SMark Adams           BJSolve_BICG(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff, stride, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print);
449e607c864SMark Adams           break;
450e607c864SMark Adams         case BATCH_KSP_TFQMR_IDX:
451e607c864SMark Adams           BJSolve_TFQMR(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff, stride, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print);
452e607c864SMark Adams           break;
45326f882f0SMark Adams         case BATCH_KSP_GMRES_IDX:
454a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
45526f882f0SMark Adams           printf("GMRES not implemented %d\n",ksp_type_idx);
45626f882f0SMark Adams #else
45726f882f0SMark Adams           /* void */
45826f882f0SMark Adams #endif
45926f882f0SMark Adams           break;
460e607c864SMark Adams         default:
461a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
462e607c864SMark Adams           printf("Unknown KSP type %d\n",ksp_type_idx);
463e607c864SMark Adams #else
464e607c864SMark Adams           /* void */;
465e607c864SMark Adams #endif
466e607c864SMark Adams         }
467e607c864SMark Adams     });
468e607c864SMark Adams     auto h_metadata = Kokkos::create_mirror(Kokkos::HostSpace::memory_space(), d_metadata);
469e607c864SMark Adams     Kokkos::fence();
470e607c864SMark Adams     Kokkos::deep_copy (h_metadata, d_metadata);
471e607c864SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL >= 3
472e607c864SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
4739566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_WORLD,"Iterations\n"));
474e607c864SMark Adams #endif
475e607c864SMark Adams     // assume species major
476e607c864SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL < 4
4779566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_WORLD,"max iterations per species (%s) :",ksp_type_idx==BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr"));
478e607c864SMark Adams #endif
479e607c864SMark Adams     for (PetscInt dmIdx=0, s=0, head=0 ; dmIdx < jac->num_dms; dmIdx += batch_sz) {
480e607c864SMark Adams       for (PetscInt f=0, idx=head ; f < jac->dm_Nf[dmIdx] ; f++,s++,idx++) {
481e607c864SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
482*63a3b9bcSJacob Faibussowitsch         PetscCall(PetscPrintf(PETSC_COMM_WORLD,"%2" PetscInt_FMT ":", s));
483e607c864SMark Adams         for (int bid=0 ; bid<batch_sz ; bid++) {
484*63a3b9bcSJacob Faibussowitsch          PetscCall(PetscPrintf(PETSC_COMM_WORLD,"%3" PetscInt_FMT " ", h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its));
485e607c864SMark Adams         }
4869566063dSJacob Faibussowitsch         PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n"));
487e607c864SMark Adams #else
488e607c864SMark Adams         PetscInt count=0;
489e607c864SMark Adams         for (int bid=0 ; bid<batch_sz ; bid++) {
490e607c864SMark Adams           if (h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its > count) count = h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its;
491e607c864SMark Adams         }
492*63a3b9bcSJacob Faibussowitsch         PetscCall(PetscPrintf(PETSC_COMM_WORLD,"%3" PetscInt_FMT " ", count));
493e607c864SMark Adams #endif
494e607c864SMark Adams       }
495e607c864SMark Adams       head += batch_sz*jac->dm_Nf[dmIdx];
496e607c864SMark Adams     }
497e607c864SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL < 4
4989566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_WORLD,"\n"));
499e607c864SMark Adams #endif
500e607c864SMark Adams #endif
50126f882f0SMark Adams     PetscInt count=0, mbid=0;
502e607c864SMark Adams     for (int blkID=0;blkID<nBlk;blkID++) {
5039566063dSJacob Faibussowitsch       PetscCall(PetscLogGpuFlops((PetscLogDouble)h_metadata[blkID].flops));
504e607c864SMark Adams       if (jac->reason) {
50526f882f0SMark Adams         if (jac->batch_target==blkID) {
5069566063dSJacob Faibussowitsch           PetscCall(PetscPrintf(PETSC_COMM_SELF,  "    Linear solve converged due to %s iterations %d, batch %" PetscInt_FMT ", species %" PetscInt_FMT "\n", KSPConvergedReasons[h_metadata[blkID].reason], h_metadata[blkID].its, blkID%batch_sz, blkID/batch_sz));
50726f882f0SMark Adams         } else if (jac->batch_target==-1 && h_metadata[blkID].its > count) {
50826f882f0SMark Adams           count = h_metadata[blkID].its;
50926f882f0SMark Adams           mbid = blkID;
510e607c864SMark Adams         }
51126f882f0SMark Adams         if (h_metadata[blkID].reason < 0) {
5129566063dSJacob Faibussowitsch           PetscCall(PetscPrintf(PETSC_COMM_SELF, "ERROR reason=%s, its=%" PetscInt_FMT ". species %" PetscInt_FMT ", batch %" PetscInt_FMT "\n",
5135f80ce2aSJacob Faibussowitsch                               KSPConvergedReasons[h_metadata[blkID].reason],h_metadata[blkID].its,blkID/batch_sz,blkID%batch_sz));
514e607c864SMark Adams         }
515e607c864SMark Adams       }
516e607c864SMark Adams     }
51726f882f0SMark Adams     if (jac->batch_target==-1 && jac->reason) {
5189566063dSJacob Faibussowitsch       PetscCall(PetscPrintf(PETSC_COMM_SELF,  "    Linear solve converged due to %s iterations %d, batch %" PetscInt_FMT ", specie %" PetscInt_FMT "\n", KSPConvergedReasons[h_metadata[mbid].reason], h_metadata[mbid].its,mbid%batch_sz,mbid/batch_sz));
51926f882f0SMark Adams     }
5209566063dSJacob Faibussowitsch     PetscCall(VecRestoreArrayAndMemType(xout,&glb_xdata));
5219566063dSJacob Faibussowitsch     PetscCall(VecRestoreArrayReadAndMemType(bvec,&glb_bdata));
522e607c864SMark Adams     {
523e607c864SMark Adams       int errsum;
524e607c864SMark Adams       Kokkos::parallel_reduce(nBlk, KOKKOS_LAMBDA (const int idx, int& lsum) {
52526f882f0SMark Adams           if (d_metadata[idx].reason < 0) ++lsum;
526e607c864SMark Adams         }, errsum);
52726f882f0SMark Adams       pcreason = errsum ? PC_SUBPC_ERROR : PC_NOERROR;
528e607c864SMark Adams     }
5299566063dSJacob Faibussowitsch     PetscCall(PCSetFailedReason(pc,pcreason));
530e607c864SMark Adams     // map back to Plex space
531e607c864SMark Adams     if (plex_batch) {
5329566063dSJacob Faibussowitsch       PetscCall(VecCopy(xout, bvec));
5339566063dSJacob Faibussowitsch       PetscCall(VecScatterBegin(plex_batch,bvec,xout,INSERT_VALUES,SCATTER_REVERSE));
5349566063dSJacob Faibussowitsch       PetscCall(VecScatterEnd(plex_batch,bvec,xout,INSERT_VALUES,SCATTER_REVERSE));
535e607c864SMark Adams     }
5369566063dSJacob Faibussowitsch     PetscCall(VecDestroy(&bvec));
537e607c864SMark Adams   }
538e607c864SMark Adams   PetscFunctionReturn(0);
539e607c864SMark Adams }
540e607c864SMark Adams 
541e607c864SMark Adams static PetscErrorCode PCSetUp_BJKOKKOS(PC pc)
542e607c864SMark Adams {
543e607c864SMark Adams   PC_PCBJKOKKOS    *jac = (PC_PCBJKOKKOS*)pc->data;
544e607c864SMark Adams   Mat               A   = pc->pmat;
545e607c864SMark Adams   Mat_SeqAIJKokkos *aijkok;
546e607c864SMark Adams   PetscBool         flg;
547e607c864SMark Adams 
548e607c864SMark Adams   PetscFunctionBegin;
5499effc8a0SJed Brown   PetscCheck(!pc->useAmat,PetscObjectComm((PetscObject)pc),PETSC_ERR_SUP,"No support for using 'use_amat'");
5509effc8a0SJed Brown   PetscCheck(A,PetscObjectComm((PetscObject)A),PETSC_ERR_USER,"No matrix - A is used above");
5519566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompareAny((PetscObject)A,&flg,MATSEQAIJKOKKOS,MATMPIAIJKOKKOS,MATAIJKOKKOS,""));
5529effc8a0SJed Brown   PetscCheck(flg,PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONG,"must use '-dm_mat_type aijkokkos -dm_vec_type kokkos' for -pc_type bjkokkos");
5539effc8a0SJed Brown   if (!(aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr))) {
5549effc8a0SJed Brown     SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_USER,"No aijkok");
5559effc8a0SJed Brown   } else {
556e607c864SMark Adams     if (!jac->vec_diag) {
557e607c864SMark Adams       Vec               *subX;
558e607c864SMark Adams       DM                pack,*subDM;
559e607c864SMark Adams       PetscInt          nDMs, n;
560e607c864SMark Adams       PetscContainer    container;
5619566063dSJacob Faibussowitsch       PetscCall(PetscObjectQuery((PetscObject) A, "plex_batch_is", (PetscObject *) &container));
562e607c864SMark Adams       { // Permute the matrix to get a block diagonal system: d_isrow_k, d_isicol_k
563e607c864SMark Adams         MatOrderingType   rtype;
564e607c864SMark Adams         IS                isrow,isicol;
565e607c864SMark Adams         const PetscInt    *rowindices,*icolindices;
566e607c864SMark Adams 
567e607c864SMark Adams         if (container) rtype = MATORDERINGNATURAL; // if we have a vecscatter then don't reorder here (all the reorder stuff goes away in future)
568e607c864SMark Adams         else rtype = MATORDERINGRCM;
569e607c864SMark Adams         // get permutation. Not what I expect so inverted here
5709566063dSJacob Faibussowitsch         PetscCall(MatGetOrdering(A,rtype,&isrow,&isicol));
5719566063dSJacob Faibussowitsch         PetscCall(ISDestroy(&isrow));
5729566063dSJacob Faibussowitsch         PetscCall(ISInvertPermutation(isicol,PETSC_DECIDE,&isrow));
5739566063dSJacob Faibussowitsch         PetscCall(ISGetIndices(isrow,&rowindices));
5749566063dSJacob Faibussowitsch         PetscCall(ISGetIndices(isicol,&icolindices));
575e607c864SMark Adams         const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_isrow_k((PetscInt*)rowindices,A->rmap->n);
576e607c864SMark Adams         const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_isicol_k ((PetscInt*)icolindices,A->rmap->n);
577e607c864SMark Adams         jac->d_isrow_k = new Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_isrow_k));
578e607c864SMark Adams         jac->d_isicol_k = new Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_isicol_k));
579e607c864SMark Adams         Kokkos::deep_copy (*jac->d_isrow_k, h_isrow_k);
580e607c864SMark Adams         Kokkos::deep_copy (*jac->d_isicol_k, h_isicol_k);
5819566063dSJacob Faibussowitsch         PetscCall(ISRestoreIndices(isrow,&rowindices));
5829566063dSJacob Faibussowitsch         PetscCall(ISRestoreIndices(isicol,&icolindices));
5839566063dSJacob Faibussowitsch         PetscCall(ISDestroy(&isrow));
5849566063dSJacob Faibussowitsch         PetscCall(ISDestroy(&isicol));
585e607c864SMark Adams       }
586e607c864SMark Adams       // get block sizes
5879566063dSJacob Faibussowitsch       PetscCall(PCGetDM(pc, &pack));
5889effc8a0SJed Brown       PetscCheck(pack,PetscObjectComm((PetscObject)A),PETSC_ERR_USER,"no DM. Requires a composite DM");
5899566063dSJacob Faibussowitsch       PetscCall(PetscObjectTypeCompare((PetscObject)pack,DMCOMPOSITE,&flg));
5909effc8a0SJed Brown       PetscCheck(flg,PetscObjectComm((PetscObject)pack),PETSC_ERR_USER,"Not for type %s",((PetscObject)pack)->type_name);
5919566063dSJacob Faibussowitsch       PetscCall(DMCompositeGetNumberDM(pack,&nDMs));
592e607c864SMark Adams       jac->num_dms = nDMs;
5939566063dSJacob Faibussowitsch       PetscCall(DMCreateGlobalVector(pack, &jac->vec_diag));
5949566063dSJacob Faibussowitsch       PetscCall(VecGetLocalSize(jac->vec_diag,&n));
595e607c864SMark Adams       jac->n = n;
596e607c864SMark Adams       jac->d_idiag_k = new Kokkos::View<PetscScalar*, Kokkos::LayoutRight>("idiag", n);
597e607c864SMark Adams       // options
5989566063dSJacob Faibussowitsch       PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
5999566063dSJacob Faibussowitsch       PetscCall(KSPSetFromOptions(jac->ksp));
6009566063dSJacob Faibussowitsch       PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp,&flg,KSPBICG,""));
601e607c864SMark Adams       if (flg) {jac->ksp_type_idx = BATCH_KSP_BICG_IDX; jac->nwork = 7;}
602e607c864SMark Adams       else {
6039566063dSJacob Faibussowitsch         PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp,&flg,KSPTFQMR,""));
604e607c864SMark Adams         if (flg) {jac->ksp_type_idx = BATCH_KSP_TFQMR_IDX; jac->nwork = 10;}
60526f882f0SMark Adams         else {
6069566063dSJacob Faibussowitsch           PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp,&flg,KSPGMRES,""));
60726f882f0SMark Adams           if (flg) {jac->ksp_type_idx = BATCH_KSP_GMRES_IDX; jac->nwork = 0;}
60826f882f0SMark Adams           SETERRQ(PetscObjectComm((PetscObject)jac->ksp),PETSC_ERR_ARG_WRONG,"unsupported type %s", ((PetscObject)jac->ksp)->type_name);
60926f882f0SMark Adams         }
610e607c864SMark Adams       }
611e607c864SMark Adams       {
612e607c864SMark Adams         PetscViewer       viewer;
613e607c864SMark Adams         PetscBool         flg;
614e607c864SMark Adams         PetscViewerFormat format;
6159566063dSJacob Faibussowitsch         PetscCall(PetscOptionsGetViewer(PetscObjectComm((PetscObject)jac->ksp),((PetscObject)jac->ksp)->options,((PetscObject)jac->ksp)->prefix,"-ksp_converged_reason",&viewer,&format,&flg));
616e607c864SMark Adams         jac->reason = flg;
6179566063dSJacob Faibussowitsch         PetscCall(PetscViewerDestroy(&viewer));
6189566063dSJacob Faibussowitsch         PetscCall(PetscOptionsGetViewer(PetscObjectComm((PetscObject)jac->ksp),((PetscObject)jac->ksp)->options,((PetscObject)jac->ksp)->prefix,"-ksp_monitor",&viewer,&format,&flg));
619e607c864SMark Adams         jac->monitor = flg;
6209566063dSJacob Faibussowitsch         PetscCall(PetscViewerDestroy(&viewer));
6219566063dSJacob Faibussowitsch         PetscCall(PetscOptionsGetInt(((PetscObject)jac->ksp)->options,((PetscObject)jac->ksp)->prefix,"-ksp_batch_target",&jac->batch_target,&flg));
6225f80ce2aSJacob Faibussowitsch         PetscCheck(jac->batch_target < jac->num_dms,PETSC_COMM_WORLD,PETSC_ERR_ARG_WRONG,"-ksp_batch_target (%" PetscInt_FMT ") >= number of DMs (%" PetscInt_FMT ")",jac->batch_target,jac->num_dms);
62326f882f0SMark Adams         if (!jac->monitor && !flg) jac->batch_target = -1; // turn it off
624e607c864SMark Adams       }
625e607c864SMark Adams       // get blocks - jac->d_bid_eqOffset_k
6269566063dSJacob Faibussowitsch       PetscCall(PetscMalloc(sizeof(*subX)*nDMs, &subX));
6279566063dSJacob Faibussowitsch       PetscCall(PetscMalloc(sizeof(*subDM)*nDMs, &subDM));
6289566063dSJacob Faibussowitsch       PetscCall(PetscMalloc(sizeof(*jac->dm_Nf)*nDMs, &jac->dm_Nf));
629*63a3b9bcSJacob Faibussowitsch       PetscCall(PetscInfo(pc, "Have %" PetscInt_FMT " DMs, n=%" PetscInt_FMT " rtol=%g type = %s\n", nDMs, n, (double)jac->ksp->rtol, ((PetscObject)jac->ksp)->type_name));
6309566063dSJacob Faibussowitsch       PetscCall(DMCompositeGetEntriesArray(pack,subDM));
631e607c864SMark Adams       jac->nBlocks = 0;
632e607c864SMark Adams       for (PetscInt ii=0;ii<nDMs;ii++) {
633e607c864SMark Adams         PetscSection section;
634e607c864SMark Adams         PetscInt Nf;
635e607c864SMark Adams         DM dm = subDM[ii];
6369566063dSJacob Faibussowitsch         PetscCall(DMGetLocalSection(dm, &section));
6379566063dSJacob Faibussowitsch         PetscCall(PetscSectionGetNumFields(section, &Nf));
638e607c864SMark Adams         jac->nBlocks += Nf;
639e607c864SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
6409566063dSJacob Faibussowitsch         if (ii==0) PetscCall(PetscInfo(pc,"%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n",ii,Nf,jac->nBlocks));
641e607c864SMark Adams #else
6429566063dSJacob Faibussowitsch         PetscCall(PetscInfo(pc,"%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n",ii,Nf,jac->nBlocks));
643e607c864SMark Adams #endif
644e607c864SMark Adams         jac->dm_Nf[ii] = Nf;
645e607c864SMark Adams       }
646e607c864SMark Adams       { // d_bid_eqOffset_k
647e607c864SMark Adams         Kokkos::View<PetscInt*, Kokkos::LayoutRight, Kokkos::HostSpace> h_block_offsets("block_offsets", jac->nBlocks+1);
6489566063dSJacob Faibussowitsch         PetscCall(DMCompositeGetAccessArray(pack, jac->vec_diag, nDMs, NULL, subX));
649e607c864SMark Adams         h_block_offsets[0] = 0;
650e607c864SMark Adams         jac->const_block_size = -1;
651e607c864SMark Adams         for (PetscInt ii=0, idx = 0;ii<nDMs;ii++) {
652e607c864SMark Adams           PetscInt nloc,nblk;
6539566063dSJacob Faibussowitsch           PetscCall(VecGetSize(subX[ii],&nloc));
654e607c864SMark Adams           nblk = nloc/jac->dm_Nf[ii];
655*63a3b9bcSJacob Faibussowitsch           PetscCheck(nloc%jac->dm_Nf[ii] == 0,PetscObjectComm((PetscObject)pc),PETSC_ERR_USER,"nloc%%jac->dm_Nf[ii] (%" PetscInt_FMT ") != 0 DMs",nloc%jac->dm_Nf[ii]);
656e607c864SMark Adams           for (PetscInt jj=0;jj<jac->dm_Nf[ii];jj++, idx++) {
657e607c864SMark Adams             h_block_offsets[idx+1] = h_block_offsets[idx] + nblk;
658e607c864SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
6599566063dSJacob Faibussowitsch             if (idx==0) PetscCall(PetscInfo(pc,"\t%" PetscInt_FMT ") Add block with %" PetscInt_FMT " equations of %" PetscInt_FMT "\n",idx+1,nblk,jac->nBlocks));
660e607c864SMark Adams #else
6619566063dSJacob Faibussowitsch             PetscCall(PetscInfo(pc,"\t%" PetscInt_FMT ") Add block with %" PetscInt_FMT " equations of %" PetscInt_FMT "\n",idx+1,nblk,jac->nBlocks));
662e607c864SMark Adams #endif
663e607c864SMark Adams             if (jac->const_block_size == -1) jac->const_block_size = nblk;
664e607c864SMark Adams             else if (jac->const_block_size > 0 && jac->const_block_size != nblk) jac->const_block_size = 0;
665e607c864SMark Adams           }
666e607c864SMark Adams         }
6679566063dSJacob Faibussowitsch         PetscCall(DMCompositeRestoreAccessArray(pack, jac->vec_diag, jac->nBlocks, NULL, subX));
6689566063dSJacob Faibussowitsch         PetscCall(PetscFree(subX));
6699566063dSJacob Faibussowitsch         PetscCall(PetscFree(subDM));
670e607c864SMark Adams         jac->d_bid_eqOffset_k = new Kokkos::View<PetscInt*, Kokkos::LayoutRight>(Kokkos::create_mirror(Kokkos::DefaultExecutionSpace::memory_space(),h_block_offsets));
671e607c864SMark Adams         Kokkos::deep_copy (*jac->d_bid_eqOffset_k, h_block_offsets);
672e607c864SMark Adams       }
673e607c864SMark Adams     }
674e607c864SMark Adams     { // get jac->d_idiag_k (PC setup),
675e607c864SMark Adams       const PetscInt    *d_ai=aijkok->i_device_data(), *d_aj=aijkok->j_device_data();
676e607c864SMark Adams       const PetscScalar *d_aa = aijkok->a_device_data();
677e607c864SMark Adams       const PetscInt    conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp==0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
678e607c864SMark Adams       PetscInt          *d_bid_eqOffset = jac->d_bid_eqOffset_k->data(), *r = jac->d_isrow_k->data(), *ic = jac->d_isicol_k->data();
679e607c864SMark Adams       PetscScalar       *d_idiag = jac->d_idiag_k->data();
680e607c864SMark Adams       Kokkos::parallel_for("Diag", Kokkos::TeamPolicy<>(jac->nBlocks, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA (const team_member team) {
681e607c864SMark Adams           const PetscInt blkID = team.league_rank();
682e607c864SMark Adams           Kokkos::parallel_for
683e607c864SMark Adams             (Kokkos::TeamThreadRange(team,d_bid_eqOffset[blkID],d_bid_eqOffset[blkID+1]),
684e607c864SMark Adams              [=] (const int rowb) {
685e607c864SMark Adams                const PetscInt    rowa = ic[rowb], ai = d_ai[rowa], *aj = d_aj + ai; // grab original data
686e607c864SMark Adams                const PetscScalar *aa  = d_aa + ai;
687e607c864SMark Adams                const PetscInt    nrow = d_ai[rowa + 1] - ai;
688e607c864SMark Adams                int found;
689e607c864SMark Adams                Kokkos::parallel_reduce
690e607c864SMark Adams                  (Kokkos::ThreadVectorRange (team, nrow),
691e607c864SMark Adams                   [=] (const int& j, int &count) {
692e607c864SMark Adams                     const PetscInt colb = r[aj[j]];
693e607c864SMark Adams                     if (colb==rowb) {
694e607c864SMark Adams                       d_idiag[rowb] = 1./aa[j];
695e607c864SMark Adams                       count++;
696e607c864SMark Adams                     }}, found);
697a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
698e607c864SMark Adams                if (found!=1) Kokkos::single (Kokkos::PerThread (team), [=] () {printf("ERRORrow %d) found = %d\n",rowb,found);});
699a25f047fSMark Adams #endif
700e607c864SMark Adams              });
701e607c864SMark Adams         });
702e607c864SMark Adams     }
703e607c864SMark Adams   }
704e607c864SMark Adams   PetscFunctionReturn(0);
705e607c864SMark Adams }
706e607c864SMark Adams 
707e607c864SMark Adams /* Default destroy, if it has never been setup */
708e607c864SMark Adams static PetscErrorCode PCReset_BJKOKKOS(PC pc)
709e607c864SMark Adams {
710e607c864SMark Adams   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;
711e607c864SMark Adams 
712e607c864SMark Adams   PetscFunctionBegin;
7139566063dSJacob Faibussowitsch   PetscCall(KSPDestroy(&jac->ksp));
7149566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&jac->vec_diag));
715e607c864SMark Adams   if (jac->d_bid_eqOffset_k) delete jac->d_bid_eqOffset_k;
716e607c864SMark Adams   if (jac->d_idiag_k) delete jac->d_idiag_k;
717e607c864SMark Adams   if (jac->d_isrow_k) delete jac->d_isrow_k;
718e607c864SMark Adams   if (jac->d_isicol_k) delete jac->d_isicol_k;
719e607c864SMark Adams   jac->d_bid_eqOffset_k = NULL;
720e607c864SMark Adams   jac->d_idiag_k = NULL;
721e607c864SMark Adams   jac->d_isrow_k = NULL;
722e607c864SMark Adams   jac->d_isicol_k = NULL;
7239566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSGetKSP_C",NULL)); // not published now (causes configure errors)
7249566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSSetKSP_C",NULL));
7259566063dSJacob Faibussowitsch   PetscCall(PetscFree(jac->dm_Nf));
726e607c864SMark Adams   jac->dm_Nf = NULL;
727e607c864SMark Adams   PetscFunctionReturn(0);
728e607c864SMark Adams }
729e607c864SMark Adams 
730e607c864SMark Adams static PetscErrorCode PCDestroy_BJKOKKOS(PC pc)
731e607c864SMark Adams {
732e607c864SMark Adams   PetscFunctionBegin;
7339566063dSJacob Faibussowitsch   PetscCall(PCReset_BJKOKKOS(pc));
7349566063dSJacob Faibussowitsch   PetscCall(PetscFree(pc->data));
735e607c864SMark Adams   PetscFunctionReturn(0);
736e607c864SMark Adams }
737e607c864SMark Adams 
738e607c864SMark Adams static PetscErrorCode PCView_BJKOKKOS(PC pc,PetscViewer viewer)
739e607c864SMark Adams {
740e607c864SMark Adams   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;
741e607c864SMark Adams   PetscBool      iascii;
742e607c864SMark Adams 
743e607c864SMark Adams   PetscFunctionBegin;
7449566063dSJacob Faibussowitsch   if (!jac->ksp) PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
7459566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&iascii));
746e607c864SMark Adams   if (iascii) {
7479566063dSJacob Faibussowitsch     PetscCall(PetscViewerASCIIPrintf(viewer,"  Batched device linear solver: Krylov (KSP) method with Jacobi preconditioning\n"));
7489566063dSJacob Faibussowitsch     PetscCall(PetscViewerASCIIPrintf(viewer,"\t\tnwork = %" PetscInt_FMT ", rel tol = %e, abs tol = %e, div tol = %e, max it =%" PetscInt_FMT ", type = %s\n",jac->nwork,jac->ksp->rtol,
749e607c864SMark Adams                                    jac->ksp->abstol, jac->ksp->divtol, jac->ksp->max_it,
7505f80ce2aSJacob Faibussowitsch                                    ((PetscObject)jac->ksp)->type_name));
751e607c864SMark Adams   }
752e607c864SMark Adams   PetscFunctionReturn(0);
753e607c864SMark Adams }
754e607c864SMark Adams 
755e607c864SMark Adams static PetscErrorCode PCSetFromOptions_BJKOKKOS(PetscOptionItems *PetscOptionsObject,PC pc)
756e607c864SMark Adams {
757e607c864SMark Adams   PetscFunctionBegin;
758d0609cedSBarry Smith   PetscOptionsHeadBegin(PetscOptionsObject,"PC BJKOKKOS options");
759d0609cedSBarry Smith   PetscOptionsHeadEnd();
760e607c864SMark Adams   PetscFunctionReturn(0);
761e607c864SMark Adams }
762e607c864SMark Adams 
763e607c864SMark Adams static PetscErrorCode  PCBJKOKKOSSetKSP_BJKOKKOS(PC pc,KSP ksp)
764e607c864SMark Adams {
765e607c864SMark Adams   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;
766e607c864SMark Adams 
767e607c864SMark Adams   PetscFunctionBegin;
7689566063dSJacob Faibussowitsch   PetscCall(PetscObjectReference((PetscObject)ksp));
7699566063dSJacob Faibussowitsch   PetscCall(KSPDestroy(&jac->ksp));
770e607c864SMark Adams   jac->ksp = ksp;
771e607c864SMark Adams   PetscFunctionReturn(0);
772e607c864SMark Adams }
773e607c864SMark Adams 
774e607c864SMark Adams /*@C
775e607c864SMark Adams    PCBJKOKKOSSetKSP - Sets the KSP context for a KSP PC.
776e607c864SMark Adams 
777e607c864SMark Adams    Collective on PC
778e607c864SMark Adams 
779e607c864SMark Adams    Input Parameters:
780e607c864SMark Adams +  pc - the preconditioner context
781e607c864SMark Adams -  ksp - the KSP solver
782e607c864SMark Adams 
783e607c864SMark Adams    Notes:
784e607c864SMark Adams    The PC and the KSP must have the same communicator
785e607c864SMark Adams 
786e607c864SMark Adams    Level: advanced
787e607c864SMark Adams 
788e607c864SMark Adams @*/
789e607c864SMark Adams PetscErrorCode  PCBJKOKKOSSetKSP(PC pc,KSP ksp)
790e607c864SMark Adams {
791e607c864SMark Adams   PetscFunctionBegin;
792e607c864SMark Adams   PetscValidHeaderSpecific(pc,PC_CLASSID,1);
793e607c864SMark Adams   PetscValidHeaderSpecific(ksp,KSP_CLASSID,2);
794e607c864SMark Adams   PetscCheckSameComm(pc,1,ksp,2);
795cac4c232SBarry Smith   PetscTryMethod(pc,"PCBJKOKKOSSetKSP_C",(PC,KSP),(pc,ksp));
796e607c864SMark Adams   PetscFunctionReturn(0);
797e607c864SMark Adams }
798e607c864SMark Adams 
799e607c864SMark Adams static PetscErrorCode  PCBJKOKKOSGetKSP_BJKOKKOS(PC pc,KSP *ksp)
800e607c864SMark Adams {
801e607c864SMark Adams   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS*)pc->data;
802e607c864SMark Adams 
803e607c864SMark Adams   PetscFunctionBegin;
8049566063dSJacob Faibussowitsch   if (!jac->ksp) PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
805e607c864SMark Adams   *ksp = jac->ksp;
806e607c864SMark Adams   PetscFunctionReturn(0);
807e607c864SMark Adams }
808e607c864SMark Adams 
809e607c864SMark Adams /*@C
810e607c864SMark Adams    PCBJKOKKOSGetKSP - Gets the KSP context for a KSP PC.
811e607c864SMark Adams 
812e607c864SMark Adams    Not Collective but KSP returned is parallel if PC was parallel
813e607c864SMark Adams 
814e607c864SMark Adams    Input Parameter:
815e607c864SMark Adams .  pc - the preconditioner context
816e607c864SMark Adams 
817e607c864SMark Adams    Output Parameters:
818e607c864SMark Adams .  ksp - the KSP solver
819e607c864SMark Adams 
820e607c864SMark Adams    Notes:
821e607c864SMark Adams    You must call KSPSetUp() before calling PCBJKOKKOSGetKSP().
822e607c864SMark Adams 
823e607c864SMark Adams    If the PC is not a PCBJKOKKOS object it raises an error
824e607c864SMark Adams 
825e607c864SMark Adams    Level: advanced
826e607c864SMark Adams 
827e607c864SMark Adams @*/
828e607c864SMark Adams PetscErrorCode  PCBJKOKKOSGetKSP(PC pc,KSP *ksp)
829e607c864SMark Adams {
830e607c864SMark Adams   PetscFunctionBegin;
831e607c864SMark Adams   PetscValidHeaderSpecific(pc,PC_CLASSID,1);
832e607c864SMark Adams   PetscValidPointer(ksp,2);
833cac4c232SBarry Smith   PetscUseMethod(pc,"PCBJKOKKOSGetKSP_C",(PC,KSP*),(pc,ksp));
834e607c864SMark Adams   PetscFunctionReturn(0);
835e607c864SMark Adams }
836e607c864SMark Adams 
837e607c864SMark Adams /* ----------------------------------------------------------------------------------*/
838e607c864SMark Adams 
839e607c864SMark Adams /*MC
840e607c864SMark Adams      PCBJKOKKOS -  Defines a preconditioner that applies a Krylov solver and preconditioner to the blocks in a AIJASeq matrix on the GPU.
841e607c864SMark Adams 
842e607c864SMark Adams    Options Database Key:
84367b8a455SSatish Balay .     -pc_bjkokkos_ - options prefix with ksp options
844e607c864SMark Adams 
845e607c864SMark Adams    Level: intermediate
846e607c864SMark Adams 
847e607c864SMark Adams    Notes:
848e607c864SMark Adams     For use with -ksp_type preonly to bypass any CPU work
849e607c864SMark Adams 
850e607c864SMark Adams    Developer Notes:
851e607c864SMark Adams 
852e607c864SMark Adams .seealso:  PCCreate(), PCSetType(), PCType (for list of available types), PC,
853e607c864SMark Adams            PCSHELL, PCCOMPOSITE, PCSetUseAmat(), PCBJKOKKOSGetKSP()
854e607c864SMark Adams 
855e607c864SMark Adams M*/
856e607c864SMark Adams 
857e607c864SMark Adams PETSC_EXTERN PetscErrorCode PCCreate_BJKOKKOS(PC pc)
858e607c864SMark Adams {
859e607c864SMark Adams   PC_PCBJKOKKOS *jac;
860e607c864SMark Adams 
861e607c864SMark Adams   PetscFunctionBegin;
8629566063dSJacob Faibussowitsch   PetscCall(PetscNewLog(pc,&jac));
863e607c864SMark Adams   pc->data = (void*)jac;
864e607c864SMark Adams 
865e607c864SMark Adams   jac->ksp              = NULL;
866e607c864SMark Adams   jac->vec_diag         = NULL;
867e607c864SMark Adams   jac->d_bid_eqOffset_k = NULL;
868e607c864SMark Adams   jac->d_idiag_k        = NULL;
869e607c864SMark Adams   jac->d_isrow_k        = NULL;
870e607c864SMark Adams   jac->d_isicol_k       = NULL;
871e607c864SMark Adams   jac->nBlocks          = 1;
872e607c864SMark Adams 
8739566063dSJacob Faibussowitsch   PetscCall(PetscMemzero(pc->ops,sizeof(struct _PCOps)));
874e607c864SMark Adams   pc->ops->apply           = PCApply_BJKOKKOS;
875e607c864SMark Adams   pc->ops->applytranspose  = NULL;
876e607c864SMark Adams   pc->ops->setup           = PCSetUp_BJKOKKOS;
877e607c864SMark Adams   pc->ops->reset           = PCReset_BJKOKKOS;
878e607c864SMark Adams   pc->ops->destroy         = PCDestroy_BJKOKKOS;
879e607c864SMark Adams   pc->ops->setfromoptions  = PCSetFromOptions_BJKOKKOS;
880e607c864SMark Adams   pc->ops->view            = PCView_BJKOKKOS;
881e607c864SMark Adams 
8829566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSGetKSP_C",PCBJKOKKOSGetKSP_BJKOKKOS));
8839566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSSetKSP_C",PCBJKOKKOSSetKSP_BJKOKKOS));
884e607c864SMark Adams   PetscFunctionReturn(0);
885e607c864SMark Adams }
886