xref: /petsc/src/mat/impls/aij/seq/kokkos/aijkok.kokkos.cxx (revision e6519f9f8d90f4199bf6bb1df3ffafa56045e5d2)
1 #include <petsc_kokkos.hpp>
2 #include <petscvec_kokkos.hpp>
3 #include <petscpkg_version.h>
4 #include <petsc/private/petscimpl.h>
5 #include <petsc/private/sfimpl.h>
6 #include <petscsystypes.h>
7 #include <petscerror.h>
8 
9 #include <Kokkos_Core.hpp>
10 #include <KokkosBlas.hpp>
11 #include <KokkosSparse_CrsMatrix.hpp>
12 #include <KokkosSparse_spmv.hpp>
13 #include <KokkosSparse_spiluk.hpp>
14 #include <KokkosSparse_sptrsv.hpp>
15 #include <KokkosSparse_spgemm.hpp>
16 #include <KokkosSparse_spadd.hpp>
17 #include <KokkosBatched_LU_Decl.hpp>
18 #include <KokkosBatched_InverseLU_Decl.hpp>
19 
20 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
21 
22 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 7, 0)
23   #include <KokkosSparse_Utils.hpp>
24 using KokkosSparse::sort_crs_matrix;
25 using KokkosSparse::Impl::transpose_matrix;
26 #else
27   #include <KokkosKernels_Sorting.hpp>
28 using KokkosKernels::sort_crs_matrix;
29 using KokkosKernels::Impl::transpose_matrix;
30 #endif
31 
32 static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat); /* Forward declaration */
33 
34 /* MatAssemblyEnd_SeqAIJKokkos() happens when we finalized nonzeros of the matrix, either after
35    we assembled the matrix on host, or after we directly produced the matrix data on device (ex., through MatMatMult).
36    In the latter case, it is important to set a_dual's sync state correctly.
37  */
38 static PetscErrorCode MatAssemblyEnd_SeqAIJKokkos(Mat A, MatAssemblyType mode)
39 {
40   Mat_SeqAIJ       *aijseq;
41   Mat_SeqAIJKokkos *aijkok;
42 
43   PetscFunctionBegin;
44   if (mode == MAT_FLUSH_ASSEMBLY) PetscFunctionReturn(PETSC_SUCCESS);
45   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
46 
47   aijseq = static_cast<Mat_SeqAIJ *>(A->data);
48   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
49 
50   /* If aijkok does not exist, we just copy i, j to device.
51      If aijkok already exists, but the device's nonzero pattern does not match with the host's, we assume the latest data is on host.
52      In both cases, we build a new aijkok structure.
53   */
54   if (!aijkok || aijkok->nonzerostate != A->nonzerostate) { /* aijkok might not exist yet or nonzero pattern has changed */
55     delete aijkok;
56     aijkok   = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aijseq, A->nonzerostate, PETSC_FALSE /*don't copy mat values to device*/);
57     A->spptr = aijkok;
58   } else if (A->rmap->n && aijkok->diag_dual.extent(0) == 0) { // MatProduct might directly produce AIJ on device, but not the diag.
59     MatRowMapKokkosViewHost diag_h(aijseq->diag, A->rmap->n);
60     auto                    diag_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), diag_h);
61     aijkok->diag_dual              = MatRowMapKokkosDualView(diag_d, diag_h);
62   }
63   PetscFunctionReturn(PETSC_SUCCESS);
64 }
65 
66 /* Sync CSR data to device if not yet */
67 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosSyncDevice(Mat A)
68 {
69   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
70 
71   PetscFunctionBegin;
72   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from host to device");
73   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
74   if (aijkok->a_dual.need_sync_device()) {
75     aijkok->a_dual.sync_device();
76     aijkok->transpose_updated = PETSC_FALSE; /* values of the transpose is out-of-date */
77     aijkok->hermitian_updated = PETSC_FALSE;
78   }
79   PetscFunctionReturn(PETSC_SUCCESS);
80 }
81 
82 /* Mark the CSR data on device as modified */
83 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosModifyDevice(Mat A)
84 {
85   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
86 
87   PetscFunctionBegin;
88   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Not supported for factorized matries");
89   aijkok->a_dual.clear_sync_state();
90   aijkok->a_dual.modify_device();
91   aijkok->transpose_updated = PETSC_FALSE;
92   aijkok->hermitian_updated = PETSC_FALSE;
93   PetscCall(MatSeqAIJInvalidateDiagonal(A));
94   PetscCall(PetscObjectStateIncrease((PetscObject)A));
95   PetscFunctionReturn(PETSC_SUCCESS);
96 }
97 
98 static PetscErrorCode MatSeqAIJKokkosSyncHost(Mat A)
99 {
100   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
101   auto             &exec   = PetscGetKokkosExecutionSpace();
102 
103   PetscFunctionBegin;
104   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
105   /* We do not expect one needs factors on host  */
106   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from device to host");
107   PetscCheck(aijkok, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing AIJKOK");
108   PetscCallCXX(aijkok->a_dual.sync_host(exec));
109   PetscCallCXX(exec.fence());
110   PetscFunctionReturn(PETSC_SUCCESS);
111 }
112 
113 static PetscErrorCode MatSeqAIJGetArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
114 {
115   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
116 
117   PetscFunctionBegin;
118   /* aijkok contains valid pointers only if the host's nonzerostate matches with the device's.
119     Calling MatSeqAIJSetPreallocation() or MatSetValues() on host, where aijseq->{i,j,a} might be
120     reallocated, will lead to stale {i,j,a}_dual in aijkok. In both operations, the hosts's nonzerostate
121     must have been updated. The stale aijkok will be rebuilt during MatAssemblyEnd.
122   */
123   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
124     auto &exec = PetscGetKokkosExecutionSpace();
125     PetscCallCXX(aijkok->a_dual.sync_host(exec));
126     PetscCallCXX(exec.fence());
127     *array = aijkok->a_dual.view_host().data();
128   } else { /* Happens when calling MatSetValues on a newly created matrix */
129     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
130   }
131   PetscFunctionReturn(PETSC_SUCCESS);
132 }
133 
134 static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
135 {
136   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
137 
138   PetscFunctionBegin;
139   if (aijkok && A->nonzerostate == aijkok->nonzerostate) aijkok->a_dual.modify_host();
140   PetscFunctionReturn(PETSC_SUCCESS);
141 }
142 
143 static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
144 {
145   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
146 
147   PetscFunctionBegin;
148   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
149     auto &exec = PetscGetKokkosExecutionSpace();
150     PetscCallCXX(aijkok->a_dual.sync_host(exec));
151     PetscCallCXX(exec.fence());
152     *array = aijkok->a_dual.view_host().data();
153   } else {
154     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
155   }
156   PetscFunctionReturn(PETSC_SUCCESS);
157 }
158 
159 static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
160 {
161   PetscFunctionBegin;
162   *array = NULL;
163   PetscFunctionReturn(PETSC_SUCCESS);
164 }
165 
166 static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
167 {
168   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
169 
170   PetscFunctionBegin;
171   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
172     *array = aijkok->a_dual.view_host().data();
173   } else { /* Ex. happens with MatZeroEntries on a preallocated but not assembled matrix */
174     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
175   }
176   PetscFunctionReturn(PETSC_SUCCESS);
177 }
178 
179 static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
180 {
181   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
182 
183   PetscFunctionBegin;
184   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
185     aijkok->a_dual.clear_sync_state();
186     aijkok->a_dual.modify_host();
187   }
188   PetscFunctionReturn(PETSC_SUCCESS);
189 }
190 
191 static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJKokkos(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
192 {
193   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
194 
195   PetscFunctionBegin;
196   PetscCheck(aijkok != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "aijkok is NULL");
197 
198   if (i) *i = aijkok->i_device_data();
199   if (j) *j = aijkok->j_device_data();
200   if (a) {
201     aijkok->a_dual.sync_device();
202     *a = aijkok->a_device_data();
203   }
204   if (mtype) *mtype = PETSC_MEMTYPE_KOKKOS;
205   PetscFunctionReturn(PETSC_SUCCESS);
206 }
207 
208 /*
209   Generate the sparsity pattern of a MatSeqAIJKokkos matrix's transpose on device.
210 
211   Input Parameter:
212 .  A       - the MATSEQAIJKOKKOS matrix
213 
214   Output Parameters:
215 +  perm_d - the permutation array on device, which connects Ta(i) = Aa(perm(i))
216 -  T_d    - the transpose on device, whose value array is allocated but not initialized
217 */
218 static PetscErrorCode MatSeqAIJKokkosGenerateTransposeStructure(Mat A, MatRowMapKokkosView &perm_d, KokkosCsrMatrix &T_d)
219 {
220   Mat_SeqAIJ             *aseq = static_cast<Mat_SeqAIJ *>(A->data);
221   PetscInt                nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
222   const PetscInt         *Ai = aseq->i, *Aj = aseq->j;
223   MatRowMapKokkosViewHost Ti_h(NoInit("Ti"), n + 1);
224   MatRowMapType          *Ti = Ti_h.data();
225   MatColIdxKokkosViewHost Tj_h(NoInit("Tj"), nz);
226   MatRowMapKokkosViewHost perm_h(NoInit("permutation"), nz);
227   PetscInt               *Tj   = Tj_h.data();
228   PetscInt               *perm = perm_h.data();
229   PetscInt               *offset;
230 
231   PetscFunctionBegin;
232   // Populate Ti
233   PetscCallCXX(Kokkos::deep_copy(Ti_h, 0));
234   Ti++;
235   for (PetscInt i = 0; i < nz; i++) Ti[Aj[i]]++;
236   Ti--;
237   for (PetscInt i = 0; i < n; i++) Ti[i + 1] += Ti[i];
238 
239   // Populate Tj and the permutation array
240   PetscCall(PetscCalloc1(n, &offset)); // offset in each T row to fill in its column indices
241   for (PetscInt i = 0; i < m; i++) {
242     for (PetscInt j = Ai[i]; j < Ai[i + 1]; j++) { // A's (i,j) is T's (j,i)
243       PetscInt r    = Aj[j];                       // row r of T
244       PetscInt disp = Ti[r] + offset[r];
245 
246       Tj[disp]   = i; // col i of T
247       perm[disp] = j;
248       offset[r]++;
249     }
250   }
251   PetscCall(PetscFree(offset));
252 
253   // Sort each row of T, along with the permutation array
254   for (PetscInt i = 0; i < n; i++) PetscCall(PetscSortIntWithArray(Ti[i + 1] - Ti[i], Tj + Ti[i], perm + Ti[i]));
255 
256   // Output perm and T on device
257   auto Ti_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ti_h);
258   auto Tj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Tj_h);
259   PetscCallCXX(T_d = KokkosCsrMatrix("csrmatT", n, m, nz, MatScalarKokkosView("Ta", nz), Ti_d, Tj_d));
260   PetscCallCXX(perm_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), perm_h));
261   PetscFunctionReturn(PETSC_SUCCESS);
262 }
263 
264 // Generate the transpose on device and cache it internally
265 // Note: KK transpose_matrix() does not have support symbolic/numeric transpose, so we do it on our own
266 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGenerateTranspose_Private(Mat A, KokkosCsrMatrix *csrmatT)
267 {
268   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
269   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
270   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
271   KokkosCsrMatrix  &T = akok->csrmatT;
272 
273   PetscFunctionBegin;
274   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
275   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's values since we are going to access them on device
276 
277   const auto &Aa = akok->a_dual.view_device();
278 
279   if (A->symmetric == PETSC_BOOL3_TRUE) {
280     *csrmatT = akok->csrmat;
281   } else {
282     // See if we already have a cached transpose and its value is up to date
283     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
284       if (!akok->transpose_updated) {            // if the value is out of date, update the cached version
285         const auto &perm = akok->transpose_perm; // get the permutation array
286         auto       &Ta   = T.values;
287 
288         PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = Aa(perm(i)); }));
289       }
290     } else { // Generate T of size n x m for the first time
291       MatRowMapKokkosView perm;
292 
293       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
294       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
295       PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = Aa(perm(i)); }));
296     }
297     akok->transpose_updated = PETSC_TRUE;
298     *csrmatT                = akok->csrmatT;
299   }
300   PetscFunctionReturn(PETSC_SUCCESS);
301 }
302 
303 // Generate the Hermitian on device and cache it internally
304 static PetscErrorCode MatSeqAIJKokkosGenerateHermitian_Private(Mat A, KokkosCsrMatrix *csrmatH)
305 {
306   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
307   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
308   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
309   KokkosCsrMatrix  &T = akok->csrmatH;
310 
311   PetscFunctionBegin;
312   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
313   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's values since we are going to access them on device
314 
315   const auto &Aa = akok->a_dual.view_device();
316 
317   if (A->hermitian == PETSC_BOOL3_TRUE) {
318     *csrmatH = akok->csrmat;
319   } else {
320     // See if we already have a cached hermitian and its value is up to date
321     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
322       if (!akok->hermitian_updated) {            // if the value is out of date, update the cached version
323         const auto &perm = akok->transpose_perm; // get the permutation array
324         auto       &Ta   = T.values;
325 
326         PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = PetscConj(Aa(perm(i))); }));
327       }
328     } else { // Generate T of size n x m for the first time
329       MatRowMapKokkosView perm;
330 
331       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
332       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
333       PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = PetscConj(Aa(perm(i))); }));
334     }
335     akok->hermitian_updated = PETSC_TRUE;
336     *csrmatH                = akok->csrmatH;
337   }
338   PetscFunctionReturn(PETSC_SUCCESS);
339 }
340 
341 /* y = A x */
342 static PetscErrorCode MatMult_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
343 {
344   Mat_SeqAIJKokkos          *aijkok;
345   ConstPetscScalarKokkosView xv;
346   PetscScalarKokkosView      yv;
347 
348   PetscFunctionBegin;
349   PetscCall(PetscLogGpuTimeBegin());
350   PetscCall(MatSeqAIJKokkosSyncDevice(A));
351   PetscCall(VecGetKokkosView(xx, &xv));
352   PetscCall(VecGetKokkosViewWrite(yy, &yv));
353   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
354   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), "N", 1.0 /*alpha*/, aijkok->csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A x + beta y */
355   PetscCall(VecRestoreKokkosView(xx, &xv));
356   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
357   /* 2.0*nnz - numRows seems more accurate here but assumes there are no zero-rows. So a little sloppy here. */
358   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
359   PetscCall(PetscLogGpuTimeEnd());
360   PetscFunctionReturn(PETSC_SUCCESS);
361 }
362 
363 /* y = A^T x */
364 static PetscErrorCode MatMultTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
365 {
366   Mat_SeqAIJKokkos          *aijkok;
367   const char                *mode;
368   ConstPetscScalarKokkosView xv;
369   PetscScalarKokkosView      yv;
370   KokkosCsrMatrix            csrmat;
371 
372   PetscFunctionBegin;
373   PetscCall(PetscLogGpuTimeBegin());
374   PetscCall(MatSeqAIJKokkosSyncDevice(A));
375   PetscCall(VecGetKokkosView(xx, &xv));
376   PetscCall(VecGetKokkosViewWrite(yy, &yv));
377   if (A->form_explicit_transpose) {
378     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
379     mode = "N";
380   } else {
381     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
382     csrmat = aijkok->csrmat;
383     mode   = "T";
384   }
385   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^T x + beta y */
386   PetscCall(VecRestoreKokkosView(xx, &xv));
387   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
388   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
389   PetscCall(PetscLogGpuTimeEnd());
390   PetscFunctionReturn(PETSC_SUCCESS);
391 }
392 
393 /* y = A^H x */
394 static PetscErrorCode MatMultHermitianTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
395 {
396   Mat_SeqAIJKokkos          *aijkok;
397   const char                *mode;
398   ConstPetscScalarKokkosView xv;
399   PetscScalarKokkosView      yv;
400   KokkosCsrMatrix            csrmat;
401 
402   PetscFunctionBegin;
403   PetscCall(PetscLogGpuTimeBegin());
404   PetscCall(MatSeqAIJKokkosSyncDevice(A));
405   PetscCall(VecGetKokkosView(xx, &xv));
406   PetscCall(VecGetKokkosViewWrite(yy, &yv));
407   if (A->form_explicit_transpose) {
408     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
409     mode = "N";
410   } else {
411     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
412     csrmat = aijkok->csrmat;
413     mode   = "C";
414   }
415   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^H x + beta y */
416   PetscCall(VecRestoreKokkosView(xx, &xv));
417   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
418   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
419   PetscCall(PetscLogGpuTimeEnd());
420   PetscFunctionReturn(PETSC_SUCCESS);
421 }
422 
423 /* z = A x + y */
424 static PetscErrorCode MatMultAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
425 {
426   Mat_SeqAIJKokkos          *aijkok;
427   ConstPetscScalarKokkosView xv;
428   PetscScalarKokkosView      zv;
429 
430   PetscFunctionBegin;
431   PetscCall(PetscLogGpuTimeBegin());
432   PetscCall(MatSeqAIJKokkosSyncDevice(A));
433   if (zz != yy) PetscCall(VecCopy(yy, zz)); // depending on yy's sync flags, zz might get its latest data on host
434   PetscCall(VecGetKokkosView(xx, &xv));
435   PetscCall(VecGetKokkosView(zz, &zv)); // do after VecCopy(yy, zz) to get the latest data on device
436   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
437   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), "N", 1.0 /*alpha*/, aijkok->csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A x + beta z */
438   PetscCall(VecRestoreKokkosView(xx, &xv));
439   PetscCall(VecRestoreKokkosView(zz, &zv));
440   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
441   PetscCall(PetscLogGpuTimeEnd());
442   PetscFunctionReturn(PETSC_SUCCESS);
443 }
444 
445 /* z = A^T x + y */
446 static PetscErrorCode MatMultTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
447 {
448   Mat_SeqAIJKokkos          *aijkok;
449   const char                *mode;
450   ConstPetscScalarKokkosView xv;
451   PetscScalarKokkosView      zv;
452   KokkosCsrMatrix            csrmat;
453 
454   PetscFunctionBegin;
455   PetscCall(PetscLogGpuTimeBegin());
456   PetscCall(MatSeqAIJKokkosSyncDevice(A));
457   if (zz != yy) PetscCall(VecCopy(yy, zz));
458   PetscCall(VecGetKokkosView(xx, &xv));
459   PetscCall(VecGetKokkosView(zz, &zv));
460   if (A->form_explicit_transpose) {
461     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
462     mode = "N";
463   } else {
464     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
465     csrmat = aijkok->csrmat;
466     mode   = "T";
467   }
468   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^T x + beta z */
469   PetscCall(VecRestoreKokkosView(xx, &xv));
470   PetscCall(VecRestoreKokkosView(zz, &zv));
471   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
472   PetscCall(PetscLogGpuTimeEnd());
473   PetscFunctionReturn(PETSC_SUCCESS);
474 }
475 
476 /* z = A^H x + y */
477 static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
478 {
479   Mat_SeqAIJKokkos          *aijkok;
480   const char                *mode;
481   ConstPetscScalarKokkosView xv;
482   PetscScalarKokkosView      zv;
483   KokkosCsrMatrix            csrmat;
484 
485   PetscFunctionBegin;
486   PetscCall(PetscLogGpuTimeBegin());
487   PetscCall(MatSeqAIJKokkosSyncDevice(A));
488   if (zz != yy) PetscCall(VecCopy(yy, zz));
489   PetscCall(VecGetKokkosView(xx, &xv));
490   PetscCall(VecGetKokkosView(zz, &zv));
491   if (A->form_explicit_transpose) {
492     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
493     mode = "N";
494   } else {
495     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
496     csrmat = aijkok->csrmat;
497     mode   = "C";
498   }
499   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^H x + beta z */
500   PetscCall(VecRestoreKokkosView(xx, &xv));
501   PetscCall(VecRestoreKokkosView(zz, &zv));
502   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
503   PetscCall(PetscLogGpuTimeEnd());
504   PetscFunctionReturn(PETSC_SUCCESS);
505 }
506 
507 static PetscErrorCode MatSetOption_SeqAIJKokkos(Mat A, MatOption op, PetscBool flg)
508 {
509   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
510 
511   PetscFunctionBegin;
512   switch (op) {
513   case MAT_FORM_EXPLICIT_TRANSPOSE:
514     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
515     if (A->form_explicit_transpose && !flg && aijkok) PetscCall(aijkok->DestroyMatTranspose());
516     A->form_explicit_transpose = flg;
517     break;
518   default:
519     PetscCall(MatSetOption_SeqAIJ(A, op, flg));
520     break;
521   }
522   PetscFunctionReturn(PETSC_SUCCESS);
523 }
524 
525 /* Depending on reuse, either build a new mat, or use the existing mat */
526 PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
527 {
528   Mat_SeqAIJ *aseq;
529 
530   PetscFunctionBegin;
531   PetscCall(PetscKokkosInitializeCheck());
532   if (reuse == MAT_INITIAL_MATRIX) {                      /* Build a brand new mat */
533     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));  /* the returned newmat is a SeqAIJKokkos */
534   } else if (reuse == MAT_REUSE_MATRIX) {                 /* Reuse the mat created before */
535     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); /* newmat is already a SeqAIJKokkos */
536   } else if (reuse == MAT_INPLACE_MATRIX) {               /* newmat is A */
537     PetscCheck(A == *newmat, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "A != *newmat with MAT_INPLACE_MATRIX");
538     PetscCall(PetscFree(A->defaultvectype));
539     PetscCall(PetscStrallocpy(VECKOKKOS, &A->defaultvectype)); /* Allocate and copy the string */
540     PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSEQAIJKOKKOS));
541     PetscCall(MatSetOps_SeqAIJKokkos(A));
542     aseq = static_cast<Mat_SeqAIJ *>(A->data);
543     if (A->assembled) { /* Copy i, j (but not values) to device for an assembled matrix if not yet */
544       PetscCheck(!A->spptr, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Expect NULL (Mat_SeqAIJKokkos*)A->spptr");
545       A->spptr = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aseq, A->nonzerostate, PETSC_FALSE);
546     }
547   }
548   PetscFunctionReturn(PETSC_SUCCESS);
549 }
550 
551 /* MatDuplicate always creates a new matrix. MatDuplicate can be called either on an assembled matrix or
552    an unassembled matrix, even though MAT_COPY_VALUES is not allowed for unassembled matrix.
553  */
554 static PetscErrorCode MatDuplicate_SeqAIJKokkos(Mat A, MatDuplicateOption dupOption, Mat *B)
555 {
556   Mat_SeqAIJ       *bseq;
557   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok;
558   Mat               mat;
559 
560   PetscFunctionBegin;
561   /* Do not copy values on host as A's latest values might be on device. We don't want to do sync blindly */
562   PetscCall(MatDuplicate_SeqAIJ(A, MAT_DO_NOT_COPY_VALUES, B));
563   mat = *B;
564   if (A->assembled) {
565     bseq = static_cast<Mat_SeqAIJ *>(mat->data);
566     bkok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, bseq, mat->nonzerostate, PETSC_FALSE);
567     bkok->a_dual.clear_sync_state(); /* Clear B's sync state as it will be decided below */
568     /* Now copy values to B if needed */
569     if (dupOption == MAT_COPY_VALUES) {
570       if (akok->a_dual.need_sync_device()) {
571         Kokkos::deep_copy(bkok->a_dual.view_host(), akok->a_dual.view_host());
572         bkok->a_dual.modify_host();
573       } else { /* If device has the latest data, we only copy data on device */
574         Kokkos::deep_copy(bkok->a_dual.view_device(), akok->a_dual.view_device());
575         bkok->a_dual.modify_device();
576       }
577     } else { /* MAT_DO_NOT_COPY_VALUES or MAT_SHARE_NONZERO_PATTERN. B's values should be zeroed */
578       /* B's values on host should be already zeroed by MatDuplicate_SeqAIJ() */
579       bkok->a_dual.modify_host();
580     }
581     mat->spptr = bkok;
582   }
583 
584   PetscCall(PetscFree(mat->defaultvectype));
585   PetscCall(PetscStrallocpy(VECKOKKOS, &mat->defaultvectype)); /* Allocate and copy the string */
586   PetscCall(PetscObjectChangeTypeName((PetscObject)mat, MATSEQAIJKOKKOS));
587   PetscCall(MatSetOps_SeqAIJKokkos(mat));
588   PetscFunctionReturn(PETSC_SUCCESS);
589 }
590 
591 static PetscErrorCode MatTranspose_SeqAIJKokkos(Mat A, MatReuse reuse, Mat *B)
592 {
593   Mat               At;
594   KokkosCsrMatrix   internT;
595   Mat_SeqAIJKokkos *atkok, *bkok;
596 
597   PetscFunctionBegin;
598   if (reuse == MAT_REUSE_MATRIX) PetscCall(MatTransposeCheckNonzeroState_Private(A, *B));
599   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &internT)); /* Generate a transpose internally */
600   if (reuse == MAT_INITIAL_MATRIX || reuse == MAT_INPLACE_MATRIX) {
601     /* Deep copy internT, as we want to isolate the internal transpose */
602     PetscCallCXX(atkok = new Mat_SeqAIJKokkos(KokkosCsrMatrix("csrmat", internT)));
603     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PetscObjectComm((PetscObject)A), atkok, &At));
604     if (reuse == MAT_INITIAL_MATRIX) *B = At;
605     else PetscCall(MatHeaderReplace(A, &At)); /* Replace A with At inplace */
606   } else {                                    /* MAT_REUSE_MATRIX, just need to copy values to B on device */
607     if ((*B)->assembled) {
608       bkok = static_cast<Mat_SeqAIJKokkos *>((*B)->spptr);
609       PetscCallCXX(Kokkos::deep_copy(bkok->a_dual.view_device(), internT.values));
610       PetscCall(MatSeqAIJKokkosModifyDevice(*B));
611     } else if ((*B)->preallocated) { /* It is ok for B to be only preallocated, as needed in MatTranspose_MPIAIJ */
612       Mat_SeqAIJ             *bseq = static_cast<Mat_SeqAIJ *>((*B)->data);
613       MatScalarKokkosViewHost a_h(bseq->a, internT.nnz()); /* bseq->nz = 0 if unassembled */
614       MatColIdxKokkosViewHost j_h(bseq->j, internT.nnz());
615       PetscCallCXX(Kokkos::deep_copy(a_h, internT.values));
616       PetscCallCXX(Kokkos::deep_copy(j_h, internT.graph.entries));
617     } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "B must be assembled or preallocated");
618   }
619   PetscFunctionReturn(PETSC_SUCCESS);
620 }
621 
622 static PetscErrorCode MatDestroy_SeqAIJKokkos(Mat A)
623 {
624   Mat_SeqAIJKokkos *aijkok;
625 
626   PetscFunctionBegin;
627   if (A->factortype == MAT_FACTOR_NONE) {
628     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
629     delete aijkok;
630   } else {
631     delete static_cast<Mat_SeqAIJKokkosTriFactors *>(A->spptr);
632   }
633   A->spptr = NULL;
634   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
635   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
636   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
637 #if defined(PETSC_HAVE_HYPRE)
638   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijkokkos_hypre_C", NULL));
639 #endif
640   PetscCall(MatDestroy_SeqAIJ(A));
641   PetscFunctionReturn(PETSC_SUCCESS);
642 }
643 
644 /*MC
645    MATSEQAIJKOKKOS - MATAIJKOKKOS = "(seq)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
646 
647    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
648 
649    Options Database Key:
650 .  -mat_type aijkokkos - sets the matrix type to `MATSEQAIJKOKKOS` during a call to `MatSetFromOptions()`
651 
652   Level: beginner
653 
654 .seealso: [](ch_matrices), `Mat`, `MatCreateSeqAIJKokkos()`, `MATMPIAIJKOKKOS`
655 M*/
656 PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJKokkos(Mat A)
657 {
658   PetscFunctionBegin;
659   PetscCall(PetscKokkosInitializeCheck());
660   PetscCall(MatCreate_SeqAIJ(A));
661   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
662   PetscFunctionReturn(PETSC_SUCCESS);
663 }
664 
665 /* Merge A, B into a matrix C. A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n) */
666 PetscErrorCode MatSeqAIJKokkosMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
667 {
668   Mat_SeqAIJ         *a, *b;
669   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
670   MatScalarKokkosView aa, ba, ca;
671   MatRowMapKokkosView ai, bi, ci;
672   MatColIdxKokkosView aj, bj, cj;
673   PetscInt            m, n, nnz, aN;
674 
675   PetscFunctionBegin;
676   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
677   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
678   PetscAssertPointer(C, 4);
679   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
680   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
681   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
682   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
683 
684   PetscCall(MatSeqAIJKokkosSyncDevice(A));
685   PetscCall(MatSeqAIJKokkosSyncDevice(B));
686   a    = static_cast<Mat_SeqAIJ *>(A->data);
687   b    = static_cast<Mat_SeqAIJ *>(B->data);
688   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
689   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
690   aa   = akok->a_dual.view_device();
691   ai   = akok->i_dual.view_device();
692   ba   = bkok->a_dual.view_device();
693   bi   = bkok->i_dual.view_device();
694   m    = A->rmap->n; /* M, N and nnz of C */
695   n    = A->cmap->n + B->cmap->n;
696   nnz  = a->nz + b->nz;
697   aN   = A->cmap->n; /* N of A */
698   if (reuse == MAT_INITIAL_MATRIX) {
699     aj           = akok->j_dual.view_device();
700     bj           = bkok->j_dual.view_device();
701     auto ca_dual = MatScalarKokkosDualView("a", aa.extent(0) + ba.extent(0));
702     auto ci_dual = MatRowMapKokkosDualView("i", ai.extent(0));
703     auto cj_dual = MatColIdxKokkosDualView("j", aj.extent(0) + bj.extent(0));
704     ca           = ca_dual.view_device();
705     ci           = ci_dual.view_device();
706     cj           = cj_dual.view_device();
707 
708     /* Concatenate A and B in parallel using Kokkos hierarchical parallelism */
709     Kokkos::parallel_for(
710       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
711         PetscInt i       = t.league_rank(); /* row i */
712         PetscInt coffset = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
713 
714         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* this side effect only happens once per whole team */
715                                                    ci(i) = coffset;
716                                                    if (i == m - 1) ci(m) = ai(m) + bi(m);
717         });
718 
719         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
720           if (k < alen) {
721             ca(coffset + k) = aa(ai(i) + k);
722             cj(coffset + k) = aj(ai(i) + k);
723           } else {
724             ca(coffset + k) = ba(bi(i) + k - alen);
725             cj(coffset + k) = bj(bi(i) + k - alen) + aN; /* Entries in B get new column indices in C */
726           }
727         });
728       });
729     ca_dual.modify_device();
730     ci_dual.modify_device();
731     cj_dual.modify_device();
732     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, n, nnz, ci_dual, cj_dual, ca_dual));
733     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
734   } else if (reuse == MAT_REUSE_MATRIX) {
735     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
736     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
737     ckok = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
738     ca   = ckok->a_dual.view_device();
739     ci   = ckok->i_dual.view_device();
740 
741     Kokkos::parallel_for(
742       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
743         PetscInt i    = t.league_rank(); /* row i */
744         PetscInt alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
745         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
746           if (k < alen) ca(ci(i) + k) = aa(ai(i) + k);
747           else ca(ci(i) + k) = ba(bi(i) + k - alen);
748         });
749       });
750     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
751   }
752   PetscFunctionReturn(PETSC_SUCCESS);
753 }
754 
755 static PetscErrorCode MatProductDataDestroy_SeqAIJKokkos(void *pdata)
756 {
757   PetscFunctionBegin;
758   delete static_cast<MatProductData_SeqAIJKokkos *>(pdata);
759   PetscFunctionReturn(PETSC_SUCCESS);
760 }
761 
762 static PetscErrorCode MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos(Mat C)
763 {
764   Mat_Product                 *product = C->product;
765   Mat                          A, B;
766   bool                         transA, transB; /* use bool, since KK needs this type */
767   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
768   Mat_SeqAIJ                  *c;
769   MatProductData_SeqAIJKokkos *pdata;
770   KokkosCsrMatrix              csrmatA, csrmatB;
771 
772   PetscFunctionBegin;
773   MatCheckProduct(C, 1);
774   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data empty");
775   pdata = static_cast<MatProductData_SeqAIJKokkos *>(C->product->data);
776 
777   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
778   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
779   // we still do numeric.
780   if (pdata->reusesym) { // numeric reuses results from symbolic
781     pdata->reusesym = PETSC_FALSE;
782     PetscFunctionReturn(PETSC_SUCCESS);
783   }
784 
785   switch (product->type) {
786   case MATPRODUCT_AB:
787     transA = false;
788     transB = false;
789     break;
790   case MATPRODUCT_AtB:
791     transA = true;
792     transB = false;
793     break;
794   case MATPRODUCT_ABt:
795     transA = false;
796     transB = true;
797     break;
798   default:
799     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
800   }
801 
802   A = product->A;
803   B = product->B;
804   PetscCall(MatSeqAIJKokkosSyncDevice(A));
805   PetscCall(MatSeqAIJKokkosSyncDevice(B));
806   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
807   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
808   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
809 
810   PetscCheck(ckok, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Device data structure spptr is empty");
811 
812   csrmatA = akok->csrmat;
813   csrmatB = bkok->csrmat;
814 
815   /* TODO: Once KK spgemm implements transpose, we can get rid of the explicit transpose here */
816   if (transA) {
817     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
818     transA = false;
819   }
820 
821   if (transB) {
822     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
823     transB = false;
824   }
825   PetscCall(PetscLogGpuTimeBegin());
826   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, ckok->csrmat));
827 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
828   auto spgemmHandle = pdata->kh.get_spgemm_handle();
829   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(ckok->csrmat)); /* without sort, mat_tests-ex62_14_seqaijkokkos fails */
830 #endif
831 
832   PetscCall(PetscLogGpuTimeEnd());
833   PetscCall(MatSeqAIJKokkosModifyDevice(C));
834   /* shorter version of MatAssemblyEnd_SeqAIJ */
835   c = (Mat_SeqAIJ *)C->data;
836   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded,%" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
837   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
838   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
839   c->reallocs         = 0;
840   C->info.mallocs     = 0;
841   C->info.nz_unneeded = 0;
842   C->assembled = C->was_assembled = PETSC_TRUE;
843   C->num_ass++;
844   PetscFunctionReturn(PETSC_SUCCESS);
845 }
846 
847 static PetscErrorCode MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos(Mat C)
848 {
849   Mat_Product                 *product = C->product;
850   MatProductType               ptype;
851   Mat                          A, B;
852   bool                         transA, transB;
853   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
854   MatProductData_SeqAIJKokkos *pdata;
855   MPI_Comm                     comm;
856   KokkosCsrMatrix              csrmatA, csrmatB, csrmatC;
857 
858   PetscFunctionBegin;
859   MatCheckProduct(C, 1);
860   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
861   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
862   A = product->A;
863   B = product->B;
864   PetscCall(MatSeqAIJKokkosSyncDevice(A));
865   PetscCall(MatSeqAIJKokkosSyncDevice(B));
866   akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
867   bkok    = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
868   csrmatA = akok->csrmat;
869   csrmatB = bkok->csrmat;
870 
871   ptype = product->type;
872   // Take advantage of the symmetry if true
873   if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
874     ptype                                          = MATPRODUCT_AB;
875     product->symbolic_used_the_fact_A_is_symmetric = PETSC_TRUE;
876   }
877   if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
878     ptype                                          = MATPRODUCT_AB;
879     product->symbolic_used_the_fact_B_is_symmetric = PETSC_TRUE;
880   }
881 
882   switch (ptype) {
883   case MATPRODUCT_AB:
884     transA = false;
885     transB = false;
886     PetscCall(MatSetBlockSizesFromMats(C, A, B));
887     break;
888   case MATPRODUCT_AtB:
889     transA = true;
890     transB = false;
891     if (A->cmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->rmap, A->cmap->bs));
892     if (B->cmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->cmap, B->cmap->bs));
893     break;
894   case MATPRODUCT_ABt:
895     transA = false;
896     transB = true;
897     if (A->rmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->rmap, A->rmap->bs));
898     if (B->rmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->cmap, B->rmap->bs));
899     break;
900   default:
901     SETERRQ(comm, PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
902   }
903   PetscCallCXX(product->data = pdata = new MatProductData_SeqAIJKokkos());
904   pdata->reusesym = product->api_user;
905 
906   /* TODO: add command line options to select spgemm algorithms */
907   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; /* default alg is TPL if enabled, otherwise KK */
908 
909   /* CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 */
910 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
911   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
912   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
913   #endif
914 #endif
915   PetscCallCXX(pdata->kh.create_spgemm_handle(spgemm_alg));
916 
917   PetscCall(PetscLogGpuTimeBegin());
918   /* TODO: Get rid of the explicit transpose once KK-spgemm implements the transpose option */
919   if (transA) {
920     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
921     transA = false;
922   }
923 
924   if (transB) {
925     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
926     transB = false;
927   }
928 
929   PetscCallCXX(KokkosSparse::spgemm_symbolic(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
930   /* spgemm_symbolic() only populates C's rowmap, but not C's column indices.
931     So we have to do a fake spgemm_numeric() here to get csrmatC.j_d setup, before
932     calling new Mat_SeqAIJKokkos().
933     TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
934   */
935   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
936 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
937   /* Query if KK outputs a sorted matrix. If not, we need to sort it */
938   auto spgemmHandle = pdata->kh.get_spgemm_handle();
939   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(csrmatC)); /* sort_option defaults to -1 in KK!*/
940 #endif
941   PetscCall(PetscLogGpuTimeEnd());
942 
943   PetscCallCXX(ckok = new Mat_SeqAIJKokkos(csrmatC));
944   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(C, ckok));
945   C->product->destroy = MatProductDataDestroy_SeqAIJKokkos;
946   PetscFunctionReturn(PETSC_SUCCESS);
947 }
948 
949 /* handles sparse matrix matrix ops */
950 static PetscErrorCode MatProductSetFromOptions_SeqAIJKokkos(Mat mat)
951 {
952   Mat_Product *product = mat->product;
953   PetscBool    Biskok = PETSC_FALSE, Ciskok = PETSC_TRUE;
954 
955   PetscFunctionBegin;
956   MatCheckProduct(mat, 1);
957   PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJKOKKOS, &Biskok));
958   if (product->type == MATPRODUCT_ABC) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJKOKKOS, &Ciskok));
959   if (Biskok && Ciskok) {
960     switch (product->type) {
961     case MATPRODUCT_AB:
962     case MATPRODUCT_AtB:
963     case MATPRODUCT_ABt:
964       mat->ops->productsymbolic = MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos;
965       break;
966     case MATPRODUCT_PtAP:
967     case MATPRODUCT_RARt:
968     case MATPRODUCT_ABC:
969       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
970       break;
971     default:
972       break;
973     }
974   } else { /* fallback for AIJ */
975     PetscCall(MatProductSetFromOptions_SeqAIJ(mat));
976   }
977   PetscFunctionReturn(PETSC_SUCCESS);
978 }
979 
980 static PetscErrorCode MatScale_SeqAIJKokkos(Mat A, PetscScalar a)
981 {
982   Mat_SeqAIJKokkos *aijkok;
983 
984   PetscFunctionBegin;
985   PetscCall(PetscLogGpuTimeBegin());
986   PetscCall(MatSeqAIJKokkosSyncDevice(A));
987   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
988   KokkosBlas::scal(PetscGetKokkosExecutionSpace(), aijkok->a_dual.view_device(), a, aijkok->a_dual.view_device());
989   PetscCall(MatSeqAIJKokkosModifyDevice(A));
990   PetscCall(PetscLogGpuFlops(aijkok->a_dual.extent(0)));
991   PetscCall(PetscLogGpuTimeEnd());
992   PetscFunctionReturn(PETSC_SUCCESS);
993 }
994 
995 // add a to A's diagonal (if A is square) or main diagonal (if A is rectangular)
996 static PetscErrorCode MatShift_SeqAIJKokkos(Mat A, PetscScalar a)
997 {
998   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
999 
1000   PetscFunctionBegin;
1001   if (A->assembled && aijseq->diagonaldense) { // no missing diagonals
1002     PetscInt n = PetscMin(A->rmap->n, A->cmap->n);
1003 
1004     PetscCall(PetscLogGpuTimeBegin());
1005     PetscCall(MatSeqAIJKokkosSyncDevice(A));
1006     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1007     const auto &Aa     = aijkok->a_dual.view_device();
1008     const auto &Adiag  = aijkok->diag_dual.view_device();
1009     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) { Aa(Adiag(i)) += a; }));
1010     PetscCall(MatSeqAIJKokkosModifyDevice(A));
1011     PetscCall(PetscLogGpuFlops(n));
1012     PetscCall(PetscLogGpuTimeEnd());
1013   } else { // need reassembly, very slow!
1014     PetscCall(MatShift_Basic(A, a));
1015   }
1016   PetscFunctionReturn(PETSC_SUCCESS);
1017 }
1018 
1019 static PetscErrorCode MatDiagonalSet_SeqAIJKokkos(Mat Y, Vec D, InsertMode is)
1020 {
1021   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(Y->data);
1022 
1023   PetscFunctionBegin;
1024   if (Y->assembled && aijseq->diagonaldense) { // no missing diagonals
1025     ConstPetscScalarKokkosView dv;
1026     PetscInt                   n, nv;
1027 
1028     PetscCall(PetscLogGpuTimeBegin());
1029     PetscCall(MatSeqAIJKokkosSyncDevice(Y));
1030     PetscCall(VecGetKokkosView(D, &dv));
1031     PetscCall(VecGetLocalSize(D, &nv));
1032     n = PetscMin(Y->rmap->n, Y->cmap->n);
1033     PetscCheck(n == nv, PetscObjectComm((PetscObject)Y), PETSC_ERR_ARG_SIZ, "Matrix size and vector size do not match");
1034 
1035     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1036     const auto &Aa     = aijkok->a_dual.view_device();
1037     const auto &Adiag  = aijkok->diag_dual.view_device();
1038     PetscCallCXX(Kokkos::parallel_for(
1039       Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) {
1040         if (is == INSERT_VALUES) Aa(Adiag(i)) = dv(i);
1041         else Aa(Adiag(i)) += dv(i);
1042       }));
1043     PetscCall(VecRestoreKokkosView(D, &dv));
1044     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1045     PetscCall(PetscLogGpuFlops(n));
1046     PetscCall(PetscLogGpuTimeEnd());
1047   } else { // need reassembly, very slow!
1048     PetscCall(MatDiagonalSet_Default(Y, D, is));
1049   }
1050   PetscFunctionReturn(PETSC_SUCCESS);
1051 }
1052 
1053 static PetscErrorCode MatDiagonalScale_SeqAIJKokkos(Mat A, Vec ll, Vec rr)
1054 {
1055   Mat_SeqAIJ                *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1056   PetscInt                   m = A->rmap->n, n = A->cmap->n, nz = aijseq->nz;
1057   ConstPetscScalarKokkosView lv, rv;
1058 
1059   PetscFunctionBegin;
1060   PetscCall(PetscLogGpuTimeBegin());
1061   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1062   const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1063   const auto &Aa     = aijkok->a_dual.view_device();
1064   const auto &Ai     = aijkok->i_dual.view_device();
1065   const auto &Aj     = aijkok->j_dual.view_device();
1066   if (ll) {
1067     PetscCall(VecGetLocalSize(ll, &m));
1068     PetscCheck(m == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Left scaling vector wrong length");
1069     PetscCall(VecGetKokkosView(ll, &lv));
1070     PetscCallCXX(Kokkos::parallel_for( // for each row
1071       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1072         PetscInt i   = t.league_rank(); // row i
1073         PetscInt len = Ai(i + 1) - Ai(i);
1074         // scale entries on the row
1075         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt j) { Aa(Ai(i) + j) *= lv(i); });
1076       }));
1077     PetscCall(VecRestoreKokkosView(ll, &lv));
1078     PetscCall(PetscLogGpuFlops(nz));
1079   }
1080   if (rr) {
1081     PetscCall(VecGetLocalSize(rr, &n));
1082     PetscCheck(n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Right scaling vector wrong length");
1083     PetscCall(VecGetKokkosView(rr, &rv));
1084     PetscCallCXX(Kokkos::parallel_for( // for each nonzero
1085       Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt k) { Aa(k) *= rv(Aj(k)); }));
1086     PetscCall(VecRestoreKokkosView(rr, &lv));
1087     PetscCall(PetscLogGpuFlops(nz));
1088   }
1089   PetscCall(MatSeqAIJKokkosModifyDevice(A));
1090   PetscCall(PetscLogGpuTimeEnd());
1091   PetscFunctionReturn(PETSC_SUCCESS);
1092 }
1093 
1094 static PetscErrorCode MatZeroEntries_SeqAIJKokkos(Mat A)
1095 {
1096   Mat_SeqAIJKokkos *aijkok;
1097 
1098   PetscFunctionBegin;
1099   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1100   if (aijkok) { /* Only zero the device if data is already there */
1101     KokkosBlas::fill(PetscGetKokkosExecutionSpace(), aijkok->a_dual.view_device(), 0.0);
1102     PetscCall(MatSeqAIJKokkosModifyDevice(A));
1103   } else { /* Might be preallocated but not assembled */
1104     PetscCall(MatZeroEntries_SeqAIJ(A));
1105   }
1106   PetscFunctionReturn(PETSC_SUCCESS);
1107 }
1108 
1109 static PetscErrorCode MatGetDiagonal_SeqAIJKokkos(Mat A, Vec x)
1110 {
1111   Mat_SeqAIJKokkos     *aijkok;
1112   PetscInt              n;
1113   PetscScalarKokkosView xv;
1114 
1115   PetscFunctionBegin;
1116   PetscCall(VecGetLocalSize(x, &n));
1117   PetscCheck(n == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Nonconforming matrix and vector");
1118   PetscCheck(A->factortype == MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_SUP, "MatGetDiagonal_SeqAIJKokkos not supported on factored matrices");
1119 
1120   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1121   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1122 
1123   const auto &Aa    = aijkok->a_dual.view_device();
1124   const auto &Ai    = aijkok->i_dual.view_device();
1125   const auto &Adiag = aijkok->diag_dual.view_device();
1126 
1127   PetscCall(VecGetKokkosViewWrite(x, &xv));
1128   Kokkos::parallel_for(
1129     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) {
1130       if (Adiag(i) < Ai(i + 1)) xv(i) = Aa(Adiag(i));
1131       else xv(i) = 0;
1132     });
1133   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
1134   PetscFunctionReturn(PETSC_SUCCESS);
1135 }
1136 
1137 /* Get a Kokkos View from a mat of type MatSeqAIJKokkos */
1138 PetscErrorCode MatSeqAIJGetKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1139 {
1140   Mat_SeqAIJKokkos *aijkok;
1141 
1142   PetscFunctionBegin;
1143   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1144   PetscAssertPointer(kv, 2);
1145   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1146   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1147   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1148   *kv    = aijkok->a_dual.view_device();
1149   PetscFunctionReturn(PETSC_SUCCESS);
1150 }
1151 
1152 PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1153 {
1154   PetscFunctionBegin;
1155   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1156   PetscAssertPointer(kv, 2);
1157   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1158   PetscFunctionReturn(PETSC_SUCCESS);
1159 }
1160 
1161 PetscErrorCode MatSeqAIJGetKokkosView(Mat A, MatScalarKokkosView *kv)
1162 {
1163   Mat_SeqAIJKokkos *aijkok;
1164 
1165   PetscFunctionBegin;
1166   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1167   PetscAssertPointer(kv, 2);
1168   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1169   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1170   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1171   *kv    = aijkok->a_dual.view_device();
1172   PetscFunctionReturn(PETSC_SUCCESS);
1173 }
1174 
1175 PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, MatScalarKokkosView *kv)
1176 {
1177   PetscFunctionBegin;
1178   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1179   PetscAssertPointer(kv, 2);
1180   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1181   PetscCall(MatSeqAIJKokkosModifyDevice(A));
1182   PetscFunctionReturn(PETSC_SUCCESS);
1183 }
1184 
1185 PetscErrorCode MatSeqAIJGetKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1186 {
1187   Mat_SeqAIJKokkos *aijkok;
1188 
1189   PetscFunctionBegin;
1190   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1191   PetscAssertPointer(kv, 2);
1192   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1193   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1194   *kv    = aijkok->a_dual.view_device();
1195   PetscFunctionReturn(PETSC_SUCCESS);
1196 }
1197 
1198 PetscErrorCode MatSeqAIJRestoreKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1199 {
1200   PetscFunctionBegin;
1201   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1202   PetscAssertPointer(kv, 2);
1203   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1204   PetscCall(MatSeqAIJKokkosModifyDevice(A));
1205   PetscFunctionReturn(PETSC_SUCCESS);
1206 }
1207 
1208 /* Computes Y += alpha X */
1209 static PetscErrorCode MatAXPY_SeqAIJKokkos(Mat Y, PetscScalar alpha, Mat X, MatStructure pattern)
1210 {
1211   Mat_SeqAIJ              *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
1212   Mat_SeqAIJKokkos        *xkok, *ykok, *zkok;
1213   ConstMatScalarKokkosView Xa;
1214   MatScalarKokkosView      Ya;
1215   auto                    &exec = PetscGetKokkosExecutionSpace();
1216 
1217   PetscFunctionBegin;
1218   PetscCheckTypeName(Y, MATSEQAIJKOKKOS);
1219   PetscCheckTypeName(X, MATSEQAIJKOKKOS);
1220   PetscCall(MatSeqAIJKokkosSyncDevice(Y));
1221   PetscCall(MatSeqAIJKokkosSyncDevice(X));
1222   PetscCall(PetscLogGpuTimeBegin());
1223 
1224   if (pattern != SAME_NONZERO_PATTERN && x->nz == y->nz) {
1225     PetscBool e;
1226     PetscCall(PetscArraycmp(x->i, y->i, Y->rmap->n + 1, &e));
1227     if (e) {
1228       PetscCall(PetscArraycmp(x->j, y->j, y->nz, &e));
1229       if (e) pattern = SAME_NONZERO_PATTERN;
1230     }
1231   }
1232 
1233   /* cusparseDcsrgeam2() computes C = alpha A + beta B. If one knew sparsity pattern of C, one can skip
1234     cusparseScsrgeam2_bufferSizeExt() / cusparseXcsrgeam2Nnz(), and directly call cusparseScsrgeam2().
1235     If X is SUBSET_NONZERO_PATTERN of Y, we could take advantage of this cusparse feature. However,
1236     KokkosSparse::spadd(alpha,A,beta,B,C) has symbolic and numeric phases, MatAXPY does not.
1237   */
1238   ykok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1239   xkok = static_cast<Mat_SeqAIJKokkos *>(X->spptr);
1240   Xa   = xkok->a_dual.view_device();
1241   Ya   = ykok->a_dual.view_device();
1242 
1243   if (pattern == SAME_NONZERO_PATTERN) {
1244     KokkosBlas::axpy(exec, alpha, Xa, Ya);
1245     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1246   } else if (pattern == SUBSET_NONZERO_PATTERN) {
1247     MatRowMapKokkosView Xi = xkok->i_dual.view_device(), Yi = ykok->i_dual.view_device();
1248     MatColIdxKokkosView Xj = xkok->j_dual.view_device(), Yj = ykok->j_dual.view_device();
1249 
1250     Kokkos::parallel_for(
1251       Kokkos::TeamPolicy<>(exec, Y->rmap->n, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1252         PetscInt i = t.league_rank(); // row i
1253         Kokkos::single(Kokkos::PerTeam(t), [=]() {
1254           // Only one thread works in a team
1255           PetscInt p, q = Yi(i);
1256           for (p = Xi(i); p < Xi(i + 1); p++) {          // For each nonzero on row i of X,
1257             while (Xj(p) != Yj(q) && q < Yi(i + 1)) q++; // find the matching nonzero on row i of Y.
1258             if (Xj(p) == Yj(q)) {                        // Found it
1259               Ya(q) += alpha * Xa(p);
1260               q++;
1261             } else {
1262             // If not found, it indicates the input is wrong (X is not a SUBSET_NONZERO_PATTERN of Y).
1263             // Just insert a NaN at the beginning of row i if it is not empty, to make the result wrong.
1264 #if PETSC_PKG_KOKKOS_VERSION_GE(3, 7, 0)
1265               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::ArithTraits<PetscScalar>::nan();
1266 #else
1267               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::Experimental::nan("1");
1268 #endif
1269             }
1270           }
1271         });
1272       });
1273     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1274   } else { // different nonzero patterns
1275     Mat             Z;
1276     KokkosCsrMatrix zcsr;
1277     KernelHandle    kh;
1278     kh.create_spadd_handle(true); // X, Y are sorted
1279     KokkosSparse::spadd_symbolic(&kh, xkok->csrmat, ykok->csrmat, zcsr);
1280     KokkosSparse::spadd_numeric(&kh, alpha, xkok->csrmat, (PetscScalar)1.0, ykok->csrmat, zcsr);
1281     zkok = new Mat_SeqAIJKokkos(zcsr);
1282     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, zkok, &Z));
1283     PetscCall(MatHeaderReplace(Y, &Z));
1284     kh.destroy_spadd_handle();
1285   }
1286   PetscCall(PetscLogGpuTimeEnd());
1287   PetscCall(PetscLogGpuFlops(xkok->a_dual.extent(0) * 2)); // Because we scaled X and then added it to Y
1288   PetscFunctionReturn(PETSC_SUCCESS);
1289 }
1290 
1291 struct MatCOOStruct_SeqAIJKokkos {
1292   PetscCount           n;
1293   PetscCount           Atot;
1294   PetscInt             nz;
1295   PetscCountKokkosView jmap;
1296   PetscCountKokkosView perm;
1297 
1298   MatCOOStruct_SeqAIJKokkos(const MatCOOStruct_SeqAIJ *coo_h)
1299   {
1300     nz   = coo_h->nz;
1301     n    = coo_h->n;
1302     Atot = coo_h->Atot;
1303     jmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->jmap, nz + 1));
1304     perm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->perm, Atot));
1305   }
1306 };
1307 
1308 static PetscErrorCode MatCOOStructDestroy_SeqAIJKokkos(void **data)
1309 {
1310   PetscFunctionBegin;
1311   PetscCallCXX(delete static_cast<MatCOOStruct_SeqAIJKokkos *>(*data));
1312   PetscFunctionReturn(PETSC_SUCCESS);
1313 }
1314 
1315 static PetscErrorCode MatSetPreallocationCOO_SeqAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1316 {
1317   Mat_SeqAIJKokkos          *akok;
1318   Mat_SeqAIJ                *aseq;
1319   PetscContainer             container_h;
1320   MatCOOStruct_SeqAIJ       *coo_h;
1321   MatCOOStruct_SeqAIJKokkos *coo_d;
1322 
1323   PetscFunctionBegin;
1324   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, coo_i, coo_j));
1325   aseq = static_cast<Mat_SeqAIJ *>(mat->data);
1326   akok = static_cast<Mat_SeqAIJKokkos *>(mat->spptr);
1327   delete akok;
1328   mat->spptr = akok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, aseq, mat->nonzerostate + 1, PETSC_FALSE);
1329   PetscCall(MatZeroEntries_SeqAIJKokkos(mat));
1330 
1331   // Copy the COO struct to device
1332   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1333   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
1334   PetscCallCXX(coo_d = new MatCOOStruct_SeqAIJKokkos(coo_h));
1335 
1336   // Put the COO struct in a container and then attach that to the matrix
1337   PetscCall(PetscObjectContainerCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", coo_d, MatCOOStructDestroy_SeqAIJKokkos));
1338   PetscFunctionReturn(PETSC_SUCCESS);
1339 }
1340 
1341 static PetscErrorCode MatSetValuesCOO_SeqAIJKokkos(Mat A, const PetscScalar v[], InsertMode imode)
1342 {
1343   MatScalarKokkosView        Aa;
1344   ConstMatScalarKokkosView   kv;
1345   PetscMemType               memtype;
1346   PetscContainer             container;
1347   MatCOOStruct_SeqAIJKokkos *coo;
1348 
1349   PetscFunctionBegin;
1350   PetscCall(PetscObjectQuery((PetscObject)A, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1351   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
1352 
1353   const auto &n    = coo->n;
1354   const auto &Annz = coo->nz;
1355   const auto &jmap = coo->jmap;
1356   const auto &perm = coo->perm;
1357 
1358   PetscCall(PetscGetMemType(v, &memtype));
1359   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
1360     kv = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatScalarKokkosViewHost(v, n));
1361   } else {
1362     kv = ConstMatScalarKokkosView(v, n); /* Directly use v[]'s memory */
1363   }
1364 
1365   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1366   else PetscCall(MatSeqAIJGetKokkosView(A, &Aa));                             /* read & write matrix values */
1367 
1368   PetscCall(PetscLogGpuTimeBegin());
1369   Kokkos::parallel_for(
1370     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz), KOKKOS_LAMBDA(const PetscCount i) {
1371       PetscScalar sum = 0.0;
1372       for (PetscCount k = jmap(i); k < jmap(i + 1); k++) sum += kv(perm(k));
1373       Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1374     });
1375   PetscCall(PetscLogGpuTimeEnd());
1376 
1377   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
1378   else PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1379   PetscFunctionReturn(PETSC_SUCCESS);
1380 }
1381 
1382 static PetscErrorCode MatLUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1383 {
1384   PetscFunctionBegin;
1385   PetscCall(MatSeqAIJKokkosSyncHost(A));
1386   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
1387   B->offloadmask = PETSC_OFFLOAD_CPU;
1388   PetscFunctionReturn(PETSC_SUCCESS);
1389 }
1390 
1391 static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat A)
1392 {
1393   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
1394 
1395   PetscFunctionBegin;
1396   A->offloadmask = PETSC_OFFLOAD_KOKKOS; /* We do not really use this flag */
1397   A->boundtocpu  = PETSC_FALSE;
1398 
1399   A->ops->assemblyend               = MatAssemblyEnd_SeqAIJKokkos;
1400   A->ops->destroy                   = MatDestroy_SeqAIJKokkos;
1401   A->ops->duplicate                 = MatDuplicate_SeqAIJKokkos;
1402   A->ops->axpy                      = MatAXPY_SeqAIJKokkos;
1403   A->ops->scale                     = MatScale_SeqAIJKokkos;
1404   A->ops->zeroentries               = MatZeroEntries_SeqAIJKokkos;
1405   A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJKokkos;
1406   A->ops->mult                      = MatMult_SeqAIJKokkos;
1407   A->ops->multadd                   = MatMultAdd_SeqAIJKokkos;
1408   A->ops->multtranspose             = MatMultTranspose_SeqAIJKokkos;
1409   A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJKokkos;
1410   A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJKokkos;
1411   A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJKokkos;
1412   A->ops->productnumeric            = MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos;
1413   A->ops->transpose                 = MatTranspose_SeqAIJKokkos;
1414   A->ops->setoption                 = MatSetOption_SeqAIJKokkos;
1415   A->ops->getdiagonal               = MatGetDiagonal_SeqAIJKokkos;
1416   A->ops->shift                     = MatShift_SeqAIJKokkos;
1417   A->ops->diagonalset               = MatDiagonalSet_SeqAIJKokkos;
1418   A->ops->diagonalscale             = MatDiagonalScale_SeqAIJKokkos;
1419   a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJKokkos;
1420   a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJKokkos;
1421   a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJKokkos;
1422   a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJKokkos;
1423   a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJKokkos;
1424   a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJKokkos;
1425   a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJKokkos;
1426 
1427   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJKokkos));
1428   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJKokkos));
1429 #if defined(PETSC_HAVE_HYPRE)
1430   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijkokkos_hypre_C", MatConvert_AIJ_HYPRE));
1431 #endif
1432   PetscFunctionReturn(PETSC_SUCCESS);
1433 }
1434 
1435 /*
1436    Extract the (prescribled) diagonal blocks of the matrix and then invert them
1437 
1438   Input Parameters:
1439 +  A       - the MATSEQAIJKOKKOS matrix
1440 .  bs      - block sizes in 'csr' format, i.e., the i-th block has size bs(i+1) - bs(i)
1441 .  bs2     - square of block sizes in 'csr' format, i.e., the i-th block should be stored at offset bs2(i) in diagVal[]
1442 .  blkMap  - map row ids to block ids, i.e., row i belongs to the block blkMap(i)
1443 -  work    - a pre-allocated work buffer (as big as diagVal) for use by this routine
1444 
1445   Output Parameter:
1446 .  diagVal - the (pre-allocated) buffer to store the inverted blocks (each block is stored in column-major order)
1447 */
1448 PETSC_INTERN PetscErrorCode MatInvertVariableBlockDiagonal_SeqAIJKokkos(Mat A, const PetscIntKokkosView &bs, const PetscIntKokkosView &bs2, const PetscIntKokkosView &blkMap, PetscScalarKokkosView &work, PetscScalarKokkosView &diagVal)
1449 {
1450   Mat_SeqAIJKokkos *akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1451   PetscInt          nblocks = bs.extent(0) - 1;
1452 
1453   PetscFunctionBegin;
1454   PetscCall(MatSeqAIJKokkosSyncDevice(A)); // Since we'll access A's value on device
1455 
1456   // Pull out the diagonal blocks of the matrix and then invert the blocks
1457   auto Aa    = akok->a_dual.view_device();
1458   auto Ai    = akok->i_dual.view_device();
1459   auto Aj    = akok->j_dual.view_device();
1460   auto Adiag = akok->diag_dual.view_device();
1461   // TODO: how to tune the team size?
1462 #if defined(KOKKOS_ENABLE_DEFAULT_DEVICE_TYPE_HOST)
1463   auto ts = Kokkos::AUTO();
1464 #else
1465   auto ts = 16; // improved performance 30% over Kokkos::AUTO() with CUDA, but failed with "Kokkos::abort: Requested Team Size is too large!" on CPUs
1466 #endif
1467   PetscCallCXX(Kokkos::parallel_for(
1468     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), nblocks, ts), KOKKOS_LAMBDA(const KokkosTeamMemberType &teamMember) {
1469       const PetscInt bid    = teamMember.league_rank();                                                   // block id
1470       const PetscInt rstart = bs(bid);                                                                    // this block starts from this row
1471       const PetscInt m      = bs(bid + 1) - bs(bid);                                                      // size of this block
1472       const auto    &B      = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft>(&diagVal(bs2(bid)), m, m); // column-major order
1473       const auto    &W      = PetscScalarKokkosView(&work(bs2(bid)), m * m);
1474 
1475       Kokkos::parallel_for(Kokkos::TeamThreadRange(teamMember, m), [=](const PetscInt &r) { // r-th row in B
1476         PetscInt i = rstart + r;                                                            // i-th row in A
1477 
1478         if (Ai(i) <= Adiag(i) && Adiag(i) < Ai(i + 1)) { // if the diagonal exists (common case)
1479           PetscInt first = Adiag(i) - r;                 // we start to check nonzeros from here along this row
1480 
1481           for (PetscInt c = 0; c < m; c++) {                   // walk n steps to see what column indices we will meet
1482             if (first + c < Ai(i) || first + c >= Ai(i + 1)) { // this entry (first+c) is out of range of this row, in other words, its value is zero
1483               B(r, c) = 0.0;
1484             } else if (Aj(first + c) == rstart + c) { // this entry is right on the (rstart+c) column
1485               B(r, c) = Aa(first + c);
1486             } else { // this entry does not show up in the CSR
1487               B(r, c) = 0.0;
1488             }
1489           }
1490         } else { // rare case that the diagonal does not exist
1491           const PetscInt begin = Ai(i);
1492           const PetscInt end   = Ai(i + 1);
1493           for (PetscInt c = 0; c < m; c++) B(r, c) = 0.0;
1494           for (PetscInt j = begin; j < end; j++) { // scan the whole row; could use binary search but this is a rare case so we did not.
1495             if (rstart <= Aj(j) && Aj(j) < rstart + m) B(r, Aj(j) - rstart) = Aa(j);
1496             else if (Aj(j) >= rstart + m) break;
1497           }
1498         }
1499       });
1500 
1501       // LU-decompose B (w/o pivoting) and then invert B
1502       KokkosBatched::TeamLU<KokkosTeamMemberType, KokkosBatched::Algo::LU::Unblocked>::invoke(teamMember, B, 0.0);
1503       KokkosBatched::TeamInverseLU<KokkosTeamMemberType, KokkosBatched::Algo::InverseLU::Unblocked>::invoke(teamMember, B, W);
1504     }));
1505   // PetscLogGpuFlops() is done in the caller PCSetUp_VPBJacobi_Kokkos as we don't want to compute the flops in kernels
1506   PetscFunctionReturn(PETSC_SUCCESS);
1507 }
1508 
1509 PETSC_INTERN PetscErrorCode MatSetSeqAIJKokkosWithCSRMatrix(Mat A, Mat_SeqAIJKokkos *akok)
1510 {
1511   Mat_SeqAIJ *aseq;
1512   PetscInt    i, m, n;
1513   auto       &exec = PetscGetKokkosExecutionSpace();
1514 
1515   PetscFunctionBegin;
1516   PetscCheck(!A->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A->spptr is supposed to be empty");
1517 
1518   m = akok->nrows();
1519   n = akok->ncols();
1520   PetscCall(MatSetSizes(A, m, n, m, n));
1521   PetscCall(MatSetType(A, MATSEQAIJKOKKOS));
1522 
1523   /* Set up data structures of A as a MATSEQAIJ */
1524   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(A, MAT_SKIP_ALLOCATION, NULL));
1525   aseq = (Mat_SeqAIJ *)A->data;
1526 
1527   PetscCallCXX(akok->i_dual.sync_host(exec)); /* We always need sync'ed i, j on host */
1528   PetscCallCXX(akok->j_dual.sync_host(exec));
1529   PetscCallCXX(exec.fence());
1530 
1531   aseq->i       = akok->i_host_data();
1532   aseq->j       = akok->j_host_data();
1533   aseq->a       = akok->a_host_data();
1534   aseq->nonew   = -1; /*this indicates that inserting a new value in the matrix that generates a new nonzero is an error*/
1535   aseq->free_a  = PETSC_FALSE;
1536   aseq->free_ij = PETSC_FALSE;
1537   aseq->nz      = akok->nnz();
1538   aseq->maxnz   = aseq->nz;
1539 
1540   PetscCall(PetscMalloc1(m, &aseq->imax));
1541   PetscCall(PetscMalloc1(m, &aseq->ilen));
1542   for (i = 0; i < m; i++) aseq->ilen[i] = aseq->imax[i] = aseq->i[i + 1] - aseq->i[i];
1543 
1544   /* It is critical to set the nonzerostate, as we use it to check if sparsity pattern (hence data) has changed on host in MatAssemblyEnd */
1545   akok->nonzerostate = A->nonzerostate;
1546   A->spptr           = akok; /* Set A->spptr before MatAssembly so that A->spptr won't be allocated again there */
1547   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
1548   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
1549   PetscFunctionReturn(PETSC_SUCCESS);
1550 }
1551 
1552 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetKokkosCsrMatrix(Mat A, KokkosCsrMatrix *csr)
1553 {
1554   PetscFunctionBegin;
1555   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1556   *csr = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
1557   PetscFunctionReturn(PETSC_SUCCESS);
1558 }
1559 
1560 PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithKokkosCsrMatrix(MPI_Comm comm, KokkosCsrMatrix csr, Mat *A)
1561 {
1562   Mat_SeqAIJKokkos *akok;
1563 
1564   PetscFunctionBegin;
1565   PetscCallCXX(akok = new Mat_SeqAIJKokkos(csr));
1566   PetscCall(MatCreate(comm, A));
1567   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
1568   PetscFunctionReturn(PETSC_SUCCESS);
1569 }
1570 
1571 /* Crete a SEQAIJKOKKOS matrix with a Mat_SeqAIJKokkos data structure
1572 
1573    Note we have names like MatSeqAIJSetPreallocationCSR, so I use capitalized CSR
1574  */
1575 PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithCSRMatrix(MPI_Comm comm, Mat_SeqAIJKokkos *akok, Mat *A)
1576 {
1577   PetscFunctionBegin;
1578   PetscCall(MatCreate(comm, A));
1579   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
1580   PetscFunctionReturn(PETSC_SUCCESS);
1581 }
1582 
1583 /*@C
1584   MatCreateSeqAIJKokkos - Creates a sparse matrix in `MATSEQAIJKOKKOS` (compressed row) format
1585   (the default parallel PETSc format). This matrix will ultimately be handled by
1586   Kokkos for calculations.
1587 
1588   Collective
1589 
1590   Input Parameters:
1591 + comm - MPI communicator, set to `PETSC_COMM_SELF`
1592 . m    - number of rows
1593 . n    - number of columns
1594 . nz   - number of nonzeros per row (same for all rows), ignored if `nnz` is provided
1595 - nnz  - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL`
1596 
1597   Output Parameter:
1598 . A - the matrix
1599 
1600   Level: intermediate
1601 
1602   Notes:
1603   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1604   MatXXXXSetPreallocation() paradgm instead of this routine directly.
1605   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1606 
1607   The AIJ format, also called
1608   compressed row storage, is fully compatible with standard Fortran
1609   storage.  That is, the stored row and column indices can begin at
1610   either one (as in Fortran) or zero.
1611 
1612   Specify the preallocated storage with either `nz` or `nnz` (not both).
1613   Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
1614   allocation.
1615 
1616 .seealso: [](ch_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
1617 @*/
1618 PetscErrorCode MatCreateSeqAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
1619 {
1620   PetscFunctionBegin;
1621   PetscCall(PetscKokkosInitializeCheck());
1622   PetscCall(MatCreate(comm, A));
1623   PetscCall(MatSetSizes(*A, m, n, m, n));
1624   PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1625   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
1626   PetscFunctionReturn(PETSC_SUCCESS);
1627 }
1628 
1629 static PetscErrorCode MatLUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1630 {
1631   PetscFunctionBegin;
1632   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
1633   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
1634   PetscFunctionReturn(PETSC_SUCCESS);
1635 }
1636 
1637 static PetscErrorCode MatSeqAIJKokkosSymbolicSolveCheck(Mat A)
1638 {
1639   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1640 
1641   PetscFunctionBegin;
1642   if (!factors->sptrsv_symbolic_completed) {
1643     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d);
1644     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d);
1645     factors->sptrsv_symbolic_completed = PETSC_TRUE;
1646   }
1647   PetscFunctionReturn(PETSC_SUCCESS);
1648 }
1649 
1650 /* Check if we need to update factors etc for transpose solve */
1651 static PetscErrorCode MatSeqAIJKokkosTransposeSolveCheck(Mat A)
1652 {
1653   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1654   MatColIdxType               n       = A->rmap->n;
1655 
1656   PetscFunctionBegin;
1657   if (!factors->transpose_updated) { /* TODO: KK needs to provide functions to do numeric transpose only */
1658     /* Update L^T and do sptrsv symbolic */
1659     factors->iLt_d = MatRowMapKokkosView("factors->iLt_d", n + 1); // KK requires 0
1660     factors->jLt_d = MatColIdxKokkosView(NoInit("factors->jLt_d"), factors->jL_d.extent(0));
1661     factors->aLt_d = MatScalarKokkosView(NoInit("factors->aLt_d"), factors->aL_d.extent(0));
1662 
1663     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iL_d, factors->jL_d, factors->aL_d,
1664                                                                                                                                                                                                               factors->iLt_d, factors->jLt_d, factors->aLt_d);
1665 
1666     /* TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices.
1667       We have to sort the indices, until KK provides finer control options.
1668     */
1669     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iLt_d, factors->jLt_d, factors->aLt_d);
1670 
1671     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d);
1672 
1673     /* Update U^T and do sptrsv symbolic */
1674     factors->iUt_d = MatRowMapKokkosView("factors->iUt_d", n + 1); // KK requires 0
1675     factors->jUt_d = MatColIdxKokkosView(NoInit("factors->jUt_d"), factors->jU_d.extent(0));
1676     factors->aUt_d = MatScalarKokkosView(NoInit("factors->aUt_d"), factors->aU_d.extent(0));
1677 
1678     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iU_d, factors->jU_d, factors->aU_d,
1679                                                                                                                                                                                                               factors->iUt_d, factors->jUt_d, factors->aUt_d);
1680 
1681     /* Sort indices. See comments above */
1682     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iUt_d, factors->jUt_d, factors->aUt_d);
1683 
1684     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d);
1685     factors->transpose_updated = PETSC_TRUE;
1686   }
1687   PetscFunctionReturn(PETSC_SUCCESS);
1688 }
1689 
1690 /* Solve Ax = b, with A = LU */
1691 static PetscErrorCode MatSolve_SeqAIJKokkos(Mat A, Vec b, Vec x)
1692 {
1693   ConstPetscScalarKokkosView  bv;
1694   PetscScalarKokkosView       xv;
1695   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1696 
1697   PetscFunctionBegin;
1698   PetscCall(PetscLogGpuTimeBegin());
1699   PetscCall(MatSeqAIJKokkosSymbolicSolveCheck(A));
1700   PetscCall(VecGetKokkosView(b, &bv));
1701   PetscCall(VecGetKokkosViewWrite(x, &xv));
1702   /* Solve L tmpv = b */
1703   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d, bv, factors->workVector));
1704   /* Solve Ux = tmpv */
1705   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, factors->workVector, xv));
1706   PetscCall(VecRestoreKokkosView(b, &bv));
1707   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
1708   PetscCall(PetscLogGpuTimeEnd());
1709   PetscFunctionReturn(PETSC_SUCCESS);
1710 }
1711 
1712 /* Solve A^T x = b, where A^T = U^T L^T */
1713 static PetscErrorCode MatSolveTranspose_SeqAIJKokkos(Mat A, Vec b, Vec x)
1714 {
1715   ConstPetscScalarKokkosView  bv;
1716   PetscScalarKokkosView       xv;
1717   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1718 
1719   PetscFunctionBegin;
1720   PetscCall(PetscLogGpuTimeBegin());
1721   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A));
1722   PetscCall(VecGetKokkosView(b, &bv));
1723   PetscCall(VecGetKokkosViewWrite(x, &xv));
1724   /* Solve U^T tmpv = b */
1725   KokkosSparse::Experimental::sptrsv_solve(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, bv, factors->workVector);
1726 
1727   /* Solve L^T x = tmpv */
1728   KokkosSparse::Experimental::sptrsv_solve(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d, factors->workVector, xv);
1729   PetscCall(VecRestoreKokkosView(b, &bv));
1730   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
1731   PetscCall(PetscLogGpuTimeEnd());
1732   PetscFunctionReturn(PETSC_SUCCESS);
1733 }
1734 
1735 static PetscErrorCode MatILUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1736 {
1737   Mat_SeqAIJKokkos           *aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
1738   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
1739   PetscInt                    fill_lev = info->levels;
1740 
1741   PetscFunctionBegin;
1742   PetscCall(PetscLogGpuTimeBegin());
1743   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1744 
1745   auto a_d = aijkok->a_dual.view_device();
1746   auto i_d = aijkok->i_dual.view_device();
1747   auto j_d = aijkok->j_dual.view_device();
1748 
1749   KokkosSparse::Experimental::spiluk_numeric(&factors->kh, fill_lev, i_d, j_d, a_d, factors->iL_d, factors->jL_d, factors->aL_d, factors->iU_d, factors->jU_d, factors->aU_d);
1750 
1751   B->assembled              = PETSC_TRUE;
1752   B->preallocated           = PETSC_TRUE;
1753   B->ops->solve             = MatSolve_SeqAIJKokkos;
1754   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJKokkos;
1755   B->ops->matsolve          = NULL;
1756   B->ops->matsolvetranspose = NULL;
1757   B->offloadmask            = PETSC_OFFLOAD_GPU;
1758 
1759   /* Once the factors' value changed, we need to update their transpose and sptrsv handle */
1760   factors->transpose_updated         = PETSC_FALSE;
1761   factors->sptrsv_symbolic_completed = PETSC_FALSE;
1762   /* TODO: log flops, but how to know that? */
1763   PetscCall(PetscLogGpuTimeEnd());
1764   PetscFunctionReturn(PETSC_SUCCESS);
1765 }
1766 
1767 static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1768 {
1769   Mat_SeqAIJKokkos           *aijkok;
1770   Mat_SeqAIJ                 *b;
1771   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
1772   PetscInt                    fill_lev = info->levels;
1773   PetscInt                    nnzA     = ((Mat_SeqAIJ *)A->data)->nz, nnzL, nnzU;
1774   PetscInt                    n        = A->rmap->n;
1775 
1776   PetscFunctionBegin;
1777   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1778   /* Rebuild factors */
1779   if (factors) {
1780     factors->Destroy();
1781   } /* Destroy the old if it exists */
1782   else {
1783     B->spptr = factors = new Mat_SeqAIJKokkosTriFactors(n);
1784   }
1785 
1786   /* Create a spiluk handle and then do symbolic factorization */
1787   nnzL = nnzU = PetscRealIntMultTruncate(info->fill, nnzA);
1788   factors->kh.create_spiluk_handle(KokkosSparse::Experimental::SPILUKAlgorithm::SEQLVLSCHD_TP1, n, nnzL, nnzU);
1789 
1790   auto spiluk_handle = factors->kh.get_spiluk_handle();
1791 
1792   Kokkos::realloc(factors->iL_d, n + 1); /* Free old arrays and realloc */
1793   Kokkos::realloc(factors->jL_d, spiluk_handle->get_nnzL());
1794   Kokkos::realloc(factors->iU_d, n + 1);
1795   Kokkos::realloc(factors->jU_d, spiluk_handle->get_nnzU());
1796 
1797   aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
1798   auto i_d = aijkok->i_dual.view_device();
1799   auto j_d = aijkok->j_dual.view_device();
1800   KokkosSparse::Experimental::spiluk_symbolic(&factors->kh, fill_lev, i_d, j_d, factors->iL_d, factors->jL_d, factors->iU_d, factors->jU_d);
1801   /* TODO: if spiluk_symbolic is asynchronous, do we need to sync before calling get_nnzL()? */
1802 
1803   Kokkos::resize(factors->jL_d, spiluk_handle->get_nnzL()); /* Shrink or expand, and retain old value */
1804   Kokkos::resize(factors->jU_d, spiluk_handle->get_nnzU());
1805   Kokkos::realloc(factors->aL_d, spiluk_handle->get_nnzL()); /* No need to retain old value */
1806   Kokkos::realloc(factors->aU_d, spiluk_handle->get_nnzU());
1807 
1808   /* TODO: add options to select sptrsv algorithms */
1809   /* Create sptrsv handles for L, U and their transpose */
1810 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1811   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SPTRSV_CUSPARSE;
1812 #else
1813   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SEQLVLSCHD_TP1;
1814 #endif
1815 
1816   factors->khL.create_sptrsv_handle(sptrsv_alg, n, true /* L is lower tri */);
1817   factors->khU.create_sptrsv_handle(sptrsv_alg, n, false /* U is not lower tri */);
1818   factors->khLt.create_sptrsv_handle(sptrsv_alg, n, false /* L^T is not lower tri */);
1819   factors->khUt.create_sptrsv_handle(sptrsv_alg, n, true /* U^T is lower tri */);
1820 
1821   /* Fill fields of the factor matrix B */
1822   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(B, MAT_SKIP_ALLOCATION, NULL));
1823   b     = (Mat_SeqAIJ *)B->data;
1824   b->nz = b->maxnz          = spiluk_handle->get_nnzL() + spiluk_handle->get_nnzU();
1825   B->info.fill_ratio_given  = info->fill;
1826   B->info.fill_ratio_needed = nnzA > 0 ? ((PetscReal)b->nz) / ((PetscReal)nnzA) : 1.0;
1827 
1828   B->offloadmask          = PETSC_OFFLOAD_GPU;
1829   B->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJKokkos;
1830   PetscFunctionReturn(PETSC_SUCCESS);
1831 }
1832 
1833 static PetscErrorCode MatFactorGetSolverType_SeqAIJKokkos(Mat A, MatSolverType *type)
1834 {
1835   PetscFunctionBegin;
1836   *type = MATSOLVERKOKKOS;
1837   PetscFunctionReturn(PETSC_SUCCESS);
1838 }
1839 
1840 /*MC
1841   MATSOLVERKOKKOS = "Kokkos" - A matrix solver type providing triangular solvers for sequential matrices
1842   on a single GPU of type, `MATSEQAIJKOKKOS`, `MATAIJKOKKOS`.
1843 
1844   Level: beginner
1845 
1846 .seealso: [](ch_matrices), `Mat`, `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJKokkos()`, `MATAIJKOKKOS`, `MatKokkosSetFormat()`, `MatKokkosStorageFormat`, `MatKokkosFormatOperation`
1847 M*/
1848 PETSC_EXTERN PetscErrorCode MatGetFactor_SeqAIJKokkos_Kokkos(Mat A, MatFactorType ftype, Mat *B) /* MatGetFactor_<MatType>_<MatSolverType> */
1849 {
1850   PetscInt n = A->rmap->n;
1851 
1852   PetscFunctionBegin;
1853   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B));
1854   PetscCall(MatSetSizes(*B, n, n, n, n));
1855   (*B)->factortype = ftype;
1856   PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
1857   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
1858 
1859   if (ftype == MAT_FACTOR_LU) {
1860     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
1861     (*B)->canuseordering        = PETSC_TRUE;
1862     (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJKokkos;
1863   } else if (ftype == MAT_FACTOR_ILU) {
1864     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
1865     (*B)->canuseordering         = PETSC_FALSE;
1866     (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJKokkos;
1867   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "MatFactorType %s is not supported by MatType SeqAIJKokkos", MatFactorTypes[ftype]);
1868 
1869   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
1870   PetscCall(PetscObjectComposeFunction((PetscObject)*B, "MatFactorGetSolverType_C", MatFactorGetSolverType_SeqAIJKokkos));
1871   PetscFunctionReturn(PETSC_SUCCESS);
1872 }
1873 
1874 PETSC_INTERN PetscErrorCode MatSolverTypeRegister_KOKKOS(void)
1875 {
1876   PetscFunctionBegin;
1877   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_SeqAIJKokkos_Kokkos));
1878   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ILU, MatGetFactor_SeqAIJKokkos_Kokkos));
1879   PetscFunctionReturn(PETSC_SUCCESS);
1880 }
1881 
1882 /* Utility to print out a KokkosCsrMatrix for debugging */
1883 PETSC_INTERN PetscErrorCode PrintCsrMatrix(const KokkosCsrMatrix &csrmat)
1884 {
1885   const auto        &iv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.row_map);
1886   const auto        &jv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.entries);
1887   const auto        &av = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.values);
1888   const PetscInt    *i  = iv.data();
1889   const PetscInt    *j  = jv.data();
1890   const PetscScalar *a  = av.data();
1891   PetscInt           m = csrmat.numRows(), n = csrmat.numCols(), nnz = csrmat.nnz();
1892 
1893   PetscFunctionBegin;
1894   PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT " x %" PetscInt_FMT " SeqAIJKokkos, with %" PetscInt_FMT " nonzeros\n", m, n, nnz));
1895   for (PetscInt k = 0; k < m; k++) {
1896     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT ": ", k));
1897     for (PetscInt p = i[k]; p < i[k + 1]; p++) PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT "(%.1f), ", j[p], (double)PetscRealPart(a[p])));
1898     PetscCall(PetscPrintf(PETSC_COMM_SELF, "\n"));
1899   }
1900   PetscFunctionReturn(PETSC_SUCCESS);
1901 }
1902