xref: /petsc/src/ksp/pc/impls/bjacobi/bjkokkos/bjkokkos.kokkos.cxx (revision a4d8cd18abfb070563a2b0c5ba002b8df2a61504)
1 #include <petscvec_kokkos.hpp>
2 #include <petsc/private/pcimpl.h>
3 #include <petsc/private/kspimpl.h>
4 #include <petscksp.h>            /*I "petscksp.h" I*/
5 #include "petscsection.h"
6 #include <petscdmcomposite.h>
7 #include <Kokkos_Core.hpp>
8 
9 typedef Kokkos::TeamPolicy<>::member_type team_member;
10 
11 #include <../src/mat/impls/aij/seq/aij.h>
12 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
13 
14 #define PCBJKOKKOS_SHARED_LEVEL 1
15 #define PCBJKOKKOS_VEC_SIZE 16
16 #define PCBJKOKKOS_TEAM_SIZE 16
17 #define PCBJKOKKOS_VERBOSE_LEVEL 0
18 //#define PCBJKOKKOS_MONITOR
19 
20 typedef enum {BATCH_KSP_BICG_IDX=0,BATCH_KSP_TFQMR_IDX=1,NUM_BATCH_TYPES} KSPIndex;
21 typedef struct {
22   Vec                                              vec_diag;
23   PetscInt                                         nBlocks; /* total number of blocks */
24   PetscInt                                         n; // cache host version of d_bid_eqOffset_k[nBlocks]
25   KSP                                              ksp; // Used just for options. Should have one for each block
26   Kokkos::View<PetscInt*, Kokkos::LayoutRight>     *d_bid_eqOffset_k;
27   Kokkos::View<PetscScalar*, Kokkos::LayoutRight>  *d_idiag_k;
28   Kokkos::View<PetscInt*>                          *d_isrow_k;
29   Kokkos::View<PetscInt*>                          *d_isicol_k;
30   KSPIndex                                         ksp_type_idx;
31   PetscInt                                         nwork;
32   PetscInt                                         const_block_size; // used to decide to use shared memory for work vectors
33   PetscInt                                         *dm_Nf;  // Number of fields in each DM
34   PetscInt                                         num_dms;
35   // diagnostics
36   PetscBool                                        reason;
37   PetscBool                                        monitor;
38   PetscInt                                         batch_target;
39 } PC_PCBJKOKKOS;
40 
41 static PetscErrorCode  PCBJKOKKOSCreateKSP_BJKOKKOS(PC pc)
42 {
43   PetscErrorCode ierr;
44   const char     *prefix;
45   PC_PCBJKOKKOS   *jac = (PC_PCBJKOKKOS*)pc->data;
46   DM             dm;
47 
48   PetscFunctionBegin;
49   ierr = KSPCreate(PetscObjectComm((PetscObject)pc),&jac->ksp);CHKERRQ(ierr);
50   ierr = KSPSetErrorIfNotConverged(jac->ksp,pc->erroriffailure);CHKERRQ(ierr);
51   ierr = PetscObjectIncrementTabLevel((PetscObject)jac->ksp,(PetscObject)pc,1);CHKERRQ(ierr);
52   ierr = PCGetOptionsPrefix(pc,&prefix);CHKERRQ(ierr);
53   ierr = KSPSetOptionsPrefix(jac->ksp,prefix);CHKERRQ(ierr);
54   ierr = KSPAppendOptionsPrefix(jac->ksp,"pc_bjkokkos_");CHKERRQ(ierr);
55   ierr = PCGetDM(pc, &dm);CHKERRQ(ierr);
56   if (dm) {
57     ierr = KSPSetDM(jac->ksp, dm);CHKERRQ(ierr);
58     ierr = KSPSetDMActive(jac->ksp, PETSC_FALSE);CHKERRQ(ierr);
59   }
60   jac->reason = PETSC_FALSE;
61   jac->monitor = PETSC_FALSE;
62   jac->batch_target = 0;
63 
64   PetscFunctionReturn(0);
65 }
66 
67 // y <-- Ax
68 KOKKOS_INLINE_FUNCTION PetscErrorCode MatMult(const team_member team,  const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
69 {
70   Kokkos::parallel_for(Kokkos::TeamThreadRange(team,start,end), [=] (const int rowb) {
71       int rowa = ic[rowb];
72       int n = glb_Aai[rowa+1] - glb_Aai[rowa];
73       const PetscInt    *aj  = glb_Aaj + glb_Aai[rowa];
74       const PetscScalar *aa  = glb_Aaa + glb_Aai[rowa];
75       PetscScalar sum;
76       Kokkos::parallel_reduce(Kokkos::ThreadVectorRange (team, n), [=] (const int i, PetscScalar& lsum) {
77           lsum += aa[i] * x_loc[r[aj[i]]-start];
78         }, sum);
79       Kokkos::single(Kokkos::PerThread (team),[=]() {y_loc[rowb-start] = sum;});
80     });
81   team.team_barrier();
82   return 0;
83 }
84 
85 // temp buffer per thread with reduction at end?
86 KOKKOS_INLINE_FUNCTION PetscErrorCode MatMultTranspose(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
87 {
88   Kokkos::parallel_for(Kokkos::TeamVectorRange(team,end-start), [=] (int i) { y_loc[i] = 0;});
89   team.team_barrier();
90   Kokkos::parallel_for(Kokkos::TeamThreadRange(team,start,end), [=] (const int rowb) {
91       int rowa = ic[rowb];
92       int n = glb_Aai[rowa+1] - glb_Aai[rowa];
93       const PetscInt    *aj  = glb_Aaj + glb_Aai[rowa];
94       const PetscScalar *aa  = glb_Aaa + glb_Aai[rowa];
95       const PetscScalar xx = x_loc[rowb-start]; // rowb = ic[rowa] = ic[r[rowb]]
96       Kokkos::parallel_for(Kokkos::ThreadVectorRange(team,n), [=] (const int &i) {
97           PetscScalar val = aa[i] * xx;
98           Kokkos::atomic_fetch_add(&y_loc[r[aj[i]]-start], val);
99         });
100     });
101   team.team_barrier();
102   return 0;
103 }
104 
105 typedef struct Batch_MetaData_TAG
106 {
107   PetscInt           flops;
108   PetscInt           its;
109   KSPConvergedReason reason;
110 }Batch_MetaData;
111 
112 // Solve A(BB^-1)x = y with TFQMR. Right preconditioned to get un-preconditioned residual
113 KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_TFQMR(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space, const PetscInt stride, PetscReal rtol, PetscReal atol, PetscReal dtol,PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
114 {
115   using Kokkos::parallel_reduce;
116   using Kokkos::parallel_for;
117   int               Nblk = end-start, i,m;
118   PetscReal         dp,dpold,w,dpest,tau,psi,cm,r0;
119   PetscScalar       *ptr = work_space, rho,rhoold,a,s,b,eta,etaold,psiold,cf,dpi;
120   const PetscScalar *Diag = &glb_idiag[start];
121   PetscScalar       *XX = ptr; ptr += stride;
122   PetscScalar       *R = ptr; ptr += stride;
123   PetscScalar       *RP = ptr; ptr += stride;
124   PetscScalar       *V = ptr; ptr += stride;
125   PetscScalar       *T = ptr; ptr += stride;
126   PetscScalar       *Q = ptr; ptr += stride;
127   PetscScalar       *P = ptr; ptr += stride;
128   PetscScalar       *U = ptr; ptr += stride;
129   PetscScalar       *D = ptr; ptr += stride;
130   PetscScalar       *AUQ = V;
131 
132   // init: get b, zero x
133   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
134       int rowa = ic[rowb];
135       R[rowb-start] = glb_b[rowa];
136       XX[rowb-start] = 0;
137     });
138   team.team_barrier();
139   parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += R[idx]*PetscConj(R[idx]);}, dpi);
140   team.team_barrier();
141   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
142   // diagnostics
143   if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", 0, (double)dp);});
144 
145   if (dp < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; return 0;}
146   if (0 == maxit) {metad->reason = KSP_DIVERGED_ITS; return 0;}
147 
148   /* Make the initial Rp = R */
149   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {RP[idx] = R[idx];});
150   team.team_barrier();
151   /* Set the initial conditions */
152   etaold = 0.0;
153   psiold = 0.0;
154   tau    = dp;
155   dpold  = dp;
156 
157   /* rhoold = (r,rp)     */
158   parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += R[idx]*PetscConj(RP[idx]);}, rhoold);
159   team.team_barrier();
160   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {U[idx] = R[idx]; P[idx] = R[idx]; T[idx] = Diag[idx]*P[idx]; D[idx] = 0;});
161   team.team_barrier();
162   MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,T,V);
163 
164   i=0;
165   do {
166     /* s <- (v,rp)          */
167     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += V[idx]*PetscConj(RP[idx]);}, s);
168     team.team_barrier();
169     a    = rhoold / s;                              /* a <- rho / s         */
170     /* q <- u - a v    VecWAXPY(w,alpha,x,y): w = alpha x + y.     */
171     /* t <- u + q           */
172     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Q[idx] = U[idx] - a*V[idx]; T[idx] = U[idx] + Q[idx];});
173     team.team_barrier();
174     // KSP_PCApplyBAorAB
175     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {T[idx] = Diag[idx]*T[idx]; });
176     team.team_barrier();
177     MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,T,AUQ);
178     /* r <- r - a K (u + q) */
179     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {R[idx] = R[idx] - a*AUQ[idx]; });
180     team.team_barrier();
181     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += R[idx]*PetscConj(R[idx]);}, dpi);
182     team.team_barrier();
183     dp = PetscSqrtReal(PetscRealPart(dpi));
184     for (m=0; m<2; m++) {
185       if (!m) w = PetscSqrtReal(dp*dpold);
186       else w = dp;
187       psi = w / tau;
188       cm  = 1.0 / PetscSqrtReal(1.0 + psi * psi);
189       tau = tau * psi * cm;
190       eta = cm * cm * a;
191       cf  = psiold * psiold * etaold / a;
192       if (!m) {
193         /* D = U + cf D */
194         parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {D[idx] = U[idx] + cf*D[idx]; });
195       } else {
196         /* D = Q + cf D */
197         parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {D[idx] = Q[idx] + cf*D[idx]; });
198       }
199       team.team_barrier();
200       parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {XX[idx] = XX[idx] + eta*D[idx]; });
201       team.team_barrier();
202       dpest = PetscSqrtReal(2*i + m + 2.0) * tau;
203       if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", i+1, (double)dpest);});
204 
205       if (dpest < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; goto done;}
206       if (dpest/r0 < rtol) {metad->reason = KSP_CONVERGED_RTOL_NORMAL; goto done;}
207 #if defined(PETSC_USE_DEBUG) || 1
208       if (dpest/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; Kokkos::single (Kokkos::PerTeam (team), [=] () {printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n",team.league_rank(),i,dpest,r0);}); goto done;}
209 #else
210       if (dpest/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; goto done;}
211 #endif
212       if (i+1 == maxit) {metad->reason = KSP_DIVERGED_ITS; goto done;}
213 
214       etaold = eta;
215       psiold = psi;
216     }
217 
218     /* rho <- (r,rp)       */
219     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += R[idx]*PetscConj(RP[idx]);}, rho);
220     team.team_barrier();
221     b    = rho / rhoold;                            /* b <- rho / rhoold   */
222     /* u <- r + b q        */
223     /* p <- u + b(q + b p) */
224     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {U[idx] = R[idx] + b*Q[idx]; Q[idx] = Q[idx] + b*P[idx]; P[idx] = U[idx] + b*Q[idx];});
225     /* v <- K p  */
226     team.team_barrier();
227     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {T[idx] = Diag[idx]*P[idx]; });
228     team.team_barrier();
229     MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,T,V);
230 
231     rhoold = rho;
232     dpold  = dp;
233 
234     i++;
235   } while (i<maxit);
236   done:
237   // KSPUnwindPreconditioner
238   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {XX[idx] = Diag[idx]*XX[idx]; });
239   team.team_barrier();
240   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
241       int rowa = ic[rowb];
242       glb_x[rowa] = XX[rowb-start];
243     });
244   metad->its = i+1;
245   if (1) {
246     int nnz;
247     parallel_reduce(Kokkos::TeamVectorRange (team, start, end), [=] (const int idx, int& lsum) {lsum += (glb_Aai[idx+1] - glb_Aai[idx]);}, nnz);
248     metad->flops = 2*(metad->its*(10*Nblk + 2*nnz) + 5*Nblk);
249   } else {
250     metad->flops = 2*(metad->its*(10*Nblk + 2*50*Nblk) + 5*Nblk); // guess
251   }
252   return 0;
253 }
254 
255 // Solve Ax = y with biCG
256 KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_BICG(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space, const PetscInt stride, PetscReal rtol, PetscReal atol, PetscReal dtol,PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
257 {
258   using Kokkos::parallel_reduce;
259   using Kokkos::parallel_for;
260   int               Nblk = end-start, i;
261   PetscReal         dp, r0;
262   PetscScalar       *ptr = work_space, dpi, a=1.0, beta, betaold=1.0, b, b2, ma, mac;
263   const PetscScalar *Di = &glb_idiag[start];
264   PetscScalar       *XX = ptr; ptr += stride;
265   PetscScalar       *Rl = ptr; ptr += stride;
266   PetscScalar       *Zl = ptr; ptr += stride;
267   PetscScalar       *Pl = ptr; ptr += stride;
268   PetscScalar       *Rr = ptr; ptr += stride;
269   PetscScalar       *Zr = ptr; ptr += stride;
270   PetscScalar       *Pr = ptr; ptr += stride;
271 
272   /*     r <- b (x is 0) */
273   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
274       int rowa = ic[rowb];
275       //ierr = VecCopy(Rr,Rl);CHKERRQ(ierr);
276       Rl[rowb-start] = Rr[rowb-start] = glb_b[rowa];
277       XX[rowb-start] = 0;
278     });
279   team.team_barrier();
280   /*     z <- Br         */
281   parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Zr[idx] = Di[idx]*Rr[idx]; Zl[idx] = Di[idx]*Rl[idx]; });
282   team.team_barrier();
283   /*    dp <- r'*r       */
284   parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += Rr[idx]*PetscConj(Rr[idx]);}, dpi);
285   team.team_barrier();
286   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
287   if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", 0, (double)dp);});
288 
289   if (dp < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; return 0;}
290   if (0 == maxit) {metad->reason = KSP_DIVERGED_ITS; return 0;}
291   i = 0;
292   do {
293     /*     beta <- r'z     */
294     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& dot) {dot += Zr[idx]*PetscConj(Rl[idx]);}, beta);
295     team.team_barrier();
296 #if PCBJKOKKOS_VERBOSE_LEVEL >= 6
297     Kokkos::single (Kokkos::PerTeam (team), [=] () {printf("%7d beta = Z.R = %22.14e \n",i,(double)beta);});
298 #endif
299     if (!i) {
300       if (beta == 0.0) {
301         metad->reason = KSP_DIVERGED_BREAKDOWN_BICG;
302         goto done;
303       }
304       /*     p <- z          */
305       parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Pr[idx] = Zr[idx]; Pl[idx] = Zl[idx];});
306     } else {
307       b    = beta/betaold;
308       /*     p <- z + b* p   */
309       b2    = PetscConj(b);
310       parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Pr[idx] = b*Pr[idx] + Zr[idx]; Pl[idx] = b2*Pl[idx] + Zl[idx];});
311     }
312     team.team_barrier();
313     betaold = beta;
314     /*     z <- Kp         */
315     MatMult         (team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,Pr,Zr);
316     MatMultTranspose(team,glb_Aai,glb_Aaj,glb_Aaa,r,ic,start,end,Pl,Zl);
317     /*     dpi <- z'p      */
318     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum += Zr[idx]*PetscConj(Pl[idx]);}, dpi);
319     team.team_barrier();
320     //
321     a       = beta/dpi;                           /*     a = beta/p'z    */
322     ma      = -a;
323     mac      = PetscConj(ma);
324     /*     x <- x + ap     */
325     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {XX[idx] = XX[idx] + a*Pr[idx]; Rr[idx] = Rr[idx] + ma*Zr[idx]; Rl[idx] = Rl[idx] + mac*Zl[idx];});team.team_barrier();
326     team.team_barrier();
327     /*    dp <- r'*r       */
328     parallel_reduce(Kokkos::TeamVectorRange (team, Nblk), [=] (const int idx, PetscScalar& lsum) {lsum +=  Rr[idx]*PetscConj(Rr[idx]);}, dpi);
329     team.team_barrier();
330     dp = PetscSqrtReal(PetscRealPart(dpi));
331     if (monitor) Kokkos::single (Kokkos::PerTeam (team), [=] () { printf("%3d KSP Residual norm %14.12e \n", i+1, (double)dp);});
332 
333     if (dp < atol) {metad->reason = KSP_CONVERGED_ATOL_NORMAL; goto done;}
334     if (dp/r0 < rtol) {metad->reason = KSP_CONVERGED_RTOL_NORMAL; goto done;}
335 #if defined(PETSC_USE_DEBUG) || 1
336     if (dp/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; Kokkos::single (Kokkos::PerTeam (team), [=] () {printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n",team.league_rank(),i,dp,r0);}); goto done;}
337 #else
338     if (dp/r0 > dtol) {metad->reason = KSP_DIVERGED_DTOL; goto done;}
339 #endif
340     if (i+1 == maxit) {metad->reason = KSP_DIVERGED_ITS; goto done;}
341     /* z <- Br  */
342     parallel_for(Kokkos::TeamVectorRange(team,Nblk), [=] (int idx) {Zr[idx] = Di[idx]*Rr[idx]; Zl[idx] = Di[idx]*Rl[idx];});
343     i++;
344     team.team_barrier();
345   } while (i<maxit);
346  done:
347   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=] (int rowb) {
348       int rowa = ic[rowb];
349       glb_x[rowa] = XX[rowb-start];
350     });
351   metad->its = i+1;
352   if (1) {
353     int nnz;
354     parallel_reduce(Kokkos::TeamVectorRange (team, start, end), [=] (const int idx, int& lsum) {lsum += (glb_Aai[idx+1] - glb_Aai[idx]);}, nnz);
355     metad->flops = 2*(metad->its*(10*Nblk + 2*nnz) + 5*Nblk);
356   } else {
357     metad->flops = 2*(metad->its*(10*Nblk + 2*50*Nblk) + 5*Nblk); // guess
358   }
359   return 0;
360 }
361 
362 // KSP solver solve Ax = b; x is output, bin is input
363 static PetscErrorCode PCApply_BJKOKKOS(PC pc,Vec bin,Vec xout)
364 {
365   PetscErrorCode      ierr;
366   PC_PCBJKOKKOS       *jac = (PC_PCBJKOKKOS*)pc->data;
367   Mat                 A = pc->pmat;
368   Mat_SeqAIJKokkos    *aijkok;
369 
370   PetscFunctionBegin;
371   if (!jac->vec_diag || !A) SETERRQ(PetscObjectComm((PetscObject)pc),PETSC_ERR_USER,"Not setup???? %p %p",jac->vec_diag,A);
372   if (!(aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr))) SETERRQ(PetscObjectComm((PetscObject)pc),PETSC_ERR_USER,"No aijkok");
373   else {
374     using scr_mem_t  = Kokkos::DefaultExecutionSpace::scratch_memory_space;
375     using vect2D_scr_t = Kokkos::View<PetscScalar**, Kokkos::LayoutLeft, scr_mem_t>;
376     PetscInt          *d_bid_eqOffset, maxit = jac->ksp->max_it, scr_bytes_team, stride, global_buff_size;
377     const PetscInt    conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp==0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
378     const PetscInt    nwork = jac->nwork, nBlk = jac->nBlocks;
379     PetscScalar       *glb_xdata=NULL;
380     PetscReal         rtol = jac->ksp->rtol, atol = jac->ksp->abstol, dtol = jac->ksp->divtol;
381     const PetscScalar *glb_idiag =jac->d_idiag_k->data(), *glb_bdata=NULL;
382     const PetscInt    *glb_Aai = aijkok->i_device_data(), *glb_Aaj = aijkok->j_device_data();
383     const PetscScalar *glb_Aaa = aijkok->a_device_data();
384     Kokkos::View<Batch_MetaData*, Kokkos::DefaultExecutionSpace> d_metadata("solver meta data", nBlk);
385     PCFailedReason    pcreason;
386     KSPIndex          ksp_type_idx = jac->ksp_type_idx;
387     PetscMemType      mtype;
388     PetscContainer    container;
389     PetscInt          batch_sz;
390     VecScatter        plex_batch=NULL;
391     Vec               bvec;
392     PetscBool         monitor = jac->monitor; // captured
393     PetscInt          view_bid = jac->batch_target;
394     // get field major is to map plex IO to/from block/field major
395     ierr = PetscObjectQuery((PetscObject) A, "plex_batch_is", (PetscObject *) &container);CHKERRQ(ierr);
396     ierr = VecDuplicate(bin,&bvec);CHKERRQ(ierr);
397     if (container) {
398       ierr = PetscContainerGetPointer(container, (void **) &plex_batch);CHKERRQ(ierr);
399       ierr = VecScatterBegin(plex_batch,bin,bvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
400       ierr = VecScatterEnd(plex_batch,bin,bvec,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
401     } else {
402       ierr = VecCopy(bin, bvec);CHKERRQ(ierr);
403     }
404     // get x
405     ierr = VecGetArrayAndMemType(xout,&glb_xdata,&mtype);CHKERRQ(ierr);
406 #if defined(PETSC_HAVE_CUDA)
407     if (mtype!=PETSC_MEMTYPE_DEVICE) SETERRQ(PetscObjectComm((PetscObject) pc),PETSC_ERR_ARG_WRONG,"No GPU data for x %D != %D",mtype,PETSC_MEMTYPE_DEVICE);
408 #endif
409     ierr = VecGetArrayReadAndMemType(bvec,&glb_bdata,&mtype);CHKERRQ(ierr);
410 #if defined(PETSC_HAVE_CUDA)
411     if (mtype!=PETSC_MEMTYPE_DEVICE) SETERRQ(PetscObjectComm((PetscObject) pc),PETSC_ERR_ARG_WRONG,"No GPU data for b");
412 #endif
413     // get batch size
414     ierr = PetscObjectQuery((PetscObject) A, "batch size", (PetscObject *) &container);CHKERRQ(ierr);
415     if (container) {
416       PetscInt *pNf=NULL;
417       ierr = PetscContainerGetPointer(container, (void **) &pNf);CHKERRQ(ierr);
418       batch_sz = *pNf;
419     } else batch_sz = 1;
420     if (nBlk%batch_sz) SETERRQ(PetscObjectComm((PetscObject) pc),PETSC_ERR_ARG_WRONG,"batch_sz = %D, nBlk = %D",batch_sz,nBlk);
421     d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
422     // solve each block independently
423     if (jac->const_block_size) { // use shared memory for work vectors only if constant block size - todo: test efficiency loss
424       scr_bytes_team = jac->const_block_size*nwork*sizeof(PetscScalar);
425       stride = jac->const_block_size; // captured
426       global_buff_size = 0;
427     } else {
428       scr_bytes_team = 0;
429       stride = jac->n; // captured
430       global_buff_size = jac->n*nwork;
431     }
432     Kokkos::View<PetscScalar*, Kokkos::DefaultExecutionSpace> d_work_vecs_k("workvectors", global_buff_size); // global work vectors
433     PetscInfo(pc,"\tn = %D. %d shared mem words/team. %D global mem words, rtol=%e, num blocks %D, team_size=%D, %D vector threads\n",jac->n,scr_bytes_team/sizeof(PetscScalar),global_buff_size,rtol,nBlk,
434                team_size, PCBJKOKKOS_VEC_SIZE);
435     PetscScalar  *d_work_vecs = scr_bytes_team ? NULL : d_work_vecs_k.data();
436     const PetscInt *d_isicol = jac->d_isicol_k->data(), *d_isrow = jac->d_isrow_k->data();
437     Kokkos::parallel_for("Solve", Kokkos::TeamPolicy<>(nBlk, team_size, PCBJKOKKOS_VEC_SIZE).set_scratch_size(PCBJKOKKOS_SHARED_LEVEL, Kokkos::PerTeam(scr_bytes_team)),
438         KOKKOS_LAMBDA (const team_member team) {
439         const int    blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID+1];
440         vect2D_scr_t work_vecs(team.team_scratch(PCBJKOKKOS_SHARED_LEVEL), scr_bytes_team ? (end-start) : 0, nwork);
441         PetscScalar *work_buff = (scr_bytes_team) ? work_vecs.data() : &d_work_vecs[start];
442         bool        print = monitor && (blkID==view_bid);
443         switch (ksp_type_idx) {
444         case BATCH_KSP_BICG_IDX:
445           BJSolve_BICG(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff, stride, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print);
446           break;
447         case BATCH_KSP_TFQMR_IDX:
448           BJSolve_TFQMR(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff, stride, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print);
449           break;
450         default:
451 #if defined(PETSC_USE_DEBUG)
452           printf("Unknown KSP type %d\n",ksp_type_idx);
453 #else
454           /* void */;
455 #endif
456         }
457 #if defined(PETSC_USE_DEBUG)
458         if (d_metadata[blkID].reason<0) printf("Solver diverged %d\n",d_metadata[blkID].reason);
459 #endif
460       });
461     auto h_metadata = Kokkos::create_mirror(Kokkos::HostSpace::memory_space(), d_metadata);
462     Kokkos::fence();
463     Kokkos::deep_copy (h_metadata, d_metadata);
464 #if PCBJKOKKOS_VERBOSE_LEVEL >= 3
465 #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
466     ierr = PetscPrintf(PETSC_COMM_WORLD,"Iterations\n");CHKERRQ(ierr);
467 #endif
468     // assume species major
469 #if PCBJKOKKOS_VERBOSE_LEVEL < 4
470     ierr = PetscPrintf(PETSC_COMM_WORLD,"max iterations per species (%s) :",ksp_type_idx==BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr");CHKERRQ(ierr);
471 #endif
472     for (PetscInt dmIdx=0, s=0, head=0 ; dmIdx < jac->num_dms; dmIdx += batch_sz) {
473       for (PetscInt f=0, idx=head ; f < jac->dm_Nf[dmIdx] ; f++,s++,idx++) {
474 #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
475         ierr = PetscPrintf(PETSC_COMM_WORLD,"%2D:", s);CHKERRQ(ierr);
476         for (int bid=0 ; bid<batch_sz ; bid++) {
477          ierr = PetscPrintf(PETSC_COMM_WORLD,"%3D ", h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its);CHKERRQ(ierr);
478         }
479         ierr = PetscPrintf(PETSC_COMM_WORLD,"\n");CHKERRQ(ierr);
480 #else
481         PetscInt count=0;
482         for (int bid=0 ; bid<batch_sz ; bid++) {
483           if (h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its > count) count = h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its;
484         }
485         ierr = PetscPrintf(PETSC_COMM_WORLD,"%3D ", count);CHKERRQ(ierr);
486 #endif
487       }
488       head += batch_sz*jac->dm_Nf[dmIdx];
489     }
490 #if PCBJKOKKOS_VERBOSE_LEVEL < 4
491     ierr = PetscPrintf(PETSC_COMM_WORLD,"\n");CHKERRQ(ierr);
492 #endif
493 #elif PCBJKOKKOS_VERBOSE_LEVEL >= 2
494     PetscInt count=0;
495     for (PetscInt dmIdx=0, s=0, head=0 ; dmIdx < jac->num_dms; dmIdx += batch_sz) {
496       for (PetscInt f=0, idx=head ; f < jac->dm_Nf[dmIdx] ; f++,s++,idx++) {
497         if (h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its > count) count = h_metadata[idx + bid*jac->dm_Nf[dmIdx]].its;
498       }
499     }
500     ierr = PetscPrintf(PETSC_COMM_WORLD,"%D max iterations", count);CHKERRQ(ierr);
501 #endif
502     for (int blkID=0;blkID<nBlk;blkID++) {
503       ierr = PetscLogGpuFlops((PetscLogDouble)h_metadata[blkID].flops);CHKERRQ(ierr);
504       if (jac->reason) {
505         if (jac->batch_target==blkID || (jac->batch_target==-1 && h_metadata[blkID].its > 100)) {
506           ierr = PetscPrintf(PETSC_COMM_SELF,  "    Linear solve converged due to %s iterations %d\n", KSPConvergedReasons[h_metadata[blkID].reason], h_metadata[blkID].its);CHKERRQ(ierr);
507         }
508         if (jac->batch_target==-1 && h_metadata[blkID].reason < 0) {
509           ierr = PetscPrintf(PETSC_COMM_SELF, "ERROR reason=%s, its=%D. species %D, batch %D, %D species\n",
510                              KSPConvergedReasons[h_metadata[blkID].reason],h_metadata[blkID].its,blkID/batch_sz,blkID%batch_sz,nBlk/batch_sz);CHKERRQ(ierr);
511         }
512       }
513     }
514     ierr = VecRestoreArrayAndMemType(xout,&glb_xdata);CHKERRQ(ierr);
515     ierr = VecRestoreArrayReadAndMemType(bvec,&glb_bdata);CHKERRQ(ierr);
516     {
517       int errsum;
518       Kokkos::parallel_reduce(nBlk, KOKKOS_LAMBDA (const int idx, int& lsum) {
519           if (d_metadata[idx].reason < 0 && d_metadata[idx].reason != KSP_DIVERGED_ITS && d_metadata[idx].reason != KSP_CONVERGED_ITS) lsum += 1;
520         }, errsum);
521       if (!errsum) pcreason = PC_NOERROR;
522       else pcreason = PC_SUBPC_ERROR;
523     }
524     ierr = PCSetFailedReason(pc,pcreason);CHKERRQ(ierr);
525     // map back to Plex space
526     if (plex_batch) {
527       ierr = VecCopy(xout, bvec);CHKERRQ(ierr);
528       ierr = VecScatterBegin(plex_batch,bvec,xout,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
529       ierr = VecScatterEnd(plex_batch,bvec,xout,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
530     }
531     ierr = VecDestroy(&bvec);CHKERRQ(ierr);
532   }
533 
534   PetscFunctionReturn(0);
535 }
536 
537 static PetscErrorCode PCSetUp_BJKOKKOS(PC pc)
538 {
539   PetscErrorCode    ierr;
540   PC_PCBJKOKKOS     *jac = (PC_PCBJKOKKOS*)pc->data;
541   Mat               A = pc->pmat;
542   Mat_SeqAIJKokkos  *aijkok;
543   PetscBool         flg;
544 
545   PetscFunctionBegin;
546   if (pc->useAmat) SETERRQ(PetscObjectComm((PetscObject)pc),PETSC_ERR_SUP,"No support for using 'use_amat'");
547   if (!A) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_USER,"No matrix - A is used above");
548   ierr = PetscObjectTypeCompareAny((PetscObject)A,&flg,MATSEQAIJKOKKOS,MATMPIAIJKOKKOS,MATAIJKOKKOS,"");CHKERRQ(ierr);
549   if (!flg) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONG,"must use '-dm_mat_type aijkokkos -dm_vec_type kokkos' for -pc_type bjkokkos");
550   if (!(aijkok = static_cast<Mat_SeqAIJKokkos*>(A->spptr))) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_USER,"No aijkok");
551   else {
552     if (!jac->vec_diag) {
553       Vec               *subX;
554       DM                pack,*subDM;
555       PetscInt          nDMs, n;
556       PetscContainer    container;
557       ierr = PetscObjectQuery((PetscObject) A, "plex_batch_is", (PetscObject *) &container);CHKERRQ(ierr);
558       { // Permute the matrix to get a block diagonal system: d_isrow_k, d_isicol_k
559         MatOrderingType   rtype;
560         IS                isrow,isicol;
561         const PetscInt    *rowindices,*icolindices;
562 
563         if (container) rtype = MATORDERINGNATURAL; // if we have a vecscatter then don't reorder here (all the reorder stuff goes away in future)
564         else rtype = MATORDERINGRCM;
565         // get permutation. Not what I expect so inverted here
566         ierr = MatGetOrdering(A,rtype,&isrow,&isicol);CHKERRQ(ierr);
567         ierr = ISDestroy(&isrow);CHKERRQ(ierr);
568         ierr = ISInvertPermutation(isicol,PETSC_DECIDE,&isrow);CHKERRQ(ierr);
569         ierr = ISGetIndices(isrow,&rowindices);CHKERRQ(ierr);
570         ierr = ISGetIndices(isicol,&icolindices);CHKERRQ(ierr);
571         const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_isrow_k((PetscInt*)rowindices,A->rmap->n);
572         const Kokkos::View<PetscInt*, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged> > h_isicol_k ((PetscInt*)icolindices,A->rmap->n);
573         jac->d_isrow_k = new Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_isrow_k));
574         jac->d_isicol_k = new Kokkos::View<PetscInt*>(Kokkos::create_mirror(DefaultMemorySpace(),h_isicol_k));
575         Kokkos::deep_copy (*jac->d_isrow_k, h_isrow_k);
576         Kokkos::deep_copy (*jac->d_isicol_k, h_isicol_k);
577         ierr = ISRestoreIndices(isrow,&rowindices);CHKERRQ(ierr);
578         ierr = ISRestoreIndices(isicol,&icolindices);CHKERRQ(ierr);
579         ierr = ISDestroy(&isrow);CHKERRQ(ierr);
580         ierr = ISDestroy(&isicol);CHKERRQ(ierr);
581       }
582       // get block sizes
583       ierr = PCGetDM(pc, &pack);CHKERRQ(ierr);
584       if (!pack) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_USER,"no DM. Requires a composite DM");
585       ierr = PetscObjectTypeCompare((PetscObject)pack,DMCOMPOSITE,&flg);CHKERRQ(ierr);
586       if (!flg) SETERRQ(PetscObjectComm((PetscObject)pack),PETSC_ERR_USER,"Not for type %s",((PetscObject)pack)->type_name);
587       ierr = DMCompositeGetNumberDM(pack,&nDMs);CHKERRQ(ierr);
588       jac->num_dms = nDMs;
589       ierr = DMCreateGlobalVector(pack, &jac->vec_diag);CHKERRQ(ierr);
590       ierr = VecGetLocalSize(jac->vec_diag,&n);CHKERRQ(ierr);
591       jac->n = n;
592       jac->d_idiag_k = new Kokkos::View<PetscScalar*, Kokkos::LayoutRight>("idiag", n);
593       // options
594       ierr = PCBJKOKKOSCreateKSP_BJKOKKOS(pc);CHKERRQ(ierr);
595       ierr = KSPSetFromOptions(jac->ksp);CHKERRQ(ierr);
596       ierr = PetscObjectTypeCompareAny((PetscObject)jac->ksp,&flg,KSPBICG,"");CHKERRQ(ierr);
597       if (flg) {jac->ksp_type_idx = BATCH_KSP_BICG_IDX; jac->nwork = 7;}
598       else {
599         ierr = PetscObjectTypeCompareAny((PetscObject)jac->ksp,&flg,KSPTFQMR,"");CHKERRQ(ierr);
600         if (flg) {jac->ksp_type_idx = BATCH_KSP_TFQMR_IDX; jac->nwork = 10;}
601         else SETERRQ(PetscObjectComm((PetscObject)jac->ksp),PETSC_ERR_ARG_WRONG,"unsupported type %s", ((PetscObject)jac->ksp)->type_name);
602       }
603       {
604         PetscViewer       viewer;
605         PetscBool         flg;
606         PetscViewerFormat format;
607         ierr   = PetscOptionsGetViewer(PetscObjectComm((PetscObject)jac->ksp),((PetscObject)jac->ksp)->options,((PetscObject)jac->ksp)->prefix,"-ksp_converged_reason",&viewer,&format,&flg);CHKERRQ(ierr);
608         jac->reason = flg;
609         ierr = PetscViewerDestroy(&viewer);CHKERRQ(ierr);
610         ierr   = PetscOptionsGetViewer(PetscObjectComm((PetscObject)jac->ksp),((PetscObject)jac->ksp)->options,((PetscObject)jac->ksp)->prefix,"-ksp_monitor",&viewer,&format,&flg);CHKERRQ(ierr);
611         jac->monitor = flg;
612         ierr = PetscViewerDestroy(&viewer);CHKERRQ(ierr);
613         ierr = PetscOptionsGetInt(((PetscObject)jac->ksp)->options,((PetscObject)jac->ksp)->prefix,"-ksp_batch_target",&jac->batch_target,NULL);CHKERRQ(ierr);
614         PetscCheckFalse(jac->batch_target >= jac->num_dms,PETSC_COMM_WORLD,PETSC_ERR_ARG_WRONG,"-ksp_batch_target (%D) >= number of DMs (%D)",jac->batch_target,jac->num_dms);
615       }
616       // get blocks - jac->d_bid_eqOffset_k
617       ierr = PetscMalloc(sizeof(*subX)*nDMs, &subX);CHKERRQ(ierr);
618       ierr = PetscMalloc(sizeof(*subDM)*nDMs, &subDM);CHKERRQ(ierr);
619       ierr = PetscMalloc(sizeof(*jac->dm_Nf)*nDMs, &jac->dm_Nf);CHKERRQ(ierr);
620       ierr = PetscInfo(pc, "Have %D DMs, n=%D rtol=%g type = %s\n", nDMs, n, jac->ksp->rtol, ((PetscObject)jac->ksp)->type_name);CHKERRQ(ierr);
621       ierr = DMCompositeGetEntriesArray(pack,subDM);CHKERRQ(ierr);
622       jac->nBlocks = 0;
623       for (PetscInt ii=0;ii<nDMs;ii++) {
624         PetscSection section;
625         PetscInt Nf;
626         DM dm = subDM[ii];
627         ierr = DMGetLocalSection(dm, &section);CHKERRQ(ierr);
628         ierr = PetscSectionGetNumFields(section, &Nf);CHKERRQ(ierr);
629         jac->nBlocks += Nf;
630 #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
631         if (ii==0) { ierr = PetscInfo(pc,"%D) %D blocks (%D total)\n",ii,Nf,jac->nBlocks); }
632 #else
633         ierr = PetscInfo(pc,"%D) %D blocks (%D total)\n",ii,Nf,jac->nBlocks);
634 #endif
635         jac->dm_Nf[ii] = Nf;
636       }
637       { // d_bid_eqOffset_k
638         Kokkos::View<PetscInt*, Kokkos::LayoutRight, Kokkos::HostSpace> h_block_offsets("block_offsets", jac->nBlocks+1);
639         ierr = DMCompositeGetAccessArray(pack, jac->vec_diag, nDMs, NULL, subX);CHKERRQ(ierr);
640         h_block_offsets[0] = 0;
641         jac->const_block_size = -1;
642         for (PetscInt ii=0, idx = 0;ii<nDMs;ii++) {
643           PetscInt nloc,nblk;
644           ierr = VecGetSize(subX[ii],&nloc);CHKERRQ(ierr);
645           nblk = nloc/jac->dm_Nf[ii];
646           if (nloc%jac->dm_Nf[ii]) SETERRQ(PetscObjectComm((PetscObject)pc),PETSC_ERR_USER,"nloc%jac->dm_Nf[ii] DMs",nloc,jac->dm_Nf[ii]);
647           for (PetscInt jj=0;jj<jac->dm_Nf[ii];jj++, idx++) {
648             h_block_offsets[idx+1] = h_block_offsets[idx] + nblk;
649 #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
650             if (idx==0) {ierr = PetscInfo(pc,"\t%D) Add block with %D equations of %D\n",idx+1,nblk,jac->nBlocks);CHKERRQ(ierr);}
651 #else
652             ierr = PetscInfo(pc,"\t%D) Add block with %D equations of %D\n",idx+1,nblk,jac->nBlocks);CHKERRQ(ierr);
653 #endif
654             if (jac->const_block_size == -1) jac->const_block_size = nblk;
655             else if (jac->const_block_size > 0 && jac->const_block_size != nblk) jac->const_block_size = 0;
656           }
657         }
658         ierr = DMCompositeRestoreAccessArray(pack, jac->vec_diag, jac->nBlocks, NULL, subX);CHKERRQ(ierr);
659         ierr = PetscFree(subX);CHKERRQ(ierr);
660         ierr = PetscFree(subDM);CHKERRQ(ierr);
661         jac->d_bid_eqOffset_k = new Kokkos::View<PetscInt*, Kokkos::LayoutRight>(Kokkos::create_mirror(Kokkos::DefaultExecutionSpace::memory_space(),h_block_offsets));
662         Kokkos::deep_copy (*jac->d_bid_eqOffset_k, h_block_offsets);
663       }
664     }
665     { // get jac->d_idiag_k (PC setup),
666       const PetscInt    *d_ai=aijkok->i_device_data(), *d_aj=aijkok->j_device_data();
667       const PetscScalar *d_aa = aijkok->a_device_data();
668       const PetscInt    conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp==0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
669       PetscInt          *d_bid_eqOffset = jac->d_bid_eqOffset_k->data(), *r = jac->d_isrow_k->data(), *ic = jac->d_isicol_k->data();
670       PetscScalar       *d_idiag = jac->d_idiag_k->data();
671       Kokkos::parallel_for("Diag", Kokkos::TeamPolicy<>(jac->nBlocks, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA (const team_member team) {
672           const PetscInt blkID = team.league_rank();
673           Kokkos::parallel_for
674             (Kokkos::TeamThreadRange(team,d_bid_eqOffset[blkID],d_bid_eqOffset[blkID+1]),
675              [=] (const int rowb) {
676                const PetscInt    rowa = ic[rowb], ai = d_ai[rowa], *aj = d_aj + ai; // grab original data
677                const PetscScalar *aa  = d_aa + ai;
678                const PetscInt    nrow = d_ai[rowa + 1] - ai;
679                int found;
680                Kokkos::parallel_reduce
681                  (Kokkos::ThreadVectorRange (team, nrow),
682                   [=] (const int& j, int &count) {
683                     const PetscInt colb = r[aj[j]];
684                     if (colb==rowb) {
685                       d_idiag[rowb] = 1./aa[j];
686                       count++;
687                     }}, found);
688                if (found!=1) Kokkos::single (Kokkos::PerThread (team), [=] () {printf("ERRORrow %d) found = %d\n",rowb,found);});
689              });
690         });
691     }
692   }
693   PetscFunctionReturn(0);
694 }
695 
696 /* Default destroy, if it has never been setup */
697 static PetscErrorCode PCReset_BJKOKKOS(PC pc)
698 {
699   PC_PCBJKOKKOS   *jac = (PC_PCBJKOKKOS*)pc->data;
700   PetscErrorCode ierr;
701 
702   PetscFunctionBegin;
703   ierr = KSPDestroy(&jac->ksp);CHKERRQ(ierr);
704   ierr = VecDestroy(&jac->vec_diag);CHKERRQ(ierr);
705   if (jac->d_bid_eqOffset_k) delete jac->d_bid_eqOffset_k;
706   if (jac->d_idiag_k) delete jac->d_idiag_k;
707   if (jac->d_isrow_k) delete jac->d_isrow_k;
708   if (jac->d_isicol_k) delete jac->d_isicol_k;
709   jac->d_bid_eqOffset_k = NULL;
710   jac->d_idiag_k = NULL;
711   jac->d_isrow_k = NULL;
712   jac->d_isicol_k = NULL;
713   ierr = PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSGetKSP_C",NULL);CHKERRQ(ierr); // not published now (causes configure errors)
714   ierr = PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSSetKSP_C",NULL);CHKERRQ(ierr);
715   ierr = PetscFree(jac->dm_Nf);CHKERRQ(ierr);
716   jac->dm_Nf = NULL;
717   PetscFunctionReturn(0);
718 }
719 
720 static PetscErrorCode PCDestroy_BJKOKKOS(PC pc)
721 {
722   PetscErrorCode ierr;
723 
724   PetscFunctionBegin;
725   ierr = PCReset_BJKOKKOS(pc);CHKERRQ(ierr);
726   ierr = PetscFree(pc->data);CHKERRQ(ierr);
727   PetscFunctionReturn(0);
728 }
729 
730 static PetscErrorCode PCView_BJKOKKOS(PC pc,PetscViewer viewer)
731 {
732   PC_PCBJKOKKOS   *jac = (PC_PCBJKOKKOS*)pc->data;
733   PetscErrorCode ierr;
734   PetscBool      iascii;
735 
736   PetscFunctionBegin;
737   if (!jac->ksp) {ierr = PCBJKOKKOSCreateKSP_BJKOKKOS(pc);CHKERRQ(ierr);}
738   ierr = PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&iascii);CHKERRQ(ierr);
739   if (iascii) {
740     ierr = PetscViewerASCIIPrintf(viewer,"  Batched device linear solver: Krylov (KSP) method with Jacobi preconditioning\n");CHKERRQ(ierr);
741     ierr = PetscViewerASCIIPrintf(viewer,"\t\tnwork = %D, rel tol = %e, abs tol = %e, div tol = %e, max it =%D, type = %s\n",jac->nwork,jac->ksp->rtol,
742                                   jac->ksp->abstol, jac->ksp->divtol, jac->ksp->max_it,
743                                   ((PetscObject)jac->ksp)->type_name);CHKERRQ(ierr);
744   }
745   PetscFunctionReturn(0);
746 }
747 
748 static PetscErrorCode PCSetFromOptions_BJKOKKOS(PetscOptionItems *PetscOptionsObject,PC pc)
749 {
750   PetscErrorCode ierr;
751 
752   PetscFunctionBegin;
753   ierr = PetscOptionsHead(PetscOptionsObject,"PC BJKOKKOS options");CHKERRQ(ierr);
754   ierr = PetscOptionsTail();CHKERRQ(ierr);
755   PetscFunctionReturn(0);
756 }
757 
758 static PetscErrorCode  PCBJKOKKOSSetKSP_BJKOKKOS(PC pc,KSP ksp)
759 {
760   PC_PCBJKOKKOS         *jac = (PC_PCBJKOKKOS*)pc->data;
761   PetscErrorCode ierr;
762 
763   PetscFunctionBegin;
764   ierr = PetscObjectReference((PetscObject)ksp);CHKERRQ(ierr);
765   ierr = KSPDestroy(&jac->ksp);CHKERRQ(ierr);
766   jac->ksp = ksp;
767   PetscFunctionReturn(0);
768 }
769 
770 /*@C
771    PCBJKOKKOSSetKSP - Sets the KSP context for a KSP PC.
772 
773    Collective on PC
774 
775    Input Parameters:
776 +  pc - the preconditioner context
777 -  ksp - the KSP solver
778 
779    Notes:
780    The PC and the KSP must have the same communicator
781 
782    Level: advanced
783 
784 @*/
785 PetscErrorCode  PCBJKOKKOSSetKSP(PC pc,KSP ksp)
786 {
787   PetscErrorCode ierr;
788 
789   PetscFunctionBegin;
790   PetscValidHeaderSpecific(pc,PC_CLASSID,1);
791   PetscValidHeaderSpecific(ksp,KSP_CLASSID,2);
792   PetscCheckSameComm(pc,1,ksp,2);
793   ierr = PetscTryMethod(pc,"PCBJKOKKOSSetKSP_C",(PC,KSP),(pc,ksp));CHKERRQ(ierr);
794   PetscFunctionReturn(0);
795 }
796 
797 static PetscErrorCode  PCBJKOKKOSGetKSP_BJKOKKOS(PC pc,KSP *ksp)
798 {
799   PC_PCBJKOKKOS         *jac = (PC_PCBJKOKKOS*)pc->data;
800   PetscErrorCode ierr;
801 
802   PetscFunctionBegin;
803   if (!jac->ksp) {ierr = PCBJKOKKOSCreateKSP_BJKOKKOS(pc);CHKERRQ(ierr);}
804   *ksp = jac->ksp;
805   PetscFunctionReturn(0);
806 }
807 
808 /*@C
809    PCBJKOKKOSGetKSP - Gets the KSP context for a KSP PC.
810 
811    Not Collective but KSP returned is parallel if PC was parallel
812 
813    Input Parameter:
814 .  pc - the preconditioner context
815 
816    Output Parameters:
817 .  ksp - the KSP solver
818 
819    Notes:
820    You must call KSPSetUp() before calling PCBJKOKKOSGetKSP().
821 
822    If the PC is not a PCBJKOKKOS object it raises an error
823 
824    Level: advanced
825 
826 @*/
827 PetscErrorCode  PCBJKOKKOSGetKSP(PC pc,KSP *ksp)
828 {
829   PetscErrorCode ierr;
830 
831   PetscFunctionBegin;
832   PetscValidHeaderSpecific(pc,PC_CLASSID,1);
833   PetscValidPointer(ksp,2);
834   ierr = PetscUseMethod(pc,"PCBJKOKKOSGetKSP_C",(PC,KSP*),(pc,ksp));CHKERRQ(ierr);
835   PetscFunctionReturn(0);
836 }
837 
838 /* ----------------------------------------------------------------------------------*/
839 
840 /*MC
841      PCBJKOKKOS -  Defines a preconditioner that applies a Krylov solver and preconditioner to the blocks in a AIJASeq matrix on the GPU.
842 
843    Options Database Key:
844 .     -pc_bjkokkos_
845 
846    Level: intermediate
847 
848    Notes:
849     For use with -ksp_type preonly to bypass any CPU work
850 
851    Developer Notes:
852 
853 .seealso:  PCCreate(), PCSetType(), PCType (for list of available types), PC,
854            PCSHELL, PCCOMPOSITE, PCSetUseAmat(), PCBJKOKKOSGetKSP()
855 
856 M*/
857 
858 PETSC_EXTERN PetscErrorCode PCCreate_BJKOKKOS(PC pc)
859 {
860   PetscErrorCode ierr;
861   PC_PCBJKOKKOS   *jac;
862 
863   PetscFunctionBegin;
864   ierr = PetscNewLog(pc,&jac);CHKERRQ(ierr);
865   pc->data = (void*)jac;
866 
867   jac->ksp = NULL;
868   jac->vec_diag = NULL;
869   jac->d_bid_eqOffset_k = NULL;
870   jac->d_idiag_k = NULL;
871   jac->d_isrow_k = NULL;
872   jac->d_isicol_k = NULL;
873   jac->nBlocks = 1;
874 
875   ierr = PetscMemzero(pc->ops,sizeof(struct _PCOps));CHKERRQ(ierr);
876   pc->ops->apply           = PCApply_BJKOKKOS;
877   pc->ops->applytranspose  = NULL;
878   pc->ops->setup           = PCSetUp_BJKOKKOS;
879   pc->ops->reset           = PCReset_BJKOKKOS;
880   pc->ops->destroy         = PCDestroy_BJKOKKOS;
881   pc->ops->setfromoptions  = PCSetFromOptions_BJKOKKOS;
882   pc->ops->view            = PCView_BJKOKKOS;
883 
884   ierr = PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSGetKSP_C",PCBJKOKKOSGetKSP_BJKOKKOS);CHKERRQ(ierr);
885   ierr = PetscObjectComposeFunction((PetscObject)pc,"PCBJKOKKOSSetKSP_C",PCBJKOKKOSSetKSP_BJKOKKOS);CHKERRQ(ierr);
886   PetscFunctionReturn(0);
887 }
888