xref: /petsc/src/ksp/pc/impls/bjacobi/bjkokkos/bjkokkos.kokkos.cxx (revision b45e3bf4ff73d80a20c3202b6cd9f79d2f2d3efe)
1 #include <petscvec_kokkos.hpp>
2 #include <petsc/private/pcimpl.h>
3 #include <petsc/private/kspimpl.h>
4 #include <petscksp.h>            /*I "petscksp.h" I*/
5 #include "petscsection.h"
6 #include <petscdmcomposite.h>
7 #include <Kokkos_Core.hpp>
8 
9 typedef Kokkos::TeamPolicy<>::member_type team_member;
10 
11 #include <../src/mat/impls/aij/seq/aij.h>
12 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
13 
14 #define PCBJKOKKOS_SHARED_LEVEL 1
15 #define PCBJKOKKOS_VEC_SIZE 16
16 #define PCBJKOKKOS_TEAM_SIZE 16
17 #define PCBJKOKKOS_VERBOSE_LEVEL 0
18 
19 typedef enum {BATCH_KSP_BICG_IDX,BATCH_KSP_TFQMR_IDX,BATCH_KSP_GMRES_IDX,NUM_BATCH_TYPES} KSPIndex;
20 typedef struct {
21   Vec                                              vec_diag;
22   PetscInt                                         nBlocks; /* total number of blocks */
23   PetscInt                                         n; // cache host version of d_bid_eqOffset_k[nBlocks]
24   KSP                                              ksp; // Used just for options. Should have one for each block
25   Kokkos::View<PetscInt*, Kokkos::LayoutRight>     *d_bid_eqOffset_k;
26   Kokkos::View<PetscScalar*, Kokkos::LayoutRight>  *d_idiag_k;
27   Kokkos::View<PetscInt*>                          *d_isrow_k;
28   Kokkos::View<PetscInt*>                          *d_isicol_k;
29   KSPIndex                                         ksp_type_idx;
30   PetscInt                                         nwork;
31   PetscInt                                         const_block_size; // used to decide to use shared memory for work vectors
32   PetscInt                                         *dm_Nf;  // Number of fields in each DM
33   PetscInt                                         num_dms;
34   // diagnostics
35   PetscBool                                        reason;
36   PetscBool                                        monitor;
37   PetscInt                                         batch_target;
38 } PC_PCBJKOKKOS;
39 
40 static PetscErrorCode  PCBJKOKKOSCreateKSP_BJKOKKOS(PC pc)
41 {
42   PetscErrorCode ierr;
43   const char     *prefix;
44   PC_PCBJKOKKOS   *jac = (PC_PCBJKOKKOS*)pc->data;
45   DM             dm;
46 
47   PetscFunctionBegin;
48   ierr = KSPCreate(PetscObjectComm((PetscObject)pc),&jac->ksp);CHKERRQ(ierr);
49   ierr = KSPSetErrorIfNotConverged(jac->ksp,pc->erroriffailure);CHKERRQ(ierr);
50   ierr = PetscObjectIncrementTabLevel((PetscObject)jac->ksp,(PetscObject)pc,1);CHKERRQ(ierr);
51   ierr = PCGetOptionsPrefix(pc,&prefix);CHKERRQ(ierr);
52   ierr = KSPSetOptionsPrefix(jac->ksp,prefix);CHKERRQ(ierr);
53   ierr = KSPAppendOptionsPrefix(jac->ksp,"pc_bjkokkos_");CHKERRQ(ierr);
54   ierr = PCGetDM(pc, &dm);CHKERRQ(ierr);
55   if (dm) {
56     ierr = KSPSetDM(jac->ksp, dm);CHKERRQ(ierr);
57     ierr = KSPSetDMActive(jac->ksp, PETSC_FALSE);CHKERRQ(ierr);
58   }
59   jac->reason = PETSC_FALSE;
60   jac->monitor = PETSC_FALSE;
61   jac->batch_target = 0;
62 
63   PetscFunctionReturn(0);
64 }
65 
66 // y <-- Ax
67 KOKKOS_INLINE_FUNCTION PetscErrorCode MatMult(const team_member team,  const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
68 {
69   Kokkos::parallel_for(Kokkos::TeamThreadRange(team,start,end), [=] (const int rowb) {
70       int rowa = ic[rowb];
71       int n = glb_Aai[rowa+1] - glb_Aai[rowa];
72       const PetscInt    *aj  = glb_Aaj + glb_Aai[rowa];
73       const PetscScalar *aa  = glb_Aaa + glb_Aai[rowa];
74       PetscScalar sum;
75       Kokkos::parallel_reduce(Kokkos::ThreadVectorRange (team, n), [=] (const int i, PetscScalar& lsum) {
76           lsum += aa[i] * x_loc[r[aj[i]]-start];
77         }, sum);
78       Kokkos::single(Kokkos::PerThread (team),[=]() {y_loc[rowb-start] = sum;});
79     });
80   team.team_barrier();
81   return 0;
82 }
83 
84 // temp buffer per thread with reduction at end?
85 KOKKOS_INLINE_FUNCTION PetscErrorCode MatMultTranspose(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
86 {
87   Kokkos::parallel_for(Kokkos::TeamVectorRange(team,end-start), [=] (int i) { y_loc[i] = 0;});
88   team.team_barrier();
89   Kokkos::parallel_for(Kokkos::TeamThreadRange(team,start,end), [=] (const int rowb) {
90       int rowa = ic[rowb];
91       int n = glb_Aai[rowa+1] - glb_Aai[rowa];
92       const PetscInt    *aj  = glb_Aaj + glb_Aai[rowa];
93       const PetscScalar *aa  = glb_Aaa + glb_Aai[rowa];
94       const PetscScalar xx = x_loc[rowb-start]; // rowb = ic[rowa] = ic[r[rowb]]
95       Kokkos::parallel_for(Kokkos::ThreadVectorRange(team,n), [=] (const int &i) {
96           PetscScalar val = aa[i] * xx;
97           Kokkos::atomic_fetch_add(&y_loc[r[aj[i]]-start], val);
98         });
99     });
100   team.team_barrier();
101   return 0;
102 }
103 
104 typedef struct Batch_MetaData_TAG
105 {
106   PetscInt           flops;
107   PetscInt           its;
108   KSPConvergedReason reason;
109 }Batch_MetaData;
110 
111 // Solve A(BB^-1)x = y with TFQMR. Right preconditioned to get un-preconditioned residual
112 KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_TFQMR(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space, const PetscInt stride, PetscReal rtol, PetscReal atol, PetscReal dtol,PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
113 {
114   using Kokkos::parallel_reduce;
115   using Kokkos::parallel_for;
116   int               Nblk = end-start, i,m;
117   PetscReal         dp,dpold,w,dpest,tau,psi,cm,r0;
118   PetscScalar       *ptr = work_space, rho,rhoold,a,s,b,eta,etaold,psiold,cf,dpi;
119   const PetscScalar *Diag = &glb_idiag[start];
120   PetscScalar       *XX = ptr; ptr += stride;
121   PetscScalar       *R = ptr; ptr += stride;
122   PetscScalar       *RP = ptr; ptr += stride;
123   PetscScalar       *V = ptr; ptr += stride;
124   PetscScalar       *T = ptr; ptr += stride;
125   PetscScalar       *Q = ptr; ptr += stride;
126   PetscScalar       *P = ptr; ptr += stride;
127   PetscScalar       *U = ptr; ptr += stride;
128   PetscScalar       *D = ptr; ptr += stride;
129   PetscScalar       *AUQ = V;
130 
131   // init: get b, zero x
132   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
133       int rowa = ic[rowb];
134       R[rowb-start] = glb_b[rowa];
135       XX[rowb-start] = 0;
136     });
137   team.team_barrier();
138   parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += R[idx]*PetscConj(R[idx]);}, dpi);
139   team.team_barrier();
140   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
141   // diagnostics
142   if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", 0, (double)dp);});
143 
144   if (dp < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; return 0;}
145   if (0 == maxit) {metad->reason = KSP_DIVERGED_ITS; return 0;}
146 
147   /* Make the initial Rp = R */
148   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {RP[idx] = R[idx];});
149   team.team_barrier();
150   /* Set the initial conditions */
151   etaold = 0.0;
152   psiold = 0.0;
153   tau    = dp;
154   dpold  = dp;
155 
156   /* rhoold = (r,rp)     */
157   parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += R[idx]*PetscConj(RP[idx]);}, rhoold);
158   team.team_barrier();
159   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {U[idx] = R[idx]; P[idx] = R[idx]; T[idx] = Diag[idx]*P[idx]; D[idx] = 0;});
160   team.team_barrier();
161   MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,T,V);
162 
163   i=0;
164   do {
165     /* s <- (v,rp)          */
166     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += V[idx]*PetscConj(RP[idx]);}, s);
167     team.team_barrier();
168     a    = rhoold / s;                              /* a <- rho / s         */
169     /* q <- u - a v    VecWAXPY(w,alpha,x,y): w = alpha x + y.     */
170     /* t <- u + q           */
171     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Q[idx] = U[idx] - a*V[idx]; T[idx] = U[idx] + Q[idx];});
172     team.team_barrier();
173     // KSP_PCApplyBAorAB
174     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {T[idx] = Diag[idx]*T[idx]; });
175     team.team_barrier();
176     MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,T,AUQ);
177     /* r <- r - a K (u + q) */
178     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {R[idx] = R[idx] - a*AUQ[idx]; });
179     team.team_barrier();
180     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += R[idx]*PetscConj(R[idx]);}, dpi);
181     team.team_barrier();
182     dp = PetscSqrtReal(PetscRealPart(dpi));
183     for (m=0; m<2; m++) {
184       if (!m) w = PetscSqrtReal(dp*dpold);
185       else w = dp;
186       psi = w / tau;
187       cm  = 1.0 / PetscSqrtReal(1.0 + psi * psi);
188       tau = tau * psi * cm;
189       eta = cm * cm * a;
190       cf  = psiold * psiold * etaold / a;
191       if (!m) {
192         /* D = U + cf D */
193         parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {D[idx] = U[idx] + cf*D[idx]; });
194       } else {
195         /* D = Q + cf D */
196         parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {D[idx] = Q[idx] + cf*D[idx]; });
197       }
198       team.team_barrier();
199       parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {XX[idx] = XX[idx] + eta*D[idx]; });
200       team.team_barrier();
201       dpest = PetscSqrtReal(2*i + m + 2.0) * tau;
202       if (monitor && m==1) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", i+1, (double)dpest);});
203 
204       if (dpest < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; goto done;}
205       if (dpest/r0 < rtol) {metad->reason = KSP_CONVERGED_RTOL_NORMAL; goto done;}
206 #if defined(PETSC_USE_DEBUG)
207       if (dpest/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; Kokkos::single (Kokkos::PerTeam (team), [=] () {printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n",team.league_rank(),i,dpest,r0);}); goto done;}
208 #else
209       if (dpest/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; goto done;}
210 #endif
211       if (i+1 == maxit) {metad->reason = KSP_DIVERGED_ITS; goto done;}
212 
213       etaold = eta;
214       psiold = psi;
215     }
216 
217     /* rho <- (r,rp)       */
218     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += R[idx]*PetscConj(RP[idx]);}, rho);
219     team.team_barrier();
220     b    = rho / rhoold;                            /* b <- rho / rhoold   */
221     /* u <- r + b q        */
222     /* p <- u + b(q + b p) */
223     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {U[idx] = R[idx] + b*Q[idx]; Q[idx] = Q[idx] + b*P[idx]; P[idx] = U[idx] + b*Q[idx];});
224     /* v <- K p  */
225     team.team_barrier();
226     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {T[idx] = Diag[idx]*P[idx]; });
227     team.team_barrier();
228     MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,T,V);
229 
230     rhoold = rho;
231     dpold  = dp;
232 
233     i++;
234   } while (i<maxit);
235   done:
236   // KSPUnwindPreconditioner
237   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {XX[idx] = Diag[idx]*XX[idx]; });
238   team.team_barrier();
239   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
240       int rowa = ic[rowb];
241       glb_x[rowa] = XX[rowb-start];
242     });
243   metad->its = i+1;
244   if (1) {
245     int nnz;
246     parallel_reduce(Kokkos::TeamVectorRange (team, start, end), [=] (const int idx, int& lsum) {lsum += (glb_Aai[idx+1] - glb_Aai[idx]);}, nnz);
247     metad->flops = 2*(metad->its*(10*Nblk + 2*nnz) + 5*Nblk);
248   } else {
249     metad->flops = 2*(metad->its*(10*Nblk + 2*50*Nblk) + 5*Nblk); // guess
250   }
251   return 0;
252 }
253 
254 // Solve Ax = y with biCG
255 KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_BICG(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space, const PetscInt stride, PetscReal rtol, PetscReal atol, PetscReal dtol,PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
256 {
257   using Kokkos::parallel_reduce;
258   using Kokkos::parallel_for;
259   int               Nblk = end-start, i;
260   PetscReal         dp, r0;
261   PetscScalar       *ptr = work_space, dpi, a=1.0, beta, betaold=1.0, b, b2, ma, mac;
262   const PetscScalar *Di = &glb_idiag[start];
263   PetscScalar       *XX = ptr; ptr += stride;
264   PetscScalar       *Rl = ptr; ptr += stride;
265   PetscScalar       *Zl = ptr; ptr += stride;
266   PetscScalar       *Pl = ptr; ptr += stride;
267   PetscScalar       *Rr = ptr; ptr += stride;
268   PetscScalar       *Zr = ptr; ptr += stride;
269   PetscScalar       *Pr = ptr; ptr += stride;
270 
271   /*     r <- b (x is 0) */
272   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
273       int rowa = ic[rowb];
274       //ierr = VecCopy(Rr,Rl);CHKERRQ(ierr);
275       Rl[rowb-start] = Rr[rowb-start] = glb_b[rowa];
276       XX[rowb-start] = 0;
277     });
278   team.team_barrier();
279   /*     z <- Br         */
280   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Zr[idx] = Di[idx]*Rr[idx]; Zl[idx] = Di[idx]*Rl[idx]; });
281   team.team_barrier();
282   /*    dp <- r'*r       */
283   parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += Rr[idx]*PetscConj(Rr[idx]);}, dpi);
284   team.team_barrier();
285   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
286   if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", 0, (double)dp);});
287 
288   if (dp < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; return 0;}
289   if (0 == maxit) {metad->reason = KSP_DIVERGED_ITS; return 0;}
290   i = 0;
291   do {
292     /*     beta <- r'z     */
293     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += Zr[idx]*PetscConj(Rl[idx]);}, beta);
294     team.team_barrier();
295 #if PCBJKOKKOS_VERBOSE_LEVEL >= 6
296     Kokkos::single (Kokkos::PerTeam (team), [=] () {printf("%7d beta = Z.R = %22.14e \n",i,(double)beta);});
297 #endif
298     if (!i) {
299       if (beta == 0.0) {
300         metad->reason = KSP_DIVERGED_BREAKDOWN_BICG;
301         goto done;
302       }
303       /*     p <- z          */
304       parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Pr[idx] = Zr[idx]; Pl[idx] = Zl[idx];});
305     } else {
306       b    = beta/betaold;
307       /*     p <- z + b* p   */
308       b2    = PetscConj(b);
309       parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Pr[idx] = b*Pr[idx] + Zr[idx]; Pl[idx] = b2*Pl[idx] + Zl[idx];});
310     }
311     team.team_barrier();
312     betaold = beta;
313     /*     z <- Kp         */
314     MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,Pr,Zr);
315     MatMultTranspose(team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,Pl,Zl);
316     /*     dpi <- z'p      */
317     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += Zr[idx]*PetscConj(Pl[idx]);}, dpi);
318     team.team_barrier();
319     //
320     a       = beta/dpi;                           /*     a = beta/p'z    */
321     ma      = -a;
322     mac      = PetscConj(ma);
323     /*     x <- x + ap     */
324     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {XX[idx] = XX[idx] + a*Pr[idx]; Rr[idx] = Rr[idx] + ma*Zr[idx]; Rl[idx] = Rl[idx] + mac*Zl[idx];});team.team_barrier();
325     team.team_barrier();
326     /*    dp <- r'*r       */
327     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum +=  Rr[idx]*PetscConj(Rr[idx]);}, dpi);
328     team.team_barrier();
329     dp = PetscSqrtReal(PetscRealPart(dpi));
330     if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", i+1, (double)dp);});
331 
332     if (dp < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; goto done;}
333     if (dp/r0 < rtol) {metad->reason = KSP_CONVERGED_RTOL_NORMAL; goto done;}
334 #if defined(PETSC_USE_DEBUG) || 1
335     if (dp/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; Kokkos::single (Kokkos::PerTeam (team), [=] () {printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n",team.league_rank(),i,dp,r0);}); goto done;}
336 #else
337     if (dp/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; goto done;}
338 #endif
339     if (i+1 == maxit) {metad->reason = KSP_DIVERGED_ITS; goto done;}
340     /* z <- Br  */
341     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Zr[idx] = Di[idx]*Rr[idx]; Zl[idx] = Di[idx]*Rl[idx];});
342     i++;
343     team.team_barrier();
344   } while (i<maxit);
345  done:
346   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
347       int rowa = ic[rowb];
348       glb_x[rowa] = XX[rowb-start];
349     });
350   metad->its = i+1;
351   if (1) {
352     int nnz;
353     parallel_reduce(Kokkos::TeamVectorRange (team, start, end), [=] (const int idx, int& lsum) {lsum += (glb_Aai[idx+1] - glb_Aai[idx]);}, nnz);
354     metad->flops = 2*(metad->its*(10*Nblk + 2*nnz) + 5*Nblk);
355   } else {
356     metad->flops = 2*(metad->its*(10*Nblk + 2*50*Nblk) + 5*Nblk); // guess
357   }
358   return 0;
359 }
360 
361 // KSP solver solve Ax = b; x is output, bin is input
362 static PetscErrorCode PCApply_BJKOKKOS(PC pc,Vec bin,Vec xout)
363 {
364   PetscErrorCode      ierr;
365   PC_PCBJKOKKOS       *jac = (PC_PCBJKOKKOS*)pc->data;
366   Mat                 A = pc->pmat;
367   Mat_SeqAIJKokkos    *aijkok;
368 
369   PetscFunctionBegin;
370   PetscCheck(jac->vec_diag && A,PetscObjectComm((PetscObject)pc),PETSC_ERR_USER,"Not setup???? %p %p",jac->vec_diag,A);
371   aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr);
372   if (!aijkok) {
373     SETERRQ(PetscObjectComm((PetscObject)pc),PETSC_ERR_USER,"No aijkok");
374   } else {
375     using scr_mem_t  = Kokkos::DefaultExecutionSpace::scratch_memory_space;
376     using vect2D_scr_t = Kokkos::View<PetscScalar**, Kokkos::LayoutLeft, scr_mem_t>;
377     PetscInt          *d_bid_eqOffset, maxit = jac->ksp->max_it, scr_bytes_team, stride, global_buff_size;
378     const PetscInt    conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp==0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
379     const PetscInt    nwork = jac->nwork, nBlk = jac->nBlocks;
380     PetscScalar       *glb_xdata=NULL;
381     PetscReal         rtol = jac->ksp->rtol, atol = jac->ksp->abstol, dtol = jac->ksp->divtol;
382     const PetscScalar *glb_idiag =jac->d_idiag_k->data(), *glb_bdata=NULL;
383     const PetscInt    *glb_Aai = aijkok->i_device_data(), *glb_Aaj = aijkok->j_device_data();
384     const PetscScalar *glb_Aaa = aijkok->a_device_data();
385     Kokkos::View<Batch_MetaData*, Kokkos::DefaultExecutionSpace> d_metadata("solver meta data", nBlk);
386     PCFailedReason    pcreason;
387     KSPIndex          ksp_type_idx = jac->ksp_type_idx;
388     PetscMemType      mtype;
389     PetscContainer    container;
390     PetscInt          batch_sz;
391     VecScatter        plex_batch=NULL;
392     Vec               bvec;
393     PetscBool         monitor = jac->monitor; // captured
394     PetscInt          view_bid = jac->batch_target;
395     // get field major is to map plex IO to/from block/field major
396     ierr = PetscObjectQuery((PetscObject) A, "plex_batch_is", (PetscObject *) &container);CHKERRQ(ierr);
397     ierr = VecDuplicate(bin,&bvec);CHKERRQ(ierr);
398     if (container) {
399       ierr = PetscContainerGetPointer(container, (void **) &plex_batch);CHKERRQ(ierr);
400       ierr = VecScatterBegin(plex_batch,bin,bvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
401       ierr = VecScatterEnd(plex_batch,bin,bvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
402     } else {
403       ierr = VecCopy(bin, bvec);CHKERRQ(ierr);
404     }
405     // get x
406     ierr = VecGetArrayAndMemType(xout,&glb_xdata,&mtype);CHKERRQ(ierr);
407 #if defined(PETSC_HAVE_CUDA)
408     PetscCheck(PetscMemTypeDevice(mtype),PetscObjectComm((PetscObject) pc),PETSC_ERR_ARG_WRONG,"No GPU data for x %" PetscInt_FMT " != %" PetscInt_FMT "",mtype,PETSC_MEMTYPE_DEVICE);
409 #endif
410     ierr = VecGetArrayReadAndMemType(bvec,&glb_bdata,&mtype);CHKERRQ(ierr);
411 #if defined(PETSC_HAVE_CUDA)
412     PetscCheck(PetscMemTypeDevice(mtype),PetscObjectComm((PetscObject) pc),PETSC_ERR_ARG_WRONG,"No GPU data for b");
413 #endif
414     // get batch size
415     ierr = PetscObjectQuery((PetscObject) A, "batch size", (PetscObject *) &container);CHKERRQ(ierr);
416     if (container) {
417       PetscInt *pNf=NULL;
418       ierr = PetscContainerGetPointer(container, (void **) &pNf);CHKERRQ(ierr);
419       batch_sz = *pNf;
420     } else batch_sz = 1;
421     PetscCheck(nBlk%batch_sz == 0,PetscObjectComm((PetscObject) pc),PETSC_ERR_ARG_WRONG,"batch_sz = %" PetscInt_FMT ", nBlk = %" PetscInt_FMT,batch_sz,nBlk);
422     d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
423     // solve each block independently
424     if (jac->const_block_size) { // use shared memory for work vectors only if constant block size - todo: test efficiency loss
425       scr_bytes_team = jac->const_block_size*nwork*sizeof(PetscScalar);
426       stride = jac->const_block_size; // captured
427       global_buff_size = 0;
428     } else {
429       scr_bytes_team = 0;
430       stride = jac->n; // captured
431       global_buff_size = jac->n*nwork;
432     }
433     Kokkos::View<PetscScalar*, Kokkos::DefaultExecutionSpace> d_work_vecs_k("workvectors", global_buff_size); // global work vectors
434     PetscInfo(pc,"\tn = %" PetscInt_FMT ". %d shared mem words/team. %" PetscInt_FMT " global mem words, rtol=%e, num blocks %" PetscInt_FMT ", team_size=%" PetscInt_FMT ", %" PetscInt_FMT " vector threads\n",jac->n,scr_bytes_team/sizeof(PetscScalar),global_buff_size,rtol,nBlk,
435                team_size, PCBJKOKKOS_VEC_SIZE);
436     PetscScalar  *d_work_vecs = scr_bytes_team ? NULL : d_work_vecs_k.data();
437     const PetscInt *d_isicol = jac->d_isicol_k->data(), *d_isrow = jac->d_isrow_k->data();
438     Kokkos::parallel_for("Solve", Kokkos::TeamPolicy<>(nBlk, team_size, PCBJKOKKOS_VEC_SIZE).set_scratch_size(PCBJKOKKOS_SHARED_LEVEL, Kokkos::PerTeam(scr_bytes_team)),
439         KOKKOS_LAMBDA (const team_member team) {
440         const int    blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID+1];
441         vect2D_scr_t work_vecs(team.team_scratch(PCBJKOKKOS_SHARED_LEVEL), scr_bytes_team ? (end-start) : 0, nwork);
442         PetscScalar *work_buff = (scr_bytes_team) ? work_vecs.data() : &d_work_vecs[start];
443         bool        print = monitor && (blkID==view_bid);
444         switch (ksp_type_idx) {
445         case BATCH_KSP_BICG_IDX:
446           BJSolve_BICG(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff, stride, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print);
447           break;
448         case BATCH_KSP_TFQMR_IDX:
449           BJSolve_TFQMR(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff, stride, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print);
450           break;
451         case BATCH_KSP_GMRES_IDX:
452 #if defined(PETSC_USE_DEBUG)
453           printf("GMRES not implemented %d\n",ksp_type_idx);
454 #else
455           /* void */
456 #endif
457           break;
458         default:
459 #if defined(PETSC_USE_DEBUG)
460           printf("Unknown KSP type %d\n",ksp_type_idx);
461 #else
462           /* void */;
463 #endif
464         }
465     });
466     auto h_metadata = Kokkos::create_mirror(Kokkos::HostSpace::memory_space(), d_metadata);
467     Kokkos::fence();
468     Kokkos::deep_copy (h_metadata, d_metadata);
469 #if PCBJKOKKOS_VERBOSE_LEVEL >= 3
470 #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
471     ierr = PetscPrintf(PETSC_COMM_WORLD,"Iterations\n");CHKERRQ(ierr);
472 #endif
473     // assume species major
474 #if PCBJKOKKOS_VERBOSE_LEVEL < 4
475     ierr = PetscPrintf(PETSC_COMM_WORLD,"max iterations per species (%s) :",ksp_type_idx==BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr");CHKERRQ(ierr);
476 #endif
477     for (PetscInt dmIdx=0, s=0, head=0 ; dmIdx < jac->num_dms; dmIdx += batch_sz) {
478       for (PetscInt f=0, idx=head ; f < jac->dm_Nf[dmIdx] ; f++,s++,idx++) {
479 #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
480         ierr = PetscPrintf(PETSC_COMM_WORLD,"%2D:", s);CHKERRQ(ierr);
481         for (int bid=0 ; bid<batch_sz ; bid++) {
482          ierr = PetscPrintf(PETSC_COMM_WORLD,"%3D ", h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its);CHKERRQ(ierr);
483         }
484         ierr = PetscPrintf(PETSC_COMM_WORLD,"\n");CHKERRQ(ierr);
485 #else
486         PetscInt count=0;
487         for (int bid=0 ; bid<batch_sz ; bid++) {
488           if (h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its > count) count = h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its;
489         }
490         ierr = PetscPrintf(PETSC_COMM_WORLD,"%3D ", count);CHKERRQ(ierr);
491 #endif
492       }
493       head += batch_sz*jac->dm_Nf[dmIdx];
494     }
495 #if PCBJKOKKOS_VERBOSE_LEVEL < 4
496     ierr = PetscPrintf(PETSC_COMM_WORLD,"\n");CHKERRQ(ierr);
497 #endif
498 #endif
499     PetscInt count=0, mbid=0;
500     for (int blkID=0;blkID<nBlk;blkID++) {
501       ierr = PetscLogGpuFlops((PetscLogDouble)h_metadata[blkID].flops);CHKERRQ(ierr);
502       if (jac->reason) {
503         if (jac->batch_target==blkID) {
504           ierr = PetscPrintf(PETSC_COMM_SELF,  "    Linear solve converged due to %s iterations %d, batch %" PetscInt_FMT ", species %" PetscInt_FMT "\n", KSPConvergedReasons[h_metadata[blkID].reason], h_metadata[blkID].its, blkID%batch_sz, blkID/batch_sz);CHKERRQ(ierr);
505         } else if (jac->batch_target==-1 && h_metadata[blkID].its > count) {
506           count = h_metadata[blkID].its;
507           mbid = blkID;
508         }
509         if (h_metadata[blkID].reason < 0) {
510           ierr = PetscPrintf(PETSC_COMM_SELF, "ERROR reason=%s, its=%" PetscInt_FMT ". species %" PetscInt_FMT ", batch %" PetscInt_FMT "\n",
511                              KSPConvergedReasons[h_metadata[blkID].reason],h_metadata[blkID].its,blkID/batch_sz,blkID%batch_sz);CHKERRQ(ierr);
512         }
513       }
514     }
515     if (jac->batch_target==-1 && jac->reason) {
516       ierr = PetscPrintf(PETSC_COMM_SELF,  "    Linear solve converged due to %s iterations %d, batch %" PetscInt_FMT ", specie %" PetscInt_FMT "\n", KSPConvergedReasons[h_metadata[mbid].reason], h_metadata[mbid].its,mbid%batch_sz,mbid/batch_sz);CHKERRQ(ierr);
517     }
518     ierr = VecRestoreArrayAndMemType(xout,&glb_xdata);CHKERRQ(ierr);
519     ierr = VecRestoreArrayReadAndMemType(bvec,&glb_bdata);CHKERRQ(ierr);
520     {
521       int errsum;
522       Kokkos::parallel_reduce(nBlk, KOKKOS_LAMBDA (const int idx, int& lsum) {
523           if (d_metadata[idx].reason < 0) ++lsum;
524         }, errsum);
525       pcreason = errsum ? PC_SUBPC_ERROR : PC_NOERROR;
526     }
527     ierr = PCSetFailedReason(pc,pcreason);CHKERRQ(ierr);
528     // map back to Plex space
529     if (plex_batch) {
530       ierr = VecCopy(xout, bvec);CHKERRQ(ierr);
531       ierr = VecScatterBegin(plex_batch,bvec,xout,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
532       ierr = VecScatterEnd(plex_batch,bvec,xout,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
533     }
534     ierr = VecDestroy(&bvec);CHKERRQ(ierr);
535   }
536 
537   PetscFunctionReturn(0);
538 }
539 
540 static PetscErrorCode PCSetUp_BJKOKKOS(PC pc)
541 {
542   PetscErrorCode    ierr;
543   PC_PCBJKOKKOS     *jac = (PC_PCBJKOKKOS*)pc->data;
544   Mat               A = pc->pmat;
545   Mat_SeqAIJKokkos  *aijkok;
546   PetscBool         flg;
547 
548   PetscFunctionBegin;
549   PetscCheck(!pc->useAmat,PetscObjectComm((PetscObject)pc),PETSC_ERR_SUP,"No support for using 'use_amat'");
550   PetscCheck(A,PetscObjectComm((PetscObject)A),PETSC_ERR_USER,"No matrix - A is used above");
551   ierr = PetscObjectTypeCompareAny((PetscObject)A,&flg,MATSEQAIJKOKKOS,MATMPIAIJKOKKOS,MATAIJKOKKOS,"");CHKERRQ(ierr);
552   PetscCheck(flg,PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONG,"must use '-dm_mat_type aijkokkos -dm_vec_type kokkos' for -pc_type bjkokkos");
553   if (!(aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr))) {
554     SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_USER,"No aijkok");
555   } else {
556     if (!jac->vec_diag) {
557       Vec               *subX;
558       DM                pack,*subDM;
559       PetscInt          nDMs, n;
560       PetscContainer    container;
561       ierr = PetscObjectQuery((PetscObject) A, "plex_batch_is", (PetscObject *) &container);CHKERRQ(ierr);
562       { // Permute the matrix to get a block diagonal system: d_isrow_k, d_isicol_k
563         MatOrderingType   rtype;
564         IS                isrow,isicol;
565         const PetscInt    *rowindices,*icolindices;
566 
567         if (container) rtype = MATORDERINGNATURAL; // if we have a vecscatter then don't reorder here (all the reorder stuff goes away in future)
568         else rtype = MATORDERINGRCM;
569         // get permutation. Not what I expect so inverted here
570         ierr = MatGetOrdering(A,rtype,&isrow,&isicol);CHKERRQ(ierr);
571         ierr = ISDestroy(&isrow);CHKERRQ(ierr);
572         ierr = ISInvertPermutation(isicol,PETSC_DECIDE,&isrow);CHKERRQ(ierr);
573         ierr = ISGetIndices(isrow,&rowindices);CHKERRQ(ierr);
574         ierr = ISGetIndices(isicol,&icolindices);CHKERRQ(ierr);
575         const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_isrow_k((PetscInt*)rowindices,A->rmap->n);
576         const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_isicol_k ((PetscInt*)icolindices,A->rmap->n);
577         jac->d_isrow_k = new Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_isrow_k));
578         jac->d_isicol_k = new Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_isicol_k));
579         Kokkos::deep_copy (*jac->d_isrow_k, h_isrow_k);
580         Kokkos::deep_copy (*jac->d_isicol_k, h_isicol_k);
581         ierr = ISRestoreIndices(isrow,&rowindices);CHKERRQ(ierr);
582         ierr = ISRestoreIndices(isicol,&icolindices);CHKERRQ(ierr);
583         ierr = ISDestroy(&isrow);CHKERRQ(ierr);
584         ierr = ISDestroy(&isicol);CHKERRQ(ierr);
585       }
586       // get block sizes
587       ierr = PCGetDM(pc, &pack);CHKERRQ(ierr);
588       PetscCheck(pack,PetscObjectComm((PetscObject)A),PETSC_ERR_USER,"no DM. Requires a composite DM");
589       ierr = PetscObjectTypeCompare((PetscObject)pack,DMCOMPOSITE,&flg);CHKERRQ(ierr);
590       PetscCheck(flg,PetscObjectComm((PetscObject)pack),PETSC_ERR_USER,"Not for type %s",((PetscObject)pack)->type_name);
591       ierr = DMCompositeGetNumberDM(pack,&nDMs);CHKERRQ(ierr);
592       jac->num_dms = nDMs;
593       ierr = DMCreateGlobalVector(pack, &jac->vec_diag);CHKERRQ(ierr);
594       ierr = VecGetLocalSize(jac->vec_diag,&n);CHKERRQ(ierr);
595       jac->n = n;
596       jac->d_idiag_k = new Kokkos::View<PetscScalar*, Kokkos::LayoutRight>("idiag", n);
597       // options
598       ierr = PCBJKOKKOSCreateKSP_BJKOKKOS(pc);CHKERRQ(ierr);
599       ierr = KSPSetFromOptions(jac->ksp);CHKERRQ(ierr);
600       ierr = PetscObjectTypeCompareAny((PetscObject)jac->ksp,&flg,KSPBICG,"");CHKERRQ(ierr);
601       if (flg) {jac->ksp_type_idx = BATCH_KSP_BICG_IDX; jac->nwork = 7;}
602       else {
603         ierr = PetscObjectTypeCompareAny((PetscObject)jac->ksp,&flg,KSPTFQMR,"");CHKERRQ(ierr);
604         if (flg) {jac->ksp_type_idx = BATCH_KSP_TFQMR_IDX; jac->nwork = 10;}
605         else {
606           ierr = PetscObjectTypeCompareAny((PetscObject)jac->ksp,&flg,KSPGMRES,"");CHKERRQ(ierr);
607           if (flg) {jac->ksp_type_idx = BATCH_KSP_GMRES_IDX; jac->nwork = 0;}
608           SETERRQ(PetscObjectComm((PetscObject)jac->ksp),PETSC_ERR_ARG_WRONG,"unsupported type %s", ((PetscObject)jac->ksp)->type_name);
609         }
610       }
611       {
612         PetscViewer       viewer;
613         PetscBool         flg;
614         PetscViewerFormat format;
615         ierr   = PetscOptionsGetViewer(PetscObjectComm((PetscObject)jac->ksp),((PetscObject)jac->ksp)->options,((PetscObject)jac->ksp)->prefix,"-ksp_converged_reason",&viewer,&format,&flg);CHKERRQ(ierr);
616         jac->reason = flg;
617         ierr = PetscViewerDestroy(&viewer);CHKERRQ(ierr);
618         ierr   = PetscOptionsGetViewer(PetscObjectComm((PetscObject)jac->ksp),((PetscObject)jac->ksp)->options,((PetscObject)jac->ksp)->prefix,"-ksp_monitor",&viewer,&format,&flg);CHKERRQ(ierr);
619         jac->monitor = flg;
620         ierr = PetscViewerDestroy(&viewer);CHKERRQ(ierr);
621         ierr = PetscOptionsGetInt(((PetscObject)jac->ksp)->options,((PetscObject)jac->ksp)->prefix,"-ksp_batch_target",&jac->batch_target,&flg);CHKERRQ(ierr);
622         PetscCheckFalse(jac->batch_target >= jac->num_dms,PETSC_COMM_WORLD,PETSC_ERR_ARG_WRONG,"-ksp_batch_target (%" PetscInt_FMT ") >= number of DMs (%" PetscInt_FMT ")",jac->batch_target,jac->num_dms);
623         if (!jac->monitor && !flg) jac->batch_target = -1; // turn it off
624       }
625       // get blocks - jac->d_bid_eqOffset_k
626       ierr = PetscMalloc(sizeof(*subX)*nDMs, &subX);CHKERRQ(ierr);
627       ierr = PetscMalloc(sizeof(*subDM)*nDMs, &subDM);CHKERRQ(ierr);
628       ierr = PetscMalloc(sizeof(*jac->dm_Nf)*nDMs, &jac->dm_Nf);CHKERRQ(ierr);
629       ierr = PetscInfo(pc, "Have %" PetscInt_FMT " DMs, n=%" PetscInt_FMT " rtol=%g type = %s\n", nDMs, n, jac->ksp->rtol, ((PetscObject)jac->ksp)->type_name);CHKERRQ(ierr);
630       ierr = DMCompositeGetEntriesArray(pack,subDM);CHKERRQ(ierr);
631       jac->nBlocks = 0;
632       for (PetscInt ii=0;ii<nDMs;ii++) {
633         PetscSection section;
634         PetscInt Nf;
635         DM dm = subDM[ii];
636         ierr = DMGetLocalSection(dm, &section);CHKERRQ(ierr);
637         ierr = PetscSectionGetNumFields(section, &Nf);CHKERRQ(ierr);
638         jac->nBlocks += Nf;
639 #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
640         if (ii==0) { ierr = PetscInfo(pc,"%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n",ii,Nf,jac->nBlocks); }
641 #else
642         ierr = PetscInfo(pc,"%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n",ii,Nf,jac->nBlocks);
643 #endif
644         jac->dm_Nf[ii] = Nf;
645       }
646       { // d_bid_eqOffset_k
647         Kokkos::View<PetscInt*, Kokkos::LayoutRight, Kokkos::HostSpace> h_block_offsets("block_offsets", jac->nBlocks+1);
648         ierr = DMCompositeGetAccessArray(pack, jac->vec_diag, nDMs, NULL, subX);CHKERRQ(ierr);
649         h_block_offsets[0] = 0;
650         jac->const_block_size = -1;
651         for (PetscInt ii=0, idx = 0;ii<nDMs;ii++) {
652           PetscInt nloc,nblk;
653           ierr = VecGetSize(subX[ii],&nloc);CHKERRQ(ierr);
654           nblk = nloc/jac->dm_Nf[ii];
655           PetscCheck(nloc%jac->dm_Nf[ii] == 0,PetscObjectComm((PetscObject)pc),PETSC_ERR_USER,"nloc%jac->dm_Nf[ii] DMs",nloc,jac->dm_Nf[ii]);
656           for (PetscInt jj=0;jj<jac->dm_Nf[ii];jj++, idx++) {
657             h_block_offsets[idx+1] = h_block_offsets[idx] + nblk;
658 #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
659             if (idx==0) {ierr = PetscInfo(pc,"\t%" PetscInt_FMT ") Add block with %" PetscInt_FMT " equations of %" PetscInt_FMT "\n",idx+1,nblk,jac->nBlocks);CHKERRQ(ierr);}
660 #else
661             ierr = PetscInfo(pc,"\t%" PetscInt_FMT ") Add block with %" PetscInt_FMT " equations of %" PetscInt_FMT "\n",idx+1,nblk,jac->nBlocks);CHKERRQ(ierr);
662 #endif
663             if (jac->const_block_size == -1) jac->const_block_size = nblk;
664             else if (jac->const_block_size > 0 && jac->const_block_size != nblk) jac->const_block_size = 0;
665           }
666         }
667         ierr = DMCompositeRestoreAccessArray(pack, jac->vec_diag, jac->nBlocks, NULL, subX);CHKERRQ(ierr);
668         ierr = PetscFree(subX);CHKERRQ(ierr);
669         ierr = PetscFree(subDM);CHKERRQ(ierr);
670         jac->d_bid_eqOffset_k = new Kokkos::View<PetscInt*, Kokkos::LayoutRight>(Kokkos::create_mirror(Kokkos::DefaultExecutionSpace::memory_space(),h_block_offsets));
671         Kokkos::deep_copy (*jac->d_bid_eqOffset_k, h_block_offsets);
672       }
673     }
674     { // get jac->d_idiag_k (PC setup),
675       const PetscInt    *d_ai=aijkok->i_device_data(), *d_aj=aijkok->j_device_data();
676       const PetscScalar *d_aa = aijkok->a_device_data();
677       const PetscInt    conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp==0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
678       PetscInt          *d_bid_eqOffset = jac->d_bid_eqOffset_k->data(), *r = jac->d_isrow_k->data(), *ic = jac->d_isicol_k->data();
679       PetscScalar       *d_idiag = jac->d_idiag_k->data();
680       Kokkos::parallel_for("Diag", Kokkos::TeamPolicy<>(jac->nBlocks, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA (const team_member team) {
681           const PetscInt blkID = team.league_rank();
682           Kokkos::parallel_for
683             (Kokkos::TeamThreadRange(team,d_bid_eqOffset[blkID],d_bid_eqOffset[blkID+1]),
684              [=] (const int rowb) {
685                const PetscInt    rowa = ic[rowb], ai = d_ai[rowa], *aj = d_aj + ai; // grab original data
686                const PetscScalar *aa  = d_aa + ai;
687                const PetscInt    nrow = d_ai[rowa + 1] - ai;
688                int found;
689                Kokkos::parallel_reduce
690                  (Kokkos::ThreadVectorRange (team, nrow),
691                   [=] (const int& j, int &count) {
692                     const PetscInt colb = r[aj[j]];
693                     if (colb==rowb) {
694                       d_idiag[rowb] = 1./aa[j];
695                       count++;
696                     }}, found);
697                if (found!=1) Kokkos::single (Kokkos::PerThread (team), [=] () {printf("ERRORrow %d) found = %d\n",rowb,found);});
698              });
699         });
700     }
701   }
702   PetscFunctionReturn(0);
703 }
704 
705 /* Default destroy, if it has never been setup */
706 static PetscErrorCode PCReset_BJKOKKOS(PC pc)
707 {
708   PC_PCBJKOKKOS   *jac = (PC_PCBJKOKKOS*)pc->data;
709   PetscErrorCode ierr;
710 
711   PetscFunctionBegin;
712   ierr = KSPDestroy(&jac->ksp);CHKERRQ(ierr);
713   ierr = VecDestroy(&jac->vec_diag);CHKERRQ(ierr);
714   if (jac->d_bid_eqOffset_k) delete jac->d_bid_eqOffset_k;
715   if (jac->d_idiag_k) delete jac->d_idiag_k;
716   if (jac->d_isrow_k) delete jac->d_isrow_k;
717   if (jac->d_isicol_k) delete jac->d_isicol_k;
718   jac->d_bid_eqOffset_k = NULL;
719   jac->d_idiag_k = NULL;
720   jac->d_isrow_k = NULL;
721   jac->d_isicol_k = NULL;
722   ierr = PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSGetKSP_C",NULL);CHKERRQ(ierr); // not published now (causes configure errors)
723   ierr = PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSSetKSP_C",NULL);CHKERRQ(ierr);
724   ierr = PetscFree(jac->dm_Nf);CHKERRQ(ierr);
725   jac->dm_Nf = NULL;
726   PetscFunctionReturn(0);
727 }
728 
729 static PetscErrorCode PCDestroy_BJKOKKOS(PC pc)
730 {
731   PetscErrorCode ierr;
732 
733   PetscFunctionBegin;
734   ierr = PCReset_BJKOKKOS(pc);CHKERRQ(ierr);
735   ierr = PetscFree(pc->data);CHKERRQ(ierr);
736   PetscFunctionReturn(0);
737 }
738 
739 static PetscErrorCode PCView_BJKOKKOS(PC pc,PetscViewer viewer)
740 {
741   PC_PCBJKOKKOS   *jac = (PC_PCBJKOKKOS*)pc->data;
742   PetscErrorCode ierr;
743   PetscBool      iascii;
744 
745   PetscFunctionBegin;
746   if (!jac->ksp) {ierr = PCBJKOKKOSCreateKSP_BJKOKKOS(pc);CHKERRQ(ierr);}
747   ierr = PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&iascii);CHKERRQ(ierr);
748   if (iascii) {
749     ierr = PetscViewerASCIIPrintf(viewer,"  Batched device linear solver: Krylov (KSP) method with Jacobi preconditioning\n");CHKERRQ(ierr);
750     ierr = PetscViewerASCIIPrintf(viewer,"\t\tnwork = %" PetscInt_FMT ", rel tol = %e, abs tol = %e, div tol = %e, max it =%" PetscInt_FMT ", type = %s\n",jac->nwork,jac->ksp->rtol,
751                                   jac->ksp->abstol, jac->ksp->divtol, jac->ksp->max_it,
752                                   ((PetscObject)jac->ksp)->type_name);CHKERRQ(ierr);
753   }
754   PetscFunctionReturn(0);
755 }
756 
757 static PetscErrorCode PCSetFromOptions_BJKOKKOS(PetscOptionItems *PetscOptionsObject,PC pc)
758 {
759   PetscErrorCode ierr;
760 
761   PetscFunctionBegin;
762   ierr = PetscOptionsHead(PetscOptionsObject,"PC BJKOKKOS options");CHKERRQ(ierr);
763   ierr = PetscOptionsTail();CHKERRQ(ierr);
764   PetscFunctionReturn(0);
765 }
766 
767 static PetscErrorCode  PCBJKOKKOSSetKSP_BJKOKKOS(PC pc,KSP ksp)
768 {
769   PC_PCBJKOKKOS         *jac = (PC_PCBJKOKKOS*)pc->data;
770   PetscErrorCode ierr;
771 
772   PetscFunctionBegin;
773   ierr = PetscObjectReference((PetscObject)ksp);CHKERRQ(ierr);
774   ierr = KSPDestroy(&jac->ksp);CHKERRQ(ierr);
775   jac->ksp = ksp;
776   PetscFunctionReturn(0);
777 }
778 
779 /*@C
780    PCBJKOKKOSSetKSP - Sets the KSP context for a KSP PC.
781 
782    Collective on PC
783 
784    Input Parameters:
785 +  pc - the preconditioner context
786 -  ksp - the KSP solver
787 
788    Notes:
789    The PC and the KSP must have the same communicator
790 
791    Level: advanced
792 
793 @*/
794 PetscErrorCode  PCBJKOKKOSSetKSP(PC pc,KSP ksp)
795 {
796   PetscErrorCode ierr;
797 
798   PetscFunctionBegin;
799   PetscValidHeaderSpecific(pc,PC_CLASSID,1);
800   PetscValidHeaderSpecific(ksp,KSP_CLASSID,2);
801   PetscCheckSameComm(pc,1,ksp,2);
802   ierr = PetscTryMethod(pc,"PCBJKOKKOSSetKSP_C",(PC,KSP),(pc,ksp));CHKERRQ(ierr);
803   PetscFunctionReturn(0);
804 }
805 
806 static PetscErrorCode  PCBJKOKKOSGetKSP_BJKOKKOS(PC pc,KSP *ksp)
807 {
808   PC_PCBJKOKKOS         *jac = (PC_PCBJKOKKOS*)pc->data;
809   PetscErrorCode ierr;
810 
811   PetscFunctionBegin;
812   if (!jac->ksp) {ierr = PCBJKOKKOSCreateKSP_BJKOKKOS(pc);CHKERRQ(ierr);}
813   *ksp = jac->ksp;
814   PetscFunctionReturn(0);
815 }
816 
817 /*@C
818    PCBJKOKKOSGetKSP - Gets the KSP context for a KSP PC.
819 
820    Not Collective but KSP returned is parallel if PC was parallel
821 
822    Input Parameter:
823 .  pc - the preconditioner context
824 
825    Output Parameters:
826 .  ksp - the KSP solver
827 
828    Notes:
829    You must call KSPSetUp() before calling PCBJKOKKOSGetKSP().
830 
831    If the PC is not a PCBJKOKKOS object it raises an error
832 
833    Level: advanced
834 
835 @*/
836 PetscErrorCode  PCBJKOKKOSGetKSP(PC pc,KSP *ksp)
837 {
838   PetscErrorCode ierr;
839 
840   PetscFunctionBegin;
841   PetscValidHeaderSpecific(pc,PC_CLASSID,1);
842   PetscValidPointer(ksp,2);
843   ierr = PetscUseMethod(pc,"PCBJKOKKOSGetKSP_C",(PC,KSP*),(pc,ksp));CHKERRQ(ierr);
844   PetscFunctionReturn(0);
845 }
846 
847 /* ----------------------------------------------------------------------------------*/
848 
849 /*MC
850      PCBJKOKKOS -  Defines a preconditioner that applies a Krylov solver and preconditioner to the blocks in a AIJASeq matrix on the GPU.
851 
852    Options Database Key:
853 .     -pc_bjkokkos_
854 
855    Level: intermediate
856 
857    Notes:
858     For use with -ksp_type preonly to bypass any CPU work
859 
860    Developer Notes:
861 
862 .seealso:  PCCreate(), PCSetType(), PCType (for list of available types), PC,
863            PCSHELL, PCCOMPOSITE, PCSetUseAmat(), PCBJKOKKOSGetKSP()
864 
865 M*/
866 
867 PETSC_EXTERN PetscErrorCode PCCreate_BJKOKKOS(PC pc)
868 {
869   PetscErrorCode ierr;
870   PC_PCBJKOKKOS   *jac;
871 
872   PetscFunctionBegin;
873   ierr = PetscNewLog(pc,&jac);CHKERRQ(ierr);
874   pc->data = (void*)jac;
875 
876   jac->ksp = NULL;
877   jac->vec_diag = NULL;
878   jac->d_bid_eqOffset_k = NULL;
879   jac->d_idiag_k = NULL;
880   jac->d_isrow_k = NULL;
881   jac->d_isicol_k = NULL;
882   jac->nBlocks = 1;
883 
884   ierr = PetscMemzero(pc->ops,sizeof(struct _PCOps));CHKERRQ(ierr);
885   pc->ops->apply           = PCApply_BJKOKKOS;
886   pc->ops->applytranspose  = NULL;
887   pc->ops->setup           = PCSetUp_BJKOKKOS;
888   pc->ops->reset           = PCReset_BJKOKKOS;
889   pc->ops->destroy         = PCDestroy_BJKOKKOS;
890   pc->ops->setfromoptions  = PCSetFromOptions_BJKOKKOS;
891   pc->ops->view            = PCView_BJKOKKOS;
892 
893   ierr = PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSGetKSP_C",PCBJKOKKOSGetKSP_BJKOKKOS);CHKERRQ(ierr);
894   ierr = PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSSetKSP_C",PCBJKOKKOSSetKSP_BJKOKKOS);CHKERRQ(ierr);
895   PetscFunctionReturn(0);
896 }
897