xref: /petsc/src/ksp/pc/impls/bjacobi/bjkokkos/bjkokkos.kokkos.cxx (revision 8e1aa562fd633201b416d8ca1397a10fc0fd2d39)
1731341e5SJunchao Zhang #define PETSC_SKIP_CXX_COMPLEX_FIX // Kokkos::complex does not need the petsc complex fix
2731341e5SJunchao Zhang 
3a4313204SMark Adams #include <petsc/private/pcbjkokkosimpl.h>
4a4313204SMark Adams 
5e607c864SMark Adams #include <petsc/private/kspimpl.h>
6e607c864SMark Adams #include <petscksp.h> /*I "petscksp.h" I*/
7a999f97fSMark Adams #include <../src/mat/impls/aij/mpi/mpiaij.h>
82c4ab24aSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
946233b44SBarry Smith #include <petscsection.h>
10e607c864SMark Adams #include <petscdmcomposite.h>
11e607c864SMark Adams 
12e607c864SMark Adams #include <../src/mat/impls/aij/seq/aij.h>
13e607c864SMark Adams #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
14e607c864SMark Adams 
150e6b6b59SJacob Faibussowitsch #include <petscdevice_cupm.h>
160e6b6b59SJacob Faibussowitsch 
17d71ae5a4SJacob Faibussowitsch static PetscErrorCode PCBJKOKKOSCreateKSP_BJKOKKOS(PC pc)
18d71ae5a4SJacob Faibussowitsch {
19e607c864SMark Adams   const char    *prefix;
20e607c864SMark Adams   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
21e607c864SMark Adams   DM             dm;
22e607c864SMark Adams 
23e607c864SMark Adams   PetscFunctionBegin;
249566063dSJacob Faibussowitsch   PetscCall(KSPCreate(PetscObjectComm((PetscObject)pc), &jac->ksp));
253821be0aSBarry Smith   PetscCall(KSPSetNestLevel(jac->ksp, pc->kspnestlevel));
269566063dSJacob Faibussowitsch   PetscCall(KSPSetErrorIfNotConverged(jac->ksp, pc->erroriffailure));
279566063dSJacob Faibussowitsch   PetscCall(PetscObjectIncrementTabLevel((PetscObject)jac->ksp, (PetscObject)pc, 1));
289566063dSJacob Faibussowitsch   PetscCall(PCGetOptionsPrefix(pc, &prefix));
299566063dSJacob Faibussowitsch   PetscCall(KSPSetOptionsPrefix(jac->ksp, prefix));
309566063dSJacob Faibussowitsch   PetscCall(KSPAppendOptionsPrefix(jac->ksp, "pc_bjkokkos_"));
319566063dSJacob Faibussowitsch   PetscCall(PCGetDM(pc, &dm));
32e607c864SMark Adams   if (dm) {
339566063dSJacob Faibussowitsch     PetscCall(KSPSetDM(jac->ksp, dm));
349566063dSJacob Faibussowitsch     PetscCall(KSPSetDMActive(jac->ksp, PETSC_FALSE));
35e607c864SMark Adams   }
36e607c864SMark Adams   jac->reason       = PETSC_FALSE;
37e607c864SMark Adams   jac->monitor      = PETSC_FALSE;
38a4313204SMark Adams   jac->batch_target = 0;
39a4313204SMark Adams   jac->rank_target  = 0;
40a1e3af9bSmarkadams4   jac->nsolves_team = 1;
41aaa8cc7dSPierre Jolivet   jac->ksp->max_it  = 50; // this is really for GMRES w/o restarts
423ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
43e607c864SMark Adams }
44e607c864SMark Adams 
45e607c864SMark Adams // y <-- Ax
46d71ae5a4SJacob Faibussowitsch KOKKOS_INLINE_FUNCTION PetscErrorCode MatMult(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
47d71ae5a4SJacob Faibussowitsch {
48e607c864SMark Adams   Kokkos::parallel_for(Kokkos::TeamThreadRange(team, start, end), [=](const int rowb) {
49e607c864SMark Adams     int                rowa = ic[rowb];
50e607c864SMark Adams     int                n    = glb_Aai[rowa + 1] - glb_Aai[rowa];
51a999f97fSMark Adams     const PetscInt    *aj   = glb_Aaj + glb_Aai[rowa]; // global
52e607c864SMark Adams     const PetscScalar *aa   = glb_Aaa + glb_Aai[rowa];
53e607c864SMark Adams     PetscScalar        sum;
54fbccb6d4SPierre Jolivet     Kokkos::parallel_reduce(Kokkos::ThreadVectorRange(team, n), [=](const int i, PetscScalar &lsum) { lsum += aa[i] * x_loc[r[aj[i]] - start]; }, sum);
55e607c864SMark Adams     Kokkos::single(Kokkos::PerThread(team), [=]() { y_loc[rowb - start] = sum; });
56e607c864SMark Adams   });
57e607c864SMark Adams   team.team_barrier();
583ba16761SJacob Faibussowitsch   return PETSC_SUCCESS;
59e607c864SMark Adams }
60e607c864SMark Adams 
61e607c864SMark Adams // temp buffer per thread with reduction at end?
62d71ae5a4SJacob Faibussowitsch KOKKOS_INLINE_FUNCTION PetscErrorCode MatMultTranspose(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, const PetscInt start, const PetscInt end, const PetscScalar *x_loc, PetscScalar *y_loc)
63d71ae5a4SJacob Faibussowitsch {
64e607c864SMark Adams   Kokkos::parallel_for(Kokkos::TeamVectorRange(team, end - start), [=](int i) { y_loc[i] = 0; });
65e607c864SMark Adams   team.team_barrier();
66e607c864SMark Adams   Kokkos::parallel_for(Kokkos::TeamThreadRange(team, start, end), [=](const int rowb) {
67e607c864SMark Adams     int                rowa = ic[rowb];
68e607c864SMark Adams     int                n    = glb_Aai[rowa + 1] - glb_Aai[rowa];
69a999f97fSMark Adams     const PetscInt    *aj   = glb_Aaj + glb_Aai[rowa]; // global
70e607c864SMark Adams     const PetscScalar *aa   = glb_Aaa + glb_Aai[rowa];
71e607c864SMark Adams     const PetscScalar  xx   = x_loc[rowb - start]; // rowb = ic[rowa] = ic[r[rowb]]
72e607c864SMark Adams     Kokkos::parallel_for(Kokkos::ThreadVectorRange(team, n), [=](const int &i) {
73e607c864SMark Adams       PetscScalar val = aa[i] * xx;
74e607c864SMark Adams       Kokkos::atomic_fetch_add(&y_loc[r[aj[i]] - start], val);
75e607c864SMark Adams     });
76e607c864SMark Adams   });
77e607c864SMark Adams   team.team_barrier();
783ba16761SJacob Faibussowitsch   return PETSC_SUCCESS;
79e607c864SMark Adams }
80e607c864SMark Adams 
819371c9d4SSatish Balay typedef struct Batch_MetaData_TAG {
82e607c864SMark Adams   PetscInt           flops;
83e607c864SMark Adams   PetscInt           its;
84e607c864SMark Adams   KSPConvergedReason reason;
85e607c864SMark Adams } Batch_MetaData;
86e607c864SMark Adams 
87e607c864SMark Adams // Solve A(BB^-1)x = y with TFQMR. Right preconditioned to get un-preconditioned residual
8866976f2fSJacob Faibussowitsch static KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_TFQMR(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space_global, const int stride_global, const int nShareVec, PetscScalar *work_space_shared, const int stride_shared, PetscReal rtol, PetscReal atol, PetscReal dtol, PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
89d71ae5a4SJacob Faibussowitsch {
90e607c864SMark Adams   using Kokkos::parallel_for;
919371c9d4SSatish Balay   using Kokkos::parallel_reduce;
92a4313204SMark Adams   int                Nblk = end - start, it, m, stride = stride_shared, idx = 0;
93e607c864SMark Adams   PetscReal          dp, dpold, w, dpest, tau, psi, cm, r0;
94e607c864SMark Adams   const PetscScalar *Diag = &glb_idiag[start];
95a1e3af9bSmarkadams4   PetscScalar       *ptr  = work_space_shared, rho, rhoold, a, s, b, eta, etaold, psiold, cf, dpi;
96a1e3af9bSmarkadams4 
979371c9d4SSatish Balay   if (idx++ == nShareVec) {
989371c9d4SSatish Balay     ptr    = work_space_global;
999371c9d4SSatish Balay     stride = stride_global;
1009371c9d4SSatish Balay   }
1019371c9d4SSatish Balay   PetscScalar *XX = ptr;
1029371c9d4SSatish Balay   ptr += stride;
1039371c9d4SSatish Balay   if (idx++ == nShareVec) {
1049371c9d4SSatish Balay     ptr    = work_space_global;
1059371c9d4SSatish Balay     stride = stride_global;
1069371c9d4SSatish Balay   }
1079371c9d4SSatish Balay   PetscScalar *R = ptr;
1089371c9d4SSatish Balay   ptr += stride;
1099371c9d4SSatish Balay   if (idx++ == nShareVec) {
1109371c9d4SSatish Balay     ptr    = work_space_global;
1119371c9d4SSatish Balay     stride = stride_global;
1129371c9d4SSatish Balay   }
1139371c9d4SSatish Balay   PetscScalar *RP = ptr;
1149371c9d4SSatish Balay   ptr += stride;
1159371c9d4SSatish Balay   if (idx++ == nShareVec) {
1169371c9d4SSatish Balay     ptr    = work_space_global;
1179371c9d4SSatish Balay     stride = stride_global;
1189371c9d4SSatish Balay   }
1199371c9d4SSatish Balay   PetscScalar *V = ptr;
1209371c9d4SSatish Balay   ptr += stride;
1219371c9d4SSatish Balay   if (idx++ == nShareVec) {
1229371c9d4SSatish Balay     ptr    = work_space_global;
1239371c9d4SSatish Balay     stride = stride_global;
1249371c9d4SSatish Balay   }
1259371c9d4SSatish Balay   PetscScalar *T = ptr;
1269371c9d4SSatish Balay   ptr += stride;
1279371c9d4SSatish Balay   if (idx++ == nShareVec) {
1289371c9d4SSatish Balay     ptr    = work_space_global;
1299371c9d4SSatish Balay     stride = stride_global;
1309371c9d4SSatish Balay   }
1319371c9d4SSatish Balay   PetscScalar *Q = ptr;
1329371c9d4SSatish Balay   ptr += stride;
1339371c9d4SSatish Balay   if (idx++ == nShareVec) {
1349371c9d4SSatish Balay     ptr    = work_space_global;
1359371c9d4SSatish Balay     stride = stride_global;
1369371c9d4SSatish Balay   }
1379371c9d4SSatish Balay   PetscScalar *P = ptr;
1389371c9d4SSatish Balay   ptr += stride;
1399371c9d4SSatish Balay   if (idx++ == nShareVec) {
1409371c9d4SSatish Balay     ptr    = work_space_global;
1419371c9d4SSatish Balay     stride = stride_global;
1429371c9d4SSatish Balay   }
1439371c9d4SSatish Balay   PetscScalar *U = ptr;
1449371c9d4SSatish Balay   ptr += stride;
1459371c9d4SSatish Balay   if (idx++ == nShareVec) {
1469371c9d4SSatish Balay     ptr    = work_space_global;
1479371c9d4SSatish Balay     stride = stride_global;
1489371c9d4SSatish Balay   }
1499371c9d4SSatish Balay   PetscScalar *D = ptr;
1509371c9d4SSatish Balay   ptr += stride;
1519371c9d4SSatish Balay   if (idx++ == nShareVec) {
1529371c9d4SSatish Balay     ptr    = work_space_global;
1539371c9d4SSatish Balay     stride = stride_global;
1549371c9d4SSatish Balay   }
155e607c864SMark Adams   PetscScalar *AUQ = V;
156e607c864SMark Adams 
157e607c864SMark Adams   // init: get b, zero x
158e607c864SMark Adams   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
159e607c864SMark Adams     int rowa         = ic[rowb];
160e607c864SMark Adams     R[rowb - start]  = glb_b[rowa];
161e607c864SMark Adams     XX[rowb - start] = 0;
162e607c864SMark Adams   });
163e607c864SMark Adams   team.team_barrier();
164fbccb6d4SPierre Jolivet   parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += R[idx] * PetscConj(R[idx]); }, dpi);
165e607c864SMark Adams   team.team_barrier();
166e607c864SMark Adams   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
167e607c864SMark Adams   // diagnostics
168a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
169e607c864SMark Adams   if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", 0, (double)dp); });
170a25f047fSMark Adams #endif
1719371c9d4SSatish Balay   if (dp < atol) {
1729371c9d4SSatish Balay     metad->reason = KSP_CONVERGED_ATOL_NORMAL;
173a4313204SMark Adams     it            = 0;
174a4313204SMark Adams     goto done;
1759371c9d4SSatish Balay   }
1769371c9d4SSatish Balay   if (0 == maxit) {
17708d80769Smarkadams4     metad->reason = KSP_CONVERGED_ITS;
178a4313204SMark Adams     it            = 0;
179a4313204SMark Adams     goto done;
1809371c9d4SSatish Balay   }
181e607c864SMark Adams 
182e607c864SMark Adams   /* Make the initial Rp = R */
183e607c864SMark Adams   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { RP[idx] = R[idx]; });
184e607c864SMark Adams   team.team_barrier();
185e607c864SMark Adams   /* Set the initial conditions */
186e607c864SMark Adams   etaold = 0.0;
187e607c864SMark Adams   psiold = 0.0;
188e607c864SMark Adams   tau    = dp;
189e607c864SMark Adams   dpold  = dp;
190e607c864SMark Adams 
191e607c864SMark Adams   /* rhoold = (r,rp)     */
192fbccb6d4SPierre Jolivet   parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += R[idx] * PetscConj(RP[idx]); }, rhoold);
193e607c864SMark Adams   team.team_barrier();
1949371c9d4SSatish Balay   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
1959371c9d4SSatish Balay     U[idx] = R[idx];
1969371c9d4SSatish Balay     P[idx] = R[idx];
1979371c9d4SSatish Balay     T[idx] = Diag[idx] * P[idx];
1989371c9d4SSatish Balay     D[idx] = 0;
1999371c9d4SSatish Balay   });
200e607c864SMark Adams   team.team_barrier();
2013ba16761SJacob Faibussowitsch   static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, V));
202e607c864SMark Adams 
203a4313204SMark Adams   it = 0;
204e607c864SMark Adams   do {
205e607c864SMark Adams     /* s <- (v,rp)          */
206fbccb6d4SPierre Jolivet     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += V[idx] * PetscConj(RP[idx]); }, s);
207e607c864SMark Adams     team.team_barrier();
20808d80769Smarkadams4     if (s == 0) {
20908d80769Smarkadams4       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
21008d80769Smarkadams4       goto done;
21108d80769Smarkadams4     }
212e607c864SMark Adams     a = rhoold / s; /* a <- rho / s         */
213e607c864SMark Adams     /* q <- u - a v    VecWAXPY(w,alpha,x,y): w = alpha x + y.     */
214e607c864SMark Adams     /* t <- u + q           */
2159371c9d4SSatish Balay     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
2169371c9d4SSatish Balay       Q[idx] = U[idx] - a * V[idx];
2179371c9d4SSatish Balay       T[idx] = U[idx] + Q[idx];
2189371c9d4SSatish Balay     });
219e607c864SMark Adams     team.team_barrier();
220e607c864SMark Adams     // KSP_PCApplyBAorAB
221e607c864SMark Adams     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { T[idx] = Diag[idx] * T[idx]; });
222e607c864SMark Adams     team.team_barrier();
2233ba16761SJacob Faibussowitsch     static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, AUQ));
224e607c864SMark Adams     /* r <- r - a K (u + q) */
225e607c864SMark Adams     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { R[idx] = R[idx] - a * AUQ[idx]; });
226e607c864SMark Adams     team.team_barrier();
227fbccb6d4SPierre Jolivet     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += R[idx] * PetscConj(R[idx]); }, dpi);
228e607c864SMark Adams     team.team_barrier();
229e607c864SMark Adams     dp = PetscSqrtReal(PetscRealPart(dpi));
230e607c864SMark Adams     for (m = 0; m < 2; m++) {
231e607c864SMark Adams       if (!m) w = PetscSqrtReal(dp * dpold);
232e607c864SMark Adams       else w = dp;
233e607c864SMark Adams       psi = w / tau;
234e607c864SMark Adams       cm  = 1.0 / PetscSqrtReal(1.0 + psi * psi);
235e607c864SMark Adams       tau = tau * psi * cm;
236e607c864SMark Adams       eta = cm * cm * a;
237e607c864SMark Adams       cf  = psiold * psiold * etaold / a;
238e607c864SMark Adams       if (!m) {
239e607c864SMark Adams         /* D = U + cf D */
240e607c864SMark Adams         parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { D[idx] = U[idx] + cf * D[idx]; });
241e607c864SMark Adams       } else {
242e607c864SMark Adams         /* D = Q + cf D */
243e607c864SMark Adams         parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { D[idx] = Q[idx] + cf * D[idx]; });
244e607c864SMark Adams       }
245e607c864SMark Adams       team.team_barrier();
246e607c864SMark Adams       parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { XX[idx] = XX[idx] + eta * D[idx]; });
247e607c864SMark Adams       team.team_barrier();
248a4313204SMark Adams       dpest = PetscSqrtReal(2 * it + m + 2.0) * tau;
249a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
250a4313204SMark Adams       if (monitor && m == 1) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", it + 1, (double)dpest); });
251a25f047fSMark Adams #endif
2529371c9d4SSatish Balay       if (dpest < atol) {
2539371c9d4SSatish Balay         metad->reason = KSP_CONVERGED_ATOL_NORMAL;
2549371c9d4SSatish Balay         goto done;
2559371c9d4SSatish Balay       }
2569371c9d4SSatish Balay       if (dpest / r0 < rtol) {
2579371c9d4SSatish Balay         metad->reason = KSP_CONVERGED_RTOL_NORMAL;
2589371c9d4SSatish Balay         goto done;
2599371c9d4SSatish Balay       }
260a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
2619371c9d4SSatish Balay       if (dpest / r0 > dtol) {
2629371c9d4SSatish Balay         metad->reason = KSP_DIVERGED_DTOL;
263a4313204SMark Adams         Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: %d it, res=%e, r_0=%e\n", team.league_rank(), it, dpest, r0); });
2649371c9d4SSatish Balay         goto done;
2659371c9d4SSatish Balay       }
266e607c864SMark Adams #else
2679371c9d4SSatish Balay       if (dpest / r0 > dtol) {
2689371c9d4SSatish Balay         metad->reason = KSP_DIVERGED_DTOL;
2699371c9d4SSatish Balay         goto done;
2709371c9d4SSatish Balay       }
271e607c864SMark Adams #endif
272a4313204SMark Adams       if (it + 1 == maxit) {
27308d80769Smarkadams4         metad->reason = KSP_CONVERGED_ITS;
27408d80769Smarkadams4 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
275a4313204SMark Adams         Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: TFQMR %d:%d it, res=%e, r_0=%e r_res=%e\n", team.league_rank(), it, m, dpest, r0, dpest / r0); });
27608d80769Smarkadams4 #endif
2779371c9d4SSatish Balay         goto done;
2789371c9d4SSatish Balay       }
279e607c864SMark Adams       etaold = eta;
280e607c864SMark Adams       psiold = psi;
281e607c864SMark Adams     }
282e607c864SMark Adams 
283e607c864SMark Adams     /* rho <- (r,rp)       */
284fbccb6d4SPierre Jolivet     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += R[idx] * PetscConj(RP[idx]); }, rho);
285e607c864SMark Adams     team.team_barrier();
28608d80769Smarkadams4     if (rho == 0) {
28708d80769Smarkadams4       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
28808d80769Smarkadams4       goto done;
28908d80769Smarkadams4     }
290e607c864SMark Adams     b = rho / rhoold; /* b <- rho / rhoold   */
291e607c864SMark Adams     /* u <- r + b q        */
292e607c864SMark Adams     /* p <- u + b(q + b p) */
2939371c9d4SSatish Balay     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
2949371c9d4SSatish Balay       U[idx] = R[idx] + b * Q[idx];
2959371c9d4SSatish Balay       Q[idx] = Q[idx] + b * P[idx];
2969371c9d4SSatish Balay       P[idx] = U[idx] + b * Q[idx];
2979371c9d4SSatish Balay     });
298e607c864SMark Adams     /* v <- K p  */
299e607c864SMark Adams     team.team_barrier();
300e607c864SMark Adams     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { T[idx] = Diag[idx] * P[idx]; });
301e607c864SMark Adams     team.team_barrier();
3023ba16761SJacob Faibussowitsch     static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, T, V));
303e607c864SMark Adams 
304e607c864SMark Adams     rhoold = rho;
305e607c864SMark Adams     dpold  = dp;
306e607c864SMark Adams 
307a4313204SMark Adams     it++;
308a4313204SMark Adams   } while (it < maxit);
309e607c864SMark Adams done:
310e607c864SMark Adams   // KSPUnwindPreconditioner
311e607c864SMark Adams   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) { XX[idx] = Diag[idx] * XX[idx]; });
312e607c864SMark Adams   team.team_barrier();
313a1e3af9bSmarkadams4   // put x into Plex order
314e607c864SMark Adams   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
315e607c864SMark Adams     int rowa    = ic[rowb];
316e607c864SMark Adams     glb_x[rowa] = XX[rowb - start];
317e607c864SMark Adams   });
318a4313204SMark Adams   metad->its = it;
319e607c864SMark Adams   if (1) {
320e607c864SMark Adams     int nnz;
321fbccb6d4SPierre Jolivet     parallel_reduce(Kokkos::TeamVectorRange(team, start, end), [=](const int idx, int &lsum) { lsum += (glb_Aai[idx + 1] - glb_Aai[idx]); }, nnz);
322e607c864SMark Adams     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * nnz) + 5 * Nblk);
323e607c864SMark Adams   } else {
324e607c864SMark Adams     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * 50 * Nblk) + 5 * Nblk); // guess
325e607c864SMark Adams   }
3263ba16761SJacob Faibussowitsch   return PETSC_SUCCESS;
327e607c864SMark Adams }
328e607c864SMark Adams 
329e607c864SMark Adams // Solve Ax = y with biCG
33066976f2fSJacob Faibussowitsch static KOKKOS_INLINE_FUNCTION PetscErrorCode BJSolve_BICG(const team_member team, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt *r, const PetscInt *ic, PetscScalar *work_space_global, const int stride_global, const int nShareVec, PetscScalar *work_space_shared, const int stride_shared, PetscReal rtol, PetscReal atol, PetscReal dtol, PetscInt maxit, Batch_MetaData *metad, const PetscInt start, const PetscInt end, const PetscScalar glb_idiag[], const PetscScalar *glb_b, PetscScalar *glb_x, bool monitor)
331d71ae5a4SJacob Faibussowitsch {
332e607c864SMark Adams   using Kokkos::parallel_for;
3339371c9d4SSatish Balay   using Kokkos::parallel_reduce;
334a4313204SMark Adams   int                Nblk = end - start, it, stride = stride_shared, idx = 0; // start in shared mem
335e607c864SMark Adams   PetscReal          dp, r0;
336e607c864SMark Adams   const PetscScalar *Di  = &glb_idiag[start];
33708d80769Smarkadams4   PetscScalar       *ptr = work_space_shared, dpi, a = 1.0, beta, betaold = 1.0, t1, t2;
338a1e3af9bSmarkadams4 
3399371c9d4SSatish Balay   if (idx++ == nShareVec) {
3409371c9d4SSatish Balay     ptr    = work_space_global;
3419371c9d4SSatish Balay     stride = stride_global;
3429371c9d4SSatish Balay   }
3439371c9d4SSatish Balay   PetscScalar *XX = ptr;
3449371c9d4SSatish Balay   ptr += stride;
3459371c9d4SSatish Balay   if (idx++ == nShareVec) {
3469371c9d4SSatish Balay     ptr    = work_space_global;
3479371c9d4SSatish Balay     stride = stride_global;
3489371c9d4SSatish Balay   }
3499371c9d4SSatish Balay   PetscScalar *Rl = ptr;
3509371c9d4SSatish Balay   ptr += stride;
3519371c9d4SSatish Balay   if (idx++ == nShareVec) {
3529371c9d4SSatish Balay     ptr    = work_space_global;
3539371c9d4SSatish Balay     stride = stride_global;
3549371c9d4SSatish Balay   }
3559371c9d4SSatish Balay   PetscScalar *Zl = ptr;
3569371c9d4SSatish Balay   ptr += stride;
3579371c9d4SSatish Balay   if (idx++ == nShareVec) {
3589371c9d4SSatish Balay     ptr    = work_space_global;
3599371c9d4SSatish Balay     stride = stride_global;
3609371c9d4SSatish Balay   }
3619371c9d4SSatish Balay   PetscScalar *Pl = ptr;
3629371c9d4SSatish Balay   ptr += stride;
3639371c9d4SSatish Balay   if (idx++ == nShareVec) {
3649371c9d4SSatish Balay     ptr    = work_space_global;
3659371c9d4SSatish Balay     stride = stride_global;
3669371c9d4SSatish Balay   }
3679371c9d4SSatish Balay   PetscScalar *Rr = ptr;
3689371c9d4SSatish Balay   ptr += stride;
3699371c9d4SSatish Balay   if (idx++ == nShareVec) {
3709371c9d4SSatish Balay     ptr    = work_space_global;
3719371c9d4SSatish Balay     stride = stride_global;
3729371c9d4SSatish Balay   }
3739371c9d4SSatish Balay   PetscScalar *Zr = ptr;
3749371c9d4SSatish Balay   ptr += stride;
3759371c9d4SSatish Balay   if (idx++ == nShareVec) {
3769371c9d4SSatish Balay     ptr    = work_space_global;
3779371c9d4SSatish Balay     stride = stride_global;
3789371c9d4SSatish Balay   }
3799371c9d4SSatish Balay   PetscScalar *Pr = ptr;
3809371c9d4SSatish Balay   ptr += stride;
381e607c864SMark Adams 
382e607c864SMark Adams   /*     r <- b (x is 0) */
383e607c864SMark Adams   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
384e607c864SMark Adams     int rowa         = ic[rowb];
385e607c864SMark Adams     Rl[rowb - start] = Rr[rowb - start] = glb_b[rowa];
386e607c864SMark Adams     XX[rowb - start]                    = 0;
387e607c864SMark Adams   });
388e607c864SMark Adams   team.team_barrier();
389e607c864SMark Adams   /*     z <- Br         */
3909371c9d4SSatish Balay   parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
3919371c9d4SSatish Balay     Zr[idx] = Di[idx] * Rr[idx];
3929371c9d4SSatish Balay     Zl[idx] = Di[idx] * Rl[idx];
3939371c9d4SSatish Balay   });
394e607c864SMark Adams   team.team_barrier();
395e607c864SMark Adams   /*    dp <- r'*r       */
396fbccb6d4SPierre Jolivet   parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Rr[idx] * PetscConj(Rr[idx]); }, dpi);
397e607c864SMark Adams   team.team_barrier();
398e607c864SMark Adams   r0 = dp = PetscSqrtReal(PetscRealPart(dpi));
399a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
400e607c864SMark Adams   if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", 0, (double)dp); });
401a25f047fSMark Adams #endif
4029371c9d4SSatish Balay   if (dp < atol) {
4039371c9d4SSatish Balay     metad->reason = KSP_CONVERGED_ATOL_NORMAL;
404a4313204SMark Adams     it            = 0;
405a4313204SMark Adams     goto done;
4069371c9d4SSatish Balay   }
4079371c9d4SSatish Balay   if (0 == maxit) {
40808d80769Smarkadams4     metad->reason = KSP_CONVERGED_ITS;
409a4313204SMark Adams     it            = 0;
410a4313204SMark Adams     goto done;
4119371c9d4SSatish Balay   }
412a4313204SMark Adams 
413a4313204SMark Adams   it = 0;
414e607c864SMark Adams   do {
415e607c864SMark Adams     /*     beta <- r'z     */
416fbccb6d4SPierre Jolivet     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &dot) { dot += Zr[idx] * PetscConj(Rl[idx]); }, beta);
417e607c864SMark Adams     team.team_barrier();
418e607c864SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL >= 6
419a25f047fSMark Adams   #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
420e607c864SMark Adams     Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%7d beta = Z.R = %22.14e \n", i, (double)beta); });
421e607c864SMark Adams   #endif
422a25f047fSMark Adams #endif
423e607c864SMark Adams     if (beta == 0.0) {
42408d80769Smarkadams4       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
425e607c864SMark Adams       goto done;
426e607c864SMark Adams     }
427a4313204SMark Adams     if (it == 0) {
428e607c864SMark Adams       /*     p <- z          */
4299371c9d4SSatish Balay       parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
4309371c9d4SSatish Balay         Pr[idx] = Zr[idx];
4319371c9d4SSatish Balay         Pl[idx] = Zl[idx];
4329371c9d4SSatish Balay       });
433e607c864SMark Adams     } else {
43408d80769Smarkadams4       t1 = beta / betaold;
435e607c864SMark Adams       /*     p <- z + b* p   */
43608d80769Smarkadams4       t2 = PetscConj(t1);
4379371c9d4SSatish Balay       parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
43808d80769Smarkadams4         Pr[idx] = t1 * Pr[idx] + Zr[idx];
43908d80769Smarkadams4         Pl[idx] = t2 * Pl[idx] + Zl[idx];
4409371c9d4SSatish Balay       });
441e607c864SMark Adams     }
442e607c864SMark Adams     team.team_barrier();
443e607c864SMark Adams     betaold = beta;
444e607c864SMark Adams     /*     z <- Kp         */
4453ba16761SJacob Faibussowitsch     static_cast<void>(MatMult(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, Pr, Zr));
4463ba16761SJacob Faibussowitsch     static_cast<void>(MatMultTranspose(team, glb_Aai, glb_Aaj, glb_Aaa, r, ic, start, end, Pl, Zl));
447e607c864SMark Adams     /*     dpi <- z'p      */
448fbccb6d4SPierre Jolivet     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Zr[idx] * PetscConj(Pl[idx]); }, dpi);
449e607c864SMark Adams     team.team_barrier();
45008d80769Smarkadams4     if (dpi == 0) {
45108d80769Smarkadams4       metad->reason = KSP_CONVERGED_HAPPY_BREAKDOWN;
45208d80769Smarkadams4       goto done;
45308d80769Smarkadams4     }
454e607c864SMark Adams     //
455e607c864SMark Adams     a  = beta / dpi; /*     a = beta/p'z    */
45608d80769Smarkadams4     t1 = -a;
45708d80769Smarkadams4     t2 = PetscConj(t1);
458e607c864SMark Adams     /*     x <- x + ap     */
4599371c9d4SSatish Balay     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
4609371c9d4SSatish Balay       XX[idx] = XX[idx] + a * Pr[idx];
46108d80769Smarkadams4       Rr[idx] = Rr[idx] + t1 * Zr[idx];
46208d80769Smarkadams4       Rl[idx] = Rl[idx] + t2 * Zl[idx];
4639371c9d4SSatish Balay     });
4649371c9d4SSatish Balay     team.team_barrier();
465e607c864SMark Adams     team.team_barrier();
466e607c864SMark Adams     /*    dp <- r'*r       */
467fbccb6d4SPierre Jolivet     parallel_reduce(Kokkos::TeamVectorRange(team, Nblk), [=](const int idx, PetscScalar &lsum) { lsum += Rr[idx] * PetscConj(Rr[idx]); }, dpi);
468e607c864SMark Adams     team.team_barrier();
469e607c864SMark Adams     dp = PetscSqrtReal(PetscRealPart(dpi));
470a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
471a4313204SMark Adams     if (monitor) Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("%3d KSP Residual norm %14.12e\n", it + 1, (double)dp); });
472a25f047fSMark Adams #endif
4739371c9d4SSatish Balay     if (dp < atol) {
4749371c9d4SSatish Balay       metad->reason = KSP_CONVERGED_ATOL_NORMAL;
4759371c9d4SSatish Balay       goto done;
4769371c9d4SSatish Balay     }
4779371c9d4SSatish Balay     if (dp / r0 < rtol) {
4789371c9d4SSatish Balay       metad->reason = KSP_CONVERGED_RTOL_NORMAL;
4799371c9d4SSatish Balay       goto done;
4809371c9d4SSatish Balay     }
481a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
4829371c9d4SSatish Balay     if (dp / r0 > dtol) {
4839371c9d4SSatish Balay       metad->reason = KSP_DIVERGED_DTOL;
484a4313204SMark Adams       Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: %d it, res=%e, r_0=%e (BICG does this)\n", team.league_rank(), it, dp, r0); });
4859371c9d4SSatish Balay       goto done;
4869371c9d4SSatish Balay     }
487e607c864SMark Adams #else
4889371c9d4SSatish Balay     if (dp / r0 > dtol) {
4899371c9d4SSatish Balay       metad->reason = KSP_DIVERGED_DTOL;
4909371c9d4SSatish Balay       goto done;
4919371c9d4SSatish Balay     }
492e607c864SMark Adams #endif
493a4313204SMark Adams     if (it + 1 == maxit) {
49408d80769Smarkadams4       metad->reason = KSP_CONVERGED_ITS; // don't worry about hitting max iterations
49508d80769Smarkadams4 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
496a4313204SMark Adams       Kokkos::single(Kokkos::PerTeam(team), [=]() { printf("ERROR block %d diverged: BICG %d it, res=%e, r_0=%e r_res=%e\n", team.league_rank(), it, dp, r0, dp / r0); });
49708d80769Smarkadams4 #endif
4989371c9d4SSatish Balay       goto done;
4999371c9d4SSatish Balay     }
500e607c864SMark Adams     /* z <- Br  */
5019371c9d4SSatish Balay     parallel_for(Kokkos::TeamVectorRange(team, Nblk), [=](int idx) {
5029371c9d4SSatish Balay       Zr[idx] = Di[idx] * Rr[idx];
5039371c9d4SSatish Balay       Zl[idx] = Di[idx] * Rl[idx];
5049371c9d4SSatish Balay     });
505a4313204SMark Adams 
506a4313204SMark Adams     it++;
507a4313204SMark Adams   } while (it < maxit);
508e607c864SMark Adams done:
509a1e3af9bSmarkadams4   // put x back into Plex order
510e607c864SMark Adams   parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
511e607c864SMark Adams     int rowa    = ic[rowb];
512e607c864SMark Adams     glb_x[rowa] = XX[rowb - start];
513e607c864SMark Adams   });
514a4313204SMark Adams   metad->its = it;
515e607c864SMark Adams   if (1) {
516e607c864SMark Adams     int nnz;
517fbccb6d4SPierre Jolivet     parallel_reduce(Kokkos::TeamVectorRange(team, start, end), [=](const int idx, int &lsum) { lsum += (glb_Aai[idx + 1] - glb_Aai[idx]); }, nnz);
518e607c864SMark Adams     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * nnz) + 5 * Nblk);
519e607c864SMark Adams   } else {
520e607c864SMark Adams     metad->flops = 2 * (metad->its * (10 * Nblk + 2 * 50 * Nblk) + 5 * Nblk); // guess
521e607c864SMark Adams   }
5223ba16761SJacob Faibussowitsch   return PETSC_SUCCESS;
523e607c864SMark Adams }
524e607c864SMark Adams 
525a999f97fSMark Adams // KSP solver solve Ax = b; xout is output, bin is input
526d71ae5a4SJacob Faibussowitsch static PetscErrorCode PCApply_BJKOKKOS(PC pc, Vec bin, Vec xout)
527d71ae5a4SJacob Faibussowitsch {
528e607c864SMark Adams   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
529ed00a6c1SMark Adams   Mat            A = pc->pmat, Aseq = A;
530a4313204SMark Adams   PetscMPIInt    rank;
531e607c864SMark Adams 
532e607c864SMark Adams   PetscFunctionBegin;
533a4313204SMark Adams   PetscCallMPI(MPI_Comm_rank(PetscObjectComm((PetscObject)A), &rank));
534ed00a6c1SMark Adams   if (!A->spptr) {
535ed00a6c1SMark Adams     Aseq = ((Mat_MPIAIJ *)A->data)->A; // MPI
536ed00a6c1SMark Adams   }
537ed00a6c1SMark Adams   PetscCall(MatSeqAIJKokkosSyncDevice(Aseq));
5380fdf79fbSJacob Faibussowitsch   {
539a1e3af9bSmarkadams4     PetscInt           maxit = jac->ksp->max_it;
540e607c864SMark Adams     const PetscInt     conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp == 0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
541e607c864SMark Adams     const PetscInt     nwork = jac->nwork, nBlk = jac->nBlocks;
542ed00a6c1SMark Adams     PetscScalar       *glb_xdata = NULL, *dummy;
543e607c864SMark Adams     PetscReal          rtol = jac->ksp->rtol, atol = jac->ksp->abstol, dtol = jac->ksp->divtol;
544e607c864SMark Adams     const PetscScalar *glb_idiag = jac->d_idiag_k->data(), *glb_bdata = NULL;
545ed00a6c1SMark Adams     const PetscInt    *glb_Aai, *glb_Aaj, *d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
546ed00a6c1SMark Adams     const PetscScalar *glb_Aaa;
547a1e3af9bSmarkadams4     const PetscInt    *d_isicol = jac->d_isicol_k->data(), *d_isrow = jac->d_isrow_k->data();
548e607c864SMark Adams     PCFailedReason     pcreason;
549e607c864SMark Adams     KSPIndex           ksp_type_idx = jac->ksp_type_idx;
550e607c864SMark Adams     PetscMemType       mtype;
551e607c864SMark Adams     PetscContainer     container;
552a999f97fSMark Adams     PetscInt           batch_sz;                // the number of repeated DMs, [DM_e_1, DM_e_2, DM_e_batch_sz, DM_i_1, ...]
553a1e3af9bSmarkadams4     VecScatter         plex_batch = NULL;       // not used
554a1e3af9bSmarkadams4     Vec                bvec;                    // a copy of b for scatter (just alias to bin now)
555e607c864SMark Adams     PetscBool          monitor  = jac->monitor; // captured
556e607c864SMark Adams     PetscInt           view_bid = jac->batch_target;
557a1e3af9bSmarkadams4     MatInfo            info;
558ed00a6c1SMark Adams 
559ed00a6c1SMark Adams     PetscCall(MatSeqAIJGetCSRAndMemType(Aseq, &glb_Aai, &glb_Aaj, &dummy, &mtype));
560a1e3af9bSmarkadams4     jac->max_nits = 0;
561a4313204SMark Adams     glb_Aaa       = dummy;
562a4313204SMark Adams     if (jac->rank_target != rank) view_bid = -1; // turn off all but one process
563a1e3af9bSmarkadams4     PetscCall(MatGetInfo(A, MAT_LOCAL, &info));
564e607c864SMark Adams     // get field major is to map plex IO to/from block/field major
5659566063dSJacob Faibussowitsch     PetscCall(PetscObjectQuery((PetscObject)A, "plex_batch_is", (PetscObject *)&container));
566e607c864SMark Adams     if (container) {
567a1e3af9bSmarkadams4       PetscCall(VecDuplicate(bin, &bvec));
5689566063dSJacob Faibussowitsch       PetscCall(PetscContainerGetPointer(container, (void **)&plex_batch));
5699566063dSJacob Faibussowitsch       PetscCall(VecScatterBegin(plex_batch, bin, bvec, INSERT_VALUES, SCATTER_FORWARD));
5709566063dSJacob Faibussowitsch       PetscCall(VecScatterEnd(plex_batch, bin, bvec, INSERT_VALUES, SCATTER_FORWARD));
571a1e3af9bSmarkadams4       SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_USER, "No plex_batch_is -- require NO field major ordering for now");
572e607c864SMark Adams     } else {
573a1e3af9bSmarkadams4       bvec = bin;
574e607c864SMark Adams     }
575e607c864SMark Adams     // get x
5769566063dSJacob Faibussowitsch     PetscCall(VecGetArrayAndMemType(xout, &glb_xdata, &mtype));
577e607c864SMark Adams #if defined(PETSC_HAVE_CUDA)
5783259942eSJunchao Zhang     PetscCheck(PetscMemTypeDevice(mtype), PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "No GPU data for x %d != %d", static_cast<int>(mtype), static_cast<int>(PETSC_MEMTYPE_DEVICE));
579e607c864SMark Adams #endif
5809566063dSJacob Faibussowitsch     PetscCall(VecGetArrayReadAndMemType(bvec, &glb_bdata, &mtype));
581e607c864SMark Adams #if defined(PETSC_HAVE_CUDA)
5829effc8a0SJed Brown     PetscCheck(PetscMemTypeDevice(mtype), PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "No GPU data for b");
583e607c864SMark Adams #endif
584e607c864SMark Adams     // get batch size
5859566063dSJacob Faibussowitsch     PetscCall(PetscObjectQuery((PetscObject)A, "batch size", (PetscObject *)&container));
586e607c864SMark Adams     if (container) {
587e607c864SMark Adams       PetscInt *pNf = NULL;
5889566063dSJacob Faibussowitsch       PetscCall(PetscContainerGetPointer(container, (void **)&pNf));
589a999f97fSMark Adams       batch_sz = *pNf; // number of times to repeat the DMs
590e607c864SMark Adams     } else batch_sz = 1;
591d618d24fSMark Adams     PetscCheck(nBlk % batch_sz == 0, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "batch_sz = %" PetscInt_FMT ", nBlk = %" PetscInt_FMT, batch_sz, nBlk);
592a4313204SMark Adams     if (ksp_type_idx == BATCH_KSP_GMRESKK_IDX) {
593a4313204SMark Adams       // KK solver - move PETSc data into Kokkos Views, setup solver, solve, move data out of Kokkos, process metadata (convergence tests, etc.)
594a4313204SMark Adams #if defined(PETSC_HAVE_KOKKOS_KERNELS_BATCH)
595a4313204SMark Adams       PetscCall(PCApply_BJKOKKOSKERNELS(pc, glb_bdata, glb_xdata, glb_Aai, glb_Aaj, glb_Aaa, team_size, info, batch_sz, &pcreason));
596a1e3af9bSmarkadams4 #else
59700045ab3SPierre Jolivet       PetscCheck(ksp_type_idx != BATCH_KSP_GMRESKK_IDX, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Type: BATCH_KSP_GMRES not supported for complex");
598a1e3af9bSmarkadams4 #endif
599a1e3af9bSmarkadams4     } else { // Kokkos Krylov
600a1e3af9bSmarkadams4       using scr_mem_t    = Kokkos::DefaultExecutionSpace::scratch_memory_space;
601a1e3af9bSmarkadams4       using vect2D_scr_t = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, scr_mem_t>;
602a1e3af9bSmarkadams4       Kokkos::View<Batch_MetaData *, Kokkos::DefaultExecutionSpace> d_metadata("solver meta data", nBlk);
603a1e3af9bSmarkadams4       int                                                           stride_shared, stride_global, global_buff_words;
604a1e3af9bSmarkadams4       d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
605a1e3af9bSmarkadams4       // solve each block independently
606a1e3af9bSmarkadams4       int scr_bytes_team_shared = 0, nShareVec = 0, nGlobBVec = 0;
6070338c944SBarry Smith       if (jac->const_block_size) { // use shared memory for work vectors only if constant block size - TODO: test efficiency loss
60808d80769Smarkadams4         size_t      maximum_shared_mem_size = 64000;
609a16fd2c9SJacob Faibussowitsch         PetscDevice device;
610a16fd2c9SJacob Faibussowitsch         PetscCall(PetscDeviceGetDefault_Internal(&device));
611a16fd2c9SJacob Faibussowitsch         PetscCall(PetscDeviceGetAttribute(device, PETSC_DEVICE_ATTR_SIZE_T_SHARED_MEM_PER_BLOCK, &maximum_shared_mem_size));
612a1e3af9bSmarkadams4         stride_shared = jac->const_block_size;                                                   // captured
613a1e3af9bSmarkadams4         nShareVec     = maximum_shared_mem_size / (jac->const_block_size * sizeof(PetscScalar)); // integer floor, number of vectors that fit in shared
614a1e3af9bSmarkadams4         if (nShareVec > nwork) nShareVec = nwork;
615a1e3af9bSmarkadams4         else nGlobBVec = nwork - nShareVec;
616a1e3af9bSmarkadams4         global_buff_words     = jac->n * nGlobBVec;
617a1e3af9bSmarkadams4         scr_bytes_team_shared = jac->const_block_size * nShareVec * sizeof(PetscScalar);
618a1e3af9bSmarkadams4       } else {
619a1e3af9bSmarkadams4         scr_bytes_team_shared = 0;
620a1e3af9bSmarkadams4         stride_shared         = 0;
621a1e3af9bSmarkadams4         global_buff_words     = jac->n * nwork;
622a1e3af9bSmarkadams4         nGlobBVec             = nwork; // not needed == fix
623a1e3af9bSmarkadams4       }
624a1e3af9bSmarkadams4       stride_global = jac->n; // captured
625a16fd2c9SJacob Faibussowitsch #if defined(PETSC_HAVE_CUDA)
626a1e3af9bSmarkadams4       nvtxRangePushA("batch-kokkos-solve");
627a1e3af9bSmarkadams4 #endif
628a1e3af9bSmarkadams4       Kokkos::View<PetscScalar *, Kokkos::DefaultExecutionSpace> d_work_vecs_k("workvectors", global_buff_words); // global work vectors
629a4313204SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL > 1
6303ba16761SJacob Faibussowitsch       PetscCall(PetscInfo(pc, "\tn = %d. %d shared bytes/team, %d global mem bytes, rtol=%e, num blocks %d, team_size=%d, %d vector threads, %d shared vectors, %d global vectors\n", (int)jac->n, scr_bytes_team_shared, global_buff_words, rtol, (int)nBlk, (int)team_size, PCBJKOKKOS_VEC_SIZE, nShareVec, nGlobBVec));
631a4313204SMark Adams #endif
632a1e3af9bSmarkadams4       PetscScalar *d_work_vecs = d_work_vecs_k.data();
6339371c9d4SSatish Balay       Kokkos::parallel_for(
6349371c9d4SSatish Balay         "Solve", Kokkos::TeamPolicy<Kokkos::LaunchBounds<256, 4>>(nBlk, team_size, PCBJKOKKOS_VEC_SIZE).set_scratch_size(PCBJKOKKOS_SHARED_LEVEL, Kokkos::PerTeam(scr_bytes_team_shared)), KOKKOS_LAMBDA(const team_member team) {
635a1e3af9bSmarkadams4           const int    blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID + 1];
636a1e3af9bSmarkadams4           vect2D_scr_t work_vecs_shared(team.team_scratch(PCBJKOKKOS_SHARED_LEVEL), end - start, nShareVec);
637a1e3af9bSmarkadams4           PetscScalar *work_buff_shared = work_vecs_shared.data();
638a1e3af9bSmarkadams4           PetscScalar *work_buff_global = &d_work_vecs[start]; // start inc'ed in
639e607c864SMark Adams           bool         print            = monitor && (blkID == view_bid);
640e607c864SMark Adams           switch (ksp_type_idx) {
641e607c864SMark Adams           case BATCH_KSP_BICG_IDX:
6423ba16761SJacob Faibussowitsch             static_cast<void>(BJSolve_BICG(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff_global, stride_global, nShareVec, work_buff_shared, stride_shared, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print));
643e607c864SMark Adams             break;
644e607c864SMark Adams           case BATCH_KSP_TFQMR_IDX:
6453ba16761SJacob Faibussowitsch             static_cast<void>(BJSolve_TFQMR(team, glb_Aai, glb_Aaj, glb_Aaa, d_isrow, d_isicol, work_buff_global, stride_global, nShareVec, work_buff_shared, stride_shared, rtol, atol, dtol, maxit, &d_metadata[blkID], start, end, glb_idiag, glb_bdata, glb_xdata, print));
646e607c864SMark Adams             break;
647e607c864SMark Adams           default:
648a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
649e607c864SMark Adams             printf("Unknown KSP type %d\n", ksp_type_idx);
650e607c864SMark Adams #else
651e607c864SMark Adams             /* void */;
652e607c864SMark Adams #endif
653e607c864SMark Adams           }
654e607c864SMark Adams         });
655e607c864SMark Adams       Kokkos::fence();
656a16fd2c9SJacob Faibussowitsch #if defined(PETSC_HAVE_CUDA)
657a1e3af9bSmarkadams4       nvtxRangePop();
658a1e3af9bSmarkadams4       nvtxRangePushA("Post-solve-metadata");
659a1e3af9bSmarkadams4 #endif
660a1e3af9bSmarkadams4       auto h_metadata = Kokkos::create_mirror(Kokkos::HostSpace::memory_space(), d_metadata);
661e607c864SMark Adams       Kokkos::deep_copy(h_metadata, d_metadata);
6629bdd0a17SMark Adams       PetscInt max_nnit = -1;
6639bdd0a17SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL > 1
6649bdd0a17SMark Adams       PetscInt mbid = 0;
6659bdd0a17SMark Adams #endif
666a4313204SMark Adams       int in[2], out[2];
667a999f97fSMark Adams       if (jac->reason) { // -pc_bjkokkos_ksp_converged_reason
668e607c864SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL >= 3
669e607c864SMark Adams   #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
6709566063dSJacob Faibussowitsch         PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Iterations\n"));
671e607c864SMark Adams   #endif
672e607c864SMark Adams         // assume species major
6739bdd0a17SMark Adams   #if PCBJKOKKOS_VERBOSE_LEVEL == 3
674a999f97fSMark Adams         if (batch_sz != 1) PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%s: max iterations per species:", ksp_type_idx == BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr"));
675a999f97fSMark Adams         else PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "    Linear solve converged due to %s iterations ", ksp_type_idx == BATCH_KSP_BICG_IDX ? "bicg" : "tfqmr"));
676e607c864SMark Adams   #endif
6779bdd0a17SMark Adams         for (PetscInt dmIdx = 0, head = 0, s = 0; dmIdx < jac->num_dms; dmIdx += batch_sz) {
6789bdd0a17SMark Adams           for (PetscInt f = 0, idx = head; f < jac->dm_Nf[dmIdx]; f++, idx++, s++) {
6799bdd0a17SMark Adams             for (int bid = 0; bid < batch_sz; bid++) {
6809bdd0a17SMark Adams   #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
6819bdd0a17SMark Adams               jac->max_nits += h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its; // report total number of iterations with high verbose
6829bdd0a17SMark Adams               if (h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its > max_nnit) {
6839bdd0a17SMark Adams                 max_nnit = h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its;
6849bdd0a17SMark Adams                 mbid     = bid;
6859bdd0a17SMark Adams               }
6869bdd0a17SMark Adams   #else
6879bdd0a17SMark Adams               if (h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its > max_nnit) {
6889bdd0a17SMark Adams                 jac->max_nits = max_nnit = h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its;
6899bdd0a17SMark Adams                 mbid                     = bid;
6909bdd0a17SMark Adams               }
6919bdd0a17SMark Adams   #endif
6929bdd0a17SMark Adams             }
693e607c864SMark Adams   #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
694a1e3af9bSmarkadams4             PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%2" PetscInt_FMT ":", s));
69548a46eb9SPierre Jolivet             for (int bid = 0; bid < batch_sz; bid++) PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%3" PetscInt_FMT " ", h_metadata[idx + bid * jac->dm_Nf[dmIdx]].its));
696a1e3af9bSmarkadams4             PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\n"));
6979bdd0a17SMark Adams   #else // == 3
6989bdd0a17SMark Adams             PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%3" PetscInt_FMT " ", max_nnit));
699e607c864SMark Adams   #endif
700e607c864SMark Adams           }
701e607c864SMark Adams           head += batch_sz * jac->dm_Nf[dmIdx];
702e607c864SMark Adams         }
703a1e3af9bSmarkadams4   #if PCBJKOKKOS_VERBOSE_LEVEL == 3
704a1e3af9bSmarkadams4         PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\n"));
705e607c864SMark Adams   #endif
706e607c864SMark Adams #endif
7079bdd0a17SMark Adams         if (max_nnit == -1) { // < 3
708e607c864SMark Adams           for (int blkID = 0; blkID < nBlk; blkID++) {
7099bdd0a17SMark Adams             if (h_metadata[blkID].its > max_nnit) {
7109bdd0a17SMark Adams               jac->max_nits = max_nnit = h_metadata[blkID].its;
7119bdd0a17SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL > 1
71226f882f0SMark Adams               mbid = blkID;
713a999f97fSMark Adams #endif
714a999f97fSMark Adams             }
715a4313204SMark Adams           }
7169bdd0a17SMark Adams         }
7179bdd0a17SMark Adams         in[0] = max_nnit;
718a4313204SMark Adams         in[1] = rank;
7196a210b70SBarry Smith         PetscCallMPI(MPIU_Allreduce(in, out, 1, MPI_2INT, MPI_MAXLOC, PetscObjectComm((PetscObject)A)));
7209bdd0a17SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL > 1
721a4313204SMark Adams         if (0 == rank) {
722a999f97fSMark Adams           if (batch_sz != 1)
7239bdd0a17SMark Adams             PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] Linear solve converged due to %s iterations %d (max), on block %" PetscInt_FMT ", species %" PetscInt_FMT " (max)\n", out[1], KSPConvergedReasons[h_metadata[mbid].reason], out[0], mbid % batch_sz, mbid / batch_sz));
7249bdd0a17SMark Adams           else PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] Linear solve converged due to %s iterations %d (max), on block %" PetscInt_FMT "\n", out[1], KSPConvergedReasons[h_metadata[mbid].reason], out[0], mbid));
725e607c864SMark Adams         }
7269bdd0a17SMark Adams #endif
727e607c864SMark Adams       }
728a999f97fSMark Adams       for (int blkID = 0; blkID < nBlk; blkID++) {
729a999f97fSMark Adams         PetscCall(PetscLogGpuFlops((PetscLogDouble)h_metadata[blkID].flops));
7309bdd0a17SMark Adams         PetscCheck(h_metadata[blkID].reason >= 0 || !jac->ksp->errorifnotconverged, PetscObjectComm((PetscObject)pc), PETSC_ERR_CONV_FAILED, "ERROR reason=%s, its=%" PetscInt_FMT ". species %" PetscInt_FMT ", batch %" PetscInt_FMT,
7319bdd0a17SMark Adams                    KSPConvergedReasons[h_metadata[blkID].reason], h_metadata[blkID].its, blkID / batch_sz, blkID % batch_sz);
73226f882f0SMark Adams       }
733e607c864SMark Adams       {
734e607c864SMark Adams         int errsum;
7359371c9d4SSatish Balay         Kokkos::parallel_reduce(
7369371c9d4SSatish Balay           nBlk,
7379371c9d4SSatish Balay           KOKKOS_LAMBDA(const int idx, int &lsum) {
73826f882f0SMark Adams             if (d_metadata[idx].reason < 0) ++lsum;
7399371c9d4SSatish Balay           },
7409371c9d4SSatish Balay           errsum);
74126f882f0SMark Adams         pcreason = errsum ? PC_SUBPC_ERROR : PC_NOERROR;
742a1e3af9bSmarkadams4         if (!errsum && !jac->max_nits) { // set max its to give back to top KSP
743a1e3af9bSmarkadams4           for (int blkID = 0; blkID < nBlk; blkID++) {
744a1e3af9bSmarkadams4             if (h_metadata[blkID].its > jac->max_nits) jac->max_nits = h_metadata[blkID].its;
745e607c864SMark Adams           }
746a1e3af9bSmarkadams4         } else if (errsum) {
74708d80769Smarkadams4           PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] ERROR Kokkos batch solver did not converge in all solves\n", (int)rank));
748a1e3af9bSmarkadams4         }
749a1e3af9bSmarkadams4       }
750a16fd2c9SJacob Faibussowitsch #if defined(PETSC_HAVE_CUDA)
751a1e3af9bSmarkadams4       nvtxRangePop();
752a1e3af9bSmarkadams4 #endif
753a1e3af9bSmarkadams4     } // end of Kokkos (not Kernels) solvers block
754a1e3af9bSmarkadams4     PetscCall(VecRestoreArrayAndMemType(xout, &glb_xdata));
755a1e3af9bSmarkadams4     PetscCall(VecRestoreArrayReadAndMemType(bvec, &glb_bdata));
7569566063dSJacob Faibussowitsch     PetscCall(PCSetFailedReason(pc, pcreason));
757a1e3af9bSmarkadams4     // map back to Plex space - not used
758e607c864SMark Adams     if (plex_batch) {
7599566063dSJacob Faibussowitsch       PetscCall(VecCopy(xout, bvec));
7609566063dSJacob Faibussowitsch       PetscCall(VecScatterBegin(plex_batch, bvec, xout, INSERT_VALUES, SCATTER_REVERSE));
7619566063dSJacob Faibussowitsch       PetscCall(VecScatterEnd(plex_batch, bvec, xout, INSERT_VALUES, SCATTER_REVERSE));
7629566063dSJacob Faibussowitsch       PetscCall(VecDestroy(&bvec));
763e607c864SMark Adams     }
764ed00a6c1SMark Adams   }
7653ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
766e607c864SMark Adams }
767e607c864SMark Adams 
768d71ae5a4SJacob Faibussowitsch static PetscErrorCode PCSetUp_BJKOKKOS(PC pc)
769d71ae5a4SJacob Faibussowitsch {
770e607c864SMark Adams   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
771a999f97fSMark Adams   Mat            A = pc->pmat, Aseq = A; // use filtered block matrix, really "P"
772e607c864SMark Adams   PetscBool      flg;
773e607c864SMark Adams 
774e607c864SMark Adams   PetscFunctionBegin;
7759bdd0a17SMark Adams   PetscCheck(A, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "No matrix - A is used above");
7769566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompareAny((PetscObject)A, &flg, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, MATAIJKOKKOS, ""));
777a999f97fSMark Adams   PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "must use '-[dm_]mat_type aijkokkos -[dm_]vec_type kokkos' for -pc_type bjkokkos");
778ed00a6c1SMark Adams   if (!A->spptr) {
779ed00a6c1SMark Adams     Aseq = ((Mat_MPIAIJ *)A->data)->A; // MPI
780a999f97fSMark Adams   }
781ed00a6c1SMark Adams   PetscCall(MatSeqAIJKokkosSyncDevice(Aseq));
7820fdf79fbSJacob Faibussowitsch   {
783a999f97fSMark Adams     PetscInt    Istart, Iend;
784a999f97fSMark Adams     PetscMPIInt rank;
785a4313204SMark Adams     PetscCallMPI(MPI_Comm_rank(PetscObjectComm((PetscObject)A), &rank));
786a999f97fSMark Adams     PetscCall(MatGetOwnershipRange(A, &Istart, &Iend));
787e607c864SMark Adams     if (!jac->vec_diag) {
788a999f97fSMark Adams       Vec     *subX = NULL;
789a999f97fSMark Adams       DM       pack, *subDM = NULL;
790a999f97fSMark Adams       PetscInt nDMs, n, *block_sizes = NULL;
791a999f97fSMark Adams       IS       isrow, isicol;
792e607c864SMark Adams       { // Permute the matrix to get a block diagonal system: d_isrow_k, d_isicol_k
793e607c864SMark Adams         MatOrderingType rtype;
794e607c864SMark Adams         const PetscInt *rowindices, *icolindices;
795a1e3af9bSmarkadams4         rtype = MATORDERINGRCM;
796a999f97fSMark Adams         // get permutation. And invert. should we convert to local indices?
797a999f97fSMark Adams         PetscCall(MatGetOrdering(Aseq, rtype, &isrow, &isicol)); // only seems to work for seq matrix
7989566063dSJacob Faibussowitsch         PetscCall(ISDestroy(&isrow));
799a999f97fSMark Adams         PetscCall(ISInvertPermutation(isicol, PETSC_DECIDE, &isrow)); // THIS IS BACKWARD -- isrow is inverse
800a999f97fSMark Adams         // if (rank==1) PetscCall(ISView(isicol, PETSC_VIEWER_STDOUT_SELF));
801a4313204SMark Adams         if (0) {
802a999f97fSMark Adams           Mat mat_block_order; // debug
803a999f97fSMark Adams           PetscCall(ISShift(isicol, Istart, isicol));
804a1e3af9bSmarkadams4           PetscCall(MatCreateSubMatrix(A, isicol, isicol, MAT_INITIAL_MATRIX, &mat_block_order));
805a999f97fSMark Adams           PetscCall(ISShift(isicol, -Istart, isicol));
806a1e3af9bSmarkadams4           PetscCall(MatViewFromOptions(mat_block_order, NULL, "-ksp_batch_reorder_view"));
807a1e3af9bSmarkadams4           PetscCall(MatDestroy(&mat_block_order));
808a999f97fSMark Adams         }
809a999f97fSMark Adams         PetscCall(ISGetIndices(isrow, &rowindices)); // local idx
8109566063dSJacob Faibussowitsch         PetscCall(ISGetIndices(isicol, &icolindices));
811e607c864SMark Adams         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_isrow_k((PetscInt *)rowindices, A->rmap->n);
812e607c864SMark Adams         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_isicol_k((PetscInt *)icolindices, A->rmap->n);
813e607c864SMark Adams         jac->d_isrow_k  = new Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_isrow_k));
814e607c864SMark Adams         jac->d_isicol_k = new Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_isicol_k));
815e607c864SMark Adams         Kokkos::deep_copy(*jac->d_isrow_k, h_isrow_k);
816e607c864SMark Adams         Kokkos::deep_copy(*jac->d_isicol_k, h_isicol_k);
8179566063dSJacob Faibussowitsch         PetscCall(ISRestoreIndices(isrow, &rowindices));
8189566063dSJacob Faibussowitsch         PetscCall(ISRestoreIndices(isicol, &icolindices));
819a999f97fSMark Adams         // if (rank==1) PetscCall(ISView(isicol, PETSC_VIEWER_STDOUT_SELF));
820a999f97fSMark Adams       }
821a999f97fSMark Adams       // get block sizes & allocate vec_diag
822a999f97fSMark Adams       PetscCall(PCGetDM(pc, &pack));
823a999f97fSMark Adams       if (pack) {
824a999f97fSMark Adams         PetscCall(PetscObjectTypeCompare((PetscObject)pack, DMCOMPOSITE, &flg));
825a999f97fSMark Adams         if (flg) {
826a999f97fSMark Adams           PetscCall(DMCompositeGetNumberDM(pack, &nDMs));
827a999f97fSMark Adams           PetscCall(DMCreateGlobalVector(pack, &jac->vec_diag));
828a999f97fSMark Adams         } else pack = NULL; // flag for no DM
829a999f97fSMark Adams       }
8300338c944SBarry Smith       if (!jac->vec_diag) { // get 'nDMs' and sizes 'block_sizes' w/o DMComposite. TODO: User could provide ISs
831a999f97fSMark Adams         PetscInt        bsrt, bend, ncols, ntot = 0;
832a999f97fSMark Adams         const PetscInt *colsA, nloc = Iend - Istart;
833a999f97fSMark Adams         const PetscInt *rowindices, *icolindices;
834a999f97fSMark Adams         PetscCall(PetscMalloc1(nloc, &block_sizes)); // very inefficient, to big
835a999f97fSMark Adams         PetscCall(ISGetIndices(isrow, &rowindices));
836a999f97fSMark Adams         PetscCall(ISGetIndices(isicol, &icolindices));
837a999f97fSMark Adams         nDMs = 0;
838a999f97fSMark Adams         bsrt = 0;
839a999f97fSMark Adams         bend = 1;
840a999f97fSMark Adams         for (PetscInt row_B = 0; row_B < nloc; row_B++) { // for all rows in block diagonal space
8411690c2aeSBarry Smith           PetscInt rowA = icolindices[row_B], minj = PETSC_INT_MAX, maxj = 0;
842a999f97fSMark Adams           //PetscCall(PetscPrintf(PETSC_COMM_SELF, "\t[%d] rowA = %d\n",rank,rowA));
843a999f97fSMark Adams           PetscCall(MatGetRow(Aseq, rowA, &ncols, &colsA, NULL)); // not sorted in permutation
844e978a55eSPierre Jolivet           PetscCheck(ncols, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Empty row not supported: %" PetscInt_FMT, row_B);
845a999f97fSMark Adams           for (PetscInt colj = 0; colj < ncols; colj++) {
846a999f97fSMark Adams             PetscInt colB = rowindices[colsA[colj]]; // use local idx
847a999f97fSMark Adams             //PetscCall(PetscPrintf(PETSC_COMM_SELF, "\t\t[%d] colB = %d\n",rank,colB));
848e978a55eSPierre Jolivet             PetscCheck(colB >= 0 && colB < nloc, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "colB < 0: %" PetscInt_FMT, colB);
849a999f97fSMark Adams             if (colB > maxj) maxj = colB;
850a999f97fSMark Adams             if (colB < minj) minj = colB;
851a999f97fSMark Adams           }
852a999f97fSMark Adams           PetscCall(MatRestoreRow(Aseq, rowA, &ncols, &colsA, NULL));
853a999f97fSMark Adams           if (minj >= bend) { // first column is > max of last block -- new block or last block
854a999f97fSMark Adams             //PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\t\t finish block %d, N loc = %d (%d,%d)\n", nDMs+1, bend - bsrt,bsrt,bend));
855a999f97fSMark Adams             block_sizes[nDMs] = bend - bsrt;
856a999f97fSMark Adams             ntot += block_sizes[nDMs];
857e978a55eSPierre Jolivet             PetscCheck(minj == bend, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "minj != bend: %" PetscInt_FMT " != %" PetscInt_FMT, minj, bend);
858a999f97fSMark Adams             bsrt = bend;
859a999f97fSMark Adams             bend++; // start with size 1 in new block
860a999f97fSMark Adams             nDMs++;
861a999f97fSMark Adams           }
862a999f97fSMark Adams           if (maxj + 1 > bend) bend = maxj + 1;
863e978a55eSPierre Jolivet           PetscCheck(minj >= bsrt || row_B == Iend - 1, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "%" PetscInt_FMT ") minj < bsrt: %" PetscInt_FMT " != %" PetscInt_FMT, rowA, minj, bsrt);
864a999f97fSMark Adams           //PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] %d) row %d.%d) cols %d : %d ; bsrt = %d, bend = %d\n",rank,row_B,nDMs,rowA,minj,maxj,bsrt,bend));
865a999f97fSMark Adams         }
866a999f97fSMark Adams         // do last block
867a999f97fSMark Adams         //PetscCall(PetscPrintf(PETSC_COMM_SELF, "\t\t\t [%d] finish block %d, N loc = %d (%d,%d)\n", rank, nDMs+1, bend - bsrt,bsrt,bend));
868a999f97fSMark Adams         block_sizes[nDMs] = bend - bsrt;
869a999f97fSMark Adams         ntot += block_sizes[nDMs];
870a999f97fSMark Adams         nDMs++;
871a999f97fSMark Adams         // cleanup
872e978a55eSPierre Jolivet         PetscCheck(ntot == nloc, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "n total != n local: %" PetscInt_FMT " != %" PetscInt_FMT, ntot, nloc);
873a999f97fSMark Adams         PetscCall(ISRestoreIndices(isrow, &rowindices));
874a999f97fSMark Adams         PetscCall(ISRestoreIndices(isicol, &icolindices));
875a999f97fSMark Adams         PetscCall(PetscRealloc(sizeof(PetscInt) * nDMs, &block_sizes));
876a999f97fSMark Adams         PetscCall(MatCreateVecs(A, &jac->vec_diag, NULL));
877a999f97fSMark Adams         PetscCall(PetscInfo(pc, "Setup Matrix based meta data (not DMComposite not attached to PC) %" PetscInt_FMT " sub domains\n", nDMs));
878a999f97fSMark Adams       }
8799566063dSJacob Faibussowitsch       PetscCall(ISDestroy(&isrow));
8809566063dSJacob Faibussowitsch       PetscCall(ISDestroy(&isicol));
881e607c864SMark Adams       jac->num_dms = nDMs;
8829566063dSJacob Faibussowitsch       PetscCall(VecGetLocalSize(jac->vec_diag, &n));
883e607c864SMark Adams       jac->n         = n;
884e607c864SMark Adams       jac->d_idiag_k = new Kokkos::View<PetscScalar *, Kokkos::LayoutRight>("idiag", n);
885e607c864SMark Adams       // options
8869566063dSJacob Faibussowitsch       PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
8879566063dSJacob Faibussowitsch       PetscCall(KSPSetFromOptions(jac->ksp));
8889566063dSJacob Faibussowitsch       PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPBICG, ""));
8899371c9d4SSatish Balay       if (flg) {
8909371c9d4SSatish Balay         jac->ksp_type_idx = BATCH_KSP_BICG_IDX;
8919371c9d4SSatish Balay         jac->nwork        = 7;
8929371c9d4SSatish Balay       } else {
8939566063dSJacob Faibussowitsch         PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPTFQMR, ""));
8949371c9d4SSatish Balay         if (flg) {
8959371c9d4SSatish Balay           jac->ksp_type_idx = BATCH_KSP_TFQMR_IDX;
8969371c9d4SSatish Balay           jac->nwork        = 10;
8979371c9d4SSatish Balay         } else {
898a4313204SMark Adams #if defined(PETSC_HAVE_KOKKOS_KERNELS_BATCH)
8999566063dSJacob Faibussowitsch           PetscCall(PetscObjectTypeCompareAny((PetscObject)jac->ksp, &flg, KSPGMRES, ""));
9000fdf79fbSJacob Faibussowitsch           PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Unsupported batch ksp type");
901a4313204SMark Adams           jac->ksp_type_idx = BATCH_KSP_GMRESKK_IDX;
9029371c9d4SSatish Balay           jac->nwork        = 0;
903a4313204SMark Adams #else
904a4313204SMark Adams           KSPType ksptype;
905a4313204SMark Adams           PetscCall(KSPGetType(jac->ksp, &ksptype));
90600045ab3SPierre Jolivet           PetscCheck(flg, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Type: %s not supported in complex", ksptype);
907a4313204SMark Adams #endif
908e607c864SMark Adams         }
909a1e3af9bSmarkadams4       }
910a1e3af9bSmarkadams4       PetscOptionsBegin(PetscObjectComm((PetscObject)jac->ksp), ((PetscObject)jac->ksp)->prefix, "Options for Kokkos batch solver", "none");
911a1e3af9bSmarkadams4       PetscCall(PetscOptionsBool("-ksp_converged_reason", "", "bjkokkos.kokkos.cxx.c", jac->reason, &jac->reason, NULL));
912a1e3af9bSmarkadams4       PetscCall(PetscOptionsBool("-ksp_monitor", "", "bjkokkos.kokkos.cxx.c", jac->monitor, &jac->monitor, NULL));
913a1e3af9bSmarkadams4       PetscCall(PetscOptionsInt("-ksp_batch_target", "", "bjkokkos.kokkos.cxx.c", jac->batch_target, &jac->batch_target, NULL));
914a4313204SMark Adams       PetscCall(PetscOptionsInt("-ksp_rank_target", "", "bjkokkos.kokkos.cxx.c", jac->rank_target, &jac->rank_target, NULL));
915a1e3af9bSmarkadams4       PetscCall(PetscOptionsInt("-ksp_batch_nsolves_team", "", "bjkokkos.kokkos.cxx.c", jac->nsolves_team, &jac->nsolves_team, NULL));
9165f80ce2aSJacob Faibussowitsch       PetscCheck(jac->batch_target < jac->num_dms, PETSC_COMM_WORLD, PETSC_ERR_ARG_WRONG, "-ksp_batch_target (%" PetscInt_FMT ") >= number of DMs (%" PetscInt_FMT ")", jac->batch_target, jac->num_dms);
917a1e3af9bSmarkadams4       PetscOptionsEnd();
918e607c864SMark Adams       // get blocks - jac->d_bid_eqOffset_k
919a999f97fSMark Adams       if (pack) {
9209566063dSJacob Faibussowitsch         PetscCall(PetscMalloc(sizeof(*subX) * nDMs, &subX));
9219566063dSJacob Faibussowitsch         PetscCall(PetscMalloc(sizeof(*subDM) * nDMs, &subDM));
922a999f97fSMark Adams       }
9239566063dSJacob Faibussowitsch       PetscCall(PetscMalloc(sizeof(*jac->dm_Nf) * nDMs, &jac->dm_Nf));
924a999f97fSMark Adams       PetscCall(PetscInfo(pc, "Have %" PetscInt_FMT " blocks, n=%" PetscInt_FMT " rtol=%g type = %s\n", nDMs, n, (double)jac->ksp->rtol, ((PetscObject)jac->ksp)->type_name));
925a999f97fSMark Adams       if (pack) PetscCall(DMCompositeGetEntriesArray(pack, subDM));
926e607c864SMark Adams       jac->nBlocks = 0;
927e607c864SMark Adams       for (PetscInt ii = 0; ii < nDMs; ii++) {
928e607c864SMark Adams         PetscInt Nf;
929a999f97fSMark Adams         if (subDM) {
930e607c864SMark Adams           DM           dm = subDM[ii];
931a999f97fSMark Adams           PetscSection section;
9329566063dSJacob Faibussowitsch           PetscCall(DMGetLocalSection(dm, &section));
9339566063dSJacob Faibussowitsch           PetscCall(PetscSectionGetNumFields(section, &Nf));
934a999f97fSMark Adams         } else Nf = 1;
935e607c864SMark Adams         jac->nBlocks += Nf;
936e607c864SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
9379566063dSJacob Faibussowitsch         if (ii == 0) PetscCall(PetscInfo(pc, "%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n", ii, Nf, jac->nBlocks));
938e607c864SMark Adams #else
9399566063dSJacob Faibussowitsch         PetscCall(PetscInfo(pc, "%" PetscInt_FMT ") %" PetscInt_FMT " blocks (%" PetscInt_FMT " total)\n", ii, Nf, jac->nBlocks));
940e607c864SMark Adams #endif
941e607c864SMark Adams         jac->dm_Nf[ii] = Nf;
942e607c864SMark Adams       }
943e607c864SMark Adams       { // d_bid_eqOffset_k
944e607c864SMark Adams         Kokkos::View<PetscInt *, Kokkos::LayoutRight, Kokkos::HostSpace> h_block_offsets("block_offsets", jac->nBlocks + 1);
945a999f97fSMark Adams         if (pack) PetscCall(DMCompositeGetAccessArray(pack, jac->vec_diag, nDMs, NULL, subX));
946e607c864SMark Adams         h_block_offsets[0]    = 0;
947e607c864SMark Adams         jac->const_block_size = -1;
948e607c864SMark Adams         for (PetscInt ii = 0, idx = 0; ii < nDMs; ii++) {
949e607c864SMark Adams           PetscInt nloc, nblk;
950a999f97fSMark Adams           if (pack) PetscCall(VecGetSize(subX[ii], &nloc));
951a999f97fSMark Adams           else nloc = block_sizes[ii];
952e607c864SMark Adams           nblk = nloc / jac->dm_Nf[ii];
95363a3b9bcSJacob Faibussowitsch           PetscCheck(nloc % jac->dm_Nf[ii] == 0, PetscObjectComm((PetscObject)pc), PETSC_ERR_USER, "nloc%%jac->dm_Nf[ii] (%" PetscInt_FMT ") != 0 DMs", nloc % jac->dm_Nf[ii]);
954e607c864SMark Adams           for (PetscInt jj = 0; jj < jac->dm_Nf[ii]; jj++, idx++) {
955e607c864SMark Adams             h_block_offsets[idx + 1] = h_block_offsets[idx] + nblk;
956e607c864SMark Adams #if PCBJKOKKOS_VERBOSE_LEVEL <= 2
957a4313204SMark Adams             if (idx == 0) PetscCall(PetscInfo(pc, "Add first of %" PetscInt_FMT " blocks with %" PetscInt_FMT " equations\n", jac->nBlocks, nblk));
958e607c864SMark Adams #else
9599566063dSJacob Faibussowitsch             PetscCall(PetscInfo(pc, "\t%" PetscInt_FMT ") Add block with %" PetscInt_FMT " equations of %" PetscInt_FMT "\n", idx + 1, nblk, jac->nBlocks));
960e607c864SMark Adams #endif
961e607c864SMark Adams             if (jac->const_block_size == -1) jac->const_block_size = nblk;
962e607c864SMark Adams             else if (jac->const_block_size > 0 && jac->const_block_size != nblk) jac->const_block_size = 0;
963e607c864SMark Adams           }
964e607c864SMark Adams         }
965a999f97fSMark Adams         if (pack) {
9669566063dSJacob Faibussowitsch           PetscCall(DMCompositeRestoreAccessArray(pack, jac->vec_diag, jac->nBlocks, NULL, subX));
9679566063dSJacob Faibussowitsch           PetscCall(PetscFree(subX));
9689566063dSJacob Faibussowitsch           PetscCall(PetscFree(subDM));
969a999f97fSMark Adams         }
970e607c864SMark Adams         jac->d_bid_eqOffset_k = new Kokkos::View<PetscInt *, Kokkos::LayoutRight>(Kokkos::create_mirror(Kokkos::DefaultExecutionSpace::memory_space(), h_block_offsets));
971e607c864SMark Adams         Kokkos::deep_copy(*jac->d_bid_eqOffset_k, h_block_offsets);
972e607c864SMark Adams       }
973a999f97fSMark Adams       if (!pack) PetscCall(PetscFree(block_sizes));
974e607c864SMark Adams     }
975e607c864SMark Adams     { // get jac->d_idiag_k (PC setup),
976ed00a6c1SMark Adams       const PetscInt    *d_ai, *d_aj;
977ed00a6c1SMark Adams       const PetscScalar *d_aa;
978e607c864SMark Adams       const PetscInt     conc = Kokkos::DefaultExecutionSpace().concurrency(), openmp = !!(conc < 1000), team_size = (openmp == 0 && PCBJKOKKOS_VEC_SIZE != 1) ? PCBJKOKKOS_TEAM_SIZE : 1;
979a999f97fSMark Adams       const PetscInt    *d_bid_eqOffset = jac->d_bid_eqOffset_k->data(), *r = jac->d_isrow_k->data(), *ic = jac->d_isicol_k->data();
980ed00a6c1SMark Adams       PetscScalar       *d_idiag = jac->d_idiag_k->data(), *dummy;
981ed00a6c1SMark Adams       PetscMemType       mtype;
982ed00a6c1SMark Adams       PetscCall(MatSeqAIJGetCSRAndMemType(Aseq, &d_ai, &d_aj, &dummy, &mtype));
983ed00a6c1SMark Adams       d_aa = dummy;
9849371c9d4SSatish Balay       Kokkos::parallel_for(
9859371c9d4SSatish Balay         "Diag", Kokkos::TeamPolicy<>(jac->nBlocks, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA(const team_member team) {
986e607c864SMark Adams           const PetscInt blkID = team.league_rank();
9879371c9d4SSatish Balay           Kokkos::parallel_for(Kokkos::TeamThreadRange(team, d_bid_eqOffset[blkID], d_bid_eqOffset[blkID + 1]), [=](const int rowb) {
988e607c864SMark Adams             const PetscInt     rowa = ic[rowb], ai = d_ai[rowa], *aj = d_aj + ai; // grab original data
989e607c864SMark Adams             const PetscScalar *aa   = d_aa + ai;
990e607c864SMark Adams             const PetscInt     nrow = d_ai[rowa + 1] - ai;
991e607c864SMark Adams             int                found;
9929371c9d4SSatish Balay             Kokkos::parallel_reduce(
9939371c9d4SSatish Balay               Kokkos::ThreadVectorRange(team, nrow),
994e607c864SMark Adams               [=](const int &j, int &count) {
995e607c864SMark Adams                 const PetscInt colb = r[aj[j]];
996e607c864SMark Adams                 if (colb == rowb) {
997e607c864SMark Adams                   d_idiag[rowb] = 1. / aa[j];
998e607c864SMark Adams                   count++;
9999371c9d4SSatish Balay                 }
10009371c9d4SSatish Balay               },
10019371c9d4SSatish Balay               found);
1002a25f047fSMark Adams #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
1003e607c864SMark Adams             if (found != 1) Kokkos::single(Kokkos::PerThread(team), [=]() { printf("ERRORrow %d) found = %d\n", rowb, found); });
1004a25f047fSMark Adams #endif
1005e607c864SMark Adams           });
1006e607c864SMark Adams         });
1007e607c864SMark Adams     }
1008e607c864SMark Adams   }
10093ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1010e607c864SMark Adams }
1011e607c864SMark Adams 
1012e607c864SMark Adams /* Default destroy, if it has never been setup */
1013d71ae5a4SJacob Faibussowitsch static PetscErrorCode PCReset_BJKOKKOS(PC pc)
1014d71ae5a4SJacob Faibussowitsch {
1015e607c864SMark Adams   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1016e607c864SMark Adams 
1017e607c864SMark Adams   PetscFunctionBegin;
10189566063dSJacob Faibussowitsch   PetscCall(KSPDestroy(&jac->ksp));
10199566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&jac->vec_diag));
1020e607c864SMark Adams   if (jac->d_bid_eqOffset_k) delete jac->d_bid_eqOffset_k;
1021e607c864SMark Adams   if (jac->d_idiag_k) delete jac->d_idiag_k;
1022e607c864SMark Adams   if (jac->d_isrow_k) delete jac->d_isrow_k;
1023e607c864SMark Adams   if (jac->d_isicol_k) delete jac->d_isicol_k;
1024e607c864SMark Adams   jac->d_bid_eqOffset_k = NULL;
1025e607c864SMark Adams   jac->d_idiag_k        = NULL;
1026e607c864SMark Adams   jac->d_isrow_k        = NULL;
1027e607c864SMark Adams   jac->d_isicol_k       = NULL;
10289566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSGetKSP_C", NULL)); // not published now (causes configure errors)
10299566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSSetKSP_C", NULL));
10309566063dSJacob Faibussowitsch   PetscCall(PetscFree(jac->dm_Nf));
1031e607c864SMark Adams   jac->dm_Nf = NULL;
1032a1e3af9bSmarkadams4   if (jac->rowOffsets) delete jac->rowOffsets;
1033a1e3af9bSmarkadams4   if (jac->colIndices) delete jac->colIndices;
1034a1e3af9bSmarkadams4   if (jac->batch_b) delete jac->batch_b;
1035a1e3af9bSmarkadams4   if (jac->batch_x) delete jac->batch_x;
1036a1e3af9bSmarkadams4   if (jac->batch_values) delete jac->batch_values;
1037a1e3af9bSmarkadams4   jac->rowOffsets   = NULL;
1038a1e3af9bSmarkadams4   jac->colIndices   = NULL;
1039a1e3af9bSmarkadams4   jac->batch_b      = NULL;
1040a1e3af9bSmarkadams4   jac->batch_x      = NULL;
1041a1e3af9bSmarkadams4   jac->batch_values = NULL;
10423ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1043e607c864SMark Adams }
1044e607c864SMark Adams 
1045d71ae5a4SJacob Faibussowitsch static PetscErrorCode PCDestroy_BJKOKKOS(PC pc)
1046d71ae5a4SJacob Faibussowitsch {
1047e607c864SMark Adams   PetscFunctionBegin;
10489566063dSJacob Faibussowitsch   PetscCall(PCReset_BJKOKKOS(pc));
10499566063dSJacob Faibussowitsch   PetscCall(PetscFree(pc->data));
10503ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1051e607c864SMark Adams }
1052e607c864SMark Adams 
1053d71ae5a4SJacob Faibussowitsch static PetscErrorCode PCView_BJKOKKOS(PC pc, PetscViewer viewer)
1054d71ae5a4SJacob Faibussowitsch {
1055e607c864SMark Adams   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1056e607c864SMark Adams   PetscBool      iascii;
1057e607c864SMark Adams 
1058e607c864SMark Adams   PetscFunctionBegin;
10599566063dSJacob Faibussowitsch   if (!jac->ksp) PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
10609566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
1061e607c864SMark Adams   if (iascii) {
10629566063dSJacob Faibussowitsch     PetscCall(PetscViewerASCIIPrintf(viewer, "  Batched device linear solver: Krylov (KSP) method with Jacobi preconditioning\n"));
10639371c9d4SSatish Balay     PetscCall(PetscViewerASCIIPrintf(viewer, "\t\tnwork = %" PetscInt_FMT ", rel tol = %e, abs tol = %e, div tol = %e, max it =%" PetscInt_FMT ", type = %s\n", jac->nwork, jac->ksp->rtol, jac->ksp->abstol, jac->ksp->divtol, jac->ksp->max_it,
10645f80ce2aSJacob Faibussowitsch                                      ((PetscObject)jac->ksp)->type_name));
1065e607c864SMark Adams   }
10663ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1067e607c864SMark Adams }
1068e607c864SMark Adams 
1069d71ae5a4SJacob Faibussowitsch static PetscErrorCode PCSetFromOptions_BJKOKKOS(PC pc, PetscOptionItems *PetscOptionsObject)
1070d71ae5a4SJacob Faibussowitsch {
1071e607c864SMark Adams   PetscFunctionBegin;
1072d0609cedSBarry Smith   PetscOptionsHeadBegin(PetscOptionsObject, "PC BJKOKKOS options");
1073d0609cedSBarry Smith   PetscOptionsHeadEnd();
10743ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1075e607c864SMark Adams }
1076e607c864SMark Adams 
1077d71ae5a4SJacob Faibussowitsch static PetscErrorCode PCBJKOKKOSSetKSP_BJKOKKOS(PC pc, KSP ksp)
1078d71ae5a4SJacob Faibussowitsch {
1079e607c864SMark Adams   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1080e607c864SMark Adams 
1081e607c864SMark Adams   PetscFunctionBegin;
10829566063dSJacob Faibussowitsch   PetscCall(PetscObjectReference((PetscObject)ksp));
10839566063dSJacob Faibussowitsch   PetscCall(KSPDestroy(&jac->ksp));
1084e607c864SMark Adams   jac->ksp = ksp;
10853ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1086e607c864SMark Adams }
1087e607c864SMark Adams 
1088cc4c1da9SBarry Smith /*@
1089f1580f4eSBarry Smith   PCBJKOKKOSSetKSP - Sets the `KSP` context for `PCBJKOKKOS`
1090e607c864SMark Adams 
1091c3339decSBarry Smith   Collective
1092e607c864SMark Adams 
1093e607c864SMark Adams   Input Parameters:
1094f1580f4eSBarry Smith + pc  - the `PCBJKOKKOS` preconditioner context
1095f1580f4eSBarry Smith - ksp - the `KSP` solver
1096e607c864SMark Adams 
10972fe279fdSBarry Smith   Level: advanced
10982fe279fdSBarry Smith 
1099e607c864SMark Adams   Notes:
1100f1580f4eSBarry Smith   The `PC` and the `KSP` must have the same communicator
1101f1580f4eSBarry Smith 
1102f1580f4eSBarry Smith   If the `PC` is not `PCBJKOKKOS` this function returns without doing anything
1103e607c864SMark Adams 
1104562efe2eSBarry Smith .seealso: [](ch_ksp), `PCBJKOKKOSGetKSP()`, `PCBJKOKKOS`
1105e607c864SMark Adams @*/
1106d71ae5a4SJacob Faibussowitsch PetscErrorCode PCBJKOKKOSSetKSP(PC pc, KSP ksp)
1107d71ae5a4SJacob Faibussowitsch {
1108e607c864SMark Adams   PetscFunctionBegin;
1109e607c864SMark Adams   PetscValidHeaderSpecific(pc, PC_CLASSID, 1);
1110e607c864SMark Adams   PetscValidHeaderSpecific(ksp, KSP_CLASSID, 2);
1111e607c864SMark Adams   PetscCheckSameComm(pc, 1, ksp, 2);
1112cac4c232SBarry Smith   PetscTryMethod(pc, "PCBJKOKKOSSetKSP_C", (PC, KSP), (pc, ksp));
11133ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1114e607c864SMark Adams }
1115e607c864SMark Adams 
1116d71ae5a4SJacob Faibussowitsch static PetscErrorCode PCBJKOKKOSGetKSP_BJKOKKOS(PC pc, KSP *ksp)
1117d71ae5a4SJacob Faibussowitsch {
1118e607c864SMark Adams   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
1119e607c864SMark Adams 
1120e607c864SMark Adams   PetscFunctionBegin;
11219566063dSJacob Faibussowitsch   if (!jac->ksp) PetscCall(PCBJKOKKOSCreateKSP_BJKOKKOS(pc));
1122e607c864SMark Adams   *ksp = jac->ksp;
11233ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1124e607c864SMark Adams }
1125e607c864SMark Adams 
1126cc4c1da9SBarry Smith /*@
1127f1580f4eSBarry Smith   PCBJKOKKOSGetKSP - Gets the `KSP` context for the `PCBJKOKKOS` preconditioner
1128e607c864SMark Adams 
1129f1580f4eSBarry Smith   Not Collective but `KSP` returned is parallel if `PC` was parallel
1130e607c864SMark Adams 
1131e607c864SMark Adams   Input Parameter:
1132e607c864SMark Adams . pc - the preconditioner context
1133e607c864SMark Adams 
1134f1580f4eSBarry Smith   Output Parameter:
1135f1580f4eSBarry Smith . ksp - the `KSP` solver
1136e607c864SMark Adams 
11372fe279fdSBarry Smith   Level: advanced
11382fe279fdSBarry Smith 
1139e607c864SMark Adams   Notes:
1140f1580f4eSBarry Smith   You must call `KSPSetUp()` before calling `PCBJKOKKOSGetKSP()`.
1141e607c864SMark Adams 
1142f1580f4eSBarry Smith   If the `PC` is not a `PCBJKOKKOS` object it raises an error
1143e607c864SMark Adams 
1144562efe2eSBarry Smith .seealso: [](ch_ksp), `PCBJKOKKOS`, `PCBJKOKKOSSetKSP()`
1145e607c864SMark Adams @*/
1146d71ae5a4SJacob Faibussowitsch PetscErrorCode PCBJKOKKOSGetKSP(PC pc, KSP *ksp)
1147d71ae5a4SJacob Faibussowitsch {
1148e607c864SMark Adams   PetscFunctionBegin;
1149e607c864SMark Adams   PetscValidHeaderSpecific(pc, PC_CLASSID, 1);
11504f572ea9SToby Isaac   PetscAssertPointer(ksp, 2);
1151cac4c232SBarry Smith   PetscUseMethod(pc, "PCBJKOKKOSGetKSP_C", (PC, KSP *), (pc, ksp));
11523ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1153e607c864SMark Adams }
1154e607c864SMark Adams 
11559bdd0a17SMark Adams static PetscErrorCode PCPostSolve_BJKOKKOS(PC pc, KSP ksp, Vec b, Vec x)
11569bdd0a17SMark Adams {
11579bdd0a17SMark Adams   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
11589bdd0a17SMark Adams 
11599bdd0a17SMark Adams   PetscFunctionBegin;
11609bdd0a17SMark Adams   PetscValidHeaderSpecific(pc, PC_CLASSID, 1);
11619bdd0a17SMark Adams   ksp->its = jac->max_nits;
11629bdd0a17SMark Adams   PetscFunctionReturn(PETSC_SUCCESS);
11639bdd0a17SMark Adams }
11649bdd0a17SMark Adams 
11659bdd0a17SMark Adams static PetscErrorCode PCPreSolve_BJKOKKOS(PC pc, KSP ksp, Vec b, Vec x)
11669bdd0a17SMark Adams {
11679bdd0a17SMark Adams   PC_PCBJKOKKOS *jac = (PC_PCBJKOKKOS *)pc->data;
11689bdd0a17SMark Adams 
11699bdd0a17SMark Adams   PetscFunctionBegin;
11709bdd0a17SMark Adams   PetscValidHeaderSpecific(pc, PC_CLASSID, 1);
11719bdd0a17SMark Adams   jac->ksp->errorifnotconverged = ksp->errorifnotconverged;
11729bdd0a17SMark Adams   PetscFunctionReturn(PETSC_SUCCESS);
11739bdd0a17SMark Adams }
11749bdd0a17SMark Adams 
1175e607c864SMark Adams /*MC
1176*8e1aa562SMark Adams      PCBJKOKKOS - A batched Krylov/block Jacobi solver that runs a solve of each diagaonl block of a block diagonal `MATSEQAIJ` in a Kokkos thread group
1177e607c864SMark Adams 
1178e607c864SMark Adams    Options Database Key:
1179f1580f4eSBarry Smith .     -pc_bjkokkos_ - options prefix for its `KSP` options
1180e607c864SMark Adams 
1181e607c864SMark Adams    Level: intermediate
1182e607c864SMark Adams 
1183f1580f4eSBarry Smith    Note:
1184f1580f4eSBarry Smith     For use with -ksp_type preonly to bypass any computation on the CPU
1185e607c864SMark Adams 
1186e607c864SMark Adams    Developer Notes:
1187*8e1aa562SMark Adams    The entire Krylov (TFQMR or BICG) with diagonal preconditioning for each block of a block diagnaol matrix runs in a Kokkos thread group (eg, one block per SM on NVIDIA). It supports taking a non-block diagonal matrix but this is not tested. One should create an explicit block diagonal matrix and use that as the preconditioning matrix in the outer KSP solver. Varaible block size are supported and tested in src/ts/utils/dmplexlandau/tutorials/ex[1|2].c
1188f1580f4eSBarry Smith 
1189562efe2eSBarry Smith .seealso: [](ch_ksp), `PCCreate()`, `PCSetType()`, `PCType`, `PC`, `PCBJACOBI`,
1190db781477SPatrick Sanan           `PCSHELL`, `PCCOMPOSITE`, `PCSetUseAmat()`, `PCBJKOKKOSGetKSP()`
1191e607c864SMark Adams M*/
1192e607c864SMark Adams 
1193d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode PCCreate_BJKOKKOS(PC pc)
1194d71ae5a4SJacob Faibussowitsch {
1195e607c864SMark Adams   PC_PCBJKOKKOS *jac;
1196e607c864SMark Adams 
1197e607c864SMark Adams   PetscFunctionBegin;
11984dfa11a4SJacob Faibussowitsch   PetscCall(PetscNew(&jac));
1199e607c864SMark Adams   pc->data = (void *)jac;
1200e607c864SMark Adams 
1201e607c864SMark Adams   jac->ksp              = NULL;
1202e607c864SMark Adams   jac->vec_diag         = NULL;
1203e607c864SMark Adams   jac->d_bid_eqOffset_k = NULL;
1204e607c864SMark Adams   jac->d_idiag_k        = NULL;
1205e607c864SMark Adams   jac->d_isrow_k        = NULL;
1206e607c864SMark Adams   jac->d_isicol_k       = NULL;
1207e607c864SMark Adams   jac->nBlocks          = 1;
1208a1e3af9bSmarkadams4   jac->max_nits         = 0;
1209e607c864SMark Adams 
12109566063dSJacob Faibussowitsch   PetscCall(PetscMemzero(pc->ops, sizeof(struct _PCOps)));
1211e607c864SMark Adams   pc->ops->apply          = PCApply_BJKOKKOS;
1212e607c864SMark Adams   pc->ops->applytranspose = NULL;
1213e607c864SMark Adams   pc->ops->setup          = PCSetUp_BJKOKKOS;
1214e607c864SMark Adams   pc->ops->reset          = PCReset_BJKOKKOS;
1215e607c864SMark Adams   pc->ops->destroy        = PCDestroy_BJKOKKOS;
1216e607c864SMark Adams   pc->ops->setfromoptions = PCSetFromOptions_BJKOKKOS;
1217e607c864SMark Adams   pc->ops->view           = PCView_BJKOKKOS;
12189bdd0a17SMark Adams   pc->ops->postsolve      = PCPostSolve_BJKOKKOS;
12199bdd0a17SMark Adams   pc->ops->presolve       = PCPreSolve_BJKOKKOS;
1220a1e3af9bSmarkadams4 
1221a1e3af9bSmarkadams4   jac->rowOffsets   = NULL;
1222a1e3af9bSmarkadams4   jac->colIndices   = NULL;
1223a1e3af9bSmarkadams4   jac->batch_b      = NULL;
1224a1e3af9bSmarkadams4   jac->batch_x      = NULL;
1225a1e3af9bSmarkadams4   jac->batch_values = NULL;
1226e607c864SMark Adams 
12279566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSGetKSP_C", PCBJKOKKOSGetKSP_BJKOKKOS));
12289566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)pc, "PCBJKOKKOSSetKSP_C", PCBJKOKKOSSetKSP_BJKOKKOS));
12293ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1230e607c864SMark Adams }
1231