xref: /petsc/src/ksp/pc/impls/bjacobi/bjkokkos/bjkokkos.kokkos.cxx (revision e91c04dfc8a52dee1965211bb1cc8e5bf775178f)
1 #define PETSC_SKIP_CXX_COMPLEX_FIX // Kokkos::complex does not need the petsc complex fix
2 
3 #include <petsc/private/pcbjkokkosimpl.h>
4 
5 #include <petsc/private/kspimpl.h>
6 #include <petscksp.h> /*I "petscksp.h" I*/
7 #include <../src/mat/impls/aij/mpi/mpiaij.h>
8 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
9 #include <petscsection.h>
10 #include <petscdmcomposite.h>
11 
12 #include <../src/mat/impls/aij/seq/aij.h>
13 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
14 
15 #include <petscdevice_cupm.h>
16 
17 static PetscErrorCode PCBJKOKKOSCreateKSP_BJKOKKOS(PC pc)
18 {
19   const char    *prefix;
20   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
21   DM             dm;
22 
23   PetscFunctionBegin;
24   PetscCall(KSPCreate(PetscObjectComm((PetscObject)pc), &jac->ksp));
25   PetscCall(KSPSetNestLevel(jac->ksp, pc->kspnestlevel));
26   PetscCall(KSPSetErrorIfNotConverged(jac->ksp, pc->erroriffailure));
27   PetscCall(PetscObjectIncrementTabLevel((PetscObject)jac->ksp, (PetscObject)pc, 1));
28   PetscCall(PCGetOptionsPrefix(pc, &prefix));
29   PetscCall(KSPSetOptionsPrefix(jac->ksp, prefix));
30   PetscCall(KSPAppendOptionsPrefix(jac->ksp, "pc_bjkokkos_"));
31   PetscCall(PCGetDM(pc, &dm));
32   if (dm) {
33     PetscCall(KSPSetDM(jac->ksp, dm));
34     PetscCall(KSPSetDMActive(jac->ksp, PETSC_FALSE));
35   }
36   jac->reason       = PETSC_FALSE;
37   jac->monitor      = PETSC_FALSE;
38   jac->batch_target = 0;
39   jac->rank_target  = 0;
40   jac->nsolves_team = 1;
41   jac->ksp->max_it  = 50; // this is really for GMRES w/o restarts
42   PetscFunctionReturn(PETSC_SUCCESS);
43 }
44 
45 // y <-- Ax
46 KOKKOS_INLINE_FUNCTION PetscErrorCode MatMult(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
47 {
48   Kokkos::parallel_for(Kokkos::TeamThreadRange(team, start, end), [=](const int rowb) {
49     int                rowa = ic[rowb];
50     int                n    = glb_Aai[rowa + 1] - glb_Aai[rowa];
51     const PetscInt    *aj   = glb_Aaj + glb_Aai[rowa]; // global
52     const PetscScalar *aa   = glb_Aaa + glb_Aai[rowa];
53     PetscScalar        sum;
54     Kokkos::parallel_reduce(Kokkos::ThreadVectorRange(team, n), [=](const int i, PetscScalar &lsum) { lsum += aa[i] * x_loc[r[aj[i]] - start]; }, sum);
55     Kokkos::single(Kokkos::PerThread(team), [=]() { y_loc[rowb - start] = sum; });
56   });
57   team.team_barrier();
58   return PETSC_SUCCESS;
59 }
60 
61 // temp buffer per thread with reduction at end?
62 KOKKOS_INLINE_FUNCTION PetscErrorCode MatMultTranspose(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
63 {
64   Kokkos::parallel_for(Kokkos::TeamVectorRange(team, end - start), [=](int i) { y_loc[i] = 0; });
65   team.team_barrier();
66   Kokkos::parallel_for(Kokkos::TeamThreadRange(team, start, end), [=](const int rowb) {
67     int                rowa = ic[rowb];
68     int                n    = glb_Aai[rowa + 1] - glb_Aai[rowa];
69     const PetscInt    *aj   = glb_Aaj + glb_Aai[rowa]; // global
70     const PetscScalar *aa   = glb_Aaa + glb_Aai[rowa];
71     const PetscScalar  xx   = x_loc[rowb - start]; // rowb = ic[rowa] = ic[r[rowb]]
72     Kokkos::parallel_for(Kokkos::ThreadVectorRange(team, n), [=](const int &i) {
73       PetscScalar val = aa[i] * xx;
74       Kokkos::atomic_fetch_add(&y_loc[r[aj[i]] - start], val);
75     });
76   });
77   team.team_barrier();
78   return PETSC_SUCCESS;
79 }
80 
81 typedef struct Batch_MetaData_TAG {
82   PetscInt           flops;
83   PetscInt           its;
84   KSPConvergedReason reason;
85 } Batch_MetaData;
86 
87 // Solve A(BB^-1)x = y with TFQMR. Right preconditioned to get un-preconditioned residual
88 static KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_TFQMR(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space_global, const int stride_global, const int nShareVec, PetscScalar *work_space_shared, const int stride_shared, PetscReal rtol, PetscReal atol, PetscReal dtol, PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
89 {
90   using Kokkos::parallel_for;
91   using Kokkos::parallel_reduce;
92   int                Nblk = end - start, it, m, stride = stride_shared, idx = 0;
93   PetscReal          dp, dpold, w, dpest, tau, psi, cm, r0;
94   const PetscScalar *Diag = &glb_idiag[start];
95   PetscScalar       *ptr  = work_space_shared, rho, rhoold, a, s, b, eta, etaold, psiold, cf, dpi;
96 
97   if (idx++ == nShareVec) {
98     ptr    = work_space_global;
99     stride = stride_global;
100   }
101   PetscScalar *XX = ptr;
102   ptr += stride;
103   if (idx++ == nShareVec) {
104     ptr    = work_space_global;
105     stride = stride_global;
106   }
107   PetscScalar *R = ptr;
108   ptr += stride;
109   if (idx++ == nShareVec) {
110     ptr    = work_space_global;
111     stride = stride_global;
112   }
113   PetscScalar *RP = ptr;
114   ptr += stride;
115   if (idx++ == nShareVec) {
116     ptr    = work_space_global;
117     stride = stride_global;
118   }
119   PetscScalar *V = ptr;
120   ptr += stride;
121   if (idx++ == nShareVec) {
122     ptr    = work_space_global;
123     stride = stride_global;
124   }
125   PetscScalar *T = ptr;
126   ptr += stride;
127   if (idx++ == nShareVec) {
128     ptr    = work_space_global;
129     stride = stride_global;
130   }
131   PetscScalar *Q = ptr;
132   ptr += stride;
133   if (idx++ == nShareVec) {
134     ptr    = work_space_global;
135     stride = stride_global;
136   }
137   PetscScalar *P = ptr;
138   ptr += stride;
139   if (idx++ == nShareVec) {
140     ptr    = work_space_global;
141     stride = stride_global;
142   }
143   PetscScalar *U = ptr;
144   ptr += stride;
145   if (idx++ == nShareVec) {
146     ptr    = work_space_global;
147     stride = stride_global;
148   }
149   PetscScalar *D = ptr;
150   ptr += stride;
151   if (idx++ == nShareVec) {
152     ptr    = work_space_global;
153     stride = stride_global;
154   }
155   PetscScalar *AUQ = V;
156 
157   // init: get b, zero x
158   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
159     int rowa         = ic[rowb];
160     R[rowb - start]  = glb_b[rowa];
161     XX[rowb - start] = 0;
162   });
163   team.team_barrier();
164   parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += R[idx] * PetscConj(R[idx]); }, dpi);
165   team.team_barrier();
166   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
167   // diagnostics
168 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
169   if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", 0, (double)dp); });
170 #endif
171   if (dp < atol) {
172     metad->reason = KSP_CONVERGED_ATOL_NORMAL;
173     it            = 0;
174     goto done;
175   }
176   if (0 == maxit) {
177     metad->reason = KSP_CONVERGED_ITS;
178     it            = 0;
179     goto done;
180   }
181 
182   /* Make the initial Rp = R */
183   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { RP[idx] = R[idx]; });
184   team.team_barrier();
185   /* Set the initial conditions */
186   etaold = 0.0;
187   psiold = 0.0;
188   tau    = dp;
189   dpold  = dp;
190 
191   /* rhoold = (r,rp)     */
192   parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += R[idx] * PetscConj(RP[idx]); }, rhoold);
193   team.team_barrier();
194   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
195     U[idx] = R[idx];
196     P[idx] = R[idx];
197     T[idx] = Diag[idx] * P[idx];
198     D[idx] = 0;
199   });
200   team.team_barrier();
201   static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, V));
202 
203   it = 0;
204   do {
205     /* s <- (v,rp)          */
206     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += V[idx] * PetscConj(RP[idx]); }, s);
207     team.team_barrier();
208     if (s == 0) {
209       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
210       goto done;
211     }
212     a = rhoold / s; /* a <- rho / s         */
213     /* q <- u - a v    VecWAXPY(w,alpha,x,y): w = alpha x + y.     */
214     /* t <- u + q           */
215     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
216       Q[idx] = U[idx] - a * V[idx];
217       T[idx] = U[idx] + Q[idx];
218     });
219     team.team_barrier();
220     // KSP_PCApplyBAorAB
221     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { T[idx] = Diag[idx] * T[idx]; });
222     team.team_barrier();
223     static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, AUQ));
224     /* r <- r - a K (u + q) */
225     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { R[idx] = R[idx] - a * AUQ[idx]; });
226     team.team_barrier();
227     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += R[idx] * PetscConj(R[idx]); }, dpi);
228     team.team_barrier();
229     dp = PetscSqrtReal(PetscRealPart(dpi));
230     for (m = 0; m < 2; m++) {
231       if (!m) w = PetscSqrtReal(dp * dpold);
232       else w = dp;
233       psi = w / tau;
234       cm  = 1.0 / PetscSqrtReal(1.0 + psi * psi);
235       tau = tau * psi * cm;
236       eta = cm * cm * a;
237       cf  = psiold * psiold * etaold / a;
238       if (!m) {
239         /* D = U + cf D */
240         parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { D[idx] = U[idx] + cf * D[idx]; });
241       } else {
242         /* D = Q + cf D */
243         parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { D[idx] = Q[idx] + cf * D[idx]; });
244       }
245       team.team_barrier();
246       parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { XX[idx] = XX[idx] + eta * D[idx]; });
247       team.team_barrier();
248       dpest = PetscSqrtReal(2 * it + m + 2.0) * tau;
249 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
250       if (monitor && m == 1) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", it + 1, (double)dpest); });
251 #endif
252       if (dpest < atol) {
253         metad->reason = KSP_CONVERGED_ATOL_NORMAL;
254         goto done;
255       }
256       if (dpest / r0 < rtol) {
257         metad->reason = KSP_CONVERGED_RTOL_NORMAL;
258         goto done;
259       }
260 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
261       if (dpest / r0 > dtol) {
262         metad->reason = KSP_DIVERGED_DTOL;
263         Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n", team.league_rank(), it, dpest, r0); });
264         goto done;
265       }
266 #else
267       if (dpest / r0 > dtol) {
268         metad->reason = KSP_DIVERGED_DTOL;
269         goto done;
270       }
271 #endif
272       if (it + 1 == maxit) {
273         metad->reason = KSP_CONVERGED_ITS;
274 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
275         Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: TFQMR %d:%d it, res=%e, r_0=%e r_res=%e\n", team.league_rank(), it, m, dpest, r0, dpest / r0); });
276 #endif
277         goto done;
278       }
279       etaold = eta;
280       psiold = psi;
281     }
282 
283     /* rho <- (r,rp)       */
284     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += R[idx] * PetscConj(RP[idx]); }, rho);
285     team.team_barrier();
286     if (rho == 0) {
287       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
288       goto done;
289     }
290     b = rho / rhoold; /* b <- rho / rhoold   */
291     /* u <- r + b q        */
292     /* p <- u + b(q + b p) */
293     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
294       U[idx] = R[idx] + b * Q[idx];
295       Q[idx] = Q[idx] + b * P[idx];
296       P[idx] = U[idx] + b * Q[idx];
297     });
298     /* v <- K p  */
299     team.team_barrier();
300     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { T[idx] = Diag[idx] * P[idx]; });
301     team.team_barrier();
302     static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, V));
303 
304     rhoold = rho;
305     dpold  = dp;
306 
307     it++;
308   } while (it < maxit);
309 done:
310   // KSPUnwindPreconditioner
311   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { XX[idx] = Diag[idx] * XX[idx]; });
312   team.team_barrier();
313   // put x into Plex order
314   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
315     int rowa    = ic[rowb];
316     glb_x[rowa] = XX[rowb - start];
317   });
318   metad->its = it;
319   if (1) {
320     int nnz;
321     parallel_reduce(Kokkos::TeamVectorRange(team, start, end), [=](const int idx, int &lsum) { lsum += (glb_Aai[idx + 1] - glb_Aai[idx]); }, nnz);
322     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * nnz) + 5 * Nblk);
323   } else {
324     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * 50 * Nblk) + 5 * Nblk); // guess
325   }
326   return PETSC_SUCCESS;
327 }
328 
329 // Solve Ax = y with biCG
330 static KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_BICG(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space_global, const int stride_global, const int nShareVec, PetscScalar *work_space_shared, const int stride_shared, PetscReal rtol, PetscReal atol, PetscReal dtol, PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
331 {
332   using Kokkos::parallel_for;
333   using Kokkos::parallel_reduce;
334   int                Nblk = end - start, it, stride = stride_shared, idx = 0; // start in shared mem
335   PetscReal          dp, r0;
336   const PetscScalar *Di  = &glb_idiag[start];
337   PetscScalar       *ptr = work_space_shared, dpi, a = 1.0, beta, betaold = 1.0, t1, t2;
338 
339   if (idx++ == nShareVec) {
340     ptr    = work_space_global;
341     stride = stride_global;
342   }
343   PetscScalar *XX = ptr;
344   ptr += stride;
345   if (idx++ == nShareVec) {
346     ptr    = work_space_global;
347     stride = stride_global;
348   }
349   PetscScalar *Rl = ptr;
350   ptr += stride;
351   if (idx++ == nShareVec) {
352     ptr    = work_space_global;
353     stride = stride_global;
354   }
355   PetscScalar *Zl = ptr;
356   ptr += stride;
357   if (idx++ == nShareVec) {
358     ptr    = work_space_global;
359     stride = stride_global;
360   }
361   PetscScalar *Pl = ptr;
362   ptr += stride;
363   if (idx++ == nShareVec) {
364     ptr    = work_space_global;
365     stride = stride_global;
366   }
367   PetscScalar *Rr = ptr;
368   ptr += stride;
369   if (idx++ == nShareVec) {
370     ptr    = work_space_global;
371     stride = stride_global;
372   }
373   PetscScalar *Zr = ptr;
374   ptr += stride;
375   if (idx++ == nShareVec) {
376     ptr    = work_space_global;
377     stride = stride_global;
378   }
379   PetscScalar *Pr = ptr;
380   ptr += stride;
381 
382   /*     r <- b (x is 0) */
383   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
384     int rowa         = ic[rowb];
385     Rl[rowb - start] = Rr[rowb - start] = glb_b[rowa];
386     XX[rowb - start]                    = 0;
387   });
388   team.team_barrier();
389   /*     z <- Br         */
390   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
391     Zr[idx] = Di[idx] * Rr[idx];
392     Zl[idx] = Di[idx] * Rl[idx];
393   });
394   team.team_barrier();
395   /*    dp <- r'*r       */
396   parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Rr[idx] * PetscConj(Rr[idx]); }, dpi);
397   team.team_barrier();
398   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
399 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
400   if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", 0, (double)dp); });
401 #endif
402   if (dp < atol) {
403     metad->reason = KSP_CONVERGED_ATOL_NORMAL;
404     it            = 0;
405     goto done;
406   }
407   if (0 == maxit) {
408     metad->reason = KSP_CONVERGED_ITS;
409     it            = 0;
410     goto done;
411   }
412 
413   it = 0;
414   do {
415     /*     beta <- r'z     */
416     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += Zr[idx] * PetscConj(Rl[idx]); }, beta);
417     team.team_barrier();
418 #if PCBJKOKKOS_VERBOSE_LEVEL >= 6
419   #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
420     Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%7d beta = Z.R = %22.14e \n", i, (double)beta); });
421   #endif
422 #endif
423     if (beta == 0.0) {
424       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
425       goto done;
426     }
427     if (it == 0) {
428       /*     p <- z          */
429       parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
430         Pr[idx] = Zr[idx];
431         Pl[idx] = Zl[idx];
432       });
433     } else {
434       t1 = beta / betaold;
435       /*     p <- z + b* p   */
436       t2 = PetscConj(t1);
437       parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
438         Pr[idx] = t1 * Pr[idx] + Zr[idx];
439         Pl[idx] = t2 * Pl[idx] + Zl[idx];
440       });
441     }
442     team.team_barrier();
443     betaold = beta;
444     /*     z <- Kp         */
445     static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, Pr, Zr));
446     static_cast<void>(MatMultTranspose(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, Pl, Zl));
447     /*     dpi <- z'p      */
448     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Zr[idx] * PetscConj(Pl[idx]); }, dpi);
449     team.team_barrier();
450     if (dpi == 0) {
451       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
452       goto done;
453     }
454     //
455     a  = beta / dpi; /*     a = beta/p'z    */
456     t1 = -a;
457     t2 = PetscConj(t1);
458     /*     x <- x + ap     */
459     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
460       XX[idx] = XX[idx] + a * Pr[idx];
461       Rr[idx] = Rr[idx] + t1 * Zr[idx];
462       Rl[idx] = Rl[idx] + t2 * Zl[idx];
463     });
464     team.team_barrier();
465     team.team_barrier();
466     /*    dp <- r'*r       */
467     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Rr[idx] * PetscConj(Rr[idx]); }, dpi);
468     team.team_barrier();
469     dp = PetscSqrtReal(PetscRealPart(dpi));
470 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
471     if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", it + 1, (double)dp); });
472 #endif
473     if (dp < atol) {
474       metad->reason = KSP_CONVERGED_ATOL_NORMAL;
475       goto done;
476     }
477     if (dp / r0 < rtol) {
478       metad->reason = KSP_CONVERGED_RTOL_NORMAL;
479       goto done;
480     }
481 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
482     if (dp / r0 > dtol) {
483       metad->reason = KSP_DIVERGED_DTOL;
484       Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: %d it, res=%e, r_0=%e (BICG does this)\n", team.league_rank(), it, dp, r0); });
485       goto done;
486     }
487 #else
488     if (dp / r0 > dtol) {
489       metad->reason = KSP_DIVERGED_DTOL;
490       goto done;
491     }
492 #endif
493     if (it + 1 == maxit) {
494       metad->reason = KSP_CONVERGED_ITS; // don't worry about hitting max iterations
495 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
496       Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: BICG %d it, res=%e, r_0=%e r_res=%e\n", team.league_rank(), it, dp, r0, dp / r0); });
497 #endif
498       goto done;
499     }
500     /* z <- Br  */
501     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
502       Zr[idx] = Di[idx] * Rr[idx];
503       Zl[idx] = Di[idx] * Rl[idx];
504     });
505 
506     it++;
507   } while (it < maxit);
508 done:
509   // put x back into Plex order
510   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
511     int rowa    = ic[rowb];
512     glb_x[rowa] = XX[rowb - start];
513   });
514   metad->its = it;
515   if (1) {
516     int nnz;
517     parallel_reduce(Kokkos::TeamVectorRange(team, start, end), [=](const int idx, int &lsum) { lsum += (glb_Aai[idx + 1] - glb_Aai[idx]); }, nnz);
518     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * nnz) + 5 * Nblk);
519   } else {
520     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * 50 * Nblk) + 5 * Nblk); // guess
521   }
522   return PETSC_SUCCESS;
523 }
524 
525 // KSP solver solve Ax = b; xout is output, bin is input
526 static PetscErrorCode PCApply_BJKOKKOS(PC pc, Vec bin, Vec xout)
527 {
528   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
529   Mat            A = pc->pmat, Aseq = A;
530   PetscMPIInt    rank;
531 
532   PetscFunctionBegin;
533   PetscCallMPI(MPI_Comm_rank(PetscObjectComm((PetscObject)A), &rank));
534   if (!A->spptr) {
535     Aseq = ((Mat_MPIAIJ *)A->data)->A; // MPI
536   }
537   PetscCall(MatSeqAIJKokkosSyncDevice(Aseq));
538   {
539     PetscInt           maxit = jac->ksp->max_it;
540     const PetscInt     conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp == 0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
541     const PetscInt     nwork = jac->nwork, nBlk = jac->nBlocks;
542     PetscScalar       *glb_xdata = NULL, *dummy;
543     PetscReal          rtol = jac->ksp->rtol, atol = jac->ksp->abstol, dtol = jac->ksp->divtol;
544     const PetscScalar *glb_idiag = jac->d_idiag_k->data(), *glb_bdata = NULL;
545     const PetscInt    *glb_Aai, *glb_Aaj, *d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
546     const PetscScalar *glb_Aaa;
547     const PetscInt    *d_isicol = jac->d_isicol_k->data(), *d_isrow = jac->d_isrow_k->data();
548     PCFailedReason     pcreason;
549     KSPIndex           ksp_type_idx = jac->ksp_type_idx;
550     PetscMemType       mtype;
551     PetscContainer     container;
552     PetscInt           batch_sz;                // the number of repeated DMs, [DM_e_1, DM_e_2, DM_e_batch_sz, DM_i_1, ...]
553     VecScatter         plex_batch = NULL;       // not used
554     Vec                bvec;                    // a copy of b for scatter (just alias to bin now)
555     PetscBool          monitor  = jac->monitor; // captured
556     PetscInt           view_bid = jac->batch_target;
557     MatInfo            info;
558 
559     PetscCall(MatSeqAIJGetCSRAndMemType(Aseq, &glb_Aai, &glb_Aaj, &dummy, &mtype));
560     jac->max_nits = 0;
561     glb_Aaa       = dummy;
562     if (jac->rank_target != rank) view_bid = -1; // turn off all but one process
563     PetscCall(MatGetInfo(A, MAT_LOCAL, &info));
564     // get field major is to map plex IO to/from block/field major
565     PetscCall(PetscObjectQuery((PetscObject)A, "plex_batch_is", (PetscObject *)&container));
566     if (container) {
567       PetscCall(VecDuplicate(bin, &bvec));
568       PetscCall(PetscContainerGetPointer(container, (void **)&plex_batch));
569       PetscCall(VecScatterBegin(plex_batch, bin, bvec, INSERT_VALUES, SCATTER_FORWARD));
570       PetscCall(VecScatterEnd(plex_batch, bin, bvec, INSERT_VALUES, SCATTER_FORWARD));
571       SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_USER, "No plex_batch_is -- require NO field major ordering for now");
572     } else {
573       bvec = bin;
574     }
575     // get x
576     PetscCall(VecGetArrayAndMemType(xout, &glb_xdata, &mtype));
577 #if defined(PETSC_HAVE_CUDA)
578     PetscCheck(PetscMemTypeDevice(mtype), PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "No GPU data for x %d != %d", static_cast<int>(mtype), static_cast<int>(PETSC_MEMTYPE_DEVICE));
579 #endif
580     PetscCall(VecGetArrayReadAndMemType(bvec, &glb_bdata, &mtype));
581 #if defined(PETSC_HAVE_CUDA)
582     PetscCheck(PetscMemTypeDevice(mtype), PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "No GPU data for b");
583 #endif
584     // get batch size
585     PetscCall(PetscObjectQuery((PetscObject)A, "batch size", (PetscObject *)&container));
586     if (container) {
587       PetscInt *pNf = NULL;
588       PetscCall(PetscContainerGetPointer(container, (void **)&pNf));
589       batch_sz = *pNf; // number of times to repeat the DMs
590     } else batch_sz = 1;
591     PetscCheck(nBlk % batch_sz == 0, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "batch_sz = %" PetscInt_FMT ", nBlk = %" PetscInt_FMT, batch_sz, nBlk);
592     if (ksp_type_idx == BATCH_KSP_GMRESKK_IDX) {
593       // KK solver - move PETSc data into Kokkos Views, setup solver, solve, move data out of Kokkos, process metadata (convergence tests, etc.)
594 #if defined(PETSC_HAVE_KOKKOS_KERNELS_BATCH)
595       PetscCall(PCApply_BJKOKKOSKERNELS(pc, glb_bdata, glb_xdata, glb_Aai, glb_Aaj, glb_Aaa, team_size, info, batch_sz, &pcreason));
596 #else
597       PetscCheck(ksp_type_idx != BATCH_KSP_GMRESKK_IDX, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Type: BATCH_KSP_GMRES not supported for complex");
598 #endif
599     } else { // Kokkos Krylov
600       using scr_mem_t    = Kokkos::DefaultExecutionSpace::scratch_memory_space;
601       using vect2D_scr_t = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, scr_mem_t>;
602       Kokkos::View<Batch_MetaData *, Kokkos::DefaultExecutionSpace> d_metadata("solver meta data", nBlk);
603       int                                                           stride_shared, stride_global, global_buff_words;
604       d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
605       // solve each block independently
606       int scr_bytes_team_shared = 0, nShareVec = 0, nGlobBVec = 0;
607       if (jac->const_block_size) { // use shared memory for work vectors only if constant block size - TODO: test efficiency loss
608         size_t      maximum_shared_mem_size = 64000;
609         PetscDevice device;
610         PetscCall(PetscDeviceGetDefault_Internal(&device));
611         PetscCall(PetscDeviceGetAttribute(device, PETSC_DEVICE_ATTR_SIZE_T_SHARED_MEM_PER_BLOCK, &maximum_shared_mem_size));
612         stride_shared = jac->const_block_size;                                                   // captured
613         nShareVec     = maximum_shared_mem_size / (jac->const_block_size * sizeof(PetscScalar)); // integer floor, number of vectors that fit in shared
614         if (nShareVec > nwork) nShareVec = nwork;
615         else nGlobBVec = nwork - nShareVec;
616         global_buff_words     = jac->n * nGlobBVec;
617         scr_bytes_team_shared = jac->const_block_size * nShareVec * sizeof(PetscScalar);
618       } else {
619         scr_bytes_team_shared = 0;
620         stride_shared         = 0;
621         global_buff_words     = jac->n * nwork;
622         nGlobBVec             = nwork; // not needed == fix
623       }
624       stride_global = jac->n; // captured
625 #if defined(PETSC_HAVE_CUDA)
626       nvtxRangePushA("batch-kokkos-solve");
627 #endif
628       Kokkos::View<PetscScalar *, Kokkos::DefaultExecutionSpace> d_work_vecs_k("workvectors", global_buff_words); // global work vectors
629 #if PCBJKOKKOS_VERBOSE_LEVEL > 1
630       PetscCall(PetscInfo(pc, "\tn = %d. %d shared bytes/team, %d global mem bytes, rtol=%e, num blocks %d, team_size=%d, %d vector threads, %d shared vectors, %d global vectors\n", (int)jac->n, scr_bytes_team_shared, global_buff_words, rtol, (int)nBlk, (int)team_size, PCBJKOKKOS_VEC_SIZE, nShareVec, nGlobBVec));
631 #endif
632       PetscScalar *d_work_vecs = d_work_vecs_k.data();
633       Kokkos::parallel_for(
634         "Solve", Kokkos::TeamPolicy<Kokkos::LaunchBounds<256, 4>>(nBlk, team_size, PCBJKOKKOS_VEC_SIZE).set_scratch_size(PCBJKOKKOS_SHARED_LEVEL, Kokkos::PerTeam(scr_bytes_team_shared)), KOKKOS_LAMBDA(const team_member team) {
635           const int    blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID + 1];
636           vect2D_scr_t work_vecs_shared(team.team_scratch(PCBJKOKKOS_SHARED_LEVEL), end - start, nShareVec);
637           PetscScalar *work_buff_shared = work_vecs_shared.data();
638           PetscScalar *work_buff_global = &d_work_vecs[start]; // start inc'ed in
639           bool         print            = monitor && (blkID == view_bid);
640           switch (ksp_type_idx) {
641           case BATCH_KSP_BICG_IDX:
642             static_cast<void>(BJSolve_BICG(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff_global, stride_global, nShareVec, work_buff_shared, stride_shared, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print));
643             break;
644           case BATCH_KSP_TFQMR_IDX:
645             static_cast<void>(BJSolve_TFQMR(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff_global, stride_global, nShareVec, work_buff_shared, stride_shared, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print));
646             break;
647           default:
648 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
649             printf("Unknown KSP type %d\n", ksp_type_idx);
650 #else
651             /* void */;
652 #endif
653           }
654         });
655       Kokkos::fence();
656 #if defined(PETSC_HAVE_CUDA)
657       nvtxRangePop();
658       nvtxRangePushA("Post-solve-metadata");
659 #endif
660       auto h_metadata = Kokkos::create_mirror(Kokkos::HostSpace::memory_space(), d_metadata);
661       Kokkos::deep_copy(h_metadata, d_metadata);
662       PetscInt max_nnit = -1;
663 #if PCBJKOKKOS_VERBOSE_LEVEL > 1
664       PetscInt mbid = 0;
665 #endif
666       int in[2], out[2];
667       if (jac->reason) { // -pc_bjkokkos_ksp_converged_reason
668 #if PCBJKOKKOS_VERBOSE_LEVEL >= 3
669   #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
670         PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Iterations\n"));
671   #endif
672         // assume species major
673   #if PCBJKOKKOS_VERBOSE_LEVEL == 3
674         if (batch_sz != 1) PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%s: max iterations per species:", ksp_type_idx == BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr"));
675         else PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "    Linear solve converged due to %s iterations ", ksp_type_idx == BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr"));
676   #endif
677         for (PetscInt dmIdx = 0, head = 0, s = 0; dmIdx < jac->num_dms; dmIdx += batch_sz) {
678           for (PetscInt f = 0, idx = head; f < jac->dm_Nf[dmIdx]; f++, idx++, s++) {
679             for (int bid = 0; bid < batch_sz; bid++) {
680   #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
681               jac->max_nits += h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its; // report total number of iterations with high verbose
682               if (h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its > max_nnit) {
683                 max_nnit = h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its;
684                 mbid     = bid;
685               }
686   #else
687               if (h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its > max_nnit) {
688                 jac->max_nits = max_nnit = h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its;
689                 mbid                     = bid;
690               }
691   #endif
692             }
693   #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
694             PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%2" PetscInt_FMT ":", s));
695             for (int bid = 0; bid < batch_sz; bid++) PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%3" PetscInt_FMT " ", h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its));
696             PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\n"));
697   #else // == 3
698             PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%3" PetscInt_FMT " ", max_nnit));
699   #endif
700           }
701           head += batch_sz * jac->dm_Nf[dmIdx];
702         }
703   #if PCBJKOKKOS_VERBOSE_LEVEL == 3
704         PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\n"));
705   #endif
706 #endif
707         if (max_nnit == -1) { // < 3
708           for (int blkID = 0; blkID < nBlk; blkID++) {
709             if (h_metadata[blkID].its > max_nnit) {
710               jac->max_nits = max_nnit = h_metadata[blkID].its;
711 #if PCBJKOKKOS_VERBOSE_LEVEL > 1
712               mbid = blkID;
713 #endif
714             }
715           }
716         }
717         in[0] = max_nnit;
718         in[1] = rank;
719         PetscCallMPI(MPIU_Allreduce(in, out, 1, MPI_2INT, MPI_MAXLOC, PetscObjectComm((PetscObject)A)));
720 #if PCBJKOKKOS_VERBOSE_LEVEL > 1
721         if (0 == rank) {
722           if (batch_sz != 1)
723             PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] Linear solve converged due to %s iterations %d (max), on block %" PetscInt_FMT ", species %" PetscInt_FMT " (max)\n", out[1], KSPConvergedReasons[h_metadata[mbid].reason], out[0], mbid % batch_sz, mbid / batch_sz));
724           else PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] Linear solve converged due to %s iterations %d (max), on block %" PetscInt_FMT "\n", out[1], KSPConvergedReasons[h_metadata[mbid].reason], out[0], mbid));
725         }
726 #endif
727       }
728       for (int blkID = 0; blkID < nBlk; blkID++) {
729         PetscCall(PetscLogGpuFlops((PetscLogDouble)h_metadata[blkID].flops));
730         PetscCheck(h_metadata[blkID].reason >= 0 || !jac->ksp->errorifnotconverged, PetscObjectComm((PetscObject)pc), PETSC_ERR_CONV_FAILED, "ERROR reason=%s, its=%" PetscInt_FMT ". species %" PetscInt_FMT ", batch %" PetscInt_FMT,
731                    KSPConvergedReasons[h_metadata[blkID].reason], h_metadata[blkID].its, blkID / batch_sz, blkID % batch_sz);
732       }
733       {
734         int errsum;
735         Kokkos::parallel_reduce(
736           nBlk,
737           KOKKOS_LAMBDA(const int idx, int &lsum) {
738             if (d_metadata[idx].reason < 0) ++lsum;
739           },
740           errsum);
741         pcreason = errsum ? PC_SUBPC_ERROR : PC_NOERROR;
742         if (!errsum && !jac->max_nits) { // set max its to give back to top KSP
743           for (int blkID = 0; blkID < nBlk; blkID++) {
744             if (h_metadata[blkID].its > jac->max_nits) jac->max_nits = h_metadata[blkID].its;
745           }
746         } else if (errsum) {
747           PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] ERROR Kokkos batch solver did not converge in all solves\n", (int)rank));
748         }
749       }
750 #if defined(PETSC_HAVE_CUDA)
751       nvtxRangePop();
752 #endif
753     } // end of Kokkos (not Kernels) solvers block
754     PetscCall(VecRestoreArrayAndMemType(xout, &glb_xdata));
755     PetscCall(VecRestoreArrayReadAndMemType(bvec, &glb_bdata));
756     PetscCall(PCSetFailedReason(pc, pcreason));
757     // map back to Plex space - not used
758     if (plex_batch) {
759       PetscCall(VecCopy(xout, bvec));
760       PetscCall(VecScatterBegin(plex_batch, bvec, xout, INSERT_VALUES, SCATTER_REVERSE));
761       PetscCall(VecScatterEnd(plex_batch, bvec, xout, INSERT_VALUES, SCATTER_REVERSE));
762       PetscCall(VecDestroy(&bvec));
763     }
764   }
765   PetscFunctionReturn(PETSC_SUCCESS);
766 }
767 
768 static PetscErrorCode PCSetUp_BJKOKKOS(PC pc)
769 {
770   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
771   Mat            A = pc->pmat, Aseq = A; // use filtered block matrix, really "P"
772   PetscBool      flg;
773 
774   PetscFunctionBegin;
775   PetscCheck(A, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "No matrix - A is used above");
776   PetscCall(PetscObjectTypeCompareAny((PetscObject)A, &flg, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, MATAIJKOKKOS, ""));
777   PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "must use '-[dm_]mat_type aijkokkos -[dm_]vec_type kokkos' for -pc_type bjkokkos");
778   if (!A->spptr) {
779     Aseq = ((Mat_MPIAIJ *)A->data)->A; // MPI
780   }
781   PetscCall(MatSeqAIJKokkosSyncDevice(Aseq));
782   {
783     PetscInt    Istart, Iend;
784     PetscMPIInt rank;
785     PetscCallMPI(MPI_Comm_rank(PetscObjectComm((PetscObject)A), &rank));
786     PetscCall(MatGetOwnershipRange(A, &Istart, &Iend));
787     if (!jac->vec_diag) {
788       Vec     *subX = NULL;
789       DM       pack, *subDM = NULL;
790       PetscInt nDMs, n, *block_sizes = NULL;
791       IS       isrow, isicol;
792       { // Permute the matrix to get a block diagonal system: d_isrow_k, d_isicol_k
793         MatOrderingType rtype;
794         const PetscInt *rowindices, *icolindices;
795         rtype = MATORDERINGRCM;
796         // get permutation. And invert. should we convert to local indices?
797         PetscCall(MatGetOrdering(Aseq, rtype, &isrow, &isicol)); // only seems to work for seq matrix
798         PetscCall(ISDestroy(&isrow));
799         PetscCall(ISInvertPermutation(isicol, PETSC_DECIDE, &isrow)); // THIS IS BACKWARD -- isrow is inverse
800         // if (rank==1) PetscCall(ISView(isicol, PETSC_VIEWER_STDOUT_SELF));
801         if (0) {
802           Mat mat_block_order; // debug
803           PetscCall(ISShift(isicol, Istart, isicol));
804           PetscCall(MatCreateSubMatrix(A, isicol, isicol, MAT_INITIAL_MATRIX, &mat_block_order));
805           PetscCall(ISShift(isicol, -Istart, isicol));
806           PetscCall(MatViewFromOptions(mat_block_order, NULL, "-ksp_batch_reorder_view"));
807           PetscCall(MatDestroy(&mat_block_order));
808         }
809         PetscCall(ISGetIndices(isrow, &rowindices)); // local idx
810         PetscCall(ISGetIndices(isicol, &icolindices));
811         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_isrow_k((PetscInt *)rowindices, A->rmap->n);
812         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_isicol_k((PetscInt *)icolindices, A->rmap->n);
813         jac->d_isrow_k  = new Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_isrow_k));
814         jac->d_isicol_k = new Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_isicol_k));
815         Kokkos::deep_copy(*jac->d_isrow_k, h_isrow_k);
816         Kokkos::deep_copy(*jac->d_isicol_k, h_isicol_k);
817         PetscCall(ISRestoreIndices(isrow, &rowindices));
818         PetscCall(ISRestoreIndices(isicol, &icolindices));
819         // if (rank==1) PetscCall(ISView(isicol, PETSC_VIEWER_STDOUT_SELF));
820       }
821       // get block sizes & allocate vec_diag
822       PetscCall(PCGetDM(pc, &pack));
823       if (pack) {
824         PetscCall(PetscObjectTypeCompare((PetscObject)pack, DMCOMPOSITE, &flg));
825         if (flg) {
826           PetscCall(DMCompositeGetNumberDM(pack, &nDMs));
827           PetscCall(DMCreateGlobalVector(pack, &jac->vec_diag));
828         } else pack = NULL; // flag for no DM
829       }
830       if (!jac->vec_diag) { // get 'nDMs' and sizes 'block_sizes' w/o DMComposite. TODO: User could provide ISs
831         PetscInt        bsrt, bend, ncols, ntot = 0;
832         const PetscInt *colsA, nloc = Iend - Istart;
833         const PetscInt *rowindices, *icolindices;
834         PetscCall(PetscMalloc1(nloc, &block_sizes)); // very inefficient, to big
835         PetscCall(ISGetIndices(isrow, &rowindices));
836         PetscCall(ISGetIndices(isicol, &icolindices));
837         nDMs = 0;
838         bsrt = 0;
839         bend = 1;
840         for (PetscInt row_B = 0; row_B < nloc; row_B++) { // for all rows in block diagonal space
841           PetscInt rowA = icolindices[row_B], minj = PETSC_INT_MAX, maxj = 0;
842           //PetscCall(PetscPrintf(PETSC_COMM_SELF, "\t[%d] rowA = %d\n",rank,rowA));
843           PetscCall(MatGetRow(Aseq, rowA, &ncols, &colsA, NULL)); // not sorted in permutation
844           PetscCheck(ncols, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Empty row not supported: %" PetscInt_FMT, row_B);
845           for (PetscInt colj = 0; colj < ncols; colj++) {
846             PetscInt colB = rowindices[colsA[colj]]; // use local idx
847             //PetscCall(PetscPrintf(PETSC_COMM_SELF, "\t\t[%d] colB = %d\n",rank,colB));
848             PetscCheck(colB >= 0 && colB < nloc, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "colB < 0: %" PetscInt_FMT, colB);
849             if (colB > maxj) maxj = colB;
850             if (colB < minj) minj = colB;
851           }
852           PetscCall(MatRestoreRow(Aseq, rowA, &ncols, &colsA, NULL));
853           if (minj >= bend) { // first column is > max of last block -- new block or last block
854             //PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\t\t finish block %d, N loc = %d (%d,%d)\n", nDMs+1, bend - bsrt,bsrt,bend));
855             block_sizes[nDMs] = bend - bsrt;
856             ntot += block_sizes[nDMs];
857             PetscCheck(minj == bend, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "minj != bend: %" PetscInt_FMT " != %" PetscInt_FMT, minj, bend);
858             bsrt = bend;
859             bend++; // start with size 1 in new block
860             nDMs++;
861           }
862           if (maxj + 1 > bend) bend = maxj + 1;
863           PetscCheck(minj >= bsrt || row_B == Iend - 1, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "%" PetscInt_FMT ") minj < bsrt: %" PetscInt_FMT " != %" PetscInt_FMT, rowA, minj, bsrt);
864           //PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] %d) row %d.%d) cols %d : %d ; bsrt = %d, bend = %d\n",rank,row_B,nDMs,rowA,minj,maxj,bsrt,bend));
865         }
866         // do last block
867         //PetscCall(PetscPrintf(PETSC_COMM_SELF, "\t\t\t [%d] finish block %d, N loc = %d (%d,%d)\n", rank, nDMs+1, bend - bsrt,bsrt,bend));
868         block_sizes[nDMs] = bend - bsrt;
869         ntot += block_sizes[nDMs];
870         nDMs++;
871         // cleanup
872         PetscCheck(ntot == nloc, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "n total != n local: %" PetscInt_FMT " != %" PetscInt_FMT, ntot, nloc);
873         PetscCall(ISRestoreIndices(isrow, &rowindices));
874         PetscCall(ISRestoreIndices(isicol, &icolindices));
875         PetscCall(PetscRealloc(sizeof(PetscInt) * nDMs, &block_sizes));
876         PetscCall(MatCreateVecs(A, &jac->vec_diag, NULL));
877         PetscCall(PetscInfo(pc, "Setup Matrix based meta data (not DMComposite not attached to PC) %" PetscInt_FMT " sub domains\n", nDMs));
878       }
879       PetscCall(ISDestroy(&isrow));
880       PetscCall(ISDestroy(&isicol));
881       jac->num_dms = nDMs;
882       PetscCall(VecGetLocalSize(jac->vec_diag, &n));
883       jac->n         = n;
884       jac->d_idiag_k = new Kokkos::View<PetscScalar *, Kokkos::LayoutRight>("idiag", n);
885       // options
886       PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
887       PetscCall(KSPSetFromOptions(jac->ksp));
888       PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPBICG, ""));
889       if (flg) {
890         jac->ksp_type_idx = BATCH_KSP_BICG_IDX;
891         jac->nwork        = 7;
892       } else {
893         PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPTFQMR, ""));
894         if (flg) {
895           jac->ksp_type_idx = BATCH_KSP_TFQMR_IDX;
896           jac->nwork        = 10;
897         } else {
898 #if defined(PETSC_HAVE_KOKKOS_KERNELS_BATCH)
899           PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPGMRES, ""));
900           PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Unsupported batch ksp type");
901           jac->ksp_type_idx = BATCH_KSP_GMRESKK_IDX;
902           jac->nwork        = 0;
903 #else
904           KSPType ksptype;
905           PetscCall(KSPGetType(jac->ksp, &ksptype));
906           PetscCheck(flg, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Type: %s not supported in complex", ksptype);
907 #endif
908         }
909       }
910       PetscOptionsBegin(PetscObjectComm((PetscObject)jac->ksp), ((PetscObject)jac->ksp)->prefix, "Options for Kokkos batch solver", "none");
911       PetscCall(PetscOptionsBool("-ksp_converged_reason", "", "bjkokkos.kokkos.cxx.c", jac->reason, &jac->reason, NULL));
912       PetscCall(PetscOptionsBool("-ksp_monitor", "", "bjkokkos.kokkos.cxx.c", jac->monitor, &jac->monitor, NULL));
913       PetscCall(PetscOptionsInt("-ksp_batch_target", "", "bjkokkos.kokkos.cxx.c", jac->batch_target, &jac->batch_target, NULL));
914       PetscCall(PetscOptionsInt("-ksp_rank_target", "", "bjkokkos.kokkos.cxx.c", jac->rank_target, &jac->rank_target, NULL));
915       PetscCall(PetscOptionsInt("-ksp_batch_nsolves_team", "", "bjkokkos.kokkos.cxx.c", jac->nsolves_team, &jac->nsolves_team, NULL));
916       PetscCheck(jac->batch_target < jac->num_dms, PETSC_COMM_WORLD, PETSC_ERR_ARG_WRONG, "-ksp_batch_target (%" PetscInt_FMT ") >= number of DMs (%" PetscInt_FMT ")", jac->batch_target, jac->num_dms);
917       PetscOptionsEnd();
918       // get blocks - jac->d_bid_eqOffset_k
919       if (pack) {
920         PetscCall(PetscMalloc(sizeof(*subX) * nDMs, &subX));
921         PetscCall(PetscMalloc(sizeof(*subDM) * nDMs, &subDM));
922       }
923       PetscCall(PetscMalloc(sizeof(*jac->dm_Nf) * nDMs, &jac->dm_Nf));
924       PetscCall(PetscInfo(pc, "Have %" PetscInt_FMT " blocks, n=%" PetscInt_FMT " rtol=%g type = %s\n", nDMs, n, (double)jac->ksp->rtol, ((PetscObject)jac->ksp)->type_name));
925       if (pack) PetscCall(DMCompositeGetEntriesArray(pack, subDM));
926       jac->nBlocks = 0;
927       for (PetscInt ii = 0; ii < nDMs; ii++) {
928         PetscInt Nf;
929         if (subDM) {
930           DM           dm = subDM[ii];
931           PetscSection section;
932           PetscCall(DMGetLocalSection(dm, &section));
933           PetscCall(PetscSectionGetNumFields(section, &Nf));
934         } else Nf = 1;
935         jac->nBlocks += Nf;
936 #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
937         if (ii == 0) PetscCall(PetscInfo(pc, "%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n", ii, Nf, jac->nBlocks));
938 #else
939         PetscCall(PetscInfo(pc, "%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n", ii, Nf, jac->nBlocks));
940 #endif
941         jac->dm_Nf[ii] = Nf;
942       }
943       { // d_bid_eqOffset_k
944         Kokkos::View<PetscInt *, Kokkos::LayoutRight, Kokkos::HostSpace> h_block_offsets("block_offsets", jac->nBlocks + 1);
945         if (pack) PetscCall(DMCompositeGetAccessArray(pack, jac->vec_diag, nDMs, NULL, subX));
946         h_block_offsets[0]    = 0;
947         jac->const_block_size = -1;
948         for (PetscInt ii = 0, idx = 0; ii < nDMs; ii++) {
949           PetscInt nloc, nblk;
950           if (pack) PetscCall(VecGetSize(subX[ii], &nloc));
951           else nloc = block_sizes[ii];
952           nblk = nloc / jac->dm_Nf[ii];
953           PetscCheck(nloc % jac->dm_Nf[ii] == 0, PetscObjectComm((PetscObject)pc), PETSC_ERR_USER, "nloc%%jac->dm_Nf[ii] (%" PetscInt_FMT ") != 0 DMs", nloc % jac->dm_Nf[ii]);
954           for (PetscInt jj = 0; jj < jac->dm_Nf[ii]; jj++, idx++) {
955             h_block_offsets[idx + 1] = h_block_offsets[idx] + nblk;
956 #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
957             if (idx == 0) PetscCall(PetscInfo(pc, "Add first of %" PetscInt_FMT " blocks with %" PetscInt_FMT " equations\n", jac->nBlocks, nblk));
958 #else
959             PetscCall(PetscInfo(pc, "\t%" PetscInt_FMT ") Add block with %" PetscInt_FMT " equations of %" PetscInt_FMT "\n", idx + 1, nblk, jac->nBlocks));
960 #endif
961             if (jac->const_block_size == -1) jac->const_block_size = nblk;
962             else if (jac->const_block_size > 0 && jac->const_block_size != nblk) jac->const_block_size = 0;
963           }
964         }
965         if (pack) {
966           PetscCall(DMCompositeRestoreAccessArray(pack, jac->vec_diag, jac->nBlocks, NULL, subX));
967           PetscCall(PetscFree(subX));
968           PetscCall(PetscFree(subDM));
969         }
970         jac->d_bid_eqOffset_k = new Kokkos::View<PetscInt *, Kokkos::LayoutRight>(Kokkos::create_mirror(Kokkos::DefaultExecutionSpace::memory_space(), h_block_offsets));
971         Kokkos::deep_copy(*jac->d_bid_eqOffset_k, h_block_offsets);
972       }
973       if (!pack) PetscCall(PetscFree(block_sizes));
974     }
975     { // get jac->d_idiag_k (PC setup),
976       const PetscInt    *d_ai, *d_aj;
977       const PetscScalar *d_aa;
978       const PetscInt     conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp == 0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
979       const PetscInt    *d_bid_eqOffset = jac->d_bid_eqOffset_k->data(), *r = jac->d_isrow_k->data(), *ic = jac->d_isicol_k->data();
980       PetscScalar       *d_idiag = jac->d_idiag_k->data(), *dummy;
981       PetscMemType       mtype;
982       PetscCall(MatSeqAIJGetCSRAndMemType(Aseq, &d_ai, &d_aj, &dummy, &mtype));
983       d_aa = dummy;
984       Kokkos::parallel_for(
985         "Diag", Kokkos::TeamPolicy<>(jac->nBlocks, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA(const team_member team) {
986           const PetscInt blkID = team.league_rank();
987           Kokkos::parallel_for(Kokkos::TeamThreadRange(team, d_bid_eqOffset[blkID], d_bid_eqOffset[blkID + 1]), [=](const int rowb) {
988             const PetscInt     rowa = ic[rowb], ai = d_ai[rowa], *aj = d_aj + ai; // grab original data
989             const PetscScalar *aa   = d_aa + ai;
990             const PetscInt     nrow = d_ai[rowa + 1] - ai;
991             int                found;
992             Kokkos::parallel_reduce(
993               Kokkos::ThreadVectorRange(team, nrow),
994               [=](const int &j, int &count) {
995                 const PetscInt colb = r[aj[j]];
996                 if (colb == rowb) {
997                   d_idiag[rowb] = 1. / aa[j];
998                   count++;
999                 }
1000               },
1001               found);
1002 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
1003             if (found != 1) Kokkos::single(Kokkos::PerThread(team), [=]() { printf("ERRORrow %d) found = %d\n", rowb, found); });
1004 #endif
1005           });
1006         });
1007     }
1008   }
1009   PetscFunctionReturn(PETSC_SUCCESS);
1010 }
1011 
1012 /* Default destroy, if it has never been setup */
1013 static PetscErrorCode PCReset_BJKOKKOS(PC pc)
1014 {
1015   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1016 
1017   PetscFunctionBegin;
1018   PetscCall(KSPDestroy(&jac->ksp));
1019   PetscCall(VecDestroy(&jac->vec_diag));
1020   if (jac->d_bid_eqOffset_k) delete jac->d_bid_eqOffset_k;
1021   if (jac->d_idiag_k) delete jac->d_idiag_k;
1022   if (jac->d_isrow_k) delete jac->d_isrow_k;
1023   if (jac->d_isicol_k) delete jac->d_isicol_k;
1024   jac->d_bid_eqOffset_k = NULL;
1025   jac->d_idiag_k        = NULL;
1026   jac->d_isrow_k        = NULL;
1027   jac->d_isicol_k       = NULL;
1028   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSGetKSP_C", NULL)); // not published now (causes configure errors)
1029   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSSetKSP_C", NULL));
1030   PetscCall(PetscFree(jac->dm_Nf));
1031   jac->dm_Nf = NULL;
1032   if (jac->rowOffsets) delete jac->rowOffsets;
1033   if (jac->colIndices) delete jac->colIndices;
1034   if (jac->batch_b) delete jac->batch_b;
1035   if (jac->batch_x) delete jac->batch_x;
1036   if (jac->batch_values) delete jac->batch_values;
1037   jac->rowOffsets   = NULL;
1038   jac->colIndices   = NULL;
1039   jac->batch_b      = NULL;
1040   jac->batch_x      = NULL;
1041   jac->batch_values = NULL;
1042   PetscFunctionReturn(PETSC_SUCCESS);
1043 }
1044 
1045 static PetscErrorCode PCDestroy_BJKOKKOS(PC pc)
1046 {
1047   PetscFunctionBegin;
1048   PetscCall(PCReset_BJKOKKOS(pc));
1049   PetscCall(PetscFree(pc->data));
1050   PetscFunctionReturn(PETSC_SUCCESS);
1051 }
1052 
1053 static PetscErrorCode PCView_BJKOKKOS(PC pc, PetscViewer viewer)
1054 {
1055   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1056   PetscBool      iascii;
1057 
1058   PetscFunctionBegin;
1059   if (!jac->ksp) PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
1060   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
1061   if (iascii) {
1062     PetscCall(PetscViewerASCIIPrintf(viewer, "  Batched device linear solver: Krylov (KSP) method with Jacobi preconditioning\n"));
1063     PetscCall(PetscViewerASCIIPrintf(viewer, "\t\tnwork = %" PetscInt_FMT ", rel tol = %e, abs tol = %e, div tol = %e, max it =%" PetscInt_FMT ", type = %s\n", jac->nwork, jac->ksp->rtol, jac->ksp->abstol, jac->ksp->divtol, jac->ksp->max_it,
1064                                      ((PetscObject)jac->ksp)->type_name));
1065   }
1066   PetscFunctionReturn(PETSC_SUCCESS);
1067 }
1068 
1069 static PetscErrorCode PCSetFromOptions_BJKOKKOS(PC pc, PetscOptionItems *PetscOptionsObject)
1070 {
1071   PetscFunctionBegin;
1072   PetscOptionsHeadBegin(PetscOptionsObject, "PC BJKOKKOS options");
1073   PetscOptionsHeadEnd();
1074   PetscFunctionReturn(PETSC_SUCCESS);
1075 }
1076 
1077 static PetscErrorCode PCBJKOKKOSSetKSP_BJKOKKOS(PC pc, KSP ksp)
1078 {
1079   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1080 
1081   PetscFunctionBegin;
1082   PetscCall(PetscObjectReference((PetscObject)ksp));
1083   PetscCall(KSPDestroy(&jac->ksp));
1084   jac->ksp = ksp;
1085   PetscFunctionReturn(PETSC_SUCCESS);
1086 }
1087 
1088 /*@
1089   PCBJKOKKOSSetKSP - Sets the `KSP` context for `PCBJKOKKOS`
1090 
1091   Collective
1092 
1093   Input Parameters:
1094 + pc  - the `PCBJKOKKOS` preconditioner context
1095 - ksp - the `KSP` solver
1096 
1097   Level: advanced
1098 
1099   Notes:
1100   The `PC` and the `KSP` must have the same communicator
1101 
1102   If the `PC` is not `PCBJKOKKOS` this function returns without doing anything
1103 
1104 .seealso: [](ch_ksp), `PCBJKOKKOSGetKSP()`, `PCBJKOKKOS`
1105 @*/
1106 PetscErrorCode PCBJKOKKOSSetKSP(PC pc, KSP ksp)
1107 {
1108   PetscFunctionBegin;
1109   PetscValidHeaderSpecific(pc, PC_CLASSID, 1);
1110   PetscValidHeaderSpecific(ksp, KSP_CLASSID, 2);
1111   PetscCheckSameComm(pc, 1, ksp, 2);
1112   PetscTryMethod(pc, "PCBJKOKKOSSetKSP_C", (PC, KSP), (pc, ksp));
1113   PetscFunctionReturn(PETSC_SUCCESS);
1114 }
1115 
1116 static PetscErrorCode PCBJKOKKOSGetKSP_BJKOKKOS(PC pc, KSP *ksp)
1117 {
1118   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1119 
1120   PetscFunctionBegin;
1121   if (!jac->ksp) PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
1122   *ksp = jac->ksp;
1123   PetscFunctionReturn(PETSC_SUCCESS);
1124 }
1125 
1126 /*@
1127   PCBJKOKKOSGetKSP - Gets the `KSP` context for the `PCBJKOKKOS` preconditioner
1128 
1129   Not Collective but `KSP` returned is parallel if `PC` was parallel
1130 
1131   Input Parameter:
1132 . pc - the preconditioner context
1133 
1134   Output Parameter:
1135 . ksp - the `KSP` solver
1136 
1137   Level: advanced
1138 
1139   Notes:
1140   You must call `KSPSetUp()` before calling `PCBJKOKKOSGetKSP()`.
1141 
1142   If the `PC` is not a `PCBJKOKKOS` object it raises an error
1143 
1144 .seealso: [](ch_ksp), `PCBJKOKKOS`, `PCBJKOKKOSSetKSP()`
1145 @*/
1146 PetscErrorCode PCBJKOKKOSGetKSP(PC pc, KSP *ksp)
1147 {
1148   PetscFunctionBegin;
1149   PetscValidHeaderSpecific(pc, PC_CLASSID, 1);
1150   PetscAssertPointer(ksp, 2);
1151   PetscUseMethod(pc, "PCBJKOKKOSGetKSP_C", (PC, KSP *), (pc, ksp));
1152   PetscFunctionReturn(PETSC_SUCCESS);
1153 }
1154 
1155 static PetscErrorCode PCPostSolve_BJKOKKOS(PC pc, KSP ksp, Vec b, Vec x)
1156 {
1157   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1158 
1159   PetscFunctionBegin;
1160   PetscValidHeaderSpecific(pc, PC_CLASSID, 1);
1161   ksp->its = jac->max_nits;
1162   PetscFunctionReturn(PETSC_SUCCESS);
1163 }
1164 
1165 static PetscErrorCode PCPreSolve_BJKOKKOS(PC pc, KSP ksp, Vec b, Vec x)
1166 {
1167   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1168 
1169   PetscFunctionBegin;
1170   PetscValidHeaderSpecific(pc, PC_CLASSID, 1);
1171   jac->ksp->errorifnotconverged = ksp->errorifnotconverged;
1172   PetscFunctionReturn(PETSC_SUCCESS);
1173 }
1174 
1175 /*MC
1176      PCBJKOKKOS - A batched Krylov/block Jacobi solver that runs a solve of each diagaonl block of a block diagonal `MATSEQAIJ` in a Kokkos thread group
1177 
1178    Options Database Key:
1179 .     -pc_bjkokkos_ - options prefix for its `KSP` options
1180 
1181    Level: intermediate
1182 
1183    Note:
1184     For use with -ksp_type preonly to bypass any computation on the CPU
1185 
1186    Developer Notes:
1187    The entire Krylov (TFQMR or BICG) with diagonal preconditioning for each block of a block diagnaol matrix runs in a Kokkos thread group (eg, one block per SM on NVIDIA). It supports taking a non-block diagonal matrix but this is not tested. One should create an explicit block diagonal matrix and use that as the preconditioning matrix in the outer KSP solver. Variable block size are supported and tested in src/ts/utils/dmplexlandau/tutorials/ex[1|2].c
1188 
1189 .seealso: [](ch_ksp), `PCCreate()`, `PCSetType()`, `PCType`, `PC`, `PCBJACOBI`,
1190           `PCSHELL`, `PCCOMPOSITE`, `PCSetUseAmat()`, `PCBJKOKKOSGetKSP()`
1191 M*/
1192 
1193 PETSC_EXTERN PetscErrorCode PCCreate_BJKOKKOS(PC pc)
1194 {
1195   PC_PCBJKOKKOS *jac;
1196 
1197   PetscFunctionBegin;
1198   PetscCall(PetscNew(&jac));
1199   pc->data = (void *)jac;
1200 
1201   jac->ksp              = NULL;
1202   jac->vec_diag         = NULL;
1203   jac->d_bid_eqOffset_k = NULL;
1204   jac->d_idiag_k        = NULL;
1205   jac->d_isrow_k        = NULL;
1206   jac->d_isicol_k       = NULL;
1207   jac->nBlocks          = 1;
1208   jac->max_nits         = 0;
1209 
1210   PetscCall(PetscMemzero(pc->ops, sizeof(struct _PCOps)));
1211   pc->ops->apply          = PCApply_BJKOKKOS;
1212   pc->ops->applytranspose = NULL;
1213   pc->ops->setup          = PCSetUp_BJKOKKOS;
1214   pc->ops->reset          = PCReset_BJKOKKOS;
1215   pc->ops->destroy        = PCDestroy_BJKOKKOS;
1216   pc->ops->setfromoptions = PCSetFromOptions_BJKOKKOS;
1217   pc->ops->view           = PCView_BJKOKKOS;
1218   pc->ops->postsolve      = PCPostSolve_BJKOKKOS;
1219   pc->ops->presolve       = PCPreSolve_BJKOKKOS;
1220 
1221   jac->rowOffsets   = NULL;
1222   jac->colIndices   = NULL;
1223   jac->batch_b      = NULL;
1224   jac->batch_x      = NULL;
1225   jac->batch_values = NULL;
1226 
1227   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSGetKSP_C", PCBJKOKKOSGetKSP_BJKOKKOS));
1228   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSSetKSP_C", PCBJKOKKOSSetKSP_BJKOKKOS));
1229   PetscFunctionReturn(PETSC_SUCCESS);
1230 }
1231