xref: /petsc/src/mat/impls/aij/seq/kokkos/aijkok.kokkos.cxx (revision 2d30e087755efd99e28fdfe792ffbeb2ee1ea928)
1 #include <petscvec_kokkos.hpp>
2 #include <petscpkg_version.h>
3 #include <petsc/private/petscimpl.h>
4 #include <petsc/private/sfimpl.h>
5 #include <petscsystypes.h>
6 #include <petscerror.h>
7 
8 #include <Kokkos_Core.hpp>
9 #include <KokkosBlas.hpp>
10 #include <KokkosSparse_CrsMatrix.hpp>
11 #include <KokkosSparse_spmv.hpp>
12 #include <KokkosSparse_spiluk.hpp>
13 #include <KokkosSparse_sptrsv.hpp>
14 #include <KokkosSparse_spgemm.hpp>
15 #include <KokkosSparse_spadd.hpp>
16 
17 #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
18 
19 #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 6, 99)
20 #include <KokkosSparse_Utils.hpp>
21 using KokkosSparse::sort_crs_matrix;
22 using KokkosSparse::Impl::transpose_matrix;
23 #else
24 #include <KokkosKernels_Sorting.hpp>
25 using KokkosKernels::sort_crs_matrix;
26 using KokkosKernels::Impl::transpose_matrix;
27 #endif
28 
29 static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat); /* Forward declaration */
30 
31 /* MatAssemblyEnd_SeqAIJKokkos() happens when we finalized nonzeros of the matrix, either after
32    we assembled the matrix on host, or after we directly produced the matrix data on device (ex., through MatMatMult).
33    In the latter case, it is important to set a_dual's sync state correctly.
34  */
35 static PetscErrorCode MatAssemblyEnd_SeqAIJKokkos(Mat A, MatAssemblyType mode) {
36   Mat_SeqAIJ       *aijseq;
37   Mat_SeqAIJKokkos *aijkok;
38 
39   PetscFunctionBegin;
40   if (mode == MAT_FLUSH_ASSEMBLY) PetscFunctionReturn(0);
41   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
42 
43   aijseq = static_cast<Mat_SeqAIJ *>(A->data);
44   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
45 
46   /* If aijkok does not exist, we just copy i, j to device.
47      If aijkok already exists, but the device's nonzero pattern does not match with the host's, we assume the latest data is on host.
48      In both cases, we build a new aijkok structure.
49   */
50   if (!aijkok || aijkok->nonzerostate != A->nonzerostate) { /* aijkok might not exist yet or nonzero pattern has changed */
51     delete aijkok;
52     aijkok   = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aijseq->nz, aijseq->i, aijseq->j, aijseq->a, A->nonzerostate, PETSC_FALSE /*don't copy mat values to device*/);
53     A->spptr = aijkok;
54   }
55 
56   if (aijkok->device_mat_d.data()) {
57     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
58   }
59   PetscFunctionReturn(0);
60 }
61 
62 /* Sync CSR data to device if not yet */
63 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosSyncDevice(Mat A) {
64   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
65 
66   PetscFunctionBegin;
67   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Cann't sync factorized matrix from host to device");
68   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
69   if (aijkok->a_dual.need_sync_device()) {
70     aijkok->a_dual.sync_device();
71     aijkok->transpose_updated = PETSC_FALSE; /* values of the transpose is out-of-date */
72     aijkok->hermitian_updated = PETSC_FALSE;
73   }
74   PetscFunctionReturn(0);
75 }
76 
77 /* Mark the CSR data on device as modified */
78 PETSC_INTERN PetscErrorCode MatSeqAIJKokkosModifyDevice(Mat A) {
79   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
80 
81   PetscFunctionBegin;
82   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Not supported for factorized matries");
83   aijkok->a_dual.clear_sync_state();
84   aijkok->a_dual.modify_device();
85   aijkok->transpose_updated = PETSC_FALSE;
86   aijkok->hermitian_updated = PETSC_FALSE;
87   PetscCall(MatSeqAIJInvalidateDiagonal(A));
88   PetscCall(PetscObjectStateIncrease((PetscObject)A));
89   PetscFunctionReturn(0);
90 }
91 
92 static PetscErrorCode MatSeqAIJKokkosSyncHost(Mat A) {
93   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
94 
95   PetscFunctionBegin;
96   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
97   /* We do not expect one needs factors on host  */
98   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Cann't sync factorized matrix from device to host");
99   PetscCheck(aijkok, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing AIJKOK");
100   aijkok->a_dual.sync_host();
101   PetscFunctionReturn(0);
102 }
103 
104 static PetscErrorCode MatSeqAIJGetArray_SeqAIJKokkos(Mat A, PetscScalar *array[]) {
105   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
106 
107   PetscFunctionBegin;
108   /* aijkok contains valid pointers only if the host's nonzerostate matches with the device's.
109     Calling MatSeqAIJSetPreallocation() or MatSetValues() on host, where aijseq->{i,j,a} might be
110     reallocated, will lead to stale {i,j,a}_dual in aijkok. In both operations, the hosts's nonzerostate
111     must have been updated. The stale aijkok will be rebuilt during MatAssemblyEnd.
112   */
113   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
114     aijkok->a_dual.sync_host();
115     *array = aijkok->a_dual.view_host().data();
116   } else { /* Happens when calling MatSetValues on a newly created matrix */
117     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
118   }
119   PetscFunctionReturn(0);
120 }
121 
122 static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJKokkos(Mat A, PetscScalar *array[]) {
123   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
124 
125   PetscFunctionBegin;
126   if (aijkok && A->nonzerostate == aijkok->nonzerostate) aijkok->a_dual.modify_host();
127   PetscFunctionReturn(0);
128 }
129 
130 static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[]) {
131   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
132 
133   PetscFunctionBegin;
134   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
135     aijkok->a_dual.sync_host();
136     *array = aijkok->a_dual.view_host().data();
137   } else {
138     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
139   }
140   PetscFunctionReturn(0);
141 }
142 
143 static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[]) {
144   PetscFunctionBegin;
145   *array = NULL;
146   PetscFunctionReturn(0);
147 }
148 
149 static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[]) {
150   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
151 
152   PetscFunctionBegin;
153   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
154     *array = aijkok->a_dual.view_host().data();
155   } else { /* Ex. happens with MatZeroEntries on a preallocated but not assembled matrix */
156     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
157   }
158   PetscFunctionReturn(0);
159 }
160 
161 static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[]) {
162   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
163 
164   PetscFunctionBegin;
165   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
166     aijkok->a_dual.clear_sync_state();
167     aijkok->a_dual.modify_host();
168   }
169   PetscFunctionReturn(0);
170 }
171 
172 static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJKokkos(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype) {
173   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
174 
175   PetscFunctionBegin;
176   PetscCheck(aijkok != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "aijkok is NULL");
177 
178   if (i) *i = aijkok->i_device_data();
179   if (j) *j = aijkok->j_device_data();
180   if (a) {
181     aijkok->a_dual.sync_device();
182     *a = aijkok->a_device_data();
183   }
184   if (mtype) *mtype = PETSC_MEMTYPE_KOKKOS;
185   PetscFunctionReturn(0);
186 }
187 
188 // MatSeqAIJKokkosSetDeviceMat takes a PetscSplitCSRDataStructure with device data and copies it to the device. Note, "deep_copy" here is really a shallow copy
189 PetscErrorCode MatSeqAIJKokkosSetDeviceMat(Mat A, PetscSplitCSRDataStructure h_mat) {
190   Mat_SeqAIJKokkos                            *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
191   Kokkos::View<SplitCSRMat, Kokkos::HostSpace> h_mat_k(h_mat);
192 
193   PetscFunctionBegin;
194   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
195   aijkok->device_mat_d = create_mirror(DefaultMemorySpace(), h_mat_k);
196   Kokkos::deep_copy(aijkok->device_mat_d, h_mat_k);
197   PetscFunctionReturn(0);
198 }
199 
200 // MatSeqAIJKokkosGetDeviceMat gets the device if it is here, otherwise it creates a place for it and returns NULL
201 PetscErrorCode MatSeqAIJKokkosGetDeviceMat(Mat A, PetscSplitCSRDataStructure *d_mat) {
202   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
203 
204   PetscFunctionBegin;
205   if (aijkok && aijkok->device_mat_d.data()) {
206     *d_mat = aijkok->device_mat_d.data();
207   } else {
208     PetscCall(MatSeqAIJKokkosSyncDevice(A)); // create aijkok (we are making d_mat now so make a place for it)
209     *d_mat = NULL;
210   }
211   PetscFunctionReturn(0);
212 }
213 
214 /* Generate the transpose on device and cache it internally */
215 static PetscErrorCode MatSeqAIJKokkosGenerateTranspose_Private(Mat A, KokkosCsrMatrix **csrmatT) {
216   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
217 
218   PetscFunctionBegin;
219   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
220   if (!aijkok->csrmatT.nnz() || !aijkok->transpose_updated) { /* Generate At for the first time OR just update its values */
221     /* FIXME: KK does not separate symbolic/numeric transpose. We could have a permutation array to help value-only update */
222     PetscCallCXX(aijkok->a_dual.sync_device());
223     PetscCallCXX(aijkok->csrmatT = transpose_matrix(aijkok->csrmat));
224     PetscCallCXX(sort_crs_matrix(aijkok->csrmatT));
225     aijkok->transpose_updated = PETSC_TRUE;
226   }
227   *csrmatT = &aijkok->csrmatT;
228   PetscFunctionReturn(0);
229 }
230 
231 /* Generate the Hermitian on device and cache it internally */
232 static PetscErrorCode MatSeqAIJKokkosGenerateHermitian_Private(Mat A, KokkosCsrMatrix **csrmatH) {
233   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
234 
235   PetscFunctionBegin;
236   PetscCall(PetscLogGpuTimeBegin());
237   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
238   if (!aijkok->csrmatH.nnz() || !aijkok->hermitian_updated) { /* Generate Ah for the first time OR just update its values */
239     PetscCallCXX(aijkok->a_dual.sync_device());
240     PetscCallCXX(aijkok->csrmatH = transpose_matrix(aijkok->csrmat));
241     PetscCallCXX(sort_crs_matrix(aijkok->csrmatH));
242 #if defined(PETSC_USE_COMPLEX)
243     const auto &a = aijkok->csrmatH.values;
244     Kokkos::parallel_for(
245       a.extent(0), KOKKOS_LAMBDA(MatRowMapType i) { a(i) = PetscConj(a(i)); });
246 #endif
247     aijkok->hermitian_updated = PETSC_TRUE;
248   }
249   *csrmatH = &aijkok->csrmatH;
250   PetscCall(PetscLogGpuTimeEnd());
251   PetscFunctionReturn(0);
252 }
253 
254 /* y = A x */
255 static PetscErrorCode MatMult_SeqAIJKokkos(Mat A, Vec xx, Vec yy) {
256   Mat_SeqAIJKokkos          *aijkok;
257   ConstPetscScalarKokkosView xv;
258   PetscScalarKokkosView      yv;
259 
260   PetscFunctionBegin;
261   PetscCall(PetscLogGpuTimeBegin());
262   PetscCall(MatSeqAIJKokkosSyncDevice(A));
263   PetscCall(VecGetKokkosView(xx, &xv));
264   PetscCall(VecGetKokkosViewWrite(yy, &yv));
265   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
266   KokkosSparse::spmv("N", 1.0 /*alpha*/, aijkok->csrmat, xv, 0.0 /*beta*/, yv); /* y = alpha A x + beta y */
267   PetscCall(VecRestoreKokkosView(xx, &xv));
268   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
269   /* 2.0*nnz - numRows seems more accurate here but assumes there are no zero-rows. So a little sloppy here. */
270   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
271   PetscCall(PetscLogGpuTimeEnd());
272   PetscFunctionReturn(0);
273 }
274 
275 /* y = A^T x */
276 static PetscErrorCode MatMultTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy) {
277   Mat_SeqAIJKokkos          *aijkok;
278   const char                *mode;
279   ConstPetscScalarKokkosView xv;
280   PetscScalarKokkosView      yv;
281   KokkosCsrMatrix           *csrmat;
282 
283   PetscFunctionBegin;
284   PetscCall(PetscLogGpuTimeBegin());
285   PetscCall(MatSeqAIJKokkosSyncDevice(A));
286   PetscCall(VecGetKokkosView(xx, &xv));
287   PetscCall(VecGetKokkosViewWrite(yy, &yv));
288   if (A->form_explicit_transpose) {
289     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
290     mode = "N";
291   } else {
292     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
293     csrmat = &aijkok->csrmat;
294     mode   = "T";
295   }
296   KokkosSparse::spmv(mode, 1.0 /*alpha*/, *csrmat, xv, 0.0 /*beta*/, yv); /* y = alpha A^T x + beta y */
297   PetscCall(VecRestoreKokkosView(xx, &xv));
298   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
299   PetscCall(PetscLogGpuFlops(2.0 * csrmat->nnz()));
300   PetscCall(PetscLogGpuTimeEnd());
301   PetscFunctionReturn(0);
302 }
303 
304 /* y = A^H x */
305 static PetscErrorCode MatMultHermitianTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy) {
306   Mat_SeqAIJKokkos          *aijkok;
307   const char                *mode;
308   ConstPetscScalarKokkosView xv;
309   PetscScalarKokkosView      yv;
310   KokkosCsrMatrix           *csrmat;
311 
312   PetscFunctionBegin;
313   PetscCall(PetscLogGpuTimeBegin());
314   PetscCall(MatSeqAIJKokkosSyncDevice(A));
315   PetscCall(VecGetKokkosView(xx, &xv));
316   PetscCall(VecGetKokkosViewWrite(yy, &yv));
317   if (A->form_explicit_transpose) {
318     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
319     mode = "N";
320   } else {
321     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
322     csrmat = &aijkok->csrmat;
323     mode   = "C";
324   }
325   KokkosSparse::spmv(mode, 1.0 /*alpha*/, *csrmat, xv, 0.0 /*beta*/, yv); /* y = alpha A^H x + beta y */
326   PetscCall(VecRestoreKokkosView(xx, &xv));
327   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
328   PetscCall(PetscLogGpuFlops(2.0 * csrmat->nnz()));
329   PetscCall(PetscLogGpuTimeEnd());
330   PetscFunctionReturn(0);
331 }
332 
333 /* z = A x + y */
334 static PetscErrorCode MatMultAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz) {
335   Mat_SeqAIJKokkos          *aijkok;
336   ConstPetscScalarKokkosView xv, yv;
337   PetscScalarKokkosView      zv;
338 
339   PetscFunctionBegin;
340   PetscCall(PetscLogGpuTimeBegin());
341   PetscCall(MatSeqAIJKokkosSyncDevice(A));
342   PetscCall(VecGetKokkosView(xx, &xv));
343   PetscCall(VecGetKokkosView(yy, &yv));
344   PetscCall(VecGetKokkosViewWrite(zz, &zv));
345   if (zz != yy) Kokkos::deep_copy(zv, yv);
346   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
347   KokkosSparse::spmv("N", 1.0 /*alpha*/, aijkok->csrmat, xv, 1.0 /*beta*/, zv); /* z = alpha A x + beta z */
348   PetscCall(VecRestoreKokkosView(xx, &xv));
349   PetscCall(VecRestoreKokkosView(yy, &yv));
350   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
351   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
352   PetscCall(PetscLogGpuTimeEnd());
353   PetscFunctionReturn(0);
354 }
355 
356 /* z = A^T x + y */
357 static PetscErrorCode MatMultTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz) {
358   Mat_SeqAIJKokkos          *aijkok;
359   const char                *mode;
360   ConstPetscScalarKokkosView xv, yv;
361   PetscScalarKokkosView      zv;
362   KokkosCsrMatrix           *csrmat;
363 
364   PetscFunctionBegin;
365   PetscCall(PetscLogGpuTimeBegin());
366   PetscCall(MatSeqAIJKokkosSyncDevice(A));
367   PetscCall(VecGetKokkosView(xx, &xv));
368   PetscCall(VecGetKokkosView(yy, &yv));
369   PetscCall(VecGetKokkosViewWrite(zz, &zv));
370   if (zz != yy) Kokkos::deep_copy(zv, yv);
371   if (A->form_explicit_transpose) {
372     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
373     mode = "N";
374   } else {
375     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
376     csrmat = &aijkok->csrmat;
377     mode   = "T";
378   }
379   KokkosSparse::spmv(mode, 1.0 /*alpha*/, *csrmat, xv, 1.0 /*beta*/, zv); /* z = alpha A^T x + beta z */
380   PetscCall(VecRestoreKokkosView(xx, &xv));
381   PetscCall(VecRestoreKokkosView(yy, &yv));
382   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
383   PetscCall(PetscLogGpuFlops(2.0 * csrmat->nnz()));
384   PetscCall(PetscLogGpuTimeEnd());
385   PetscFunctionReturn(0);
386 }
387 
388 /* z = A^H x + y */
389 static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz) {
390   Mat_SeqAIJKokkos          *aijkok;
391   const char                *mode;
392   ConstPetscScalarKokkosView xv, yv;
393   PetscScalarKokkosView      zv;
394   KokkosCsrMatrix           *csrmat;
395 
396   PetscFunctionBegin;
397   PetscCall(PetscLogGpuTimeBegin());
398   PetscCall(MatSeqAIJKokkosSyncDevice(A));
399   PetscCall(VecGetKokkosView(xx, &xv));
400   PetscCall(VecGetKokkosView(yy, &yv));
401   PetscCall(VecGetKokkosViewWrite(zz, &zv));
402   if (zz != yy) Kokkos::deep_copy(zv, yv);
403   if (A->form_explicit_transpose) {
404     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
405     mode = "N";
406   } else {
407     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
408     csrmat = &aijkok->csrmat;
409     mode   = "C";
410   }
411   KokkosSparse::spmv(mode, 1.0 /*alpha*/, *csrmat, xv, 1.0 /*beta*/, zv); /* z = alpha A^H x + beta z */
412   PetscCall(VecRestoreKokkosView(xx, &xv));
413   PetscCall(VecRestoreKokkosView(yy, &yv));
414   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
415   PetscCall(PetscLogGpuFlops(2.0 * csrmat->nnz()));
416   PetscCall(PetscLogGpuTimeEnd());
417   PetscFunctionReturn(0);
418 }
419 
420 PetscErrorCode MatSetOption_SeqAIJKokkos(Mat A, MatOption op, PetscBool flg) {
421   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
422 
423   PetscFunctionBegin;
424   switch (op) {
425   case MAT_FORM_EXPLICIT_TRANSPOSE:
426     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
427     if (A->form_explicit_transpose && !flg && aijkok) PetscCall(aijkok->DestroyMatTranspose());
428     A->form_explicit_transpose = flg;
429     break;
430   default: PetscCall(MatSetOption_SeqAIJ(A, op, flg)); break;
431   }
432   PetscFunctionReturn(0);
433 }
434 
435 /* Depending on reuse, either build a new mat, or use the existing mat */
436 PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) {
437   Mat_SeqAIJ *aseq;
438 
439   PetscFunctionBegin;
440   PetscCall(PetscKokkosInitializeCheck());
441   if (reuse == MAT_INITIAL_MATRIX) {                      /* Build a brand new mat */
442     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));  /* the returned newmat is a SeqAIJKokkos */
443   } else if (reuse == MAT_REUSE_MATRIX) {                 /* Reuse the mat created before */
444     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); /* newmat is already a SeqAIJKokkos */
445   } else if (reuse == MAT_INPLACE_MATRIX) {               /* newmat is A */
446     PetscCheck(A == *newmat, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "A != *newmat with MAT_INPLACE_MATRIX");
447     PetscCall(PetscFree(A->defaultvectype));
448     PetscCall(PetscStrallocpy(VECKOKKOS, &A->defaultvectype)); /* Allocate and copy the string */
449     PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSEQAIJKOKKOS));
450     PetscCall(MatSetOps_SeqAIJKokkos(A));
451     aseq = static_cast<Mat_SeqAIJ *>(A->data);
452     if (A->assembled) { /* Copy i, j (but not values) to device for an assembled matrix if not yet */
453       PetscCheck(!A->spptr, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Expect NULL (Mat_SeqAIJKokkos*)A->spptr");
454       A->spptr = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aseq->nz, aseq->i, aseq->j, aseq->a, A->nonzerostate, PETSC_FALSE);
455     }
456   }
457   PetscFunctionReturn(0);
458 }
459 
460 /* MatDuplicate always creates a new matrix. MatDuplicate can be called either on an assembled matrix or
461    an unassembled matrix, even though MAT_COPY_VALUES is not allowed for unassembled matrix.
462  */
463 static PetscErrorCode MatDuplicate_SeqAIJKokkos(Mat A, MatDuplicateOption dupOption, Mat *B) {
464   Mat_SeqAIJ       *bseq;
465   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok;
466   Mat               mat;
467 
468   PetscFunctionBegin;
469   /* Do not copy values on host as A's latest values might be on device. We don't want to do sync blindly */
470   PetscCall(MatDuplicate_SeqAIJ(A, MAT_DO_NOT_COPY_VALUES, B));
471   mat = *B;
472   if (A->assembled) {
473     bseq = static_cast<Mat_SeqAIJ *>(mat->data);
474     bkok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, bseq->nz, bseq->i, bseq->j, bseq->a, mat->nonzerostate, PETSC_FALSE);
475     bkok->a_dual.clear_sync_state(); /* Clear B's sync state as it will be decided below */
476     /* Now copy values to B if needed */
477     if (dupOption == MAT_COPY_VALUES) {
478       if (akok->a_dual.need_sync_device()) {
479         Kokkos::deep_copy(bkok->a_dual.view_host(), akok->a_dual.view_host());
480         bkok->a_dual.modify_host();
481       } else { /* If device has the latest data, we only copy data on device */
482         Kokkos::deep_copy(bkok->a_dual.view_device(), akok->a_dual.view_device());
483         bkok->a_dual.modify_device();
484       }
485     } else { /* MAT_DO_NOT_COPY_VALUES or MAT_SHARE_NONZERO_PATTERN. B's values should be zeroed */
486       /* B's values on host should be already zeroed by MatDuplicate_SeqAIJ() */
487       bkok->a_dual.modify_host();
488     }
489     mat->spptr = bkok;
490   }
491 
492   PetscCall(PetscFree(mat->defaultvectype));
493   PetscCall(PetscStrallocpy(VECKOKKOS, &mat->defaultvectype)); /* Allocate and copy the string */
494   PetscCall(PetscObjectChangeTypeName((PetscObject)mat, MATSEQAIJKOKKOS));
495   PetscCall(MatSetOps_SeqAIJKokkos(mat));
496   PetscFunctionReturn(0);
497 }
498 
499 static PetscErrorCode MatTranspose_SeqAIJKokkos(Mat A, MatReuse reuse, Mat *B) {
500   Mat               At;
501   KokkosCsrMatrix  *internT;
502   Mat_SeqAIJKokkos *atkok, *bkok;
503 
504   PetscFunctionBegin;
505   if (reuse == MAT_REUSE_MATRIX) PetscCall(MatTransposeCheckNonzeroState_Private(A, *B));
506   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &internT)); /* Generate a transpose internally */
507   if (reuse == MAT_INITIAL_MATRIX || reuse == MAT_INPLACE_MATRIX) {
508     /* Deep copy internT, as we want to isolate the internal transpose */
509     PetscCallCXX(atkok = new Mat_SeqAIJKokkos(KokkosCsrMatrix("csrmat", *internT)));
510     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PetscObjectComm((PetscObject)A), atkok, &At));
511     if (reuse == MAT_INITIAL_MATRIX) *B = At;
512     else PetscCall(MatHeaderReplace(A, &At)); /* Replace A with At inplace */
513   } else {                                    /* MAT_REUSE_MATRIX, just need to copy values to B on device */
514     if ((*B)->assembled) {
515       bkok = static_cast<Mat_SeqAIJKokkos *>((*B)->spptr);
516       PetscCallCXX(Kokkos::deep_copy(bkok->a_dual.view_device(), internT->values));
517       PetscCall(MatSeqAIJKokkosModifyDevice(*B));
518     } else if ((*B)->preallocated) { /* It is ok for B to be only preallocated, as needed in MatTranspose_MPIAIJ */
519       Mat_SeqAIJ             *bseq = static_cast<Mat_SeqAIJ *>((*B)->data);
520       MatScalarKokkosViewHost a_h(bseq->a, internT->nnz()); /* bseq->nz = 0 if unassembled */
521       MatColIdxKokkosViewHost j_h(bseq->j, internT->nnz());
522       PetscCallCXX(Kokkos::deep_copy(a_h, internT->values));
523       PetscCallCXX(Kokkos::deep_copy(j_h, internT->graph.entries));
524     } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "B must be assembled or preallocated");
525   }
526   PetscFunctionReturn(0);
527 }
528 
529 static PetscErrorCode MatDestroy_SeqAIJKokkos(Mat A) {
530   Mat_SeqAIJKokkos *aijkok;
531 
532   PetscFunctionBegin;
533   if (A->factortype == MAT_FACTOR_NONE) {
534     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
535     delete aijkok;
536   } else {
537     delete static_cast<Mat_SeqAIJKokkosTriFactors *>(A->spptr);
538   }
539   A->spptr = NULL;
540   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
541   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
542   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
543   PetscCall(MatDestroy_SeqAIJ(A));
544   PetscFunctionReturn(0);
545 }
546 
547 /*MC
548    MATSEQAIJKOKKOS - MATAIJKOKKOS = "(seq)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
549 
550    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
551 
552    Options Database Keys:
553 .  -mat_type aijkokkos - sets the matrix type to `MATSEQAIJKOKKOS` during a call to `MatSetFromOptions()`
554 
555   Level: beginner
556 
557 .seealso: `MatCreateSeqAIJKokkos()`, `MATMPIAIJKOKKOS`
558 M*/
559 PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJKokkos(Mat A) {
560   PetscFunctionBegin;
561   PetscCall(PetscKokkosInitializeCheck());
562   PetscCall(MatCreate_SeqAIJ(A));
563   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
564   PetscFunctionReturn(0);
565 }
566 
567 /* Merge A, B into a matrix C. A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n) */
568 PetscErrorCode MatSeqAIJKokkosMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C) {
569   Mat_SeqAIJ         *a, *b;
570   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
571   MatScalarKokkosView aa, ba, ca;
572   MatRowMapKokkosView ai, bi, ci;
573   MatColIdxKokkosView aj, bj, cj;
574   PetscInt            m, n, nnz, aN;
575 
576   PetscFunctionBegin;
577   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
578   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
579   PetscValidPointer(C, 4);
580   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
581   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
582   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
583   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
584 
585   PetscCall(MatSeqAIJKokkosSyncDevice(A));
586   PetscCall(MatSeqAIJKokkosSyncDevice(B));
587   a    = static_cast<Mat_SeqAIJ *>(A->data);
588   b    = static_cast<Mat_SeqAIJ *>(B->data);
589   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
590   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
591   aa   = akok->a_dual.view_device();
592   ai   = akok->i_dual.view_device();
593   ba   = bkok->a_dual.view_device();
594   bi   = bkok->i_dual.view_device();
595   m    = A->rmap->n; /* M, N and nnz of C */
596   n    = A->cmap->n + B->cmap->n;
597   nnz  = a->nz + b->nz;
598   aN   = A->cmap->n; /* N of A */
599   if (reuse == MAT_INITIAL_MATRIX) {
600     aj           = akok->j_dual.view_device();
601     bj           = bkok->j_dual.view_device();
602     auto ca_dual = MatScalarKokkosDualView("a", aa.extent(0) + ba.extent(0));
603     auto ci_dual = MatRowMapKokkosDualView("i", ai.extent(0));
604     auto cj_dual = MatColIdxKokkosDualView("j", aj.extent(0) + bj.extent(0));
605     ca           = ca_dual.view_device();
606     ci           = ci_dual.view_device();
607     cj           = cj_dual.view_device();
608 
609     /* Concatenate A and B in parallel using Kokkos hierarchical parallelism */
610     Kokkos::parallel_for(
611       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
612         PetscInt i       = t.league_rank(); /* row i */
613         PetscInt coffset = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
614 
615         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* this side effect only happens once per whole team */
616                                                    ci(i) = coffset;
617                                                    if (i == m - 1) ci(m) = ai(m) + bi(m);
618         });
619 
620         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
621           if (k < alen) {
622             ca(coffset + k) = aa(ai(i) + k);
623             cj(coffset + k) = aj(ai(i) + k);
624           } else {
625             ca(coffset + k) = ba(bi(i) + k - alen);
626             cj(coffset + k) = bj(bi(i) + k - alen) + aN; /* Entries in B get new column indices in C */
627           }
628         });
629       });
630     ca_dual.modify_device();
631     ci_dual.modify_device();
632     cj_dual.modify_device();
633     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, n, nnz, ci_dual, cj_dual, ca_dual));
634     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
635   } else if (reuse == MAT_REUSE_MATRIX) {
636     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
637     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
638     ckok = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
639     ca   = ckok->a_dual.view_device();
640     ci   = ckok->i_dual.view_device();
641 
642     Kokkos::parallel_for(
643       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
644         PetscInt i    = t.league_rank(); /* row i */
645         PetscInt alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
646         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
647           if (k < alen) ca(ci(i) + k) = aa(ai(i) + k);
648           else ca(ci(i) + k) = ba(bi(i) + k - alen);
649         });
650       });
651     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
652   }
653   PetscFunctionReturn(0);
654 }
655 
656 static PetscErrorCode MatProductDataDestroy_SeqAIJKokkos(void *pdata) {
657   PetscFunctionBegin;
658   delete static_cast<MatProductData_SeqAIJKokkos *>(pdata);
659   PetscFunctionReturn(0);
660 }
661 
662 static PetscErrorCode MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos(Mat C) {
663   Mat_Product                 *product = C->product;
664   Mat                          A, B;
665   bool                         transA, transB; /* use bool, since KK needs this type */
666   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
667   Mat_SeqAIJ                  *c;
668   MatProductData_SeqAIJKokkos *pdata;
669   KokkosCsrMatrix             *csrmatA, *csrmatB;
670 
671   PetscFunctionBegin;
672   MatCheckProduct(C, 1);
673   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data empty");
674   pdata = static_cast<MatProductData_SeqAIJKokkos *>(C->product->data);
675 
676   if (pdata->reusesym) {           /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
677     pdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
678     PetscFunctionReturn(0);
679   }
680 
681   switch (product->type) {
682   case MATPRODUCT_AB:
683     transA = false;
684     transB = false;
685     break;
686   case MATPRODUCT_AtB:
687     transA = true;
688     transB = false;
689     break;
690   case MATPRODUCT_ABt:
691     transA = false;
692     transB = true;
693     break;
694   default: SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
695   }
696 
697   A = product->A;
698   B = product->B;
699   PetscCall(MatSeqAIJKokkosSyncDevice(A));
700   PetscCall(MatSeqAIJKokkosSyncDevice(B));
701   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
702   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
703   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
704 
705   PetscCheck(ckok, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Device data structure spptr is empty");
706 
707   csrmatA = &akok->csrmat;
708   csrmatB = &bkok->csrmat;
709 
710   /* TODO: Once KK spgemm implements transpose, we can get rid of the explicit transpose here */
711   if (transA) {
712     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
713     transA = false;
714   }
715 
716   if (transB) {
717     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
718     transB = false;
719   }
720   PetscCall(PetscLogGpuTimeBegin());
721   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, *csrmatA, transA, *csrmatB, transB, ckok->csrmat));
722   auto spgemmHandle = pdata->kh.get_spgemm_handle();
723   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(ckok->csrmat)); /* without sort, mat_tests-ex62_14_seqaijkokkos fails */
724 
725   PetscCall(PetscLogGpuTimeEnd());
726   PetscCall(MatSeqAIJKokkosModifyDevice(C));
727   /* shorter version of MatAssemblyEnd_SeqAIJ */
728   c = (Mat_SeqAIJ *)C->data;
729   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded,%" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
730   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
731   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
732   c->reallocs         = 0;
733   C->info.mallocs     = 0;
734   C->info.nz_unneeded = 0;
735   C->assembled = C->was_assembled = PETSC_TRUE;
736   C->num_ass++;
737   PetscFunctionReturn(0);
738 }
739 
740 static PetscErrorCode MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos(Mat C) {
741   Mat_Product                 *product = C->product;
742   MatProductType               ptype;
743   Mat                          A, B;
744   bool                         transA, transB;
745   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
746   MatProductData_SeqAIJKokkos *pdata;
747   MPI_Comm                     comm;
748   KokkosCsrMatrix             *csrmatA, *csrmatB, csrmatC;
749 
750   PetscFunctionBegin;
751   MatCheckProduct(C, 1);
752   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
753   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
754   A = product->A;
755   B = product->B;
756   PetscCall(MatSeqAIJKokkosSyncDevice(A));
757   PetscCall(MatSeqAIJKokkosSyncDevice(B));
758   akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
759   bkok    = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
760   csrmatA = &akok->csrmat;
761   csrmatB = &bkok->csrmat;
762 
763   ptype = product->type;
764   switch (ptype) {
765   case MATPRODUCT_AB:
766     transA = false;
767     transB = false;
768     break;
769   case MATPRODUCT_AtB:
770     transA = true;
771     transB = false;
772     break;
773   case MATPRODUCT_ABt:
774     transA = false;
775     transB = true;
776     break;
777   default: SETERRQ(comm, PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
778   }
779 
780   product->data = pdata = new MatProductData_SeqAIJKokkos();
781   pdata->kh.set_team_work_size(16);
782   pdata->kh.set_dynamic_scheduling(true);
783   pdata->reusesym = product->api_user;
784 
785   /* TODO: add command line options to select spgemm algorithms */
786   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; /* default alg is TPL if enabled, otherwise KK */
787 
788   /* CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 */
789 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
790 #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
791   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
792 #endif
793 #endif
794 
795   pdata->kh.create_spgemm_handle(spgemm_alg);
796 
797   PetscCall(PetscLogGpuTimeBegin());
798   /* TODO: Get rid of the explicit transpose once KK-spgemm implements the transpose option */
799   if (transA) {
800     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
801     transA = false;
802   }
803 
804   if (transB) {
805     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
806     transB = false;
807   }
808 
809   PetscCallCXX(KokkosSparse::spgemm_symbolic(pdata->kh, *csrmatA, transA, *csrmatB, transB, csrmatC));
810 
811   /* spgemm_symbolic() only populates C's rowmap, but not C's column indices.
812     So we have to do a fake spgemm_numeric() here to get csrmatC.j_d setup, before
813     calling new Mat_SeqAIJKokkos().
814     TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
815   */
816   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, *csrmatA, transA, *csrmatB, transB, csrmatC));
817   /* Query if KK outputs a sorted matrix. If not, we need to sort it */
818   auto spgemmHandle = pdata->kh.get_spgemm_handle();
819   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(csrmatC)); /* sort_option defaults to -1 in KK!*/
820   PetscCall(PetscLogGpuTimeEnd());
821 
822   PetscCallCXX(ckok = new Mat_SeqAIJKokkos(csrmatC));
823   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(C, ckok));
824   C->product->destroy = MatProductDataDestroy_SeqAIJKokkos;
825   PetscFunctionReturn(0);
826 }
827 
828 /* handles sparse matrix matrix ops */
829 static PetscErrorCode MatProductSetFromOptions_SeqAIJKokkos(Mat mat) {
830   Mat_Product *product = mat->product;
831   PetscBool    Biskok = PETSC_FALSE, Ciskok = PETSC_TRUE;
832 
833   PetscFunctionBegin;
834   MatCheckProduct(mat, 1);
835   PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJKOKKOS, &Biskok));
836   if (product->type == MATPRODUCT_ABC) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJKOKKOS, &Ciskok));
837   if (Biskok && Ciskok) {
838     switch (product->type) {
839     case MATPRODUCT_AB:
840     case MATPRODUCT_AtB:
841     case MATPRODUCT_ABt: mat->ops->productsymbolic = MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos; break;
842     case MATPRODUCT_PtAP:
843     case MATPRODUCT_RARt:
844     case MATPRODUCT_ABC: mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic; break;
845     default: break;
846     }
847   } else { /* fallback for AIJ */
848     PetscCall(MatProductSetFromOptions_SeqAIJ(mat));
849   }
850   PetscFunctionReturn(0);
851 }
852 
853 static PetscErrorCode MatScale_SeqAIJKokkos(Mat A, PetscScalar a) {
854   Mat_SeqAIJKokkos *aijkok;
855 
856   PetscFunctionBegin;
857   PetscCall(PetscLogGpuTimeBegin());
858   PetscCall(MatSeqAIJKokkosSyncDevice(A));
859   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
860   KokkosBlas::scal(aijkok->a_dual.view_device(), a, aijkok->a_dual.view_device());
861   PetscCall(MatSeqAIJKokkosModifyDevice(A));
862   PetscCall(PetscLogGpuFlops(aijkok->a_dual.extent(0)));
863   PetscCall(PetscLogGpuTimeEnd());
864   PetscFunctionReturn(0);
865 }
866 
867 static PetscErrorCode MatZeroEntries_SeqAIJKokkos(Mat A) {
868   Mat_SeqAIJKokkos *aijkok;
869 
870   PetscFunctionBegin;
871   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
872   if (aijkok) { /* Only zero the device if data is already there */
873     KokkosBlas::fill(aijkok->a_dual.view_device(), 0.0);
874     PetscCall(MatSeqAIJKokkosModifyDevice(A));
875   } else { /* Might be preallocated but not assembled */
876     PetscCall(MatZeroEntries_SeqAIJ(A));
877   }
878   PetscFunctionReturn(0);
879 }
880 
881 static PetscErrorCode MatGetDiagonal_SeqAIJKokkos(Mat A, Vec x) {
882   Mat_SeqAIJ           *aijseq;
883   Mat_SeqAIJKokkos     *aijkok;
884   PetscInt              n;
885   PetscScalarKokkosView xv;
886 
887   PetscFunctionBegin;
888   PetscCall(VecGetLocalSize(x, &n));
889   PetscCheck(n == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Nonconforming matrix and vector");
890   PetscCheck(A->factortype == MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_SUP, "MatGetDiagonal_SeqAIJKokkos not supported on factored matrices");
891 
892   PetscCall(MatSeqAIJKokkosSyncDevice(A));
893   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
894 
895   if (A->rmap->n && aijkok->diag_dual.extent(0) == 0) { /* Set the diagonal pointer if not already */
896     PetscCall(MatMarkDiagonal_SeqAIJ(A));
897     aijseq = static_cast<Mat_SeqAIJ *>(A->data);
898     aijkok->SetDiagonal(aijseq->diag);
899   }
900 
901   const auto &Aa    = aijkok->a_dual.view_device();
902   const auto &Ai    = aijkok->i_dual.view_device();
903   const auto &Adiag = aijkok->diag_dual.view_device();
904 
905   PetscCall(VecGetKokkosViewWrite(x, &xv));
906   Kokkos::parallel_for(
907     n, KOKKOS_LAMBDA(const PetscInt i) {
908       if (Adiag(i) < Ai(i + 1)) xv(i) = Aa(Adiag(i));
909       else xv(i) = 0;
910     });
911   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
912   PetscFunctionReturn(0);
913 }
914 
915 /* Get a Kokkos View from a mat of type MatSeqAIJKokkos */
916 PetscErrorCode MatSeqAIJGetKokkosView(Mat A, ConstMatScalarKokkosView *kv) {
917   Mat_SeqAIJKokkos *aijkok;
918 
919   PetscFunctionBegin;
920   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
921   PetscValidPointer(kv, 2);
922   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
923   PetscCall(MatSeqAIJKokkosSyncDevice(A));
924   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
925   *kv    = aijkok->a_dual.view_device();
926   PetscFunctionReturn(0);
927 }
928 
929 PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, ConstMatScalarKokkosView *kv) {
930   PetscFunctionBegin;
931   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
932   PetscValidPointer(kv, 2);
933   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
934   PetscFunctionReturn(0);
935 }
936 
937 PetscErrorCode MatSeqAIJGetKokkosView(Mat A, MatScalarKokkosView *kv) {
938   Mat_SeqAIJKokkos *aijkok;
939 
940   PetscFunctionBegin;
941   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
942   PetscValidPointer(kv, 2);
943   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
944   PetscCall(MatSeqAIJKokkosSyncDevice(A));
945   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
946   *kv    = aijkok->a_dual.view_device();
947   PetscFunctionReturn(0);
948 }
949 
950 PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, MatScalarKokkosView *kv) {
951   PetscFunctionBegin;
952   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
953   PetscValidPointer(kv, 2);
954   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
955   PetscCall(MatSeqAIJKokkosModifyDevice(A));
956   PetscFunctionReturn(0);
957 }
958 
959 PetscErrorCode MatSeqAIJGetKokkosViewWrite(Mat A, MatScalarKokkosView *kv) {
960   Mat_SeqAIJKokkos *aijkok;
961 
962   PetscFunctionBegin;
963   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
964   PetscValidPointer(kv, 2);
965   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
966   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
967   *kv    = aijkok->a_dual.view_device();
968   PetscFunctionReturn(0);
969 }
970 
971 PetscErrorCode MatSeqAIJRestoreKokkosViewWrite(Mat A, MatScalarKokkosView *kv) {
972   PetscFunctionBegin;
973   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
974   PetscValidPointer(kv, 2);
975   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
976   PetscCall(MatSeqAIJKokkosModifyDevice(A));
977   PetscFunctionReturn(0);
978 }
979 
980 /* Computes Y += alpha X */
981 static PetscErrorCode MatAXPY_SeqAIJKokkos(Mat Y, PetscScalar alpha, Mat X, MatStructure pattern) {
982   Mat_SeqAIJ              *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
983   Mat_SeqAIJKokkos        *xkok, *ykok, *zkok;
984   ConstMatScalarKokkosView Xa;
985   MatScalarKokkosView      Ya;
986 
987   PetscFunctionBegin;
988   PetscCheckTypeName(Y, MATSEQAIJKOKKOS);
989   PetscCheckTypeName(X, MATSEQAIJKOKKOS);
990   PetscCall(MatSeqAIJKokkosSyncDevice(Y));
991   PetscCall(MatSeqAIJKokkosSyncDevice(X));
992   PetscCall(PetscLogGpuTimeBegin());
993 
994   if (pattern != SAME_NONZERO_PATTERN && x->nz == y->nz) {
995     /* We could compare on device, but have to get the comparison result on host. So compare on host instead. */
996     PetscBool e;
997     PetscCall(PetscArraycmp(x->i, y->i, Y->rmap->n + 1, &e));
998     if (e) {
999       PetscCall(PetscArraycmp(x->j, y->j, y->nz, &e));
1000       if (e) pattern = SAME_NONZERO_PATTERN;
1001     }
1002   }
1003 
1004   /* cusparseDcsrgeam2() computes C = alpha A + beta B. If one knew sparsity pattern of C, one can skip
1005     cusparseScsrgeam2_bufferSizeExt() / cusparseXcsrgeam2Nnz(), and directly call cusparseScsrgeam2().
1006     If X is SUBSET_NONZERO_PATTERN of Y, we could take advantage of this cusparse feature. However,
1007     KokkosSparse::spadd(alpha,A,beta,B,C) has symbolic and numeric phases, MatAXPY does not.
1008   */
1009   ykok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1010   xkok = static_cast<Mat_SeqAIJKokkos *>(X->spptr);
1011   Xa   = xkok->a_dual.view_device();
1012   Ya   = ykok->a_dual.view_device();
1013 
1014   if (pattern == SAME_NONZERO_PATTERN) {
1015     KokkosBlas::axpy(alpha, Xa, Ya);
1016     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1017   } else if (pattern == SUBSET_NONZERO_PATTERN) {
1018     MatRowMapKokkosView Xi = xkok->i_dual.view_device(), Yi = ykok->i_dual.view_device();
1019     MatColIdxKokkosView Xj = xkok->j_dual.view_device(), Yj = ykok->j_dual.view_device();
1020 
1021     Kokkos::parallel_for(
1022       Kokkos::TeamPolicy<>(Y->rmap->n, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1023         PetscInt i = t.league_rank();              /* row i */
1024         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* Only one thread works in a team */
1025                                                    PetscInt p, q = Yi(i);
1026                                                    for (p = Xi(i); p < Xi(i + 1); p++) {          /* For each nonzero on row i of X */
1027                                                      while (Xj(p) != Yj(q) && q < Yi(i + 1)) q++; /* find the matching nonzero on row i of Y */
1028                                                      if (Xj(p) == Yj(q)) {                        /* Found it */
1029                                                        Ya(q) += alpha * Xa(p);
1030                                                        q++;
1031                                                      } else {
1032                                                        /* If not found, it indicates the input is wrong (X is not a SUBSET_NONZERO_PATTERN of Y).
1033                Just insert a NaN at the beginning of row i if it is not empty, to make the result wrong.
1034             */
1035                                                        if (Yi(i) != Yi(i + 1))
1036                                                          Ya(Yi(i)) =
1037 #if PETSC_PKG_KOKKOS_VERSION_GE(3, 6, 99)
1038                                                            Kokkos::nan("1"); /* auto promote the double NaN if needed */
1039 #else
1040               Kokkos::Experimental::nan("1");
1041 #endif
1042                                                      }
1043                                                    }
1044         });
1045       });
1046     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1047   } else { /* different nonzero patterns */
1048     Mat             Z;
1049     KokkosCsrMatrix zcsr;
1050     KernelHandle    kh;
1051     kh.create_spadd_handle(false);
1052     KokkosSparse::spadd_symbolic(&kh, xkok->csrmat, ykok->csrmat, zcsr);
1053     KokkosSparse::spadd_numeric(&kh, alpha, xkok->csrmat, (PetscScalar)1.0, ykok->csrmat, zcsr);
1054     zkok = new Mat_SeqAIJKokkos(zcsr);
1055     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, zkok, &Z));
1056     PetscCall(MatHeaderReplace(Y, &Z));
1057     kh.destroy_spadd_handle();
1058   }
1059   PetscCall(PetscLogGpuTimeEnd());
1060   PetscCall(PetscLogGpuFlops(xkok->a_dual.extent(0) * 2)); /* Because we scaled X and then added it to Y */
1061   PetscFunctionReturn(0);
1062 }
1063 
1064 static PetscErrorCode MatSetPreallocationCOO_SeqAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) {
1065   Mat_SeqAIJKokkos *akok;
1066   Mat_SeqAIJ       *aseq;
1067 
1068   PetscFunctionBegin;
1069   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, coo_i, coo_j));
1070   aseq = static_cast<Mat_SeqAIJ *>(mat->data);
1071   akok = static_cast<Mat_SeqAIJKokkos *>(mat->spptr);
1072   delete akok;
1073   mat->spptr = akok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, aseq->nz, aseq->i, aseq->j, aseq->a, mat->nonzerostate + 1, PETSC_FALSE);
1074   PetscCall(MatZeroEntries_SeqAIJKokkos(mat));
1075   akok->SetUpCOO(aseq);
1076   PetscFunctionReturn(0);
1077 }
1078 
1079 static PetscErrorCode MatSetValuesCOO_SeqAIJKokkos(Mat A, const PetscScalar v[], InsertMode imode) {
1080   Mat_SeqAIJ                 *aseq = static_cast<Mat_SeqAIJ *>(A->data);
1081   Mat_SeqAIJKokkos           *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1082   PetscCount                  Annz = aseq->nz;
1083   const PetscCountKokkosView &jmap = akok->jmap_d;
1084   const PetscCountKokkosView &perm = akok->perm_d;
1085   MatScalarKokkosView         Aa;
1086   ConstMatScalarKokkosView    kv;
1087   PetscMemType                memtype;
1088 
1089   PetscFunctionBegin;
1090   PetscCall(PetscGetMemType(v, &memtype));
1091   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
1092     kv = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatScalarKokkosViewHost(v, aseq->coo_n));
1093   } else {
1094     kv = ConstMatScalarKokkosView(v, aseq->coo_n); /* Directly use v[]'s memory */
1095   }
1096 
1097   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1098   else PetscCall(MatSeqAIJGetKokkosView(A, &Aa));                             /* read & write matrix values */
1099 
1100   Kokkos::parallel_for(
1101     Annz, KOKKOS_LAMBDA(const PetscCount i) {
1102       PetscScalar sum = 0.0;
1103       for (PetscCount k = jmap(i); k < jmap(i + 1); k++) sum += kv(perm(k));
1104       Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1105     });
1106 
1107   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
1108   else PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1109   PetscFunctionReturn(0);
1110 }
1111 
1112 PETSC_INTERN PetscErrorCode MatSeqAIJMoveDiagonalValuesFront_SeqAIJKokkos(Mat A, const PetscInt *diag) {
1113   Mat_SeqAIJKokkos          *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1114   MatScalarKokkosView        Aa;
1115   const MatRowMapKokkosView &Ai = akok->i_dual.view_device();
1116   PetscInt                   m  = A->rmap->n;
1117   ConstMatRowMapKokkosView   Adiag(diag, m); /* diag is a device pointer */
1118 
1119   PetscFunctionBegin;
1120   PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa));
1121   Kokkos::parallel_for(
1122     m, KOKKOS_LAMBDA(const PetscInt i) {
1123       PetscScalar tmp;
1124       if (Adiag(i) >= Ai(i) && Adiag(i) < Ai(i + 1)) { /* The diagonal element exists */
1125         tmp          = Aa(Ai(i));
1126         Aa(Ai(i))    = Aa(Adiag(i));
1127         Aa(Adiag(i)) = tmp;
1128       }
1129     });
1130   PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
1131   PetscFunctionReturn(0);
1132 }
1133 
1134 static PetscErrorCode MatLUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info) {
1135   PetscFunctionBegin;
1136   PetscCall(MatSeqAIJKokkosSyncHost(A));
1137   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
1138   B->offloadmask = PETSC_OFFLOAD_CPU;
1139   PetscFunctionReturn(0);
1140 }
1141 
1142 static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat A) {
1143   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
1144 
1145   PetscFunctionBegin;
1146   A->offloadmask = PETSC_OFFLOAD_KOKKOS; /* We do not really use this flag */
1147   A->boundtocpu  = PETSC_FALSE;
1148 
1149   A->ops->assemblyend               = MatAssemblyEnd_SeqAIJKokkos;
1150   A->ops->destroy                   = MatDestroy_SeqAIJKokkos;
1151   A->ops->duplicate                 = MatDuplicate_SeqAIJKokkos;
1152   A->ops->axpy                      = MatAXPY_SeqAIJKokkos;
1153   A->ops->scale                     = MatScale_SeqAIJKokkos;
1154   A->ops->zeroentries               = MatZeroEntries_SeqAIJKokkos;
1155   A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJKokkos;
1156   A->ops->mult                      = MatMult_SeqAIJKokkos;
1157   A->ops->multadd                   = MatMultAdd_SeqAIJKokkos;
1158   A->ops->multtranspose             = MatMultTranspose_SeqAIJKokkos;
1159   A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJKokkos;
1160   A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJKokkos;
1161   A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJKokkos;
1162   A->ops->productnumeric            = MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos;
1163   A->ops->transpose                 = MatTranspose_SeqAIJKokkos;
1164   A->ops->setoption                 = MatSetOption_SeqAIJKokkos;
1165   A->ops->getdiagonal               = MatGetDiagonal_SeqAIJKokkos;
1166   a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJKokkos;
1167   a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJKokkos;
1168   a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJKokkos;
1169   a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJKokkos;
1170   a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJKokkos;
1171   a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJKokkos;
1172   a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJKokkos;
1173 
1174   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJKokkos));
1175   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJKokkos));
1176   PetscFunctionReturn(0);
1177 }
1178 
1179 PETSC_INTERN PetscErrorCode MatSetSeqAIJKokkosWithCSRMatrix(Mat A, Mat_SeqAIJKokkos *akok) {
1180   Mat_SeqAIJ *aseq;
1181   PetscInt    i, m, n;
1182 
1183   PetscFunctionBegin;
1184   PetscCheck(!A->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A->spptr is supposed to be empty");
1185 
1186   m = akok->nrows();
1187   n = akok->ncols();
1188   PetscCall(MatSetSizes(A, m, n, m, n));
1189   PetscCall(MatSetType(A, MATSEQAIJKOKKOS));
1190 
1191   /* Set up data structures of A as a MATSEQAIJ */
1192   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(A, MAT_SKIP_ALLOCATION, NULL));
1193   aseq = (Mat_SeqAIJ *)(A)->data;
1194 
1195   akok->i_dual.sync_host(); /* We always need sync'ed i, j on host */
1196   akok->j_dual.sync_host();
1197 
1198   aseq->i            = akok->i_host_data();
1199   aseq->j            = akok->j_host_data();
1200   aseq->a            = akok->a_host_data();
1201   aseq->nonew        = -1; /*this indicates that inserting a new value in the matrix that generates a new nonzero is an error*/
1202   aseq->singlemalloc = PETSC_FALSE;
1203   aseq->free_a       = PETSC_FALSE;
1204   aseq->free_ij      = PETSC_FALSE;
1205   aseq->nz           = akok->nnz();
1206   aseq->maxnz        = aseq->nz;
1207 
1208   PetscCall(PetscMalloc1(m, &aseq->imax));
1209   PetscCall(PetscMalloc1(m, &aseq->ilen));
1210   for (i = 0; i < m; i++) aseq->ilen[i] = aseq->imax[i] = aseq->i[i + 1] - aseq->i[i];
1211 
1212   /* It is critical to set the nonzerostate, as we use it to check if sparsity pattern (hence data) has changed on host in MatAssemblyEnd */
1213   akok->nonzerostate = A->nonzerostate;
1214   A->spptr           = akok; /* Set A->spptr before MatAssembly so that A->spptr won't be allocated again there */
1215   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
1216   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
1217   PetscFunctionReturn(0);
1218 }
1219 
1220 /* Crete a SEQAIJKOKKOS matrix with a Mat_SeqAIJKokkos data structure
1221 
1222    Note we have names like MatSeqAIJSetPreallocationCSR, so I use capitalized CSR
1223  */
1224 PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithCSRMatrix(MPI_Comm comm, Mat_SeqAIJKokkos *akok, Mat *A) {
1225   PetscFunctionBegin;
1226   PetscCall(MatCreate(comm, A));
1227   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
1228   PetscFunctionReturn(0);
1229 }
1230 
1231 /* --------------------------------------------------------------------------------*/
1232 /*@C
1233    MatCreateSeqAIJKokkos - Creates a sparse matrix in `MATSEQAIJKOKKOS` (compressed row) format
1234    (the default parallel PETSc format). This matrix will ultimately be handled by
1235    Kokkos for calculations. For good matrix
1236    assembly performance the user should preallocate the matrix storage by setting
1237    the parameter nz (or the array nnz).  By setting these parameters accurately,
1238    performance during matrix assembly can be increased by more than a factor of 50.
1239 
1240    Collective
1241 
1242    Input Parameters:
1243 +  comm - MPI communicator, set to `PETSC_COMM_SELF`
1244 .  m - number of rows
1245 .  n - number of columns
1246 .  nz - number of nonzeros per row (same for all rows)
1247 -  nnz - array containing the number of nonzeros in the various rows
1248          (possibly different for each row) or NULL
1249 
1250    Output Parameter:
1251 .  A - the matrix
1252 
1253    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1254    MatXXXXSetPreallocation() paradgm instead of this routine directly.
1255    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
1256 
1257    Notes:
1258    If nnz is given then nz is ignored
1259 
1260    The AIJ format, also called
1261    compressed row storage, is fully compatible with standard Fortran 77
1262    storage.  That is, the stored row and column indices can begin at
1263    either one (as in Fortran) or zero.  See the users' manual for details.
1264 
1265    Specify the preallocated storage with either nz or nnz (not both).
1266    Set nz = `PETSC_DEFAULT` and nnz = NULL for PETSc to control dynamic memory
1267    allocation.  For large problems you MUST preallocate memory or you
1268    will get TERRIBLE performance, see the users' manual chapter on matrices.
1269 
1270    By default, this format uses inodes (identical nodes) when possible, to
1271    improve numerical efficiency of matrix-vector products and solves. We
1272    search for consecutive rows with the same nonzero structure, thereby
1273    reusing matrix information to achieve increased efficiency.
1274 
1275    Level: intermediate
1276 
1277 .seealso: `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`
1278 @*/
1279 PetscErrorCode MatCreateSeqAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A) {
1280   PetscFunctionBegin;
1281   PetscCall(PetscKokkosInitializeCheck());
1282   PetscCall(MatCreate(comm, A));
1283   PetscCall(MatSetSizes(*A, m, n, m, n));
1284   PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1285   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
1286   PetscFunctionReturn(0);
1287 }
1288 
1289 typedef Kokkos::TeamPolicy<>::member_type team_member;
1290 //
1291 // This factorization exploits block diagonal matrices with "Nf" (not used).
1292 // Use -pc_factor_mat_ordering_type rcm to order decouple blocks of size N/Nf for this optimization
1293 //
1294 static PetscErrorCode                     MatLUFactorNumeric_SeqAIJKOKKOSDEVICE(Mat B, Mat A, const MatFactorInfo *info) {
1295                       Mat_SeqAIJ       *b      = (Mat_SeqAIJ *)B->data;
1296                       Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *baijkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
1297                       IS                isrow = b->row, isicol = b->icol;
1298                       const PetscInt   *r_h, *ic_h;
1299                       const PetscInt n = A->rmap->n, *ai_d = aijkok->i_dual.view_device().data(), *aj_d = aijkok->j_dual.view_device().data(), *bi_d = baijkok->i_dual.view_device().data(), *bj_d = baijkok->j_dual.view_device().data(), *bdiag_d = baijkok->diag_d.data();
1300                       const PetscScalar *aa_d = aijkok->a_dual.view_device().data();
1301                       PetscScalar       *ba_d = baijkok->a_dual.view_device().data();
1302                       PetscBool          row_identity, col_identity;
1303                       PetscInt           nc, Nf = 1, nVec = 32; // should be a parameter, Nf is batch size - not used
1304 
1305                       PetscFunctionBegin;
1306                       PetscCheck(A->rmap->n == n, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "square matrices only supported %" PetscInt_FMT " %" PetscInt_FMT, A->rmap->n, n);
1307                       PetscCall(MatIsStructurallySymmetric(A, &row_identity));
1308                       PetscCheck(row_identity, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "structurally symmetric matrices only supported");
1309                       PetscCall(ISGetIndices(isrow, &r_h));
1310                       PetscCall(ISGetIndices(isicol, &ic_h));
1311                       PetscCall(ISGetSize(isicol, &nc));
1312                       PetscCall(PetscLogGpuTimeBegin());
1313                       PetscCall(MatSeqAIJKokkosSyncDevice(A));
1314                       {
1315 #define KOKKOS_SHARED_LEVEL 1
1316     using scr_mem_t    = Kokkos::DefaultExecutionSpace::scratch_memory_space;
1317     using sizet_scr_t  = Kokkos::View<size_t, scr_mem_t>;
1318     using scalar_scr_t = Kokkos::View<PetscScalar, scr_mem_t>;
1319     const Kokkos::View<const PetscInt *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_r_k(r_h, n);
1320     Kokkos::View<PetscInt *, Kokkos::LayoutLeft>                                                                         d_r_k("r", n);
1321     const Kokkos::View<const PetscInt *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_ic_k(ic_h, nc);
1322     Kokkos::View<PetscInt *, Kokkos::LayoutLeft>                                                                         d_ic_k("ic", nc);
1323     size_t                                                                                                               flops_h = 0.0;
1324     Kokkos::View<size_t, Kokkos::HostSpace>                                                                              h_flops_k(&flops_h);
1325     Kokkos::View<size_t>                                                                                                 d_flops_k("flops");
1326     const int                                                                                                            conc = Kokkos::DefaultExecutionSpace().concurrency(), team_size = conc > 1 ? 16 : 1; // 8*32 = 256
1327     const int                                                                                                            nloc = n / Nf, Ni = (conc > 8) ? 1 /* some intelegent number of SMs -- but need league_barrier */ : 1;
1328     Kokkos::deep_copy(d_flops_k, h_flops_k);
1329     Kokkos::deep_copy(d_r_k, h_r_k);
1330     Kokkos::deep_copy(d_ic_k, h_ic_k);
1331     // Fill A --> fact
1332     Kokkos::parallel_for(
1333       Kokkos::TeamPolicy<>(Nf * Ni, team_size, nVec), KOKKOS_LAMBDA(const team_member team) {
1334         const PetscInt  field = team.league_rank() / Ni, field_block = team.league_rank() % Ni; // use grid.x/y in CUDA
1335         const PetscInt  nloc_i = (nloc / Ni + !!(nloc % Ni)), start_i = field * nloc + field_block * nloc_i, end_i = (start_i + nloc_i) > (field + 1) * nloc ? (field + 1) * nloc : (start_i + nloc_i);
1336         const PetscInt *ic = d_ic_k.data(), *r = d_r_k.data();
1337         // zero rows of B
1338         Kokkos::parallel_for(Kokkos::TeamVectorRange(team, start_i, end_i), [=](const int &rowb) {
1339           PetscInt     nzbL = bi_d[rowb + 1] - bi_d[rowb], nzbU = bdiag_d[rowb] - bdiag_d[rowb + 1]; // with diag
1340           PetscScalar *baL = ba_d + bi_d[rowb];
1341           PetscScalar *baU = ba_d + bdiag_d[rowb + 1] + 1;
1342           /* zero (unfactored row) */
1343           for (int j = 0; j < nzbL; j++) baL[j] = 0;
1344           for (int j = 0; j < nzbU; j++) baU[j] = 0;
1345         });
1346         // copy A into B
1347         Kokkos::parallel_for(Kokkos::TeamVectorRange(team, start_i, end_i), [=](const int &rowb) {
1348           PetscInt           rowa = r[rowb], nza = ai_d[rowa + 1] - ai_d[rowa];
1349           const PetscScalar *av    = aa_d + ai_d[rowa];
1350           const PetscInt    *ajtmp = aj_d + ai_d[rowa];
1351           /* load in initial (unfactored row) */
1352           for (int j = 0; j < nza; j++) {
1353             PetscInt    colb = ic[ajtmp[j]];
1354             PetscScalar vala = av[j];
1355             if (colb == rowb) {
1356               *(ba_d + bdiag_d[rowb]) = vala;
1357             } else {
1358               const PetscInt *pbj = bj_d + ((colb > rowb) ? bdiag_d[rowb + 1] + 1 : bi_d[rowb]);
1359               PetscScalar    *pba = ba_d + ((colb > rowb) ? bdiag_d[rowb + 1] + 1 : bi_d[rowb]);
1360               PetscInt        nz = (colb > rowb) ? bdiag_d[rowb] - (bdiag_d[rowb + 1] + 1) : bi_d[rowb + 1] - bi_d[rowb], set = 0;
1361               for (int j = 0; j < nz; j++) {
1362                 if (pbj[j] == colb) {
1363                   pba[j] = vala;
1364                   set++;
1365                   break;
1366                 }
1367               }
1368 #if !defined(PETSC_HAVE_SYCL)
1369               if (set != 1) printf("\t\t\t ERROR DID NOT SET ?????\n");
1370 #endif
1371             }
1372           }
1373         });
1374       });
1375     Kokkos::fence();
1376 
1377     Kokkos::parallel_for(
1378       Kokkos::TeamPolicy<>(Nf * Ni, team_size, nVec).set_scratch_size(KOKKOS_SHARED_LEVEL, Kokkos::PerThread(sizet_scr_t::shmem_size() + scalar_scr_t::shmem_size()), Kokkos::PerTeam(sizet_scr_t::shmem_size())), KOKKOS_LAMBDA(const team_member team) {
1379         sizet_scr_t    colkIdx(team.thread_scratch(KOKKOS_SHARED_LEVEL));
1380         scalar_scr_t   L_ki(team.thread_scratch(KOKKOS_SHARED_LEVEL));
1381         sizet_scr_t    flops(team.team_scratch(KOKKOS_SHARED_LEVEL));
1382         const PetscInt field = team.league_rank() / Ni, field_block_idx = team.league_rank() % Ni; // use grid.x/y in CUDA
1383         const PetscInt start = field * nloc, end = start + nloc;
1384         Kokkos::single(Kokkos::PerTeam(team), [=]() { flops() = 0; });
1385         // A22 panel update for each row A(1,:) and col A(:,1)
1386         for (int ii = start; ii < end - 1; ii++) {
1387           const PetscInt    *bjUi = bj_d + bdiag_d[ii + 1] + 1, nzUi = bdiag_d[ii] - (bdiag_d[ii + 1] + 1); // vector, and vector size, of column indices of U(i,(i+1):end)
1388           const PetscScalar *baUi    = ba_d + bdiag_d[ii + 1] + 1;                                          // vector of data  U(i,i+1:end)
1389           const PetscInt     nUi_its = nzUi / Ni + !!(nzUi % Ni);
1390           const PetscScalar  Bii     = *(ba_d + bdiag_d[ii]); // diagonal in its special place
1391           Kokkos::parallel_for(Kokkos::TeamThreadRange(team, nUi_its), [=](const int j) {
1392             PetscInt kIdx = j * Ni + field_block_idx;
1393             if (kIdx >= nzUi) /* void */
1394               ;
1395             else {
1396               const PetscInt  myk = bjUi[kIdx];                // assume symmetric structure, need a transposed meta-data here in general
1397               const PetscInt *pjL = bj_d + bi_d[myk];          // look for L(myk,ii) in start of row
1398               const PetscInt  nzL = bi_d[myk + 1] - bi_d[myk]; // size of L_k(:)
1399               size_t          st_idx;
1400               // find and do L(k,i) = A(:k,i) / A(i,i)
1401               Kokkos::single(Kokkos::PerThread(team), [&]() { colkIdx() = PETSC_MAX_INT; });
1402               // get column, there has got to be a better way
1403               Kokkos::parallel_reduce(
1404                 Kokkos::ThreadVectorRange(team, nzL),
1405                 [&](const int &j, size_t &idx) {
1406                   if (pjL[j] == ii) {
1407                     PetscScalar *pLki = ba_d + bi_d[myk] + j;
1408                     idx               = j;           // output
1409                     *pLki             = *pLki / Bii; // column scaling:  L(k,i) = A(:k,i) / A(i,i)
1410                   }
1411                 },
1412                 st_idx);
1413               Kokkos::single(Kokkos::PerThread(team), [=]() {
1414                 colkIdx() = st_idx;
1415                 L_ki()    = *(ba_d + bi_d[myk] + st_idx);
1416               });
1417 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
1418               if (colkIdx() == PETSC_MAX_INT) printf("\t\t\t\t\t\t\tERROR: failed to find L_ki(%d,%d)\n", (int)myk, ii); // uses a register
1419 #endif
1420               // active row k, do  A_kj -= Lki * U_ij; j \in U(i,:) j != i
1421               // U(i+1,:end)
1422               Kokkos::parallel_for(Kokkos::ThreadVectorRange(team, nzUi), [=](const int &uiIdx) { // index into i (U)
1423                 PetscScalar Uij = baUi[uiIdx];
1424                 PetscInt    col = bjUi[uiIdx];
1425                 if (col == myk) {
1426                   // A_kk = A_kk - L_ki * U_ij(k)
1427                   PetscScalar *Akkv = (ba_d + bdiag_d[myk]); // diagonal in its special place
1428                   *Akkv             = *Akkv - L_ki() * Uij;  // UiK
1429                 } else {
1430                   PetscScalar    *start, *end, *pAkjv = NULL;
1431                   PetscInt        high, low;
1432                   const PetscInt *startj;
1433                   if (col < myk) { // L
1434                     PetscScalar *pLki = ba_d + bi_d[myk] + colkIdx();
1435                     PetscInt     idx  = (pLki + 1) - (ba_d + bi_d[myk]); // index into row
1436                     start             = pLki + 1;                        // start at pLki+1, A22(myk,1)
1437                     startj            = bj_d + bi_d[myk] + idx;
1438                     end               = ba_d + bi_d[myk + 1];
1439                   } else {
1440                     PetscInt idx = bdiag_d[myk + 1] + 1;
1441                     start        = ba_d + idx;
1442                     startj       = bj_d + idx;
1443                     end          = ba_d + bdiag_d[myk];
1444                   }
1445                   // search for 'col', use bisection search - TODO
1446                   low  = 0;
1447                   high = (PetscInt)(end - start);
1448                   while (high - low > 5) {
1449                     int t = (low + high) / 2;
1450                     if (startj[t] > col) high = t;
1451                     else low = t;
1452                   }
1453                   for (pAkjv = start + low; pAkjv < start + high; pAkjv++) {
1454                     if (startj[pAkjv - start] == col) break;
1455                   }
1456 #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
1457                   if (pAkjv == start + high) printf("\t\t\t\t\t\t\t\t\t\t\tERROR: *** failed to find Akj(%d,%d)\n", (int)myk, (int)col); // uses a register
1458 #endif
1459                   *pAkjv = *pAkjv - L_ki() * Uij; // A_kj = A_kj - L_ki * U_ij
1460                 }
1461               });
1462             }
1463           });
1464           team.team_barrier(); // this needs to be a league barrier to use more that one SM per block
1465           if (field_block_idx == 0) Kokkos::single(Kokkos::PerTeam(team), [&]() { Kokkos::atomic_add(flops.data(), (size_t)(2 * (nzUi * nzUi) + 2)); });
1466         } /* endof for (i=0; i<n; i++) { */
1467         Kokkos::single(Kokkos::PerTeam(team), [=]() {
1468           Kokkos::atomic_add(&d_flops_k(), flops());
1469           flops() = 0;
1470         });
1471       });
1472     Kokkos::fence();
1473     Kokkos::deep_copy(h_flops_k, d_flops_k);
1474     PetscCall(PetscLogGpuFlops((PetscLogDouble)h_flops_k()));
1475     Kokkos::parallel_for(
1476       Kokkos::TeamPolicy<>(Nf * Ni, 1, 256), KOKKOS_LAMBDA(const team_member team) {
1477         const PetscInt lg_rank = team.league_rank(), field = lg_rank / Ni;                            //, field_offset = lg_rank%Ni;
1478         const PetscInt start = field * nloc, end = start + nloc, n_its = (nloc / Ni + !!(nloc % Ni)); // 1/Ni iters
1479         /* Invert diagonal for simpler triangular solves */
1480         Kokkos::parallel_for(Kokkos::TeamVectorRange(team, n_its), [=](int outer_index) {
1481           int i = start + outer_index * Ni + lg_rank % Ni;
1482           if (i < end) {
1483             PetscScalar *pv = ba_d + bdiag_d[i];
1484             *pv             = 1.0 / (*pv);
1485           }
1486         });
1487       });
1488   }
1489   PetscCall(PetscLogGpuTimeEnd());
1490   PetscCall(ISRestoreIndices(isicol, &ic_h));
1491   PetscCall(ISRestoreIndices(isrow, &r_h));
1492 
1493   PetscCall(ISIdentity(isrow, &row_identity));
1494   PetscCall(ISIdentity(isicol, &col_identity));
1495   if (b->inode.size) {
1496     B->ops->solve = MatSolve_SeqAIJ_Inode;
1497   } else if (row_identity && col_identity) {
1498     B->ops->solve = MatSolve_SeqAIJ_NaturalOrdering;
1499   } else {
1500     B->ops->solve = MatSolve_SeqAIJ; // at least this needs to be in Kokkos
1501   }
1502   B->offloadmask = PETSC_OFFLOAD_GPU;
1503   PetscCall(MatSeqAIJKokkosSyncHost(B));          // solve on CPU
1504   B->ops->solveadd          = MatSolveAdd_SeqAIJ; // and this
1505   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJ;
1506   B->ops->solvetransposeadd = MatSolveTransposeAdd_SeqAIJ;
1507   B->ops->matsolve          = MatMatSolve_SeqAIJ;
1508   B->assembled              = PETSC_TRUE;
1509   B->preallocated           = PETSC_TRUE;
1510 
1511   PetscFunctionReturn(0);
1512 }
1513 
1514 static PetscErrorCode MatLUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info) {
1515   PetscFunctionBegin;
1516   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
1517   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
1518   PetscFunctionReturn(0);
1519 }
1520 
1521 static PetscErrorCode MatSeqAIJKokkosSymbolicSolveCheck(Mat A) {
1522   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1523 
1524   PetscFunctionBegin;
1525   if (!factors->sptrsv_symbolic_completed) {
1526     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d);
1527     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d);
1528     factors->sptrsv_symbolic_completed = PETSC_TRUE;
1529   }
1530   PetscFunctionReturn(0);
1531 }
1532 
1533 /* Check if we need to update factors etc for transpose solve */
1534 static PetscErrorCode MatSeqAIJKokkosTransposeSolveCheck(Mat A) {
1535   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1536   MatColIdxType               n       = A->rmap->n;
1537 
1538   PetscFunctionBegin;
1539   if (!factors->transpose_updated) { /* TODO: KK needs to provide functions to do numeric transpose only */
1540     /* Update L^T and do sptrsv symbolic */
1541     factors->iLt_d = MatRowMapKokkosView("factors->iLt_d", n + 1);
1542     Kokkos::deep_copy(factors->iLt_d, 0); /* KK requires 0 */
1543     factors->jLt_d = MatColIdxKokkosView("factors->jLt_d", factors->jL_d.extent(0));
1544     factors->aLt_d = MatScalarKokkosView("factors->aLt_d", factors->aL_d.extent(0));
1545 
1546     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iL_d, factors->jL_d, factors->aL_d,
1547                                                                                                                                                                                                               factors->iLt_d, factors->jLt_d, factors->aLt_d);
1548 
1549     /* TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices.
1550       We have to sort the indices, until KK provides finer control options.
1551     */
1552     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iLt_d, factors->jLt_d, factors->aLt_d);
1553 
1554     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d);
1555 
1556     /* Update U^T and do sptrsv symbolic */
1557     factors->iUt_d = MatRowMapKokkosView("factors->iUt_d", n + 1);
1558     Kokkos::deep_copy(factors->iUt_d, 0); /* KK requires 0 */
1559     factors->jUt_d = MatColIdxKokkosView("factors->jUt_d", factors->jU_d.extent(0));
1560     factors->aUt_d = MatScalarKokkosView("factors->aUt_d", factors->aU_d.extent(0));
1561 
1562     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iU_d, factors->jU_d, factors->aU_d,
1563                                                                                                                                                                                                               factors->iUt_d, factors->jUt_d, factors->aUt_d);
1564 
1565     /* Sort indices. See comments above */
1566     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iUt_d, factors->jUt_d, factors->aUt_d);
1567 
1568     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d);
1569     factors->transpose_updated = PETSC_TRUE;
1570   }
1571   PetscFunctionReturn(0);
1572 }
1573 
1574 /* Solve Ax = b, with A = LU */
1575 static PetscErrorCode MatSolve_SeqAIJKokkos(Mat A, Vec b, Vec x) {
1576   ConstPetscScalarKokkosView  bv;
1577   PetscScalarKokkosView       xv;
1578   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1579 
1580   PetscFunctionBegin;
1581   PetscCall(PetscLogGpuTimeBegin());
1582   PetscCall(MatSeqAIJKokkosSymbolicSolveCheck(A));
1583   PetscCall(VecGetKokkosView(b, &bv));
1584   PetscCall(VecGetKokkosViewWrite(x, &xv));
1585   /* Solve L tmpv = b */
1586   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d, bv, factors->workVector));
1587   /* Solve Ux = tmpv */
1588   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, factors->workVector, xv));
1589   PetscCall(VecRestoreKokkosView(b, &bv));
1590   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
1591   PetscCall(PetscLogGpuTimeEnd());
1592   PetscFunctionReturn(0);
1593 }
1594 
1595 /* Solve A^T x = b, where A^T = U^T L^T */
1596 static PetscErrorCode MatSolveTranspose_SeqAIJKokkos(Mat A, Vec b, Vec x) {
1597   ConstPetscScalarKokkosView  bv;
1598   PetscScalarKokkosView       xv;
1599   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1600 
1601   PetscFunctionBegin;
1602   PetscCall(PetscLogGpuTimeBegin());
1603   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A));
1604   PetscCall(VecGetKokkosView(b, &bv));
1605   PetscCall(VecGetKokkosViewWrite(x, &xv));
1606   /* Solve U^T tmpv = b */
1607   KokkosSparse::Experimental::sptrsv_solve(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, bv, factors->workVector);
1608 
1609   /* Solve L^T x = tmpv */
1610   KokkosSparse::Experimental::sptrsv_solve(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d, factors->workVector, xv);
1611   PetscCall(VecRestoreKokkosView(b, &bv));
1612   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
1613   PetscCall(PetscLogGpuTimeEnd());
1614   PetscFunctionReturn(0);
1615 }
1616 
1617 static PetscErrorCode MatILUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info) {
1618   Mat_SeqAIJKokkos           *aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
1619   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
1620   PetscInt                    fill_lev = info->levels;
1621 
1622   PetscFunctionBegin;
1623   PetscCall(PetscLogGpuTimeBegin());
1624   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1625 
1626   auto a_d = aijkok->a_dual.view_device();
1627   auto i_d = aijkok->i_dual.view_device();
1628   auto j_d = aijkok->j_dual.view_device();
1629 
1630   KokkosSparse::Experimental::spiluk_numeric(&factors->kh, fill_lev, i_d, j_d, a_d, factors->iL_d, factors->jL_d, factors->aL_d, factors->iU_d, factors->jU_d, factors->aU_d);
1631 
1632   B->assembled              = PETSC_TRUE;
1633   B->preallocated           = PETSC_TRUE;
1634   B->ops->solve             = MatSolve_SeqAIJKokkos;
1635   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJKokkos;
1636   B->ops->matsolve          = NULL;
1637   B->ops->matsolvetranspose = NULL;
1638   B->offloadmask            = PETSC_OFFLOAD_GPU;
1639 
1640   /* Once the factors' value changed, we need to update their transpose and sptrsv handle */
1641   factors->transpose_updated         = PETSC_FALSE;
1642   factors->sptrsv_symbolic_completed = PETSC_FALSE;
1643   /* TODO: log flops, but how to know that? */
1644   PetscCall(PetscLogGpuTimeEnd());
1645   PetscFunctionReturn(0);
1646 }
1647 
1648 static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info) {
1649   Mat_SeqAIJKokkos           *aijkok;
1650   Mat_SeqAIJ                 *b;
1651   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
1652   PetscInt                    fill_lev = info->levels;
1653   PetscInt                    nnzA     = ((Mat_SeqAIJ *)A->data)->nz, nnzL, nnzU;
1654   PetscInt                    n        = A->rmap->n;
1655 
1656   PetscFunctionBegin;
1657   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1658   /* Rebuild factors */
1659   if (factors) {
1660     factors->Destroy();
1661   } /* Destroy the old if it exists */
1662   else {
1663     B->spptr = factors = new Mat_SeqAIJKokkosTriFactors(n);
1664   }
1665 
1666   /* Create a spiluk handle and then do symbolic factorization */
1667   nnzL = nnzU = PetscRealIntMultTruncate(info->fill, nnzA);
1668   factors->kh.create_spiluk_handle(KokkosSparse::Experimental::SPILUKAlgorithm::SEQLVLSCHD_TP1, n, nnzL, nnzU);
1669 
1670   auto spiluk_handle = factors->kh.get_spiluk_handle();
1671 
1672   Kokkos::realloc(factors->iL_d, n + 1); /* Free old arrays and realloc */
1673   Kokkos::realloc(factors->jL_d, spiluk_handle->get_nnzL());
1674   Kokkos::realloc(factors->iU_d, n + 1);
1675   Kokkos::realloc(factors->jU_d, spiluk_handle->get_nnzU());
1676 
1677   aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
1678   auto i_d = aijkok->i_dual.view_device();
1679   auto j_d = aijkok->j_dual.view_device();
1680   KokkosSparse::Experimental::spiluk_symbolic(&factors->kh, fill_lev, i_d, j_d, factors->iL_d, factors->jL_d, factors->iU_d, factors->jU_d);
1681   /* TODO: if spiluk_symbolic is asynchronous, do we need to sync before calling get_nnzL()? */
1682 
1683   Kokkos::resize(factors->jL_d, spiluk_handle->get_nnzL()); /* Shrink or expand, and retain old value */
1684   Kokkos::resize(factors->jU_d, spiluk_handle->get_nnzU());
1685   Kokkos::realloc(factors->aL_d, spiluk_handle->get_nnzL()); /* No need to retain old value */
1686   Kokkos::realloc(factors->aU_d, spiluk_handle->get_nnzU());
1687 
1688   /* TODO: add options to select sptrsv algorithms */
1689   /* Create sptrsv handles for L, U and their transpose */
1690 #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
1691   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SPTRSV_CUSPARSE;
1692 #else
1693   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SEQLVLSCHD_TP1;
1694 #endif
1695 
1696   factors->khL.create_sptrsv_handle(sptrsv_alg, n, true /* L is lower tri */);
1697   factors->khU.create_sptrsv_handle(sptrsv_alg, n, false /* U is not lower tri */);
1698   factors->khLt.create_sptrsv_handle(sptrsv_alg, n, false /* L^T is not lower tri */);
1699   factors->khUt.create_sptrsv_handle(sptrsv_alg, n, true /* U^T is lower tri */);
1700 
1701   /* Fill fields of the factor matrix B */
1702   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(B, MAT_SKIP_ALLOCATION, NULL));
1703   b     = (Mat_SeqAIJ *)B->data;
1704   b->nz = b->maxnz          = spiluk_handle->get_nnzL() + spiluk_handle->get_nnzU();
1705   B->info.fill_ratio_given  = info->fill;
1706   B->info.fill_ratio_needed = ((PetscReal)b->nz) / ((PetscReal)nnzA);
1707 
1708   B->offloadmask          = PETSC_OFFLOAD_GPU;
1709   B->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJKokkos;
1710   PetscFunctionReturn(0);
1711 }
1712 
1713 static PetscErrorCode MatLUFactorSymbolic_SeqAIJKOKKOSDEVICE(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info) {
1714   Mat_SeqAIJ    *b     = (Mat_SeqAIJ *)B->data;
1715   const PetscInt nrows = A->rmap->n;
1716 
1717   PetscFunctionBegin;
1718   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
1719   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKOKKOSDEVICE;
1720   // move B data into Kokkos
1721   PetscCall(MatSeqAIJKokkosSyncDevice(B)); // create aijkok
1722   PetscCall(MatSeqAIJKokkosSyncDevice(A)); // create aijkok
1723   {
1724     Mat_SeqAIJKokkos *baijkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
1725     if (!baijkok->diag_d.extent(0)) {
1726       const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_diag(b->diag, nrows + 1);
1727       baijkok->diag_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_diag));
1728       Kokkos::deep_copy(baijkok->diag_d, h_diag);
1729     }
1730   }
1731   PetscFunctionReturn(0);
1732 }
1733 
1734 static PetscErrorCode MatFactorGetSolverType_SeqAIJKokkos(Mat A, MatSolverType *type) {
1735   PetscFunctionBegin;
1736   *type = MATSOLVERKOKKOS;
1737   PetscFunctionReturn(0);
1738 }
1739 
1740 static PetscErrorCode MatFactorGetSolverType_seqaij_kokkos_device(Mat A, MatSolverType *type) {
1741   PetscFunctionBegin;
1742   *type = MATSOLVERKOKKOSDEVICE;
1743   PetscFunctionReturn(0);
1744 }
1745 
1746 /*MC
1747   MATSOLVERKOKKOS = "Kokkos" - A matrix solver type providing triangular solvers for sequential matrices
1748   on a single GPU of type, `MATSEQAIJKOKKOS`, `MATAIJKOKKOS`.
1749 
1750   Level: beginner
1751 
1752 .seealso: `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJKokkos()`, `MATAIJKOKKOS`, `MatKokkosSetFormat()`, `MatKokkosStorageFormat`, `MatKokkosFormatOperation`
1753 M*/
1754 PETSC_EXTERN PetscErrorCode MatGetFactor_SeqAIJKokkos_Kokkos(Mat A, MatFactorType ftype, Mat *B) /* MatGetFactor_<MatType>_<MatSolverType> */
1755 {
1756   PetscInt n = A->rmap->n;
1757 
1758   PetscFunctionBegin;
1759   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B));
1760   PetscCall(MatSetSizes(*B, n, n, n, n));
1761   (*B)->factortype = ftype;
1762   PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
1763   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
1764 
1765   if (ftype == MAT_FACTOR_LU) {
1766     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
1767     (*B)->canuseordering        = PETSC_TRUE;
1768     (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJKokkos;
1769   } else if (ftype == MAT_FACTOR_ILU) {
1770     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
1771     (*B)->canuseordering         = PETSC_FALSE;
1772     (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJKokkos;
1773   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "MatFactorType %s is not supported by MatType SeqAIJKokkos", MatFactorTypes[ftype]);
1774 
1775   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
1776   PetscCall(PetscObjectComposeFunction((PetscObject)(*B), "MatFactorGetSolverType_C", MatFactorGetSolverType_SeqAIJKokkos));
1777   PetscFunctionReturn(0);
1778 }
1779 
1780 PETSC_EXTERN PetscErrorCode MatGetFactor_seqaijkokkos_kokkos_device(Mat A, MatFactorType ftype, Mat *B) {
1781   PetscInt n = A->rmap->n;
1782 
1783   PetscFunctionBegin;
1784   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B));
1785   PetscCall(MatSetSizes(*B, n, n, n, n));
1786   (*B)->factortype     = ftype;
1787   (*B)->canuseordering = PETSC_TRUE;
1788   PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
1789   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
1790 
1791   if (ftype == MAT_FACTOR_LU) {
1792     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
1793     (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJKOKKOSDEVICE;
1794   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Factor type not supported for KOKKOS Matrix Types");
1795 
1796   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
1797   PetscCall(PetscObjectComposeFunction((PetscObject)(*B), "MatFactorGetSolverType_C", MatFactorGetSolverType_seqaij_kokkos_device));
1798   PetscFunctionReturn(0);
1799 }
1800 
1801 PETSC_EXTERN PetscErrorCode MatSolverTypeRegister_KOKKOS(void) {
1802   PetscFunctionBegin;
1803   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_SeqAIJKokkos_Kokkos));
1804   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ILU, MatGetFactor_SeqAIJKokkos_Kokkos));
1805   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOSDEVICE, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_seqaijkokkos_kokkos_device));
1806   PetscFunctionReturn(0);
1807 }
1808 
1809 /* Utility to print out a KokkosCsrMatrix for debugging */
1810 PETSC_INTERN PetscErrorCode PrintCsrMatrix(const KokkosCsrMatrix &csrmat) {
1811   const auto        &iv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.row_map);
1812   const auto        &jv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.entries);
1813   const auto        &av = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.values);
1814   const PetscInt    *i  = iv.data();
1815   const PetscInt    *j  = jv.data();
1816   const PetscScalar *a  = av.data();
1817   PetscInt           m = csrmat.numRows(), n = csrmat.numCols(), nnz = csrmat.nnz();
1818 
1819   PetscFunctionBegin;
1820   PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT " x %" PetscInt_FMT " SeqAIJKokkos, with %" PetscInt_FMT " nonzeros\n", m, n, nnz));
1821   for (PetscInt k = 0; k < m; k++) {
1822     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT ": ", k));
1823     for (PetscInt p = i[k]; p < i[k + 1]; p++) PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT "(%.1f), ", j[p], (double)PetscRealPart(a[p])));
1824     PetscCall(PetscPrintf(PETSC_COMM_SELF, "\n"));
1825   }
1826   PetscFunctionReturn(0);
1827 }
1828