xref: /petsc/src/ksp/pc/impls/bjacobi/bjkokkos/bjkokkoskernels.kokkos.cxx (revision a43132046c00290928454b92120f2306b39b67d0)
1*a4313204SMark Adams #include <petsc/private/pcbjkokkosimpl.h>
2*a4313204SMark Adams 
3*a4313204SMark Adams #if defined(PETSC_HAVE_KOKKOS_KERNELS_BATCH)
4*a4313204SMark Adams   #include <fstream>
5*a4313204SMark Adams 
6*a4313204SMark Adams   #include "Kokkos_Timer.hpp"
7*a4313204SMark Adams   #include "Kokkos_Random.hpp"
8*a4313204SMark Adams   #include "Kokkos_UnorderedMap.hpp"
9*a4313204SMark Adams   #include "Kokkos_Sort.hpp"
10*a4313204SMark Adams 
11*a4313204SMark Adams   /// KokkosKernels headers
12*a4313204SMark Adams   #include "KokkosBatched_Util.hpp"
13*a4313204SMark Adams   #include "KokkosBatched_Vector.hpp"
14*a4313204SMark Adams 
15*a4313204SMark Adams   #include <Kokkos_ArithTraits.hpp>
16*a4313204SMark Adams   #include <KokkosBatched_Util.hpp>
17*a4313204SMark Adams   #include <KokkosBatched_Vector.hpp>
18*a4313204SMark Adams   #include <KokkosBatched_Copy_Decl.hpp>
19*a4313204SMark Adams   #include <KokkosBatched_Copy_Impl.hpp>
20*a4313204SMark Adams   #include <KokkosBatched_AddRadial_Decl.hpp>
21*a4313204SMark Adams   #include <KokkosBatched_AddRadial_Impl.hpp>
22*a4313204SMark Adams   #include <KokkosBatched_Gemm_Decl.hpp>
23*a4313204SMark Adams   #include <KokkosBatched_Gemm_Serial_Impl.hpp>
24*a4313204SMark Adams   #include <KokkosBatched_Gemm_Team_Impl.hpp>
25*a4313204SMark Adams   #include <KokkosBatched_Gemv_Decl.hpp>
26*a4313204SMark Adams   // #include <KokkosBatched_Gemv_Serial_Impl.hpp>
27*a4313204SMark Adams   #include <KokkosBatched_Gemv_Team_Impl.hpp>
28*a4313204SMark Adams   #include <KokkosBatched_Trsm_Decl.hpp>
29*a4313204SMark Adams   #include <KokkosBatched_Trsm_Serial_Impl.hpp>
30*a4313204SMark Adams   #include <KokkosBatched_Trsm_Team_Impl.hpp>
31*a4313204SMark Adams   #include <KokkosBatched_Trsv_Decl.hpp>
32*a4313204SMark Adams   #include <KokkosBatched_Trsv_Serial_Impl.hpp>
33*a4313204SMark Adams   #include <KokkosBatched_Trsv_Team_Impl.hpp>
34*a4313204SMark Adams   #include <KokkosBatched_LU_Decl.hpp>
35*a4313204SMark Adams   #include <KokkosBatched_LU_Serial_Impl.hpp>
36*a4313204SMark Adams   #include <KokkosBatched_LU_Team_Impl.hpp>
37*a4313204SMark Adams   #include <KokkosSparse_CrsMatrix.hpp>
38*a4313204SMark Adams   #include "KokkosBatched_Spmv.hpp"
39*a4313204SMark Adams   #include "KokkosBatched_CrsMatrix.hpp"
40*a4313204SMark Adams   #include "KokkosBatched_Krylov_Handle.hpp"
41*a4313204SMark Adams 
42*a4313204SMark Adams   #include "KokkosBatched_GMRES.hpp"
43*a4313204SMark Adams   #include "KokkosBatched_JacobiPrec.hpp"
44*a4313204SMark Adams 
45*a4313204SMark Adams template <typename DeviceType, typename ValuesViewType, typename IntView, typename VectorViewType, typename KrylovHandleType>
46*a4313204SMark Adams struct Functor_TestBatchedTeamVectorGMRES {
47*a4313204SMark Adams   const ValuesViewType _D;
48*a4313204SMark Adams   const ValuesViewType _diag;
49*a4313204SMark Adams   const IntView        _r;
50*a4313204SMark Adams   const IntView        _c;
51*a4313204SMark Adams   const VectorViewType _X;
52*a4313204SMark Adams   const VectorViewType _B;
53*a4313204SMark Adams   const int            _N_team, _team_size, _vector_length;
54*a4313204SMark Adams   const int            _N_iteration;
55*a4313204SMark Adams   const double         _tol;
56*a4313204SMark Adams   const int            _ortho_strategy;
57*a4313204SMark Adams   const int            _scratch_pad_level;
58*a4313204SMark Adams   KrylovHandleType     _handle;
59*a4313204SMark Adams 
60*a4313204SMark Adams   KOKKOS_INLINE_FUNCTION
61*a4313204SMark Adams   Functor_TestBatchedTeamVectorGMRES(const ValuesViewType &D, const IntView &r, const IntView &c, const VectorViewType &X, const VectorViewType &B, const int N_team, const int team_size, const int vector_length, const int N_iteration, const double tol, const int ortho_strategy, const int scratch_pad_level, KrylovHandleType &handle) :
62*a4313204SMark Adams     _D(D), _r(r), _c(c), _X(X), _B(B), _N_team(N_team), _team_size(team_size), _vector_length(vector_length), _N_iteration(N_iteration), _tol(tol), _ortho_strategy(ortho_strategy), _scratch_pad_level(scratch_pad_level), _handle(handle)
63*a4313204SMark Adams   {
64*a4313204SMark Adams   }
65*a4313204SMark Adams 
66*a4313204SMark Adams   KOKKOS_INLINE_FUNCTION
67*a4313204SMark Adams   Functor_TestBatchedTeamVectorGMRES(const ValuesViewType &D, const ValuesViewType &diag, const IntView &r, const IntView &c, const VectorViewType &X, const VectorViewType &B, const int N_team, const int team_size, const int vector_length, const int N_iteration, const double tol, int ortho_strategy, const int scratch_pad_level, KrylovHandleType &handle) :
68*a4313204SMark Adams     _D(D), _diag(diag), _r(r), _c(c), _X(X), _B(B), _N_team(N_team), _team_size(team_size), _vector_length(vector_length), _N_iteration(N_iteration), _tol(tol), _ortho_strategy(ortho_strategy), _scratch_pad_level(scratch_pad_level), _handle(handle)
69*a4313204SMark Adams   {
70*a4313204SMark Adams   }
71*a4313204SMark Adams 
72*a4313204SMark Adams   template <typename MemberType>
73*a4313204SMark Adams   KOKKOS_INLINE_FUNCTION void operator()(const MemberType &member) const
74*a4313204SMark Adams   {
75*a4313204SMark Adams     const int first_matrix = static_cast<int>(member.league_rank()) * _N_team;
76*a4313204SMark Adams     const int N            = _D.extent(0);
77*a4313204SMark Adams     const int last_matrix  = (static_cast<int>(member.league_rank() + 1) * _N_team < N ? static_cast<int>(member.league_rank() + 1) * _N_team : N);
78*a4313204SMark Adams     const int graphID      = static_cast<int>(member.league_rank());
79*a4313204SMark Adams     using TeamVectorCopy1D = KokkosBatched::TeamVectorCopy<MemberType, KokkosBatched::Trans::NoTranspose, 1>;
80*a4313204SMark Adams 
81*a4313204SMark Adams     auto d                         = Kokkos::subview(_D, Kokkos::make_pair(first_matrix, last_matrix), Kokkos::ALL);
82*a4313204SMark Adams     auto x                         = Kokkos::subview(_X, Kokkos::make_pair(first_matrix, last_matrix), Kokkos::ALL);
83*a4313204SMark Adams     auto b                         = Kokkos::subview(_B, Kokkos::make_pair(first_matrix, last_matrix), Kokkos::ALL);
84*a4313204SMark Adams     using ScratchPadIntViewType    = Kokkos::View<typename IntView::non_const_value_type *, typename IntView::array_layout, typename IntView::execution_space::scratch_memory_space>;
85*a4313204SMark Adams     using ScratchPadValuesViewType = Kokkos::View<typename ValuesViewType::non_const_value_type **, typename ValuesViewType::array_layout, typename ValuesViewType::execution_space::scratch_memory_space>;
86*a4313204SMark Adams 
87*a4313204SMark Adams     using Operator = KokkosBatched::CrsMatrix<ValuesViewType, ScratchPadIntViewType>;
88*a4313204SMark Adams     ScratchPadIntViewType r(member.team_scratch(1), _r.extent(1));
89*a4313204SMark Adams     ScratchPadIntViewType c(member.team_scratch(1), _c.extent(1));
90*a4313204SMark Adams 
91*a4313204SMark Adams     TeamVectorCopy1D::invoke(member, Kokkos::subview(_r, graphID, Kokkos::ALL), r);
92*a4313204SMark Adams     TeamVectorCopy1D::invoke(member, Kokkos::subview(_c, graphID, Kokkos::ALL), c);
93*a4313204SMark Adams     Operator A(d, r, c);
94*a4313204SMark Adams 
95*a4313204SMark Adams     ScratchPadValuesViewType diag(member.team_scratch(1), last_matrix - first_matrix, _diag.extent(1));
96*a4313204SMark Adams     using PrecOperator = KokkosBatched::JacobiPrec<ScratchPadValuesViewType>;
97*a4313204SMark Adams 
98*a4313204SMark Adams     KokkosBatched::TeamVectorCopy<MemberType>::invoke(member, Kokkos::subview(_diag, Kokkos::make_pair(first_matrix, last_matrix), Kokkos::ALL), diag);
99*a4313204SMark Adams     PrecOperator P(diag);
100*a4313204SMark Adams     P.setComputedInverse();
101*a4313204SMark Adams 
102*a4313204SMark Adams     KokkosBatched::TeamVectorGMRES<MemberType>::template invoke<Operator, VectorViewType, PrecOperator, KrylovHandleType>(member, A, b, x, P, _handle);
103*a4313204SMark Adams   }
104*a4313204SMark Adams   inline double run(PC pc)
105*a4313204SMark Adams   {
106*a4313204SMark Adams     //typedef typename ValuesViewType::value_type value_type;
107*a4313204SMark Adams     std::string   name("KokkosBatched::Test::TeamVectorGMRES");
108*a4313204SMark Adams     Kokkos::Timer timer;
109*a4313204SMark Adams     Kokkos::Profiling::pushRegion(name.c_str());
110*a4313204SMark Adams 
111*a4313204SMark Adams     Kokkos::TeamPolicy<DeviceType> auto_policy(ceil(1. * _D.extent(0) / _N_team), Kokkos::AUTO(), Kokkos::AUTO());
112*a4313204SMark Adams     Kokkos::TeamPolicy<DeviceType> tuned_policy(ceil(1. * _D.extent(0) / _N_team), _team_size, _vector_length);
113*a4313204SMark Adams     Kokkos::TeamPolicy<DeviceType> policy;
114*a4313204SMark Adams 
115*a4313204SMark Adams     if (_team_size < 1) policy = auto_policy;
116*a4313204SMark Adams     else policy = tuned_policy;
117*a4313204SMark Adams 
118*a4313204SMark Adams     _handle.set_max_iteration(_N_iteration);
119*a4313204SMark Adams     _handle.set_tolerance(_tol);
120*a4313204SMark Adams     _handle.set_ortho_strategy(_ortho_strategy);
121*a4313204SMark Adams     _handle.set_scratch_pad_level(_scratch_pad_level);
122*a4313204SMark Adams     _handle.set_compute_last_residual(true);
123*a4313204SMark Adams 
124*a4313204SMark Adams     int maximum_iteration = _handle.get_max_iteration();
125*a4313204SMark Adams 
126*a4313204SMark Adams     using ScalarType = typename ValuesViewType::non_const_value_type;
127*a4313204SMark Adams     using Layout     = typename ValuesViewType::array_layout;
128*a4313204SMark Adams     using EXSP       = typename ValuesViewType::execution_space;
129*a4313204SMark Adams 
130*a4313204SMark Adams     using ViewType2D    = Kokkos::View<ScalarType **, Layout, EXSP>;
131*a4313204SMark Adams     using IntViewType1D = Kokkos::View<PetscInt *, Layout, EXSP>;
132*a4313204SMark Adams 
133*a4313204SMark Adams     size_t bytes_1D      = ViewType2D::shmem_size(_N_team, 1);
134*a4313204SMark Adams     size_t bytes_row_ptr = IntViewType1D::shmem_size(_r.extent(1));
135*a4313204SMark Adams     size_t bytes_col_idc = IntViewType1D::shmem_size(_c.extent(1));
136*a4313204SMark Adams     size_t bytes_2D_1    = ViewType2D::shmem_size(_N_team, _X.extent(1));
137*a4313204SMark Adams     size_t bytes_2D_2    = ViewType2D::shmem_size(_N_team, maximum_iteration + 1);
138*a4313204SMark Adams 
139*a4313204SMark Adams     size_t bytes_diag = bytes_2D_1;
140*a4313204SMark Adams     size_t bytes_tmp  = 2 * bytes_2D_1 + 2 * bytes_1D + bytes_2D_2;
141*a4313204SMark Adams 
142*a4313204SMark Adams     policy.set_scratch_size(0, Kokkos::PerTeam(bytes_tmp));
143*a4313204SMark Adams     policy.set_scratch_size(1, Kokkos::PerTeam(bytes_col_idc + bytes_row_ptr + bytes_diag));
144*a4313204SMark Adams     PetscCall(PetscInfo(pc, "%d scratch memory(0) = %d + %d + %d bytes_diag=%d; %d scratch memory(1); %d maximum_iterations\n", (int)(bytes_tmp), 2 * (int)bytes_2D_1, 2 * (int)bytes_1D, (int)bytes_2D_2, (int)bytes_diag, (int)(bytes_row_ptr + bytes_col_idc + bytes_diag), (int)maximum_iteration));
145*a4313204SMark Adams     exec_space().fence();
146*a4313204SMark Adams     timer.reset();
147*a4313204SMark Adams     Kokkos::parallel_for(name.c_str(), policy, *this);
148*a4313204SMark Adams     exec_space().fence();
149*a4313204SMark Adams     double sec = timer.seconds();
150*a4313204SMark Adams 
151*a4313204SMark Adams     return sec;
152*a4313204SMark Adams   }
153*a4313204SMark Adams };
154*a4313204SMark Adams 
155*a4313204SMark Adams PETSC_INTERN PetscErrorCode PCApply_BJKOKKOSKERNELS(PC pc, const PetscScalar *glb_bdata, PetscScalar *glb_xdata, const PetscInt *glb_Aai, const PetscInt *glb_Aaj, const PetscScalar *glb_Aaa, const PetscInt team_size, MatInfo info, const PetscInt batch_sz, PCFailedReason *pcreason)
156*a4313204SMark Adams {
157*a4313204SMark Adams   PC_PCBJKOKKOS     *jac   = (PC_PCBJKOKKOS *)pc->data;
158*a4313204SMark Adams   Mat                A     = pc->pmat;
159*a4313204SMark Adams   const PetscInt     maxit = jac->ksp->max_it, nBlk = jac->nBlocks;
160*a4313204SMark Adams   const int          Nsolves      = nBlk;
161*a4313204SMark Adams   int                Nsolves_team = jac->nsolves_team, fill_idx = 0;
162*a4313204SMark Adams   int                Nloc           = jac->const_block_size;       // same grids
163*a4313204SMark Adams   const int          nnz            = (int)info.nz_used / Nsolves; // fix for variable grid size
164*a4313204SMark Adams   PetscReal          rtol           = jac->ksp->rtol;
165*a4313204SMark Adams   const PetscScalar *glb_idiag      = jac->d_idiag_k->data();
166*a4313204SMark Adams   const PetscInt    *d_bid_eqOffset = jac->d_bid_eqOffset_k->data();
167*a4313204SMark Adams   const PetscInt    *d_isicol = jac->d_isicol_k->data(), *d_isrow = jac->d_isrow_k->data();
168*a4313204SMark Adams 
169*a4313204SMark Adams   PetscFunctionBegin;
170*a4313204SMark Adams   if (Nsolves_team > batch_sz) Nsolves_team = batch_sz; // silently fix this
171*a4313204SMark Adams   PetscCheck(jac->const_block_size, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Kokkos (GMRES) solver requires constant block size (but can be made to work with species ordering or N_team==1)");
172*a4313204SMark Adams   PetscCheck(Nsolves % Nsolves_team == 0, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "Nsolves.mod(Nsolves_team) != 0: Nsolves = %d, Nsolves_team = %d", Nsolves, Nsolves_team);
173*a4313204SMark Adams   PetscCheck(((int)info.nz_used) % Nsolves == 0, PetscObjectComm((PetscObject)pc), PETSC_ERR_ARG_WRONG, "info.nz_used.mod(Nsolves) != 0: info.nz_used = %g, Nsolves = %d", info.nz_used, Nsolves);
174*a4313204SMark Adams   #if defined(PETSC_HAVE_CUDA)
175*a4313204SMark Adams   nvtxRangePushA("gmres-kk");
176*a4313204SMark Adams   #endif
177*a4313204SMark Adams   Kokkos::View<PetscScalar **, layout, exec_space, Kokkos::MemoryTraits<Kokkos::Unmanaged>> inv_diag((PetscScalar *)glb_idiag, Nsolves, Nloc); // in correct order
178*a4313204SMark Adams   if (!jac->rowOffsets) {
179*a4313204SMark Adams     jac->rowOffsets   = new IntView("rowOffsets", Nsolves / Nsolves_team, Nloc + 1); // same grids
180*a4313204SMark Adams     jac->colIndices   = new IntView("colIndices", Nsolves / Nsolves_team, nnz);
181*a4313204SMark Adams     jac->batch_b      = new XYType("batch rhs", Nsolves, Nloc);
182*a4313204SMark Adams     jac->batch_x      = new XYType("batch sol", Nsolves, Nloc);
183*a4313204SMark Adams     jac->batch_values = new AMatrixValueView("batch values", Nsolves, nnz);
184*a4313204SMark Adams     fill_idx          = 1;
185*a4313204SMark Adams     PetscCall(PetscInfo(pc, "Setup indices Nloc=%d, nnz=%d\n", Nloc, nnz));
186*a4313204SMark Adams   }
187*a4313204SMark Adams   IntView          &rowOffsets   = *jac->rowOffsets;
188*a4313204SMark Adams   IntView          &colIndices   = *jac->colIndices;
189*a4313204SMark Adams   XYType           &batch_b      = *jac->batch_b;
190*a4313204SMark Adams   XYType           &batch_x      = *jac->batch_x;
191*a4313204SMark Adams   AMatrixValueView &batch_values = *jac->batch_values;
192*a4313204SMark Adams 
193*a4313204SMark Adams   Kokkos::deep_copy(batch_x, 0.);
194*a4313204SMark Adams   PetscCall(PetscInfo(pc, "\tjac->n = %d, Nloc = %d, Nsolves = %d, nnz = %d, Nsolves_team = %d, league size = %d, maxit = %d\n", (int)jac->n, Nloc, Nsolves, nnz, Nsolves_team, Nsolves / Nsolves_team, (int)maxit));
195*a4313204SMark Adams   Kokkos::parallel_for(
196*a4313204SMark Adams     "rowOffsets+map", Kokkos::TeamPolicy<>(Nsolves, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA(const team_member team) {
197*a4313204SMark Adams       const int blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID + 1];
198*a4313204SMark Adams       if (fill_idx) {
199*a4313204SMark Adams         if (blkID % Nsolves_team == 0) {                                                        // first matrix on this member
200*a4313204SMark Adams           Kokkos::parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](const int rowb) { // Nloc
201*a4313204SMark Adams             int rowa                                           = d_isicol[rowb];
202*a4313204SMark Adams             int n                                              = glb_Aai[rowa + 1] - glb_Aai[rowa];
203*a4313204SMark Adams             rowOffsets(blkID / Nsolves_team, rowb + 1 - start) = n; // save sizes
204*a4313204SMark Adams           });
205*a4313204SMark Adams         }
206*a4313204SMark Adams       }
207*a4313204SMark Adams       // map b into field major space
208*a4313204SMark Adams       Kokkos::parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
209*a4313204SMark Adams         int rowa                     = d_isicol[rowb];
210*a4313204SMark Adams         batch_b(blkID, rowb - start) = glb_bdata[rowa];
211*a4313204SMark Adams       });
212*a4313204SMark Adams     });
213*a4313204SMark Adams   Kokkos::fence();
214*a4313204SMark Adams   if (fill_idx) {
215*a4313204SMark Adams     Kokkos::parallel_for(
216*a4313204SMark Adams       "prefix sum", Kokkos::TeamPolicy<>(Nsolves / Nsolves_team, 1, 1), KOKKOS_LAMBDA(const team_member team) {
217*a4313204SMark Adams         const int graphID      = team.league_rank();
218*a4313204SMark Adams         rowOffsets(graphID, 0) = 0;
219*a4313204SMark Adams         for (int i = 0; i < Nloc; ++i) rowOffsets(graphID, i + 1) += rowOffsets(graphID, i);
220*a4313204SMark Adams       });
221*a4313204SMark Adams     Kokkos::fence();
222*a4313204SMark Adams   }
223*a4313204SMark Adams   Kokkos::parallel_for(
224*a4313204SMark Adams     "copy matrix", Kokkos::TeamPolicy<>(Nsolves /* /batch_sz */, team_size, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA(const team_member team) {
225*a4313204SMark Adams       const int blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID + 1], graphID = blkID / Nsolves_team;
226*a4313204SMark Adams       Kokkos::parallel_for(Kokkos::TeamThreadRange(team, start, end), [=](const int rowb) {
227*a4313204SMark Adams         int                rowa = d_isicol[rowb];
228*a4313204SMark Adams         int                n    = glb_Aai[rowa + 1] - glb_Aai[rowa];
229*a4313204SMark Adams         const PetscInt    *aj   = glb_Aaj + glb_Aai[rowa]; // global index
230*a4313204SMark Adams         const PetscScalar *aa   = glb_Aaa + glb_Aai[rowa];
231*a4313204SMark Adams         Kokkos::parallel_for(Kokkos::ThreadVectorRange(team, n), [=](const int &i) {
232*a4313204SMark Adams           PetscScalar val = aa[i];
233*a4313204SMark Adams           if (fill_idx && blkID % Nsolves_team == 0) colIndices(graphID, rowOffsets(graphID, rowb - start) + i) = d_isrow[aj[i]] - blkID * Nloc; // local" global - block start
234*a4313204SMark Adams           batch_values(blkID, rowOffsets(graphID, rowb - start) + i) = val;
235*a4313204SMark Adams         });
236*a4313204SMark Adams       });
237*a4313204SMark Adams     });
238*a4313204SMark Adams   Kokkos::fence();
239*a4313204SMark Adams   // setup solver
240*a4313204SMark Adams   using ScalarType    = typename AMatrixValueView::non_const_value_type;
241*a4313204SMark Adams   using MagnitudeType = typename Kokkos::Details::ArithTraits<ScalarType>::mag_type;
242*a4313204SMark Adams   //using NormViewType              = Kokkos::View<MagnitudeType *, layout, exec_space>;
243*a4313204SMark Adams   using Norm2DViewType   = Kokkos::View<MagnitudeType **, layout, exec_space>;
244*a4313204SMark Adams   using Scalar3DViewType = Kokkos::View<ScalarType ***, layout, exec_space>;
245*a4313204SMark Adams   using IntViewType      = Kokkos::View<int *, layout, exec_space>;
246*a4313204SMark Adams   using KrylovHandleType = KokkosBatched::KrylovHandle<Norm2DViewType, IntViewType, Scalar3DViewType>;
247*a4313204SMark Adams   const int n_iterations = maxit;
248*a4313204SMark Adams   //const int        team_size      = -1;
249*a4313204SMark Adams   const int        vector_length  = -1;
250*a4313204SMark Adams   const double     tol            = rtol;
251*a4313204SMark Adams   const int        ortho_strategy = 0;
252*a4313204SMark Adams   KrylovHandleType handle(Nsolves, Nsolves_team, n_iterations, true);
253*a4313204SMark Adams   handle.Arnoldi_view = Scalar3DViewType("", Nsolves, n_iterations, Nloc + n_iterations + 3);
254*a4313204SMark Adams   // solve
255*a4313204SMark Adams   Functor_TestBatchedTeamVectorGMRES<exec_space, AMatrixValueView, IntView, XYType, KrylovHandleType>(batch_values, inv_diag, rowOffsets, colIndices, batch_x, batch_b, Nsolves_team, -1, vector_length, n_iterations, tol, ortho_strategy, 0, handle).run(pc);
256*a4313204SMark Adams   Kokkos::fence();
257*a4313204SMark Adams   // get data back
258*a4313204SMark Adams   Kokkos::parallel_for(
259*a4313204SMark Adams     "map", Kokkos::TeamPolicy<>(Nsolves /* /batch_sz */, -1, PCBJKOKKOS_VEC_SIZE), KOKKOS_LAMBDA(const team_member team) {
260*a4313204SMark Adams       const int blkID = team.league_rank(), start = d_bid_eqOffset[blkID], end = d_bid_eqOffset[blkID + 1]; // 0
261*a4313204SMark Adams       // map x into Plex/PETSc
262*a4313204SMark Adams       Kokkos::parallel_for(Kokkos::TeamVectorRange(team, start, end), [=](int rowb) {
263*a4313204SMark Adams         int rowa        = d_isicol[rowb];
264*a4313204SMark Adams         glb_xdata[rowa] = batch_x(blkID, rowb - start);
265*a4313204SMark Adams       });
266*a4313204SMark Adams     });
267*a4313204SMark Adams   // output assume species major - clone from Kokkos solvers
268*a4313204SMark Adams   #if PCBJKOKKOS_VERBOSE_LEVEL >= 3
269*a4313204SMark Adams     #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
270*a4313204SMark Adams   PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "Iterations\n"));
271*a4313204SMark Adams     #else
272*a4313204SMark Adams   PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "max iterations per species (gmres) :"));
273*a4313204SMark Adams     #endif
274*a4313204SMark Adams   for (PetscInt dmIdx = 0, s = 0, head = 0; dmIdx < jac->num_dms; dmIdx += batch_sz) {
275*a4313204SMark Adams     for (PetscInt f = 0, idx = head; f < jac->dm_Nf[dmIdx]; f++, s++, idx++) {
276*a4313204SMark Adams     #if PCBJKOKKOS_VERBOSE_LEVEL >= 4
277*a4313204SMark Adams       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%2D:", s));
278*a4313204SMark Adams       for (int bid = 0; bid < batch_sz; bid++) PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%3D ", handle.get_iteration_host(idx + bid * jac->dm_Nf[dmIdx])));
279*a4313204SMark Adams       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\n"));
280*a4313204SMark Adams     #else
281*a4313204SMark Adams       int count = 0, ii;
282*a4313204SMark Adams       for (int bid = 0; bid < batch_sz; bid++) {
283*a4313204SMark Adams         if ((ii = handle.get_iteration_host(idx + bid * jac->dm_Nf[dmIdx])) > count) count = ii;
284*a4313204SMark Adams       }
285*a4313204SMark Adams       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "%3d", count));
286*a4313204SMark Adams     #endif
287*a4313204SMark Adams     }
288*a4313204SMark Adams     head += batch_sz * jac->dm_Nf[dmIdx];
289*a4313204SMark Adams   }
290*a4313204SMark Adams     #if PCBJKOKKOS_VERBOSE_LEVEL == 3
291*a4313204SMark Adams   PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "\n"));
292*a4313204SMark Adams     #endif
293*a4313204SMark Adams   #endif
294*a4313204SMark Adams   // return error code, get max it
295*a4313204SMark Adams   PetscInt count = 0, mbid = 0;
296*a4313204SMark Adams   if (handle.is_converged_host()) {
297*a4313204SMark Adams     *pcreason = PC_NOERROR;
298*a4313204SMark Adams     if (!jac->max_nits) {
299*a4313204SMark Adams       for (int blkID = 0; blkID < nBlk; blkID++) {
300*a4313204SMark Adams         if (handle.get_iteration_host(blkID) > jac->max_nits) {
301*a4313204SMark Adams           jac->max_nits = handle.get_iteration_host(blkID);
302*a4313204SMark Adams           mbid          = blkID;
303*a4313204SMark Adams         }
304*a4313204SMark Adams       }
305*a4313204SMark Adams     }
306*a4313204SMark Adams   } else {
307*a4313204SMark Adams     PetscCall(PetscPrintf(PETSC_COMM_SELF, "There is at least one system that did not converge."));
308*a4313204SMark Adams     *pcreason = PC_SUBPC_ERROR;
309*a4313204SMark Adams   }
310*a4313204SMark Adams   // output - assume species major order
311*a4313204SMark Adams   for (int blkID = 0; blkID < nBlk; blkID++) {
312*a4313204SMark Adams     if (jac->reason) { // -pc_bjkokkos_ksp_converged_reason
313*a4313204SMark Adams       if (jac->batch_target == blkID) {
314*a4313204SMark Adams         if (batch_sz != 1)
315*a4313204SMark Adams           PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "    Linear solve %s in %d iterations, batch %" PetscInt_FMT ", species %" PetscInt_FMT "\n", handle.is_converged_host(blkID) ? "converged" : "diverged", handle.get_iteration_host(blkID), blkID % batch_sz, blkID / batch_sz));
316*a4313204SMark Adams         else PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "    Linear solve %s in %d iterations, block %" PetscInt_FMT "\n", handle.is_converged_host(blkID) ? "converged" : "diverged", handle.get_iteration_host(blkID), blkID));
317*a4313204SMark Adams       } else if (jac->batch_target == -1 && handle.get_iteration_host(blkID) >= count) {
318*a4313204SMark Adams         jac->max_nits = count = handle.get_iteration_host(blkID);
319*a4313204SMark Adams         mbid                  = blkID;
320*a4313204SMark Adams       }
321*a4313204SMark Adams       if (!handle.is_converged_host(blkID)) PetscCall(PetscPrintf(PETSC_COMM_SELF, "ERROR species %d, batch %d did not converge with %d iterations\n", (int)(blkID / batch_sz), (int)blkID % batch_sz, handle.get_iteration_host(blkID)));
322*a4313204SMark Adams     }
323*a4313204SMark Adams   }
324*a4313204SMark Adams   if (jac->batch_target == -1 && jac->reason) {
325*a4313204SMark Adams     if (batch_sz != 1)
326*a4313204SMark Adams       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "    Linear solve %s in %d iteration, batch %" PetscInt_FMT ", specie %" PetscInt_FMT "\n", handle.is_converged_host(mbid) ? "converged" : "diverged", jac->max_nits, mbid % batch_sz, mbid / batch_sz));
327*a4313204SMark Adams     else PetscCall(PetscPrintf(PetscObjectComm((PetscObject)A), "    Linear solve %s in %d iteration, block %" PetscInt_FMT "\n", handle.is_converged_host(mbid) ? "converged" : "diverged", jac->max_nits, mbid));
328*a4313204SMark Adams   }
329*a4313204SMark Adams   #if defined(PETSC_HAVE_CUDA)
330*a4313204SMark Adams   nvtxRangePop();
331*a4313204SMark Adams   #endif
332*a4313204SMark Adams 
333*a4313204SMark Adams   return PETSC_SUCCESS;
334*a4313204SMark Adams }
335*a4313204SMark Adams #endif
336