xref: /petsc/src/ksp/pc/impls/bjacobi/bjkokkos/bjkokkos.kokkos.cxx (revision a4e35b1925eceef64945ea472b84f2bf06a67b5e)
1 #include <petsc/private/pcbjkokkosimpl.h>
2 
3 #include <petsc/private/kspimpl.h>
4 #include <petscksp.h> /*I "petscksp.h" I*/
5 #include <../src/mat/impls/aij/mpi/mpiaij.h>
6 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
7 #include "petscsection.h"
8 #include <petscdmcomposite.h>
9 
10 #include <../src/mat/impls/aij/seq/aij.h>
11 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
12 
13 #include <petscdevice_cupm.h>
14 
15 static PetscErrorCode PCBJKOKKOSCreateKSP_BJKOKKOS(PC pc)
16 {
17   const char    *prefix;
18   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
19   DM             dm;
20 
21   PetscFunctionBegin;
22   PetscCall(KSPCreate(PetscObjectComm((PetscObject)pc), &jac->ksp));
23   PetscCall(KSPSetNestLevel(jac->ksp, pc->kspnestlevel));
24   PetscCall(KSPSetErrorIfNotConverged(jac->ksp, pc->erroriffailure));
25   PetscCall(PetscObjectIncrementTabLevel((PetscObject)jac->ksp, (PetscObject)pc, 1));
26   PetscCall(PCGetOptionsPrefix(pc, &prefix));
27   PetscCall(KSPSetOptionsPrefix(jac->ksp, prefix));
28   PetscCall(KSPAppendOptionsPrefix(jac->ksp, "pc_bjkokkos_"));
29   PetscCall(PCGetDM(pc, &dm));
30   if (dm) {
31     PetscCall(KSPSetDM(jac->ksp, dm));
32     PetscCall(KSPSetDMActive(jac->ksp, PETSC_FALSE));
33   }
34   jac->reason       = PETSC_FALSE;
35   jac->monitor      = PETSC_FALSE;
36   jac->batch_target = 0;
37   jac->rank_target  = 0;
38   jac->nsolves_team = 1;
39   jac->ksp->max_it  = 50; // this is really for GMRES w/o restarts
40   PetscFunctionReturn(PETSC_SUCCESS);
41 }
42 
43 // y <-- Ax
44 KOKKOS_INLINE_FUNCTION PetscErrorCode MatMult(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
45 {
46   Kokkos::parallel_for(Kokkos::TeamThreadRange(team, start, end), [=](const int rowb) {
47     int                rowa = ic[rowb];
48     int                n    = glb_Aai[rowa + 1] - glb_Aai[rowa];
49     const PetscInt    *aj   = glb_Aaj + glb_Aai[rowa]; // global
50     const PetscScalar *aa   = glb_Aaa + glb_Aai[rowa];
51     PetscScalar        sum;
52     Kokkos::parallel_reduce(
53       Kokkos::ThreadVectorRange(team, n), [=](const int i, PetscScalar &lsum) { lsum += aa[i] * x_loc[r[aj[i]] - start]; }, sum);
54     Kokkos::single(Kokkos::PerThread(team), [=]() { y_loc[rowb - start] = sum; });
55   });
56   team.team_barrier();
57   return PETSC_SUCCESS;
58 }
59 
60 // temp buffer per thread with reduction at end?
61 KOKKOS_INLINE_FUNCTION PetscErrorCode MatMultTranspose(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
62 {
63   Kokkos::parallel_for(Kokkos::TeamVectorRange(team, end - start), [=](int i) { y_loc[i] = 0; });
64   team.team_barrier();
65   Kokkos::parallel_for(Kokkos::TeamThreadRange(team, start, end), [=](const int rowb) {
66     int                rowa = ic[rowb];
67     int                n    = glb_Aai[rowa + 1] - glb_Aai[rowa];
68     const PetscInt    *aj   = glb_Aaj + glb_Aai[rowa]; // global
69     const PetscScalar *aa   = glb_Aaa + glb_Aai[rowa];
70     const PetscScalar  xx   = x_loc[rowb - start]; // rowb = ic[rowa] = ic[r[rowb]]
71     Kokkos::parallel_for(Kokkos::ThreadVectorRange(team, n), [=](const int &i) {
72       PetscScalar val = aa[i] * xx;
73       Kokkos::atomic_fetch_add(&y_loc[r[aj[i]] - start], val);
74     });
75   });
76   team.team_barrier();
77   return PETSC_SUCCESS;
78 }
79 
80 typedef struct Batch_MetaData_TAG {
81   PetscInt           flops;
82   PetscInt           its;
83   KSPConvergedReason reason;
84 } Batch_MetaData;
85 
86 // Solve A(BB^-1)x = y with TFQMR. Right preconditioned to get un-preconditioned residual
87 KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_TFQMR(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space_global, const int stride_global, const int nShareVec, PetscScalar *work_space_shared, const int stride_shared, PetscReal rtol, PetscReal atol, PetscReal dtol, PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
88 {
89   using Kokkos::parallel_for;
90   using Kokkos::parallel_reduce;
91   int                Nblk = end - start, it, m, stride = stride_shared, idx = 0;
92   PetscReal          dp, dpold, w, dpest, tau, psi, cm, r0;
93   const PetscScalar *Diag = &glb_idiag[start];
94   PetscScalar       *ptr  = work_space_shared, rho, rhoold, a, s, b, eta, etaold, psiold, cf, dpi;
95 
96   if (idx++ == nShareVec) {
97     ptr    = work_space_global;
98     stride = stride_global;
99   }
100   PetscScalar *XX = ptr;
101   ptr += stride;
102   if (idx++ == nShareVec) {
103     ptr    = work_space_global;
104     stride = stride_global;
105   }
106   PetscScalar *R = ptr;
107   ptr += stride;
108   if (idx++ == nShareVec) {
109     ptr    = work_space_global;
110     stride = stride_global;
111   }
112   PetscScalar *RP = ptr;
113   ptr += stride;
114   if (idx++ == nShareVec) {
115     ptr    = work_space_global;
116     stride = stride_global;
117   }
118   PetscScalar *V = ptr;
119   ptr += stride;
120   if (idx++ == nShareVec) {
121     ptr    = work_space_global;
122     stride = stride_global;
123   }
124   PetscScalar *T = ptr;
125   ptr += stride;
126   if (idx++ == nShareVec) {
127     ptr    = work_space_global;
128     stride = stride_global;
129   }
130   PetscScalar *Q = ptr;
131   ptr += stride;
132   if (idx++ == nShareVec) {
133     ptr    = work_space_global;
134     stride = stride_global;
135   }
136   PetscScalar *P = ptr;
137   ptr += stride;
138   if (idx++ == nShareVec) {
139     ptr    = work_space_global;
140     stride = stride_global;
141   }
142   PetscScalar *U = ptr;
143   ptr += stride;
144   if (idx++ == nShareVec) {
145     ptr    = work_space_global;
146     stride = stride_global;
147   }
148   PetscScalar *D = ptr;
149   ptr += stride;
150   if (idx++ == nShareVec) {
151     ptr    = work_space_global;
152     stride = stride_global;
153   }
154   PetscScalar *AUQ = V;
155 
156   // init: get b, zero x
157   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
158     int rowa         = ic[rowb];
159     R[rowb - start]  = glb_b[rowa];
160     XX[rowb - start] = 0;
161   });
162   team.team_barrier();
163   parallel_reduce(
164     Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += R[idx] * PetscConj(R[idx]); }, dpi);
165   team.team_barrier();
166   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
167   // diagnostics
168 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
169   if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e \n", 0, (double)dp); });
170 #endif
171   if (dp < atol) {
172     metad->reason = KSP_CONVERGED_ATOL_NORMAL;
173     it            = 0;
174     goto done;
175   }
176   if (0 == maxit) {
177     metad->reason = KSP_CONVERGED_ITS;
178     it            = 0;
179     goto done;
180   }
181 
182   /* Make the initial Rp = R */
183   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { RP[idx] = R[idx]; });
184   team.team_barrier();
185   /* Set the initial conditions */
186   etaold = 0.0;
187   psiold = 0.0;
188   tau    = dp;
189   dpold  = dp;
190 
191   /* rhoold = (r,rp)     */
192   parallel_reduce(
193     Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += R[idx] * PetscConj(RP[idx]); }, rhoold);
194   team.team_barrier();
195   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
196     U[idx] = R[idx];
197     P[idx] = R[idx];
198     T[idx] = Diag[idx] * P[idx];
199     D[idx] = 0;
200   });
201   team.team_barrier();
202   static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, V));
203 
204   it = 0;
205   do {
206     /* s <- (v,rp)          */
207     parallel_reduce(
208       Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += V[idx] * PetscConj(RP[idx]); }, s);
209     team.team_barrier();
210     if (s == 0) {
211       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
212       goto done;
213     }
214     a = rhoold / s; /* a <- rho / s         */
215     /* q <- u - a v    VecWAXPY(w,alpha,x,y): w = alpha x + y.     */
216     /* t <- u + q           */
217     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
218       Q[idx] = U[idx] - a * V[idx];
219       T[idx] = U[idx] + Q[idx];
220     });
221     team.team_barrier();
222     // KSP_PCApplyBAorAB
223     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { T[idx] = Diag[idx] * T[idx]; });
224     team.team_barrier();
225     static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, AUQ));
226     /* r <- r - a K (u + q) */
227     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { R[idx] = R[idx] - a * AUQ[idx]; });
228     team.team_barrier();
229     parallel_reduce(
230       Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += R[idx] * PetscConj(R[idx]); }, dpi);
231     team.team_barrier();
232     dp = PetscSqrtReal(PetscRealPart(dpi));
233     for (m = 0; m < 2; m++) {
234       if (!m) w = PetscSqrtReal(dp * dpold);
235       else w = dp;
236       psi = w / tau;
237       cm  = 1.0 / PetscSqrtReal(1.0 + psi * psi);
238       tau = tau * psi * cm;
239       eta = cm * cm * a;
240       cf  = psiold * psiold * etaold / a;
241       if (!m) {
242         /* D = U + cf D */
243         parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { D[idx] = U[idx] + cf * D[idx]; });
244       } else {
245         /* D = Q + cf D */
246         parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { D[idx] = Q[idx] + cf * D[idx]; });
247       }
248       team.team_barrier();
249       parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { XX[idx] = XX[idx] + eta * D[idx]; });
250       team.team_barrier();
251       dpest = PetscSqrtReal(2 * it + m + 2.0) * tau;
252 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
253       if (monitor && m == 1) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e \n", it + 1, (double)dpest); });
254 #endif
255       if (dpest < atol) {
256         metad->reason = KSP_CONVERGED_ATOL_NORMAL;
257         goto done;
258       }
259       if (dpest / r0 < rtol) {
260         metad->reason = KSP_CONVERGED_RTOL_NORMAL;
261         goto done;
262       }
263 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
264       if (dpest / r0 > dtol) {
265         metad->reason = KSP_DIVERGED_DTOL;
266         Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n", team.league_rank(), it, dpest, r0); });
267         goto done;
268       }
269 #else
270       if (dpest / r0 > dtol) {
271         metad->reason = KSP_DIVERGED_DTOL;
272         goto done;
273       }
274 #endif
275       if (it + 1 == maxit) {
276         metad->reason = KSP_CONVERGED_ITS;
277 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
278         Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: TFQMR %d:%d it, res=%e, r_0=%e r_res=%e\n", team.league_rank(), it, m, dpest, r0, dpest / r0); });
279 #endif
280         goto done;
281       }
282       etaold = eta;
283       psiold = psi;
284     }
285 
286     /* rho <- (r,rp)       */
287     parallel_reduce(
288       Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += R[idx] * PetscConj(RP[idx]); }, rho);
289     team.team_barrier();
290     if (rho == 0) {
291       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
292       goto done;
293     }
294     b = rho / rhoold; /* b <- rho / rhoold   */
295     /* u <- r + b q        */
296     /* p <- u + b(q + b p) */
297     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
298       U[idx] = R[idx] + b * Q[idx];
299       Q[idx] = Q[idx] + b * P[idx];
300       P[idx] = U[idx] + b * Q[idx];
301     });
302     /* v <- K p  */
303     team.team_barrier();
304     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { T[idx] = Diag[idx] * P[idx]; });
305     team.team_barrier();
306     static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, V));
307 
308     rhoold = rho;
309     dpold  = dp;
310 
311     it++;
312   } while (it < maxit);
313 done:
314   // KSPUnwindPreconditioner
315   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { XX[idx] = Diag[idx] * XX[idx]; });
316   team.team_barrier();
317   // put x into Plex order
318   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
319     int rowa    = ic[rowb];
320     glb_x[rowa] = XX[rowb - start];
321   });
322   metad->its = it;
323   if (1) {
324     int nnz;
325     parallel_reduce(
326       Kokkos::TeamVectorRange(team, start, end), [=](const int idx, int &lsum) { lsum += (glb_Aai[idx + 1] - glb_Aai[idx]); }, nnz);
327     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * nnz) + 5 * Nblk);
328   } else {
329     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * 50 * Nblk) + 5 * Nblk); // guess
330   }
331   return PETSC_SUCCESS;
332 }
333 
334 // Solve Ax = y with biCG
335 KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_BICG(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space_global, const int stride_global, const int nShareVec, PetscScalar *work_space_shared, const int stride_shared, PetscReal rtol, PetscReal atol, PetscReal dtol, PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
336 {
337   using Kokkos::parallel_for;
338   using Kokkos::parallel_reduce;
339   int                Nblk = end - start, it, stride = stride_shared, idx = 0; // start in shared mem
340   PetscReal          dp, r0;
341   const PetscScalar *Di  = &glb_idiag[start];
342   PetscScalar       *ptr = work_space_shared, dpi, a = 1.0, beta, betaold = 1.0, t1, t2;
343 
344   if (idx++ == nShareVec) {
345     ptr    = work_space_global;
346     stride = stride_global;
347   }
348   PetscScalar *XX = ptr;
349   ptr += stride;
350   if (idx++ == nShareVec) {
351     ptr    = work_space_global;
352     stride = stride_global;
353   }
354   PetscScalar *Rl = ptr;
355   ptr += stride;
356   if (idx++ == nShareVec) {
357     ptr    = work_space_global;
358     stride = stride_global;
359   }
360   PetscScalar *Zl = ptr;
361   ptr += stride;
362   if (idx++ == nShareVec) {
363     ptr    = work_space_global;
364     stride = stride_global;
365   }
366   PetscScalar *Pl = ptr;
367   ptr += stride;
368   if (idx++ == nShareVec) {
369     ptr    = work_space_global;
370     stride = stride_global;
371   }
372   PetscScalar *Rr = ptr;
373   ptr += stride;
374   if (idx++ == nShareVec) {
375     ptr    = work_space_global;
376     stride = stride_global;
377   }
378   PetscScalar *Zr = ptr;
379   ptr += stride;
380   if (idx++ == nShareVec) {
381     ptr    = work_space_global;
382     stride = stride_global;
383   }
384   PetscScalar *Pr = ptr;
385   ptr += stride;
386 
387   /*     r <- b (x is 0) */
388   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
389     int rowa         = ic[rowb];
390     Rl[rowb - start] = Rr[rowb - start] = glb_b[rowa];
391     XX[rowb - start]                    = 0;
392   });
393   team.team_barrier();
394   /*     z <- Br         */
395   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
396     Zr[idx] = Di[idx] * Rr[idx];
397     Zl[idx] = Di[idx] * Rl[idx];
398   });
399   team.team_barrier();
400   /*    dp <- r'*r       */
401   parallel_reduce(
402     Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Rr[idx] * PetscConj(Rr[idx]); }, dpi);
403   team.team_barrier();
404   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
405 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
406   if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e \n", 0, (double)dp); });
407 #endif
408   if (dp < atol) {
409     metad->reason = KSP_CONVERGED_ATOL_NORMAL;
410     it            = 0;
411     goto done;
412   }
413   if (0 == maxit) {
414     metad->reason = KSP_CONVERGED_ITS;
415     it            = 0;
416     goto done;
417   }
418 
419   it = 0;
420   do {
421     /*     beta <- r'z     */
422     parallel_reduce(
423       Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += Zr[idx] * PetscConj(Rl[idx]); }, beta);
424     team.team_barrier();
425 #if PCBJKOKKOS_VERBOSE_LEVEL >= 6
426   #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
427     Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%7d beta = Z.R = %22.14e \n", i, (double)beta); });
428   #endif
429 #endif
430     if (beta == 0.0) {
431       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
432       goto done;
433     }
434     if (it == 0) {
435       /*     p <- z          */
436       parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
437         Pr[idx] = Zr[idx];
438         Pl[idx] = Zl[idx];
439       });
440     } else {
441       t1 = beta / betaold;
442       /*     p <- z + b* p   */
443       t2 = PetscConj(t1);
444       parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
445         Pr[idx] = t1 * Pr[idx] + Zr[idx];
446         Pl[idx] = t2 * Pl[idx] + Zl[idx];
447       });
448     }
449     team.team_barrier();
450     betaold = beta;
451     /*     z <- Kp         */
452     static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, Pr, Zr));
453     static_cast<void>(MatMultTranspose(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, Pl, Zl));
454     /*     dpi <- z'p      */
455     parallel_reduce(
456       Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Zr[idx] * PetscConj(Pl[idx]); }, dpi);
457     team.team_barrier();
458     if (dpi == 0) {
459       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
460       goto done;
461     }
462     //
463     a  = beta / dpi; /*     a = beta/p'z    */
464     t1 = -a;
465     t2 = PetscConj(t1);
466     /*     x <- x + ap     */
467     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
468       XX[idx] = XX[idx] + a * Pr[idx];
469       Rr[idx] = Rr[idx] + t1 * Zr[idx];
470       Rl[idx] = Rl[idx] + t2 * Zl[idx];
471     });
472     team.team_barrier();
473     team.team_barrier();
474     /*    dp <- r'*r       */
475     parallel_reduce(
476       Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Rr[idx] * PetscConj(Rr[idx]); }, dpi);
477     team.team_barrier();
478     dp = PetscSqrtReal(PetscRealPart(dpi));
479 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
480     if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e \n", it + 1, (double)dp); });
481 #endif
482     if (dp < atol) {
483       metad->reason = KSP_CONVERGED_ATOL_NORMAL;
484       goto done;
485     }
486     if (dp / r0 < rtol) {
487       metad->reason = KSP_CONVERGED_RTOL_NORMAL;
488       goto done;
489     }
490 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
491     if (dp / r0 > dtol) {
492       metad->reason = KSP_DIVERGED_DTOL;
493       Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: %d it, res=%e, r_0=%e (BICG does this)\n", team.league_rank(), it, dp, r0); });
494       goto done;
495     }
496 #else
497     if (dp / r0 > dtol) {
498       metad->reason = KSP_DIVERGED_DTOL;
499       goto done;
500     }
501 #endif
502     if (it + 1 == maxit) {
503       metad->reason = KSP_CONVERGED_ITS; // don't worry about hitting max iterations
504 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
505       Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: BICG %d it, res=%e, r_0=%e r_res=%e\n", team.league_rank(), it, dp, r0, dp / r0); });
506 #endif
507       goto done;
508     }
509     /* z <- Br  */
510     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
511       Zr[idx] = Di[idx] * Rr[idx];
512       Zl[idx] = Di[idx] * Rl[idx];
513     });
514 
515     it++;
516   } while (it < maxit);
517 done:
518   // put x back into Plex order
519   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
520     int rowa    = ic[rowb];
521     glb_x[rowa] = XX[rowb - start];
522   });
523   metad->its = it;
524   if (1) {
525     int nnz;
526     parallel_reduce(
527       Kokkos::TeamVectorRange(team, start, end), [=](const int idx, int &lsum) { lsum += (glb_Aai[idx + 1] - glb_Aai[idx]); }, nnz);
528     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * nnz) + 5 * Nblk);
529   } else {
530     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * 50 * Nblk) + 5 * Nblk); // guess
531   }
532   return PETSC_SUCCESS;
533 }
534 
535 // KSP solver solve Ax = b; xout is output, bin is input
536 static PetscErrorCode PCApply_BJKOKKOS(PC pc, Vec bin, Vec xout)
537 {
538   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
539   Mat            A = pc->pmat, Aseq = A;
540   PetscMPIInt    rank;
541 
542   PetscFunctionBegin;
543   PetscCallMPI(MPI_Comm_rank(PetscObjectComm((PetscObject)A), &rank));
544   if (!A->spptr) {
545     Aseq = ((Mat_MPIAIJ *)A->data)->A; // MPI
546   }
547   PetscCall(MatSeqAIJKokkosSyncDevice(Aseq));
548   {
549     PetscInt           maxit = jac->ksp->max_it;
550     const PetscInt     conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp == 0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
551     const PetscInt     nwork = jac->nwork, nBlk = jac->nBlocks;
552     PetscScalar       *glb_xdata = NULL, *dummy;
553     PetscReal          rtol = jac->ksp->rtol, atol = jac->ksp->abstol, dtol = jac->ksp->divtol;
554     const PetscScalar *glb_idiag = jac->d_idiag_k->data(), *glb_bdata = NULL;
555     const PetscInt    *glb_Aai, *glb_Aaj, *d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
556     const PetscScalar *glb_Aaa;
557     const PetscInt    *d_isicol = jac->d_isicol_k->data(), *d_isrow = jac->d_isrow_k->data();
558     PCFailedReason     pcreason;
559     KSPIndex           ksp_type_idx = jac->ksp_type_idx;
560     PetscMemType       mtype;
561     PetscContainer     container;
562     PetscInt           batch_sz;                // the number of repeated DMs, [DM_e_1, DM_e_2, DM_e_batch_sz, DM_i_1, ...]
563     VecScatter         plex_batch = NULL;       // not used
564     Vec                bvec;                    // a copy of b for scatter (just alias to bin now)
565     PetscBool          monitor  = jac->monitor; // captured
566     PetscInt           view_bid = jac->batch_target;
567     MatInfo            info;
568 
569     PetscCall(MatSeqAIJGetCSRAndMemType(Aseq, &glb_Aai, &glb_Aaj, &dummy, &mtype));
570     jac->max_nits = 0;
571     glb_Aaa       = dummy;
572     if (jac->rank_target != rank) view_bid = -1; // turn off all but one process
573     PetscCall(MatGetInfo(A, MAT_LOCAL, &info));
574     // get field major is to map plex IO to/from block/field major
575     PetscCall(PetscObjectQuery((PetscObject)A, "plex_batch_is", (PetscObject *)&container));
576     if (container) {
577       PetscCall(VecDuplicate(bin, &bvec));
578       PetscCall(PetscContainerGetPointer(container, (void **)&plex_batch));
579       PetscCall(VecScatterBegin(plex_batch, bin, bvec, INSERT_VALUES, SCATTER_FORWARD));
580       PetscCall(VecScatterEnd(plex_batch, bin, bvec, INSERT_VALUES, SCATTER_FORWARD));
581       SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_USER, "No plex_batch_is -- require NO field major ordering for now");
582     } else {
583       bvec = bin;
584     }
585     // get x
586     PetscCall(VecGetArrayAndMemType(xout, &glb_xdata, &mtype));
587 #if defined(PETSC_HAVE_CUDA)
588     PetscCheck(PetscMemTypeDevice(mtype), PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "No GPU data for x %" PetscInt_FMT " != %" PetscInt_FMT, mtype, PETSC_MEMTYPE_DEVICE);
589 #endif
590     PetscCall(VecGetArrayReadAndMemType(bvec, &glb_bdata, &mtype));
591 #if defined(PETSC_HAVE_CUDA)
592     PetscCheck(PetscMemTypeDevice(mtype), PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "No GPU data for b");
593 #endif
594     // get batch size
595     PetscCall(PetscObjectQuery((PetscObject)A, "batch size", (PetscObject *)&container));
596     if (container) {
597       PetscInt *pNf = NULL;
598       PetscCall(PetscContainerGetPointer(container, (void **)&pNf));
599       batch_sz = *pNf; // number of times to repeat the DMs
600     } else batch_sz = 1;
601     PetscCheck(nBlk % batch_sz == 0, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "batch_sz = %" PetscInt_FMT ", nBlk = %" PetscInt_FMT, batch_sz, nBlk);
602     if (ksp_type_idx == BATCH_KSP_GMRESKK_IDX) {
603       // KK solver - move PETSc data into Kokkos Views, setup solver, solve, move data out of Kokkos, process metadata (convergence tests, etc.)
604 #if defined(PETSC_HAVE_KOKKOS_KERNELS_BATCH)
605       PetscCall(PCApply_BJKOKKOSKERNELS(pc, glb_bdata, glb_xdata, glb_Aai, glb_Aaj, glb_Aaa, team_size, info, batch_sz, &pcreason));
606 #else
607       PetscCheck(ksp_type_idx != BATCH_KSP_GMRESKK_IDX, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Type: BATCH_KSP_GMRES not supported for complex\n");
608 #endif
609     } else { // Kokkos Krylov
610       using scr_mem_t    = Kokkos::DefaultExecutionSpace::scratch_memory_space;
611       using vect2D_scr_t = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, scr_mem_t>;
612       Kokkos::View<Batch_MetaData *, Kokkos::DefaultExecutionSpace> d_metadata("solver meta data", nBlk);
613       int                                                           stride_shared, stride_global, global_buff_words;
614       d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
615       // solve each block independently
616       int scr_bytes_team_shared = 0, nShareVec = 0, nGlobBVec = 0;
617       if (jac->const_block_size) { // use shared memory for work vectors only if constant block size - todo: test efficiency loss
618         size_t      maximum_shared_mem_size = 64000;
619         PetscDevice device;
620         PetscCall(PetscDeviceGetDefault_Internal(&device));
621         PetscCall(PetscDeviceGetAttribute(device, PETSC_DEVICE_ATTR_SIZE_T_SHARED_MEM_PER_BLOCK, &maximum_shared_mem_size));
622         stride_shared = jac->const_block_size;                                                   // captured
623         nShareVec     = maximum_shared_mem_size / (jac->const_block_size * sizeof(PetscScalar)); // integer floor, number of vectors that fit in shared
624         if (nShareVec > nwork) nShareVec = nwork;
625         else nGlobBVec = nwork - nShareVec;
626         global_buff_words     = jac->n * nGlobBVec;
627         scr_bytes_team_shared = jac->const_block_size * nShareVec * sizeof(PetscScalar);
628       } else {
629         scr_bytes_team_shared = 0;
630         stride_shared         = 0;
631         global_buff_words     = jac->n * nwork;
632         nGlobBVec             = nwork; // not needed == fix
633       }
634       stride_global = jac->n; // captured
635 #if defined(PETSC_HAVE_CUDA)
636       nvtxRangePushA("batch-kokkos-solve");
637 #endif
638       Kokkos::View<PetscScalar *, Kokkos::DefaultExecutionSpace> d_work_vecs_k("workvectors", global_buff_words); // global work vectors
639 #if PCBJKOKKOS_VERBOSE_LEVEL > 1
640       PetscCall(PetscInfo(pc, "\tn = %d. %d shared bytes/team, %d global mem bytes, rtol=%e, num blocks %d, team_size=%d, %d vector threads, %d shared vectors, %d global vectors\n", (int)jac->n, scr_bytes_team_shared, global_buff_words, rtol, (int)nBlk, (int)team_size, PCBJKOKKOS_VEC_SIZE, nShareVec, nGlobBVec));
641 #endif
642       PetscScalar *d_work_vecs = d_work_vecs_k.data();
643       Kokkos::parallel_for(
644         "Solve", Kokkos::TeamPolicy<Kokkos::LaunchBounds<256, 4>>(nBlk, team_size, PCBJKOKKOS_VEC_SIZE).set_scratch_size(PCBJKOKKOS_SHARED_LEVEL, Kokkos::PerTeam(scr_bytes_team_shared)), KOKKOS_LAMBDA(const team_member team) {
645           const int    blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID + 1];
646           vect2D_scr_t work_vecs_shared(team.team_scratch(PCBJKOKKOS_SHARED_LEVEL), end - start, nShareVec);
647           PetscScalar *work_buff_shared = work_vecs_shared.data();
648           PetscScalar *work_buff_global = &d_work_vecs[start]; // start inc'ed in
649           bool         print            = monitor && (blkID == view_bid);
650           switch (ksp_type_idx) {
651           case BATCH_KSP_BICG_IDX:
652             static_cast<void>(BJSolve_BICG(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff_global, stride_global, nShareVec, work_buff_shared, stride_shared, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print));
653             break;
654           case BATCH_KSP_TFQMR_IDX:
655             static_cast<void>(BJSolve_TFQMR(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff_global, stride_global, nShareVec, work_buff_shared, stride_shared, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print));
656             break;
657           default:
658 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
659             printf("Unknown KSP type %d\n", ksp_type_idx);
660 #else
661             /* void */;
662 #endif
663           }
664         });
665       Kokkos::fence();
666 #if defined(PETSC_HAVE_CUDA)
667       nvtxRangePop();
668       nvtxRangePushA("Post-solve-metadata");
669 #endif
670       auto h_metadata = Kokkos::create_mirror(Kokkos::HostSpace::memory_space(), d_metadata);
671       Kokkos::deep_copy(h_metadata, d_metadata);
672       PetscInt count = -1, mbid = 0;
673       int      in[2], out[2];
674       if (jac->reason) { // -pc_bjkokkos_ksp_converged_reason
675 #if PCBJKOKKOS_VERBOSE_LEVEL >= 3
676   #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
677         PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Iterations\n"));
678   #endif
679         // assume species major
680   #if PCBJKOKKOS_VERBOSE_LEVEL < 4
681         if (batch_sz != 1) PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%s: max iterations per species:", ksp_type_idx == BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr"));
682         else PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "    Linear solve converged due to %s iterations ", ksp_type_idx == BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr"));
683   #endif
684         for (PetscInt dmIdx = 0, head = 0; dmIdx < jac->num_dms; dmIdx += batch_sz) {
685           for (PetscInt f = 0, idx = head; f < jac->dm_Nf[dmIdx]; f++, s++, idx++) {
686   #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
687             PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%2" PetscInt_FMT ":", s));
688             for (int bid = 0; bid < batch_sz; bid++) PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%3" PetscInt_FMT " ", h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its));
689             PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\n"));
690   #else
691             for (int bid = 0; bid < batch_sz; bid++) {
692               if (h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its > count) {
693                 count = h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its;
694                 mbid  = bid;
695               }
696             }
697             PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%3" PetscInt_FMT " ", count));
698   #endif
699           }
700           head += batch_sz * jac->dm_Nf[dmIdx];
701         }
702   #if PCBJKOKKOS_VERBOSE_LEVEL == 3
703         PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\n"));
704   #endif
705 #endif
706         if (count == -1) {
707           for (int blkID = 0; blkID < nBlk; blkID++) {
708             if (h_metadata[blkID].its > count) {
709               jac->max_nits = count = h_metadata[blkID].its;
710               mbid                  = blkID;
711             }
712 #if PCBJKOKKOS_VERBOSE_LEVEL > 0
713             if (h_metadata[blkID].reason < 0) {
714               PetscCall(PetscPrintf(PETSC_COMM_SELF, "ERROR reason=%s, its=%" PetscInt_FMT ". species %" PetscInt_FMT ", batch %" PetscInt_FMT "\n", KSPConvergedReasons[h_metadata[blkID].reason], h_metadata[blkID].its, blkID / batch_sz, blkID % batch_sz));
715             }
716 #endif
717             PetscCall(PetscLogGpuFlops((PetscLogDouble)h_metadata[blkID].flops));
718           }
719         }
720         in[0] = count;
721         in[1] = rank;
722         PetscCallMPI(MPI_Allreduce(in, out, 1, MPI_2INT, MPI_MAXLOC, PetscObjectComm((PetscObject)A)));
723         if (0 == rank) {
724           if (batch_sz != 1)
725             PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] Linear solve converged due to %s iterations %d, batch %" PetscInt_FMT ", species %" PetscInt_FMT " (max)\n", out[1], KSPConvergedReasons[h_metadata[mbid].reason], out[0], mbid % batch_sz, mbid / batch_sz));
726           else PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] Linear solve converged due to %s iterations %d, block %d (max)\n", out[1], KSPConvergedReasons[h_metadata[mbid].reason], out[0], mbid));
727         }
728       }
729       for (int blkID = 0; blkID < nBlk; blkID++) {
730         PetscCall(PetscLogGpuFlops((PetscLogDouble)h_metadata[blkID].flops));
731 #if PCBJKOKKOS_VERBOSE_LEVEL > 0
732         if (h_metadata[blkID].reason < 0) {
733           PetscCall(PetscPrintf(PETSC_COMM_SELF, "ERROR reason=%s, its=%" PetscInt_FMT ". species %" PetscInt_FMT ", batch %" PetscInt_FMT "\n", KSPConvergedReasons[h_metadata[blkID].reason], h_metadata[blkID].its, blkID / batch_sz, blkID % batch_sz));
734         }
735 #endif
736       }
737       {
738         int errsum;
739         Kokkos::parallel_reduce(
740           nBlk,
741           KOKKOS_LAMBDA(const int idx, int &lsum) {
742             if (d_metadata[idx].reason < 0) ++lsum;
743           },
744           errsum);
745         pcreason = errsum ? PC_SUBPC_ERROR : PC_NOERROR;
746         if (!errsum && !jac->max_nits) { // set max its to give back to top KSP
747           for (int blkID = 0; blkID < nBlk; blkID++) {
748             if (h_metadata[blkID].its > jac->max_nits) jac->max_nits = h_metadata[blkID].its;
749           }
750         } else if (errsum) {
751           PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] ERROR Kokkos batch solver did not converge in all solves\n", (int)rank));
752         }
753       }
754 #if defined(PETSC_HAVE_CUDA)
755       nvtxRangePop();
756 #endif
757     } // end of Kokkos (not Kernels) solvers block
758     PetscCall(VecRestoreArrayAndMemType(xout, &glb_xdata));
759     PetscCall(VecRestoreArrayReadAndMemType(bvec, &glb_bdata));
760     PetscCall(PCSetFailedReason(pc, pcreason));
761     // map back to Plex space - not used
762     if (plex_batch) {
763       PetscCall(VecCopy(xout, bvec));
764       PetscCall(VecScatterBegin(plex_batch, bvec, xout, INSERT_VALUES, SCATTER_REVERSE));
765       PetscCall(VecScatterEnd(plex_batch, bvec, xout, INSERT_VALUES, SCATTER_REVERSE));
766       PetscCall(VecDestroy(&bvec));
767     }
768   }
769   PetscFunctionReturn(PETSC_SUCCESS);
770 }
771 
772 static PetscErrorCode PCSetUp_BJKOKKOS(PC pc)
773 {
774   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
775   Mat            A = pc->pmat, Aseq = A; // use filtered block matrix, really "P"
776   PetscBool      flg;
777 
778   PetscFunctionBegin;
779   //PetscCheck(!pc->useAmat, PetscObjectComm((PetscObject)pc), PETSC_ERR_SUP, "No support for using 'use_amat'");
780   PetscCheck(A, PetscObjectComm((PetscObject)A), PETSC_ERR_USER, "No matrix - A is used above");
781   PetscCall(PetscObjectTypeCompareAny((PetscObject)A, &flg, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, MATAIJKOKKOS, ""));
782   PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "must use '-[dm_]mat_type aijkokkos -[dm_]vec_type kokkos' for -pc_type bjkokkos");
783   if (!A->spptr) {
784     Aseq = ((Mat_MPIAIJ *)A->data)->A; // MPI
785   }
786   PetscCall(MatSeqAIJKokkosSyncDevice(Aseq));
787   {
788     PetscInt    Istart, Iend;
789     PetscMPIInt rank;
790     PetscCallMPI(MPI_Comm_rank(PetscObjectComm((PetscObject)A), &rank));
791     PetscCall(MatGetOwnershipRange(A, &Istart, &Iend));
792     if (!jac->vec_diag) {
793       Vec     *subX = NULL;
794       DM       pack, *subDM = NULL;
795       PetscInt nDMs, n, *block_sizes = NULL;
796       IS       isrow, isicol;
797       { // Permute the matrix to get a block diagonal system: d_isrow_k, d_isicol_k
798         MatOrderingType rtype;
799         const PetscInt *rowindices, *icolindices;
800         rtype = MATORDERINGRCM;
801         // get permutation. And invert. should we convert to local indices?
802         PetscCall(MatGetOrdering(Aseq, rtype, &isrow, &isicol)); // only seems to work for seq matrix
803         PetscCall(ISDestroy(&isrow));
804         PetscCall(ISInvertPermutation(isicol, PETSC_DECIDE, &isrow)); // THIS IS BACKWARD -- isrow is inverse
805         // if (rank==1) PetscCall(ISView(isicol, PETSC_VIEWER_STDOUT_SELF));
806         if (0) {
807           Mat mat_block_order; // debug
808           PetscCall(ISShift(isicol, Istart, isicol));
809           PetscCall(MatCreateSubMatrix(A, isicol, isicol, MAT_INITIAL_MATRIX, &mat_block_order));
810           PetscCall(ISShift(isicol, -Istart, isicol));
811           PetscCall(MatViewFromOptions(mat_block_order, NULL, "-ksp_batch_reorder_view"));
812           PetscCall(MatDestroy(&mat_block_order));
813         }
814         PetscCall(ISGetIndices(isrow, &rowindices)); // local idx
815         PetscCall(ISGetIndices(isicol, &icolindices));
816         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_isrow_k((PetscInt *)rowindices, A->rmap->n);
817         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_isicol_k((PetscInt *)icolindices, A->rmap->n);
818         jac->d_isrow_k  = new Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_isrow_k));
819         jac->d_isicol_k = new Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_isicol_k));
820         Kokkos::deep_copy(*jac->d_isrow_k, h_isrow_k);
821         Kokkos::deep_copy(*jac->d_isicol_k, h_isicol_k);
822         PetscCall(ISRestoreIndices(isrow, &rowindices));
823         PetscCall(ISRestoreIndices(isicol, &icolindices));
824         // if (rank==1) PetscCall(ISView(isicol, PETSC_VIEWER_STDOUT_SELF));
825       }
826       // get block sizes & allocate vec_diag
827       PetscCall(PCGetDM(pc, &pack));
828       if (pack) {
829         PetscCall(PetscObjectTypeCompare((PetscObject)pack, DMCOMPOSITE, &flg));
830         if (flg) {
831           PetscCall(DMCompositeGetNumberDM(pack, &nDMs));
832           PetscCall(DMCreateGlobalVector(pack, &jac->vec_diag));
833         } else pack = NULL; // flag for no DM
834       }
835       if (!jac->vec_diag) { // get 'nDMs' and sizes 'block_sizes' w/o DMComposite. User could provide ISs (todo)
836         PetscInt        bsrt, bend, ncols, ntot = 0;
837         const PetscInt *colsA, nloc = Iend - Istart;
838         const PetscInt *rowindices, *icolindices;
839         PetscCall(PetscMalloc1(nloc, &block_sizes)); // very inefficient, to big
840         PetscCall(ISGetIndices(isrow, &rowindices));
841         PetscCall(ISGetIndices(isicol, &icolindices));
842         nDMs = 0;
843         bsrt = 0;
844         bend = 1;
845         for (PetscInt row_B = 0; row_B < nloc; row_B++) { // for all rows in block diagonal space
846           PetscInt rowA = icolindices[row_B], minj = PETSC_MAX_INT, maxj = 0;
847           //PetscCall(PetscPrintf(PETSC_COMM_SELF, "\t[%d] rowA = %d\n",rank,rowA));
848           PetscCall(MatGetRow(Aseq, rowA, &ncols, &colsA, NULL)); // not sorted in permutation
849           PetscCheck(ncols, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Empty row not supported: %" PetscInt_FMT "\n", row_B);
850           for (PetscInt colj = 0; colj < ncols; colj++) {
851             PetscInt colB = rowindices[colsA[colj]]; // use local idx
852             //PetscCall(PetscPrintf(PETSC_COMM_SELF, "\t\t[%d] colB = %d\n",rank,colB));
853             PetscCheck(colB >= 0 && colB < nloc, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "colB < 0: %" PetscInt_FMT "\n", colB);
854             if (colB > maxj) maxj = colB;
855             if (colB < minj) minj = colB;
856           }
857           PetscCall(MatRestoreRow(Aseq, rowA, &ncols, &colsA, NULL));
858           if (minj >= bend) { // first column is > max of last block -- new block or last block
859             //PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\t\t finish block %d, N loc = %d (%d,%d)\n", nDMs+1, bend - bsrt,bsrt,bend));
860             block_sizes[nDMs] = bend - bsrt;
861             ntot += block_sizes[nDMs];
862             PetscCheck(minj == bend, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "minj != bend: %" PetscInt_FMT " != %" PetscInt_FMT "\n", minj, bend);
863             bsrt = bend;
864             bend++; // start with size 1 in new block
865             nDMs++;
866           }
867           if (maxj + 1 > bend) bend = maxj + 1;
868           PetscCheck(minj >= bsrt || row_B == Iend - 1, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "%" PetscInt_FMT ") minj < bsrt: %" PetscInt_FMT " != %" PetscInt_FMT "\n", rowA, minj, bsrt);
869           //PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] %d) row %d.%d) cols %d : %d ; bsrt = %d, bend = %d\n",rank,row_B,nDMs,rowA,minj,maxj,bsrt,bend));
870         }
871         // do last block
872         //PetscCall(PetscPrintf(PETSC_COMM_SELF, "\t\t\t [%d] finish block %d, N loc = %d (%d,%d)\n", rank, nDMs+1, bend - bsrt,bsrt,bend));
873         block_sizes[nDMs] = bend - bsrt;
874         ntot += block_sizes[nDMs];
875         nDMs++;
876         // cleanup
877         PetscCheck(ntot == nloc, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "n total != n local: %" PetscInt_FMT " != %" PetscInt_FMT "\n", ntot, nloc);
878         PetscCall(ISRestoreIndices(isrow, &rowindices));
879         PetscCall(ISRestoreIndices(isicol, &icolindices));
880         PetscCall(PetscRealloc(sizeof(PetscInt) * nDMs, &block_sizes));
881         PetscCall(MatCreateVecs(A, &jac->vec_diag, NULL));
882         PetscCall(PetscInfo(pc, "Setup Matrix based meta data (not DMComposite not attached to PC) %" PetscInt_FMT " sub domains\n", nDMs));
883       }
884       PetscCall(ISDestroy(&isrow));
885       PetscCall(ISDestroy(&isicol));
886       jac->num_dms = nDMs;
887       PetscCall(VecGetLocalSize(jac->vec_diag, &n));
888       jac->n         = n;
889       jac->d_idiag_k = new Kokkos::View<PetscScalar *, Kokkos::LayoutRight>("idiag", n);
890       // options
891       PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
892       PetscCall(KSPSetFromOptions(jac->ksp));
893       PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPBICG, ""));
894       if (flg) {
895         jac->ksp_type_idx = BATCH_KSP_BICG_IDX;
896         jac->nwork        = 7;
897       } else {
898         PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPTFQMR, ""));
899         if (flg) {
900           jac->ksp_type_idx = BATCH_KSP_TFQMR_IDX;
901           jac->nwork        = 10;
902         } else {
903 #if defined(PETSC_HAVE_KOKKOS_KERNELS_BATCH)
904           PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPGMRES, ""));
905           PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Unsupported batch ksp type");
906           jac->ksp_type_idx = BATCH_KSP_GMRESKK_IDX;
907           jac->nwork        = 0;
908 #else
909           KSPType ksptype;
910           PetscCall(KSPGetType(jac->ksp, &ksptype));
911           PetscCheck(flg, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Type: %s not supported in complex\n", ksptype);
912 #endif
913         }
914       }
915       PetscOptionsBegin(PetscObjectComm((PetscObject)jac->ksp), ((PetscObject)jac->ksp)->prefix, "Options for Kokkos batch solver", "none");
916       PetscCall(PetscOptionsBool("-ksp_converged_reason", "", "bjkokkos.kokkos.cxx.c", jac->reason, &jac->reason, NULL));
917       PetscCall(PetscOptionsBool("-ksp_monitor", "", "bjkokkos.kokkos.cxx.c", jac->monitor, &jac->monitor, NULL));
918       PetscCall(PetscOptionsInt("-ksp_batch_target", "", "bjkokkos.kokkos.cxx.c", jac->batch_target, &jac->batch_target, NULL));
919       PetscCall(PetscOptionsInt("-ksp_rank_target", "", "bjkokkos.kokkos.cxx.c", jac->rank_target, &jac->rank_target, NULL));
920       PetscCall(PetscOptionsInt("-ksp_batch_nsolves_team", "", "bjkokkos.kokkos.cxx.c", jac->nsolves_team, &jac->nsolves_team, NULL));
921       PetscCheck(jac->batch_target < jac->num_dms, PETSC_COMM_WORLD, PETSC_ERR_ARG_WRONG, "-ksp_batch_target (%" PetscInt_FMT ") >= number of DMs (%" PetscInt_FMT ")", jac->batch_target, jac->num_dms);
922       PetscOptionsEnd();
923       // get blocks - jac->d_bid_eqOffset_k
924       if (pack) {
925         PetscCall(PetscMalloc(sizeof(*subX) * nDMs, &subX));
926         PetscCall(PetscMalloc(sizeof(*subDM) * nDMs, &subDM));
927       }
928       PetscCall(PetscMalloc(sizeof(*jac->dm_Nf) * nDMs, &jac->dm_Nf));
929       PetscCall(PetscInfo(pc, "Have %" PetscInt_FMT " blocks, n=%" PetscInt_FMT " rtol=%g type = %s\n", nDMs, n, (double)jac->ksp->rtol, ((PetscObject)jac->ksp)->type_name));
930       if (pack) PetscCall(DMCompositeGetEntriesArray(pack, subDM));
931       jac->nBlocks = 0;
932       for (PetscInt ii = 0; ii < nDMs; ii++) {
933         PetscInt Nf;
934         if (subDM) {
935           DM           dm = subDM[ii];
936           PetscSection section;
937           PetscCall(DMGetLocalSection(dm, &section));
938           PetscCall(PetscSectionGetNumFields(section, &Nf));
939         } else Nf = 1;
940         jac->nBlocks += Nf;
941 #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
942         if (ii == 0) PetscCall(PetscInfo(pc, "%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n", ii, Nf, jac->nBlocks));
943 #else
944         PetscCall(PetscInfo(pc, "%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n", ii, Nf, jac->nBlocks));
945 #endif
946         jac->dm_Nf[ii] = Nf;
947       }
948       { // d_bid_eqOffset_k
949         Kokkos::View<PetscInt *, Kokkos::LayoutRight, Kokkos::HostSpace> h_block_offsets("block_offsets", jac->nBlocks + 1);
950         if (pack) PetscCall(DMCompositeGetAccessArray(pack, jac->vec_diag, nDMs, NULL, subX));
951         h_block_offsets[0]    = 0;
952         jac->const_block_size = -1;
953         for (PetscInt ii = 0, idx = 0; ii < nDMs; ii++) {
954           PetscInt nloc, nblk;
955           if (pack) PetscCall(VecGetSize(subX[ii], &nloc));
956           else nloc = block_sizes[ii];
957           nblk = nloc / jac->dm_Nf[ii];
958           PetscCheck(nloc % jac->dm_Nf[ii] == 0, PetscObjectComm((PetscObject)pc), PETSC_ERR_USER, "nloc%%jac->dm_Nf[ii] (%" PetscInt_FMT ") != 0 DMs", nloc % jac->dm_Nf[ii]);
959           for (PetscInt jj = 0; jj < jac->dm_Nf[ii]; jj++, idx++) {
960             h_block_offsets[idx + 1] = h_block_offsets[idx] + nblk;
961 #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
962             if (idx == 0) PetscCall(PetscInfo(pc, "Add first of %" PetscInt_FMT " blocks with %" PetscInt_FMT " equations\n", jac->nBlocks, nblk));
963 #else
964             PetscCall(PetscInfo(pc, "\t%" PetscInt_FMT ") Add block with %" PetscInt_FMT " equations of %" PetscInt_FMT "\n", idx + 1, nblk, jac->nBlocks));
965 #endif
966             if (jac->const_block_size == -1) jac->const_block_size = nblk;
967             else if (jac->const_block_size > 0 && jac->const_block_size != nblk) jac->const_block_size = 0;
968           }
969         }
970         if (pack) {
971           PetscCall(DMCompositeRestoreAccessArray(pack, jac->vec_diag, jac->nBlocks, NULL, subX));
972           PetscCall(PetscFree(subX));
973           PetscCall(PetscFree(subDM));
974         }
975         jac->d_bid_eqOffset_k = new Kokkos::View<PetscInt *, Kokkos::LayoutRight>(Kokkos::create_mirror(Kokkos::DefaultExecutionSpace::memory_space(), h_block_offsets));
976         Kokkos::deep_copy(*jac->d_bid_eqOffset_k, h_block_offsets);
977       }
978       if (!pack) PetscCall(PetscFree(block_sizes));
979     }
980     { // get jac->d_idiag_k (PC setup),
981       const PetscInt    *d_ai, *d_aj;
982       const PetscScalar *d_aa;
983       const PetscInt     conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp == 0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
984       const PetscInt    *d_bid_eqOffset = jac->d_bid_eqOffset_k->data(), *r = jac->d_isrow_k->data(), *ic = jac->d_isicol_k->data();
985       PetscScalar       *d_idiag = jac->d_idiag_k->data(), *dummy;
986       PetscMemType       mtype;
987       PetscCall(MatSeqAIJGetCSRAndMemType(Aseq, &d_ai, &d_aj, &dummy, &mtype));
988       d_aa = dummy;
989       Kokkos::parallel_for(
990         "Diag", Kokkos::TeamPolicy<>(jac->nBlocks, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA(const team_member team) {
991           const PetscInt blkID = team.league_rank();
992           Kokkos::parallel_for(Kokkos::TeamThreadRange(team, d_bid_eqOffset[blkID], d_bid_eqOffset[blkID + 1]), [=](const int rowb) {
993             const PetscInt     rowa = ic[rowb], ai = d_ai[rowa], *aj = d_aj + ai; // grab original data
994             const PetscScalar *aa   = d_aa + ai;
995             const PetscInt     nrow = d_ai[rowa + 1] - ai;
996             int                found;
997             Kokkos::parallel_reduce(
998               Kokkos::ThreadVectorRange(team, nrow),
999               [=](const int &j, int &count) {
1000                 const PetscInt colb = r[aj[j]];
1001                 if (colb == rowb) {
1002                   d_idiag[rowb] = 1. / aa[j];
1003                   count++;
1004                 }
1005               },
1006               found);
1007 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
1008             if (found != 1) Kokkos::single(Kokkos::PerThread(team), [=]() { printf("ERRORrow %d) found = %d\n", rowb, found); });
1009 #endif
1010           });
1011         });
1012     }
1013   }
1014   PetscFunctionReturn(PETSC_SUCCESS);
1015 }
1016 
1017 /* Default destroy, if it has never been setup */
1018 static PetscErrorCode PCReset_BJKOKKOS(PC pc)
1019 {
1020   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1021 
1022   PetscFunctionBegin;
1023   PetscCall(KSPDestroy(&jac->ksp));
1024   PetscCall(VecDestroy(&jac->vec_diag));
1025   if (jac->d_bid_eqOffset_k) delete jac->d_bid_eqOffset_k;
1026   if (jac->d_idiag_k) delete jac->d_idiag_k;
1027   if (jac->d_isrow_k) delete jac->d_isrow_k;
1028   if (jac->d_isicol_k) delete jac->d_isicol_k;
1029   jac->d_bid_eqOffset_k = NULL;
1030   jac->d_idiag_k        = NULL;
1031   jac->d_isrow_k        = NULL;
1032   jac->d_isicol_k       = NULL;
1033   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSGetKSP_C", NULL)); // not published now (causes configure errors)
1034   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSSetKSP_C", NULL));
1035   PetscCall(PetscFree(jac->dm_Nf));
1036   jac->dm_Nf = NULL;
1037   if (jac->rowOffsets) delete jac->rowOffsets;
1038   if (jac->colIndices) delete jac->colIndices;
1039   if (jac->batch_b) delete jac->batch_b;
1040   if (jac->batch_x) delete jac->batch_x;
1041   if (jac->batch_values) delete jac->batch_values;
1042   jac->rowOffsets   = NULL;
1043   jac->colIndices   = NULL;
1044   jac->batch_b      = NULL;
1045   jac->batch_x      = NULL;
1046   jac->batch_values = NULL;
1047 
1048   PetscFunctionReturn(PETSC_SUCCESS);
1049 }
1050 
1051 static PetscErrorCode PCDestroy_BJKOKKOS(PC pc)
1052 {
1053   PetscFunctionBegin;
1054   PetscCall(PCReset_BJKOKKOS(pc));
1055   PetscCall(PetscFree(pc->data));
1056   PetscFunctionReturn(PETSC_SUCCESS);
1057 }
1058 
1059 static PetscErrorCode PCView_BJKOKKOS(PC pc, PetscViewer viewer)
1060 {
1061   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1062   PetscBool      iascii;
1063 
1064   PetscFunctionBegin;
1065   if (!jac->ksp) PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
1066   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
1067   if (iascii) {
1068     PetscCall(PetscViewerASCIIPrintf(viewer, "  Batched device linear solver: Krylov (KSP) method with Jacobi preconditioning\n"));
1069     PetscCall(PetscViewerASCIIPrintf(viewer, "\t\tnwork = %" PetscInt_FMT ", rel tol = %e, abs tol = %e, div tol = %e, max it =%" PetscInt_FMT ", type = %s\n", jac->nwork, jac->ksp->rtol, jac->ksp->abstol, jac->ksp->divtol, jac->ksp->max_it,
1070                                      ((PetscObject)jac->ksp)->type_name));
1071   }
1072   PetscFunctionReturn(PETSC_SUCCESS);
1073 }
1074 
1075 static PetscErrorCode PCSetFromOptions_BJKOKKOS(PC pc, PetscOptionItems *PetscOptionsObject)
1076 {
1077   PetscFunctionBegin;
1078   PetscOptionsHeadBegin(PetscOptionsObject, "PC BJKOKKOS options");
1079   PetscOptionsHeadEnd();
1080   PetscFunctionReturn(PETSC_SUCCESS);
1081 }
1082 
1083 static PetscErrorCode PCBJKOKKOSSetKSP_BJKOKKOS(PC pc, KSP ksp)
1084 {
1085   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1086 
1087   PetscFunctionBegin;
1088   PetscCall(PetscObjectReference((PetscObject)ksp));
1089   PetscCall(KSPDestroy(&jac->ksp));
1090   jac->ksp = ksp;
1091   PetscFunctionReturn(PETSC_SUCCESS);
1092 }
1093 
1094 /*@C
1095   PCBJKOKKOSSetKSP - Sets the `KSP` context for `PCBJKOKKOS`
1096 
1097   Collective
1098 
1099   Input Parameters:
1100 + pc  - the `PCBJKOKKOS` preconditioner context
1101 - ksp - the `KSP` solver
1102 
1103   Level: advanced
1104 
1105   Notes:
1106   The `PC` and the `KSP` must have the same communicator
1107 
1108   If the `PC` is not `PCBJKOKKOS` this function returns without doing anything
1109 
1110   .seealso: `PCBJKOKKOSGetKSP()`, `PCBJKOKKOS`
1111 @*/
1112 PetscErrorCode PCBJKOKKOSSetKSP(PC pc, KSP ksp)
1113 {
1114   PetscFunctionBegin;
1115   PetscValidHeaderSpecific(pc, PC_CLASSID, 1);
1116   PetscValidHeaderSpecific(ksp, KSP_CLASSID, 2);
1117   PetscCheckSameComm(pc, 1, ksp, 2);
1118   PetscTryMethod(pc, "PCBJKOKKOSSetKSP_C", (PC, KSP), (pc, ksp));
1119   PetscFunctionReturn(PETSC_SUCCESS);
1120 }
1121 
1122 static PetscErrorCode PCBJKOKKOSGetKSP_BJKOKKOS(PC pc, KSP *ksp)
1123 {
1124   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1125 
1126   PetscFunctionBegin;
1127   if (!jac->ksp) PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
1128   *ksp = jac->ksp;
1129   PetscFunctionReturn(PETSC_SUCCESS);
1130 }
1131 
1132 /*@C
1133   PCBJKOKKOSGetKSP - Gets the `KSP` context for the `PCBJKOKKOS` preconditioner
1134 
1135   Not Collective but `KSP` returned is parallel if `PC` was parallel
1136 
1137   Input Parameter:
1138 . pc - the preconditioner context
1139 
1140   Output Parameter:
1141 . ksp - the `KSP` solver
1142 
1143   Level: advanced
1144 
1145   Notes:
1146   You must call `KSPSetUp()` before calling `PCBJKOKKOSGetKSP()`.
1147 
1148   If the `PC` is not a `PCBJKOKKOS` object it raises an error
1149 
1150 .seealso: `PCBJKOKKOS`, `PCBJKOKKOSSetKSP()`
1151 @*/
1152 PetscErrorCode PCBJKOKKOSGetKSP(PC pc, KSP *ksp)
1153 {
1154   PetscFunctionBegin;
1155   PetscValidHeaderSpecific(pc, PC_CLASSID, 1);
1156   PetscAssertPointer(ksp, 2);
1157   PetscUseMethod(pc, "PCBJKOKKOSGetKSP_C", (PC, KSP *), (pc, ksp));
1158   PetscFunctionReturn(PETSC_SUCCESS);
1159 }
1160 
1161 /*MC
1162      PCBJKOKKOS -  Defines a preconditioner that applies a Krylov solver and preconditioner to the blocks in a `MATSEQAIJ` matrix on the GPU using Kokkos
1163 
1164    Options Database Key:
1165 .     -pc_bjkokkos_ - options prefix for its `KSP` options
1166 
1167    Level: intermediate
1168 
1169    Note:
1170     For use with -ksp_type preonly to bypass any computation on the CPU
1171 
1172    Developer Notes:
1173    The documentation is incomplete. Is this a block Jacobi preconditioner?
1174 
1175    Why does it have its own `KSP`? Where is the `KSP` run if used with -ksp_type preonly?
1176 
1177 .seealso: `PCCreate()`, `PCSetType()`, `PCType`, `PC`, `PCBJACOBI`,
1178           `PCSHELL`, `PCCOMPOSITE`, `PCSetUseAmat()`, `PCBJKOKKOSGetKSP()`
1179 M*/
1180 
1181 PETSC_EXTERN PetscErrorCode PCCreate_BJKOKKOS(PC pc)
1182 {
1183   PC_PCBJKOKKOS *jac;
1184 
1185   PetscFunctionBegin;
1186   PetscCall(PetscNew(&jac));
1187   pc->data = (void *)jac;
1188 
1189   jac->ksp              = NULL;
1190   jac->vec_diag         = NULL;
1191   jac->d_bid_eqOffset_k = NULL;
1192   jac->d_idiag_k        = NULL;
1193   jac->d_isrow_k        = NULL;
1194   jac->d_isicol_k       = NULL;
1195   jac->nBlocks          = 1;
1196   jac->max_nits         = 0;
1197 
1198   PetscCall(PetscMemzero(pc->ops, sizeof(struct _PCOps)));
1199   pc->ops->apply          = PCApply_BJKOKKOS;
1200   pc->ops->applytranspose = NULL;
1201   pc->ops->setup          = PCSetUp_BJKOKKOS;
1202   pc->ops->reset          = PCReset_BJKOKKOS;
1203   pc->ops->destroy        = PCDestroy_BJKOKKOS;
1204   pc->ops->setfromoptions = PCSetFromOptions_BJKOKKOS;
1205   pc->ops->view           = PCView_BJKOKKOS;
1206 
1207   jac->rowOffsets   = NULL;
1208   jac->colIndices   = NULL;
1209   jac->batch_b      = NULL;
1210   jac->batch_x      = NULL;
1211   jac->batch_values = NULL;
1212 
1213   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSGetKSP_C", PCBJKOKKOSGetKSP_BJKOKKOS));
1214   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSSetKSP_C", PCBJKOKKOSSetKSP_BJKOKKOS));
1215   PetscFunctionReturn(PETSC_SUCCESS);
1216 }
1217