xref: /petsc/src/mat/impls/aij/seq/kokkos/aijkok.kokkos.cxx (revision 0e3ece09e3140f5ddabf8db2b6f5fd48b6ec6274)
111d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
2076ba34aSJunchao Zhang #include <petscpkg_version.h>
3152b3e56SJunchao Zhang #include <petsc/private/petscimpl.h>
442550becSJunchao Zhang #include <petsc/private/sfimpl.h>
58c3ff71bSJunchao Zhang #include <petscsystypes.h>
68c3ff71bSJunchao Zhang #include <petscerror.h>
78c3ff71bSJunchao Zhang 
88c3ff71bSJunchao Zhang #include <Kokkos_Core.hpp>
9f0cf5187SStefano Zampini #include <KokkosBlas.hpp>
108c3ff71bSJunchao Zhang #include <KokkosSparse_CrsMatrix.hpp>
118c3ff71bSJunchao Zhang #include <KokkosSparse_spmv.hpp>
1286a27549SJunchao Zhang #include <KokkosSparse_spiluk.hpp>
1386a27549SJunchao Zhang #include <KokkosSparse_sptrsv.hpp>
14076ba34aSJunchao Zhang #include <KokkosSparse_spgemm.hpp>
15076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
1686a27549SJunchao Zhang 
1742550becSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
188c3ff71bSJunchao Zhang 
19*0e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 7, 0)
20f98996d3SJunchao Zhang   #include <KokkosSparse_Utils.hpp>
21f98996d3SJunchao Zhang using KokkosSparse::sort_crs_matrix;
229371c9d4SSatish Balay using KokkosSparse::Impl::transpose_matrix;
23f98996d3SJunchao Zhang #else
24f98996d3SJunchao Zhang   #include <KokkosKernels_Sorting.hpp>
25f98996d3SJunchao Zhang using KokkosKernels::sort_crs_matrix;
269371c9d4SSatish Balay using KokkosKernels::Impl::transpose_matrix;
27f98996d3SJunchao Zhang #endif
28f98996d3SJunchao Zhang 
298c3ff71bSJunchao Zhang static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat); /* Forward declaration */
308c3ff71bSJunchao Zhang 
31076ba34aSJunchao Zhang /* MatAssemblyEnd_SeqAIJKokkos() happens when we finalized nonzeros of the matrix, either after
32076ba34aSJunchao Zhang    we assembled the matrix on host, or after we directly produced the matrix data on device (ex., through MatMatMult).
33076ba34aSJunchao Zhang    In the latter case, it is important to set a_dual's sync state correctly.
34076ba34aSJunchao Zhang  */
35d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_SeqAIJKokkos(Mat A, MatAssemblyType mode)
36d71ae5a4SJacob Faibussowitsch {
37076ba34aSJunchao Zhang   Mat_SeqAIJ       *aijseq;
38076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
398c3ff71bSJunchao Zhang 
408c3ff71bSJunchao Zhang   PetscFunctionBegin;
413ba16761SJacob Faibussowitsch   if (mode == MAT_FLUSH_ASSEMBLY) PetscFunctionReturn(PETSC_SUCCESS);
429566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
43076ba34aSJunchao Zhang 
44076ba34aSJunchao Zhang   aijseq = static_cast<Mat_SeqAIJ *>(A->data);
45076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
46076ba34aSJunchao Zhang 
47076ba34aSJunchao Zhang   /* If aijkok does not exist, we just copy i, j to device.
48076ba34aSJunchao Zhang      If aijkok already exists, but the device's nonzero pattern does not match with the host's, we assume the latest data is on host.
49076ba34aSJunchao Zhang      In both cases, we build a new aijkok structure.
50076ba34aSJunchao Zhang   */
51076ba34aSJunchao Zhang   if (!aijkok || aijkok->nonzerostate != A->nonzerostate) { /* aijkok might not exist yet or nonzero pattern has changed */
52076ba34aSJunchao Zhang     delete aijkok;
53076ba34aSJunchao Zhang     aijkok   = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aijseq->nz, aijseq->i, aijseq->j, aijseq->a, A->nonzerostate, PETSC_FALSE /*don't copy mat values to device*/);
54076ba34aSJunchao Zhang     A->spptr = aijkok;
55076ba34aSJunchao Zhang   }
56076ba34aSJunchao Zhang 
575519a089SJose E. Roman   if (aijkok->device_mat_d.data()) {
58a587d139SMark     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
59a587d139SMark   }
603ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
618c3ff71bSJunchao Zhang }
628c3ff71bSJunchao Zhang 
6386a27549SJunchao Zhang /* Sync CSR data to device if not yet */
64d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosSyncDevice(Mat A)
65d71ae5a4SJacob Faibussowitsch {
668c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
678c3ff71bSJunchao Zhang 
688c3ff71bSJunchao Zhang   PetscFunctionBegin;
695f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Cann't sync factorized matrix from host to device");
705f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
71076ba34aSJunchao Zhang   if (aijkok->a_dual.need_sync_device()) {
72076ba34aSJunchao Zhang     aijkok->a_dual.sync_device();
73580c7c76SPierre Jolivet     aijkok->transpose_updated = PETSC_FALSE; /* values of the transpose is out-of-date */
7486a27549SJunchao Zhang     aijkok->hermitian_updated = PETSC_FALSE;
758c3ff71bSJunchao Zhang   }
763ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
778c3ff71bSJunchao Zhang }
788c3ff71bSJunchao Zhang 
79076ba34aSJunchao Zhang /* Mark the CSR data on device as modified */
80d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosModifyDevice(Mat A)
81d71ae5a4SJacob Faibussowitsch {
8286a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
8386a27549SJunchao Zhang 
8486a27549SJunchao Zhang   PetscFunctionBegin;
855f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Not supported for factorized matries");
8686a27549SJunchao Zhang   aijkok->a_dual.clear_sync_state();
8786a27549SJunchao Zhang   aijkok->a_dual.modify_device();
8886a27549SJunchao Zhang   aijkok->transpose_updated = PETSC_FALSE;
8986a27549SJunchao Zhang   aijkok->hermitian_updated = PETSC_FALSE;
909566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJInvalidateDiagonal(A));
919566063dSJacob Faibussowitsch   PetscCall(PetscObjectStateIncrease((PetscObject)A));
923ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
9386a27549SJunchao Zhang }
9486a27549SJunchao Zhang 
95d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSyncHost(Mat A)
96d71ae5a4SJacob Faibussowitsch {
97f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
98f0cf5187SStefano Zampini 
99f0cf5187SStefano Zampini   PetscFunctionBegin;
100f0cf5187SStefano Zampini   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10186a27549SJunchao Zhang   /* We do not expect one needs factors on host  */
1025f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Cann't sync factorized matrix from device to host");
1035f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing AIJKOK");
104076ba34aSJunchao Zhang   aijkok->a_dual.sync_host();
1053ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
106f0cf5187SStefano Zampini }
107f0cf5187SStefano Zampini 
108d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
109d71ae5a4SJacob Faibussowitsch {
110076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
111f0cf5187SStefano Zampini 
112f0cf5187SStefano Zampini   PetscFunctionBegin;
1135519a089SJose E. Roman   /* aijkok contains valid pointers only if the host's nonzerostate matches with the device's.
1145519a089SJose E. Roman     Calling MatSeqAIJSetPreallocation() or MatSetValues() on host, where aijseq->{i,j,a} might be
1155519a089SJose E. Roman     reallocated, will lead to stale {i,j,a}_dual in aijkok. In both operations, the hosts's nonzerostate
1165519a089SJose E. Roman     must have been updated. The stale aijkok will be rebuilt during MatAssemblyEnd.
1175519a089SJose E. Roman   */
1185519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
119076ba34aSJunchao Zhang     aijkok->a_dual.sync_host();
120076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
121076ba34aSJunchao Zhang   } else { /* Happens when calling MatSetValues on a newly created matrix */
122076ba34aSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
123076ba34aSJunchao Zhang   }
1243ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
125076ba34aSJunchao Zhang }
126076ba34aSJunchao Zhang 
127d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
128d71ae5a4SJacob Faibussowitsch {
129076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
130076ba34aSJunchao Zhang 
131076ba34aSJunchao Zhang   PetscFunctionBegin;
1325519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) aijkok->a_dual.modify_host();
1333ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
134076ba34aSJunchao Zhang }
135076ba34aSJunchao Zhang 
136d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
137d71ae5a4SJacob Faibussowitsch {
138076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
139076ba34aSJunchao Zhang 
140076ba34aSJunchao Zhang   PetscFunctionBegin;
1415519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
142076ba34aSJunchao Zhang     aijkok->a_dual.sync_host();
143076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1442328674fSJunchao Zhang   } else {
1452328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1462328674fSJunchao Zhang   }
1473ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
148076ba34aSJunchao Zhang }
149076ba34aSJunchao Zhang 
150d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
151d71ae5a4SJacob Faibussowitsch {
152076ba34aSJunchao Zhang   PetscFunctionBegin;
153076ba34aSJunchao Zhang   *array = NULL;
1543ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
155076ba34aSJunchao Zhang }
156076ba34aSJunchao Zhang 
157d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
158d71ae5a4SJacob Faibussowitsch {
159076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
160076ba34aSJunchao Zhang 
161076ba34aSJunchao Zhang   PetscFunctionBegin;
1625519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
163076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1642328674fSJunchao Zhang   } else { /* Ex. happens with MatZeroEntries on a preallocated but not assembled matrix */
1652328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1662328674fSJunchao Zhang   }
1673ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
168076ba34aSJunchao Zhang }
169076ba34aSJunchao Zhang 
170d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
171d71ae5a4SJacob Faibussowitsch {
172076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
173076ba34aSJunchao Zhang 
174076ba34aSJunchao Zhang   PetscFunctionBegin;
1755519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
176076ba34aSJunchao Zhang     aijkok->a_dual.clear_sync_state();
177076ba34aSJunchao Zhang     aijkok->a_dual.modify_host();
1782328674fSJunchao Zhang   }
1793ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
180f0cf5187SStefano Zampini }
181f0cf5187SStefano Zampini 
182d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJKokkos(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
183d71ae5a4SJacob Faibussowitsch {
1847ee59b9bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1857ee59b9bSJunchao Zhang 
1867ee59b9bSJunchao Zhang   PetscFunctionBegin;
1877ee59b9bSJunchao Zhang   PetscCheck(aijkok != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "aijkok is NULL");
1887ee59b9bSJunchao Zhang 
1897ee59b9bSJunchao Zhang   if (i) *i = aijkok->i_device_data();
1907ee59b9bSJunchao Zhang   if (j) *j = aijkok->j_device_data();
1917ee59b9bSJunchao Zhang   if (a) {
1927ee59b9bSJunchao Zhang     aijkok->a_dual.sync_device();
1937ee59b9bSJunchao Zhang     *a = aijkok->a_device_data();
1947ee59b9bSJunchao Zhang   }
1957ee59b9bSJunchao Zhang   if (mtype) *mtype = PETSC_MEMTYPE_KOKKOS;
1963ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1977ee59b9bSJunchao Zhang }
1987ee59b9bSJunchao Zhang 
199a587d139SMark // MatSeqAIJKokkosSetDeviceMat takes a PetscSplitCSRDataStructure with device data and copies it to the device. Note, "deep_copy" here is really a shallow copy
200d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosSetDeviceMat(Mat A, PetscSplitCSRDataStructure h_mat)
201d71ae5a4SJacob Faibussowitsch {
202042217e8SBarry Smith   Mat_SeqAIJKokkos                            *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
203042217e8SBarry Smith   Kokkos::View<SplitCSRMat, Kokkos::HostSpace> h_mat_k(h_mat);
204a587d139SMark 
205a587d139SMark   PetscFunctionBegin;
2065f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
207152b3e56SJunchao Zhang   aijkok->device_mat_d = create_mirror(DefaultMemorySpace(), h_mat_k);
208a587d139SMark   Kokkos::deep_copy(aijkok->device_mat_d, h_mat_k);
2093ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
210a587d139SMark }
211a587d139SMark 
212a587d139SMark // MatSeqAIJKokkosGetDeviceMat gets the device if it is here, otherwise it creates a place for it and returns NULL
213d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosGetDeviceMat(Mat A, PetscSplitCSRDataStructure *d_mat)
214d71ae5a4SJacob Faibussowitsch {
215042217e8SBarry Smith   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
216a587d139SMark 
217a587d139SMark   PetscFunctionBegin;
218a587d139SMark   if (aijkok && aijkok->device_mat_d.data()) {
219a587d139SMark     *d_mat = aijkok->device_mat_d.data();
220a587d139SMark   } else {
2219566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosSyncDevice(A)); // create aijkok (we are making d_mat now so make a place for it)
222a587d139SMark     *d_mat = NULL;
223a587d139SMark   }
2243ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
225a587d139SMark }
226076ba34aSJunchao Zhang 
227*0e3ece09SJunchao Zhang /*
228*0e3ece09SJunchao Zhang   Generate the sparsity pattern of a MatSeqAIJKokkos matrix's transpose on device.
229*0e3ece09SJunchao Zhang 
230*0e3ece09SJunchao Zhang   Input Parameter:
231*0e3ece09SJunchao Zhang .  A       - the MATSEQAIJKOKKOS matrix
232*0e3ece09SJunchao Zhang 
233*0e3ece09SJunchao Zhang   Output Parameters:
234*0e3ece09SJunchao Zhang +  perm_d - the permutation array on device, which connects Ta(i) = Aa(perm(i))
235*0e3ece09SJunchao Zhang -  T_d    - the transpose on device, whose value array is allcoated but not initialized
236*0e3ece09SJunchao Zhang */
237*0e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateTransposeStructure(Mat A, MatRowMapKokkosView &perm_d, KokkosCsrMatrix &T_d)
238d71ae5a4SJacob Faibussowitsch {
239*0e3ece09SJunchao Zhang   Mat_SeqAIJ             *aseq = static_cast<Mat_SeqAIJ *>(A->data);
240*0e3ece09SJunchao Zhang   PetscInt                nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
241*0e3ece09SJunchao Zhang   const PetscInt         *Ai = aseq->i, *Aj = aseq->j;
242*0e3ece09SJunchao Zhang   MatRowMapKokkosViewHost Ti_h("Ti", n + 1);
243*0e3ece09SJunchao Zhang   MatRowMapType          *Ti = Ti_h.data();
244*0e3ece09SJunchao Zhang   MatColIdxKokkosViewHost Tj_h("Tj", nz);
245*0e3ece09SJunchao Zhang   MatRowMapKokkosViewHost perm_h("permutation", nz);
246*0e3ece09SJunchao Zhang   PetscInt               *Tj   = Tj_h.data();
247*0e3ece09SJunchao Zhang   PetscInt               *perm = perm_h.data();
248*0e3ece09SJunchao Zhang   PetscInt               *offset;
249152b3e56SJunchao Zhang 
250152b3e56SJunchao Zhang   PetscFunctionBegin;
251*0e3ece09SJunchao Zhang   // Populate Ti
252*0e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::deep_copy(Ti_h, 0));
253*0e3ece09SJunchao Zhang   Ti++;
254*0e3ece09SJunchao Zhang   for (PetscInt i = 0; i < nz; i++) Ti[Aj[i]]++;
255*0e3ece09SJunchao Zhang   Ti--;
256*0e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) Ti[i + 1] += Ti[i];
257*0e3ece09SJunchao Zhang 
258*0e3ece09SJunchao Zhang   // Populate Tj and the permutation array
259*0e3ece09SJunchao Zhang   PetscCall(PetscCalloc1(n, &offset)); // offset in each T row to fill in its column indices
260*0e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
261*0e3ece09SJunchao Zhang     for (PetscInt j = Ai[i]; j < Ai[i + 1]; j++) { // A's (i,j) is T's (j,i)
262*0e3ece09SJunchao Zhang       PetscInt r    = Aj[j];                       // row r of T
263*0e3ece09SJunchao Zhang       PetscInt disp = Ti[r] + offset[r];
264*0e3ece09SJunchao Zhang 
265*0e3ece09SJunchao Zhang       Tj[disp]   = i; // col i of T
266*0e3ece09SJunchao Zhang       perm[disp] = j;
267*0e3ece09SJunchao Zhang       offset[r]++;
268076ba34aSJunchao Zhang     }
269*0e3ece09SJunchao Zhang   }
270*0e3ece09SJunchao Zhang   PetscCall(PetscFree(offset));
271*0e3ece09SJunchao Zhang 
272*0e3ece09SJunchao Zhang   // Sort each row of T, along with the permutation array
273*0e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) PetscCall(PetscSortIntWithArray(Ti[i + 1] - Ti[i], Tj + Ti[i], perm + Ti[i]));
274*0e3ece09SJunchao Zhang 
275*0e3ece09SJunchao Zhang   // Output perm and T on device
276*0e3ece09SJunchao Zhang   auto Ti_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ti_h);
277*0e3ece09SJunchao Zhang   auto Tj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Tj_h);
278*0e3ece09SJunchao Zhang   PetscCallCXX(T_d = KokkosCsrMatrix("csrmatT", n, m, nz, MatScalarKokkosView("Ta", nz), Ti_d, Tj_d));
279*0e3ece09SJunchao Zhang   PetscCallCXX(perm_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), perm_h));
2803ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
281152b3e56SJunchao Zhang }
282152b3e56SJunchao Zhang 
283*0e3ece09SJunchao Zhang // Generate the transpose on device and cache it internally
284*0e3ece09SJunchao Zhang // Note: KK transpose_matrix() does not have support symbolic/numeric transpose, so we do it on our own
285*0e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGenerateTranspose_Private(Mat A, KokkosCsrMatrix *csrmatT)
286d71ae5a4SJacob Faibussowitsch {
287*0e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
288*0e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
289*0e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
290*0e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatT;
291152b3e56SJunchao Zhang 
292152b3e56SJunchao Zhang   PetscFunctionBegin;
293*0e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
294*0e3ece09SJunchao Zhang   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's valeus since we are going to access them on device
295*0e3ece09SJunchao Zhang 
296*0e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
297*0e3ece09SJunchao Zhang 
298*0e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE) {
299*0e3ece09SJunchao Zhang     *csrmatT = akok->csrmat;
300*0e3ece09SJunchao Zhang   } else {
301*0e3ece09SJunchao Zhang     // See if we already have a cached transpose and its value is up to date
302*0e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
303*0e3ece09SJunchao Zhang       if (!akok->transpose_updated) {            // if the value is out of date, update the cached version
304*0e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
305*0e3ece09SJunchao Zhang         auto       &Ta   = T.values;
306*0e3ece09SJunchao Zhang 
307*0e3ece09SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(
308*0e3ece09SJunchao Zhang           nz, KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = Aa(perm(i)); }));
309076ba34aSJunchao Zhang       }
310*0e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
311*0e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
312*0e3ece09SJunchao Zhang 
313*0e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
314*0e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
315*0e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(
316*0e3ece09SJunchao Zhang         nz, KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = Aa(perm(i)); }));
317*0e3ece09SJunchao Zhang     }
318*0e3ece09SJunchao Zhang     akok->transpose_updated = PETSC_TRUE;
319*0e3ece09SJunchao Zhang     *csrmatT                = akok->csrmatT;
320*0e3ece09SJunchao Zhang   }
321*0e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
322*0e3ece09SJunchao Zhang }
323*0e3ece09SJunchao Zhang 
324*0e3ece09SJunchao Zhang // Generate the Hermitian on device and cache it internally
325*0e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateHermitian_Private(Mat A, KokkosCsrMatrix *csrmatH)
326*0e3ece09SJunchao Zhang {
327*0e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
328*0e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
329*0e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
330*0e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatH;
331*0e3ece09SJunchao Zhang 
332*0e3ece09SJunchao Zhang   PetscFunctionBegin;
333*0e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
334*0e3ece09SJunchao Zhang   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's values since we are going to access them on device
335*0e3ece09SJunchao Zhang 
336*0e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
337*0e3ece09SJunchao Zhang 
338*0e3ece09SJunchao Zhang   if (A->hermitian == PETSC_BOOL3_TRUE) {
339*0e3ece09SJunchao Zhang     *csrmatH = akok->csrmat;
340*0e3ece09SJunchao Zhang   } else {
341*0e3ece09SJunchao Zhang     // See if we already have a cached hermitian and its value is up to date
342*0e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
343*0e3ece09SJunchao Zhang       if (!akok->hermitian_updated) {            // if the value is out of date, update the cached version
344*0e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
345*0e3ece09SJunchao Zhang         auto       &Ta   = T.values;
346*0e3ece09SJunchao Zhang 
347*0e3ece09SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(
348*0e3ece09SJunchao Zhang           nz, KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = PetscConj(Aa(perm(i))); }));
349*0e3ece09SJunchao Zhang       }
350*0e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
351*0e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
352*0e3ece09SJunchao Zhang 
353*0e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
354*0e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
355*0e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(
356*0e3ece09SJunchao Zhang         nz, KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = PetscConj(Aa(perm(i))); }));
357*0e3ece09SJunchao Zhang     }
358*0e3ece09SJunchao Zhang     akok->hermitian_updated = PETSC_TRUE;
359*0e3ece09SJunchao Zhang     *csrmatH                = akok->csrmatH;
360*0e3ece09SJunchao Zhang   }
3613ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
362152b3e56SJunchao Zhang }
363a587d139SMark 
3648c3ff71bSJunchao Zhang /* y = A x */
365d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMult_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
366d71ae5a4SJacob Faibussowitsch {
3678c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
368152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
369152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3708c3ff71bSJunchao Zhang 
3718c3ff71bSJunchao Zhang   PetscFunctionBegin;
3729566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3739566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3749566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3759566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
3768c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3779d52486cSJunchao Zhang   PetscCallCXX(KokkosSparse::spmv("N", 1.0 /*alpha*/, aijkok->csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A x + beta y */
3789566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3799566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
380076ba34aSJunchao Zhang   /* 2.0*nnz - numRows seems more accurate here but assumes there are no zero-rows. So a little sloppy here. */
3819566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
3829566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
3833ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
3848c3ff71bSJunchao Zhang }
3858c3ff71bSJunchao Zhang 
3868c3ff71bSJunchao Zhang /* y = A^T x */
387d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
388d71ae5a4SJacob Faibussowitsch {
3898c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
390152b3e56SJunchao Zhang   const char                *mode;
391152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
392152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
393*0e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
3948c3ff71bSJunchao Zhang 
3958c3ff71bSJunchao Zhang   PetscFunctionBegin;
3969566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3979566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3989566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3999566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
400152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4019566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
402152b3e56SJunchao Zhang     mode = "N";
403152b3e56SJunchao Zhang   } else {
404076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
405*0e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
406152b3e56SJunchao Zhang     mode   = "T";
407152b3e56SJunchao Zhang   }
408*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^T x + beta y */
4099566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4109566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
411*0e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4129566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4133ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4148c3ff71bSJunchao Zhang }
4158c3ff71bSJunchao Zhang 
4168c3ff71bSJunchao Zhang /* y = A^H x */
417d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
418d71ae5a4SJacob Faibussowitsch {
4198c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
420152b3e56SJunchao Zhang   const char                *mode;
421152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
422152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
423*0e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4248c3ff71bSJunchao Zhang 
4258c3ff71bSJunchao Zhang   PetscFunctionBegin;
4269566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4279566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4289566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4299566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
430152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4319566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
432152b3e56SJunchao Zhang     mode = "N";
433152b3e56SJunchao Zhang   } else {
434076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
435*0e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
436152b3e56SJunchao Zhang     mode   = "C";
437152b3e56SJunchao Zhang   }
438*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^H x + beta y */
4399566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4409566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
441*0e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4429566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4433ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4448c3ff71bSJunchao Zhang }
4458c3ff71bSJunchao Zhang 
4468c3ff71bSJunchao Zhang /* z = A x + y */
447d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
448d71ae5a4SJacob Faibussowitsch {
4498c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
450152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
451152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4528c3ff71bSJunchao Zhang 
4538c3ff71bSJunchao Zhang   PetscFunctionBegin;
4549566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4559566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4569566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4579566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
4589566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
4598c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
4608c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4619d52486cSJunchao Zhang   PetscCallCXX(KokkosSparse::spmv("N", 1.0 /*alpha*/, aijkok->csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A x + beta z */
4629566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4639566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
4649566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
4659566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
4669566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4673ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4688c3ff71bSJunchao Zhang }
4698c3ff71bSJunchao Zhang 
4708c3ff71bSJunchao Zhang /* z = A^T x + y */
471d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
472d71ae5a4SJacob Faibussowitsch {
4738c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
474152b3e56SJunchao Zhang   const char                *mode;
475152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
476152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
477*0e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4788c3ff71bSJunchao Zhang 
4798c3ff71bSJunchao Zhang   PetscFunctionBegin;
4809566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4819566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4829566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4839566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
4849566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
4858c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
486152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4879566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
488152b3e56SJunchao Zhang     mode = "N";
489152b3e56SJunchao Zhang   } else {
490076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
491*0e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
492152b3e56SJunchao Zhang     mode   = "T";
493152b3e56SJunchao Zhang   }
494*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^T x + beta z */
4959566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4969566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
4979566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
498*0e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4999566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
5003ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5018c3ff71bSJunchao Zhang }
5028c3ff71bSJunchao Zhang 
5038c3ff71bSJunchao Zhang /* z = A^H x + y */
504d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
505d71ae5a4SJacob Faibussowitsch {
5068c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
507152b3e56SJunchao Zhang   const char                *mode;
508152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
509152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
510*0e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
5118c3ff71bSJunchao Zhang 
5128c3ff71bSJunchao Zhang   PetscFunctionBegin;
5139566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
5149566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
5159566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
5169566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
5179566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
5188c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
519152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
5209566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
521152b3e56SJunchao Zhang     mode = "N";
522152b3e56SJunchao Zhang   } else {
523076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
524*0e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
525152b3e56SJunchao Zhang     mode   = "C";
526152b3e56SJunchao Zhang   }
527*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^H x + beta z */
5289566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
5299566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
5309566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
531*0e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
5329566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
5333ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
534152b3e56SJunchao Zhang }
535152b3e56SJunchao Zhang 
536d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSetOption_SeqAIJKokkos(Mat A, MatOption op, PetscBool flg)
537d71ae5a4SJacob Faibussowitsch {
538152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
539152b3e56SJunchao Zhang 
540152b3e56SJunchao Zhang   PetscFunctionBegin;
541152b3e56SJunchao Zhang   switch (op) {
542152b3e56SJunchao Zhang   case MAT_FORM_EXPLICIT_TRANSPOSE:
543152b3e56SJunchao Zhang     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
5449566063dSJacob Faibussowitsch     if (A->form_explicit_transpose && !flg && aijkok) PetscCall(aijkok->DestroyMatTranspose());
545152b3e56SJunchao Zhang     A->form_explicit_transpose = flg;
546152b3e56SJunchao Zhang     break;
547d71ae5a4SJacob Faibussowitsch   default:
548d71ae5a4SJacob Faibussowitsch     PetscCall(MatSetOption_SeqAIJ(A, op, flg));
549d71ae5a4SJacob Faibussowitsch     break;
550152b3e56SJunchao Zhang   }
5513ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5528c3ff71bSJunchao Zhang }
5538c3ff71bSJunchao Zhang 
554076ba34aSJunchao Zhang /* Depending on reuse, either build a new mat, or use the existing mat */
555d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
556d71ae5a4SJacob Faibussowitsch {
557076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
5588c3ff71bSJunchao Zhang 
5598c3ff71bSJunchao Zhang   PetscFunctionBegin;
5609566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
561076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {                      /* Build a brand new mat */
5629566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));  /* the returned newmat is a SeqAIJKokkos */
5638c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {                 /* Reuse the mat created before */
5649566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); /* newmat is already a SeqAIJKokkos */
565076ba34aSJunchao Zhang   } else if (reuse == MAT_INPLACE_MATRIX) {               /* newmat is A */
5665f80ce2aSJacob Faibussowitsch     PetscCheck(A == *newmat, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "A != *newmat with MAT_INPLACE_MATRIX");
5679566063dSJacob Faibussowitsch     PetscCall(PetscFree(A->defaultvectype));
5689566063dSJacob Faibussowitsch     PetscCall(PetscStrallocpy(VECKOKKOS, &A->defaultvectype)); /* Allocate and copy the string */
5699566063dSJacob Faibussowitsch     PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSEQAIJKOKKOS));
5709566063dSJacob Faibussowitsch     PetscCall(MatSetOps_SeqAIJKokkos(A));
571076ba34aSJunchao Zhang     aseq = static_cast<Mat_SeqAIJ *>(A->data);
572394ed5ebSJunchao Zhang     if (A->assembled) { /* Copy i, j (but not values) to device for an assembled matrix if not yet */
5735f80ce2aSJacob Faibussowitsch       PetscCheck(!A->spptr, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Expect NULL (Mat_SeqAIJKokkos*)A->spptr");
574076ba34aSJunchao Zhang       A->spptr = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aseq->nz, aseq->i, aseq->j, aseq->a, A->nonzerostate, PETSC_FALSE);
5758c3ff71bSJunchao Zhang     }
576076ba34aSJunchao Zhang   }
5773ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5788c3ff71bSJunchao Zhang }
5798c3ff71bSJunchao Zhang 
580076ba34aSJunchao Zhang /* MatDuplicate always creates a new matrix. MatDuplicate can be called either on an assembled matrix or
581076ba34aSJunchao Zhang    an unassembled matrix, even though MAT_COPY_VALUES is not allowed for unassembled matrix.
582076ba34aSJunchao Zhang  */
583d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDuplicate_SeqAIJKokkos(Mat A, MatDuplicateOption dupOption, Mat *B)
584d71ae5a4SJacob Faibussowitsch {
585076ba34aSJunchao Zhang   Mat_SeqAIJ       *bseq;
586076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok;
587076ba34aSJunchao Zhang   Mat               mat;
5888c3ff71bSJunchao Zhang 
5898c3ff71bSJunchao Zhang   PetscFunctionBegin;
590076ba34aSJunchao Zhang   /* Do not copy values on host as A's latest values might be on device. We don't want to do sync blindly */
5919566063dSJacob Faibussowitsch   PetscCall(MatDuplicate_SeqAIJ(A, MAT_DO_NOT_COPY_VALUES, B));
592076ba34aSJunchao Zhang   mat = *B;
593076ba34aSJunchao Zhang   if (A->assembled) {
594076ba34aSJunchao Zhang     bseq = static_cast<Mat_SeqAIJ *>(mat->data);
595076ba34aSJunchao Zhang     bkok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, bseq->nz, bseq->i, bseq->j, bseq->a, mat->nonzerostate, PETSC_FALSE);
596076ba34aSJunchao Zhang     bkok->a_dual.clear_sync_state(); /* Clear B's sync state as it will be decided below */
597076ba34aSJunchao Zhang     /* Now copy values to B if needed */
598076ba34aSJunchao Zhang     if (dupOption == MAT_COPY_VALUES) {
599076ba34aSJunchao Zhang       if (akok->a_dual.need_sync_device()) {
600076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_host(), akok->a_dual.view_host());
601076ba34aSJunchao Zhang         bkok->a_dual.modify_host();
602076ba34aSJunchao Zhang       } else { /* If device has the latest data, we only copy data on device */
603076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_device(), akok->a_dual.view_device());
604076ba34aSJunchao Zhang         bkok->a_dual.modify_device();
605076ba34aSJunchao Zhang       }
606076ba34aSJunchao Zhang     } else { /* MAT_DO_NOT_COPY_VALUES or MAT_SHARE_NONZERO_PATTERN. B's values should be zeroed */
607076ba34aSJunchao Zhang       /* B's values on host should be already zeroed by MatDuplicate_SeqAIJ() */
608076ba34aSJunchao Zhang       bkok->a_dual.modify_host();
609076ba34aSJunchao Zhang     }
610076ba34aSJunchao Zhang     mat->spptr = bkok;
611076ba34aSJunchao Zhang   }
612076ba34aSJunchao Zhang 
6139566063dSJacob Faibussowitsch   PetscCall(PetscFree(mat->defaultvectype));
6149566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &mat->defaultvectype)); /* Allocate and copy the string */
6159566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)mat, MATSEQAIJKOKKOS));
6169566063dSJacob Faibussowitsch   PetscCall(MatSetOps_SeqAIJKokkos(mat));
6173ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6188c3ff71bSJunchao Zhang }
6198c3ff71bSJunchao Zhang 
620d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatTranspose_SeqAIJKokkos(Mat A, MatReuse reuse, Mat *B)
621d71ae5a4SJacob Faibussowitsch {
6220ecb592aSJunchao Zhang   Mat               At;
623*0e3ece09SJunchao Zhang   KokkosCsrMatrix   internT;
6240ecb592aSJunchao Zhang   Mat_SeqAIJKokkos *atkok, *bkok;
6250ecb592aSJunchao Zhang 
6260ecb592aSJunchao Zhang   PetscFunctionBegin;
6277fb60732SBarry Smith   if (reuse == MAT_REUSE_MATRIX) PetscCall(MatTransposeCheckNonzeroState_Private(A, *B));
6289566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &internT)); /* Generate a transpose internally */
6290ecb592aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX || reuse == MAT_INPLACE_MATRIX) {
630ff751488SJunchao Zhang     /* Deep copy internT, as we want to isolate the internal transpose */
631*0e3ece09SJunchao Zhang     PetscCallCXX(atkok = new Mat_SeqAIJKokkos(KokkosCsrMatrix("csrmat", internT)));
6329566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PetscObjectComm((PetscObject)A), atkok, &At));
6330ecb592aSJunchao Zhang     if (reuse == MAT_INITIAL_MATRIX) *B = At;
6349566063dSJacob Faibussowitsch     else PetscCall(MatHeaderReplace(A, &At)); /* Replace A with At inplace */
6350ecb592aSJunchao Zhang   } else {                                    /* MAT_REUSE_MATRIX, just need to copy values to B on device */
6360ecb592aSJunchao Zhang     if ((*B)->assembled) {
6370ecb592aSJunchao Zhang       bkok = static_cast<Mat_SeqAIJKokkos *>((*B)->spptr);
638*0e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(bkok->a_dual.view_device(), internT.values));
6399566063dSJacob Faibussowitsch       PetscCall(MatSeqAIJKokkosModifyDevice(*B));
6400ecb592aSJunchao Zhang     } else if ((*B)->preallocated) { /* It is ok for B to be only preallocated, as needed in MatTranspose_MPIAIJ */
6410ecb592aSJunchao Zhang       Mat_SeqAIJ             *bseq = static_cast<Mat_SeqAIJ *>((*B)->data);
642*0e3ece09SJunchao Zhang       MatScalarKokkosViewHost a_h(bseq->a, internT.nnz()); /* bseq->nz = 0 if unassembled */
643*0e3ece09SJunchao Zhang       MatColIdxKokkosViewHost j_h(bseq->j, internT.nnz());
644*0e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(a_h, internT.values));
645*0e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(j_h, internT.graph.entries));
6460ecb592aSJunchao Zhang     } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "B must be assembled or preallocated");
6470ecb592aSJunchao Zhang   }
6483ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6490ecb592aSJunchao Zhang }
6500ecb592aSJunchao Zhang 
651d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDestroy_SeqAIJKokkos(Mat A)
652d71ae5a4SJacob Faibussowitsch {
65386a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
6548c3ff71bSJunchao Zhang 
6558c3ff71bSJunchao Zhang   PetscFunctionBegin;
65686a27549SJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
65786a27549SJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
6588c3ff71bSJunchao Zhang     delete aijkok;
65986a27549SJunchao Zhang   } else {
66086a27549SJunchao Zhang     delete static_cast<Mat_SeqAIJKokkosTriFactors *>(A->spptr);
66186a27549SJunchao Zhang   }
662cbc6b225SStefano Zampini   A->spptr = NULL;
6639566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
6649566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
6659566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
6669566063dSJacob Faibussowitsch   PetscCall(MatDestroy_SeqAIJ(A));
6673ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6688c3ff71bSJunchao Zhang }
6698c3ff71bSJunchao Zhang 
6703f3ba80aSJunchao Zhang /*MC
6713f3ba80aSJunchao Zhang    MATSEQAIJKOKKOS - MATAIJKOKKOS = "(seq)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
6723f3ba80aSJunchao Zhang 
6733f3ba80aSJunchao Zhang    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
6743f3ba80aSJunchao Zhang 
6752ef1f0ffSBarry Smith    Options Database Key:
67611a5261eSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATSEQAIJKOKKOS` during a call to `MatSetFromOptions()`
6773f3ba80aSJunchao Zhang 
6783f3ba80aSJunchao Zhang   Level: beginner
6793f3ba80aSJunchao Zhang 
6802ef1f0ffSBarry Smith .seealso: [](chapter_matrices), `Mat`, `MatCreateSeqAIJKokkos()`, `MATMPIAIJKOKKOS`
6813f3ba80aSJunchao Zhang M*/
682d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJKokkos(Mat A)
683d71ae5a4SJacob Faibussowitsch {
68486a27549SJunchao Zhang   PetscFunctionBegin;
6859566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
6869566063dSJacob Faibussowitsch   PetscCall(MatCreate_SeqAIJ(A));
6879566063dSJacob Faibussowitsch   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
6883ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
68986a27549SJunchao Zhang }
69086a27549SJunchao Zhang 
691076ba34aSJunchao Zhang /* Merge A, B into a matrix C. A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n) */
692d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
693d71ae5a4SJacob Faibussowitsch {
694076ba34aSJunchao Zhang   Mat_SeqAIJ         *a, *b;
695076ba34aSJunchao Zhang   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
696076ba34aSJunchao Zhang   MatScalarKokkosView aa, ba, ca;
697076ba34aSJunchao Zhang   MatRowMapKokkosView ai, bi, ci;
698076ba34aSJunchao Zhang   MatColIdxKokkosView aj, bj, cj;
699076ba34aSJunchao Zhang   PetscInt            m, n, nnz, aN;
700a3f881fbSStefano Zampini 
701a3f881fbSStefano Zampini   PetscFunctionBegin;
702076ba34aSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
703076ba34aSJunchao Zhang   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
704076ba34aSJunchao Zhang   PetscValidPointer(C, 4);
705076ba34aSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
706076ba34aSJunchao Zhang   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
7075f80ce2aSJacob Faibussowitsch   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
7085f80ce2aSJacob Faibussowitsch   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
709076ba34aSJunchao Zhang 
7109566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
7119566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
712076ba34aSJunchao Zhang   a    = static_cast<Mat_SeqAIJ *>(A->data);
713076ba34aSJunchao Zhang   b    = static_cast<Mat_SeqAIJ *>(B->data);
714076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
715076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
716076ba34aSJunchao Zhang   aa   = akok->a_dual.view_device();
717076ba34aSJunchao Zhang   ai   = akok->i_dual.view_device();
718076ba34aSJunchao Zhang   ba   = bkok->a_dual.view_device();
719076ba34aSJunchao Zhang   bi   = bkok->i_dual.view_device();
720076ba34aSJunchao Zhang   m    = A->rmap->n; /* M, N and nnz of C */
721076ba34aSJunchao Zhang   n    = A->cmap->n + B->cmap->n;
722076ba34aSJunchao Zhang   nnz  = a->nz + b->nz;
723076ba34aSJunchao Zhang   aN   = A->cmap->n; /* N of A */
724076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
725076ba34aSJunchao Zhang     aj           = akok->j_dual.view_device();
726076ba34aSJunchao Zhang     bj           = bkok->j_dual.view_device();
727076ba34aSJunchao Zhang     auto ca_dual = MatScalarKokkosDualView("a", aa.extent(0) + ba.extent(0));
728076ba34aSJunchao Zhang     auto ci_dual = MatRowMapKokkosDualView("i", ai.extent(0));
729076ba34aSJunchao Zhang     auto cj_dual = MatColIdxKokkosDualView("j", aj.extent(0) + bj.extent(0));
730076ba34aSJunchao Zhang     ca           = ca_dual.view_device();
731076ba34aSJunchao Zhang     ci           = ci_dual.view_device();
732076ba34aSJunchao Zhang     cj           = cj_dual.view_device();
733076ba34aSJunchao Zhang 
734076ba34aSJunchao Zhang     /* Concatenate A and B in parallel using Kokkos hierarchical parallelism */
7359371c9d4SSatish Balay     Kokkos::parallel_for(
7369371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
737076ba34aSJunchao Zhang         PetscInt i       = t.league_rank(); /* row i */
738076ba34aSJunchao Zhang         PetscInt coffset = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
739076ba34aSJunchao Zhang 
740076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* this side effect only happens once per whole team */
741076ba34aSJunchao Zhang                                                    ci(i) = coffset;
742076ba34aSJunchao Zhang                                                    if (i == m - 1) ci(m) = ai(m) + bi(m);
743076ba34aSJunchao Zhang         });
744076ba34aSJunchao Zhang 
745076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
746076ba34aSJunchao Zhang           if (k < alen) {
747076ba34aSJunchao Zhang             ca(coffset + k) = aa(ai(i) + k);
748076ba34aSJunchao Zhang             cj(coffset + k) = aj(ai(i) + k);
749076ba34aSJunchao Zhang           } else {
750076ba34aSJunchao Zhang             ca(coffset + k) = ba(bi(i) + k - alen);
751076ba34aSJunchao Zhang             cj(coffset + k) = bj(bi(i) + k - alen) + aN; /* Entries in B get new column indices in C */
752076ba34aSJunchao Zhang           }
753076ba34aSJunchao Zhang         });
754076ba34aSJunchao Zhang       });
755076ba34aSJunchao Zhang     ca_dual.modify_device();
756076ba34aSJunchao Zhang     ci_dual.modify_device();
757076ba34aSJunchao Zhang     cj_dual.modify_device();
7589566063dSJacob Faibussowitsch     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, n, nnz, ci_dual, cj_dual, ca_dual));
7599566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
760076ba34aSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
761076ba34aSJunchao Zhang     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
762076ba34aSJunchao Zhang     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
763076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
764076ba34aSJunchao Zhang     ca   = ckok->a_dual.view_device();
765076ba34aSJunchao Zhang     ci   = ckok->i_dual.view_device();
766076ba34aSJunchao Zhang 
7679371c9d4SSatish Balay     Kokkos::parallel_for(
7689371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
769076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* row i */
770076ba34aSJunchao Zhang         PetscInt alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
771076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
772076ba34aSJunchao Zhang           if (k < alen) ca(ci(i) + k) = aa(ai(i) + k);
773076ba34aSJunchao Zhang           else ca(ci(i) + k) = ba(bi(i) + k - alen);
774076ba34aSJunchao Zhang         });
775076ba34aSJunchao Zhang       });
7769566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
777076ba34aSJunchao Zhang   }
7783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
779076ba34aSJunchao Zhang }
780076ba34aSJunchao Zhang 
781d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_SeqAIJKokkos(void *pdata)
782d71ae5a4SJacob Faibussowitsch {
783076ba34aSJunchao Zhang   PetscFunctionBegin;
784076ba34aSJunchao Zhang   delete static_cast<MatProductData_SeqAIJKokkos *>(pdata);
7853ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
786a3f881fbSStefano Zampini }
787a3f881fbSStefano Zampini 
788d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos(Mat C)
789d71ae5a4SJacob Faibussowitsch {
790a3f881fbSStefano Zampini   Mat_Product                 *product = C->product;
791a3f881fbSStefano Zampini   Mat                          A, B;
792076ba34aSJunchao Zhang   bool                         transA, transB; /* use bool, since KK needs this type */
793a3f881fbSStefano Zampini   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
794a3f881fbSStefano Zampini   Mat_SeqAIJ                  *c;
795076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
796*0e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB;
797a3f881fbSStefano Zampini 
798a3f881fbSStefano Zampini   PetscFunctionBegin;
799a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
8005f80ce2aSJacob Faibussowitsch   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data empty");
801076ba34aSJunchao Zhang   pdata = static_cast<MatProductData_SeqAIJKokkos *>(C->product->data);
802076ba34aSJunchao Zhang 
803*0e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
804*0e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
805*0e3ece09SJunchao Zhang   // we still do numeric.
806*0e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
807*0e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
8083ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
809076ba34aSJunchao Zhang   }
810076ba34aSJunchao Zhang 
811076ba34aSJunchao Zhang   switch (product->type) {
8129371c9d4SSatish Balay   case MATPRODUCT_AB:
8139371c9d4SSatish Balay     transA = false;
8149371c9d4SSatish Balay     transB = false;
8159371c9d4SSatish Balay     break;
8169371c9d4SSatish Balay   case MATPRODUCT_AtB:
8179371c9d4SSatish Balay     transA = true;
8189371c9d4SSatish Balay     transB = false;
8199371c9d4SSatish Balay     break;
8209371c9d4SSatish Balay   case MATPRODUCT_ABt:
8219371c9d4SSatish Balay     transA = false;
8229371c9d4SSatish Balay     transB = true;
8239371c9d4SSatish Balay     break;
824d71ae5a4SJacob Faibussowitsch   default:
825d71ae5a4SJacob Faibussowitsch     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
826076ba34aSJunchao Zhang   }
827076ba34aSJunchao Zhang 
828a3f881fbSStefano Zampini   A = product->A;
829a3f881fbSStefano Zampini   B = product->B;
8309566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8319566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
832a3f881fbSStefano Zampini   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
833a3f881fbSStefano Zampini   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
834a3f881fbSStefano Zampini   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
835076ba34aSJunchao Zhang 
8365f80ce2aSJacob Faibussowitsch   PetscCheck(ckok, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Device data structure spptr is empty");
837076ba34aSJunchao Zhang 
838*0e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
839*0e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
840076ba34aSJunchao Zhang 
841076ba34aSJunchao Zhang   /* TODO: Once KK spgemm implements transpose, we can get rid of the explicit transpose here */
842076ba34aSJunchao Zhang   if (transA) {
8439566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
844076ba34aSJunchao Zhang     transA = false;
845a3f881fbSStefano Zampini   }
846a3f881fbSStefano Zampini 
847076ba34aSJunchao Zhang   if (transB) {
8489566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
849076ba34aSJunchao Zhang     transB = false;
850076ba34aSJunchao Zhang   }
8519566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
852*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, ckok->csrmat));
853*0e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
854866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
855866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(ckok->csrmat)); /* without sort, mat_tests-ex62_14_seqaijkokkos fails */
856e944a159SJunchao Zhang #endif
857866eb059SJunchao Zhang 
8589566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
8599566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(C));
860a3f881fbSStefano Zampini   /* shorter version of MatAssemblyEnd_SeqAIJ */
861a3f881fbSStefano Zampini   c = (Mat_SeqAIJ *)C->data;
8629566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded,%" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
8639566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
8649566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
865a3f881fbSStefano Zampini   c->reallocs         = 0;
866076ba34aSJunchao Zhang   C->info.mallocs     = 0;
867a3f881fbSStefano Zampini   C->info.nz_unneeded = 0;
868a3f881fbSStefano Zampini   C->assembled = C->was_assembled = PETSC_TRUE;
869a3f881fbSStefano Zampini   C->num_ass++;
8703ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
871a3f881fbSStefano Zampini }
872a3f881fbSStefano Zampini 
873d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos(Mat C)
874d71ae5a4SJacob Faibussowitsch {
875076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
876076ba34aSJunchao Zhang   MatProductType               ptype;
877076ba34aSJunchao Zhang   Mat                          A, B;
878076ba34aSJunchao Zhang   bool                         transA, transB;
879076ba34aSJunchao Zhang   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
880076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
881076ba34aSJunchao Zhang   MPI_Comm                     comm;
882*0e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB, csrmatC;
883a3f881fbSStefano Zampini 
884a3f881fbSStefano Zampini   PetscFunctionBegin;
885a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
8869566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
8875f80ce2aSJacob Faibussowitsch   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
888a3f881fbSStefano Zampini   A = product->A;
889a3f881fbSStefano Zampini   B = product->B;
8909566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8919566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
892a3f881fbSStefano Zampini   akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
893a3f881fbSStefano Zampini   bkok    = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
894*0e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
895*0e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
896076ba34aSJunchao Zhang 
897a3f881fbSStefano Zampini   ptype = product->type;
898*0e3ece09SJunchao Zhang   // Take advantage of the symmetry if true
899*0e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
900*0e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
901*0e3ece09SJunchao Zhang     product->symbolic_used_the_fact_A_is_symmetric = PETSC_TRUE;
902*0e3ece09SJunchao Zhang   }
903*0e3ece09SJunchao Zhang   if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
904*0e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
905*0e3ece09SJunchao Zhang     product->symbolic_used_the_fact_B_is_symmetric = PETSC_TRUE;
906*0e3ece09SJunchao Zhang   }
907*0e3ece09SJunchao Zhang 
908a3f881fbSStefano Zampini   switch (ptype) {
9099371c9d4SSatish Balay   case MATPRODUCT_AB:
9109371c9d4SSatish Balay     transA = false;
9119371c9d4SSatish Balay     transB = false;
9129371c9d4SSatish Balay     break;
9139371c9d4SSatish Balay   case MATPRODUCT_AtB:
9149371c9d4SSatish Balay     transA = true;
9159371c9d4SSatish Balay     transB = false;
9169371c9d4SSatish Balay     break;
9179371c9d4SSatish Balay   case MATPRODUCT_ABt:
9189371c9d4SSatish Balay     transA = false;
9199371c9d4SSatish Balay     transB = true;
9209371c9d4SSatish Balay     break;
921d71ae5a4SJacob Faibussowitsch   default:
922d71ae5a4SJacob Faibussowitsch     SETERRQ(comm, PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
923a3f881fbSStefano Zampini   }
924*0e3ece09SJunchao Zhang   PetscCallCXX(product->data = pdata = new MatProductData_SeqAIJKokkos());
925076ba34aSJunchao Zhang   pdata->reusesym = product->api_user;
926a3f881fbSStefano Zampini 
927076ba34aSJunchao Zhang   /* TODO: add command line options to select spgemm algorithms */
928866eb059SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; /* default alg is TPL if enabled, otherwise KK */
929866eb059SJunchao Zhang 
930866eb059SJunchao Zhang   /* CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 */
931866eb059SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
932866eb059SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
933866eb059SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
934866eb059SJunchao Zhang   #endif
935866eb059SJunchao Zhang #endif
936*0e3ece09SJunchao Zhang   PetscCallCXX(pdata->kh.create_spgemm_handle(spgemm_alg));
937076ba34aSJunchao Zhang 
9389566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
939076ba34aSJunchao Zhang   /* TODO: Get rid of the explicit transpose once KK-spgemm implements the transpose option */
940076ba34aSJunchao Zhang   if (transA) {
9419566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
942076ba34aSJunchao Zhang     transA = false;
943076ba34aSJunchao Zhang   }
944076ba34aSJunchao Zhang 
945076ba34aSJunchao Zhang   if (transB) {
9469566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
947076ba34aSJunchao Zhang     transB = false;
948076ba34aSJunchao Zhang   }
949076ba34aSJunchao Zhang 
950*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
951076ba34aSJunchao Zhang   /* spgemm_symbolic() only populates C's rowmap, but not C's column indices.
952076ba34aSJunchao Zhang     So we have to do a fake spgemm_numeric() here to get csrmatC.j_d setup, before
953076ba34aSJunchao Zhang     calling new Mat_SeqAIJKokkos().
954076ba34aSJunchao Zhang     TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
955076ba34aSJunchao Zhang   */
956*0e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
957*0e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
958866eb059SJunchao Zhang   /* Query if KK outputs a sorted matrix. If not, we need to sort it */
959866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
960866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(csrmatC)); /* sort_option defaults to -1 in KK!*/
961e944a159SJunchao Zhang #endif
9629566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
963076ba34aSJunchao Zhang 
9649566063dSJacob Faibussowitsch   PetscCallCXX(ckok = new Mat_SeqAIJKokkos(csrmatC));
9659566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(C, ckok));
966076ba34aSJunchao Zhang   C->product->destroy = MatProductDataDestroy_SeqAIJKokkos;
9673ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
968a3f881fbSStefano Zampini }
969a3f881fbSStefano Zampini 
970a3f881fbSStefano Zampini /* handles sparse matrix matrix ops */
971d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSetFromOptions_SeqAIJKokkos(Mat mat)
972d71ae5a4SJacob Faibussowitsch {
973076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
974a3f881fbSStefano Zampini   PetscBool    Biskok = PETSC_FALSE, Ciskok = PETSC_TRUE;
975a3f881fbSStefano Zampini 
976a3f881fbSStefano Zampini   PetscFunctionBegin;
977a3f881fbSStefano Zampini   MatCheckProduct(mat, 1);
9789566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJKOKKOS, &Biskok));
97948a46eb9SPierre Jolivet   if (product->type == MATPRODUCT_ABC) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJKOKKOS, &Ciskok));
980a3f881fbSStefano Zampini   if (Biskok && Ciskok) {
981a3f881fbSStefano Zampini     switch (product->type) {
982a3f881fbSStefano Zampini     case MATPRODUCT_AB:
983a3f881fbSStefano Zampini     case MATPRODUCT_AtB:
984d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABt:
985d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos;
986d71ae5a4SJacob Faibussowitsch       break;
987a3f881fbSStefano Zampini     case MATPRODUCT_PtAP:
988a3f881fbSStefano Zampini     case MATPRODUCT_RARt:
989d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABC:
990d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
991d71ae5a4SJacob Faibussowitsch       break;
992d71ae5a4SJacob Faibussowitsch     default:
993d71ae5a4SJacob Faibussowitsch       break;
994a3f881fbSStefano Zampini     }
995a3f881fbSStefano Zampini   } else { /* fallback for AIJ */
9969566063dSJacob Faibussowitsch     PetscCall(MatProductSetFromOptions_SeqAIJ(mat));
997a3f881fbSStefano Zampini   }
9983ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
999a3f881fbSStefano Zampini }
1000a587d139SMark 
1001d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatScale_SeqAIJKokkos(Mat A, PetscScalar a)
1002d71ae5a4SJacob Faibussowitsch {
1003f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok;
1004f0cf5187SStefano Zampini 
1005f0cf5187SStefano Zampini   PetscFunctionBegin;
10069566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
10079566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1008f0cf5187SStefano Zampini   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1009076ba34aSJunchao Zhang   KokkosBlas::scal(aijkok->a_dual.view_device(), a, aijkok->a_dual.view_device());
10109566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
10119566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(aijkok->a_dual.extent(0)));
10129566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
10133ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1014f0cf5187SStefano Zampini }
1015f0cf5187SStefano Zampini 
1016d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatZeroEntries_SeqAIJKokkos(Mat A)
1017d71ae5a4SJacob Faibussowitsch {
1018076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1019a587d139SMark 
1020a587d139SMark   PetscFunctionBegin;
1021076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
10222328674fSJunchao Zhang   if (aijkok) { /* Only zero the device if data is already there */
1023076ba34aSJunchao Zhang     KokkosBlas::fill(aijkok->a_dual.view_device(), 0.0);
10249566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A));
10252328674fSJunchao Zhang   } else { /* Might be preallocated but not assembled */
10269566063dSJacob Faibussowitsch     PetscCall(MatZeroEntries_SeqAIJ(A));
10272328674fSJunchao Zhang   }
10283ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1029a587d139SMark }
1030a587d139SMark 
1031d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatGetDiagonal_SeqAIJKokkos(Mat A, Vec x)
1032d71ae5a4SJacob Faibussowitsch {
1033f78ce678SMark Adams   Mat_SeqAIJ           *aijseq;
1034f78ce678SMark Adams   Mat_SeqAIJKokkos     *aijkok;
1035f78ce678SMark Adams   PetscInt              n;
1036f78ce678SMark Adams   PetscScalarKokkosView xv;
1037f78ce678SMark Adams 
1038f78ce678SMark Adams   PetscFunctionBegin;
1039f78ce678SMark Adams   PetscCall(VecGetLocalSize(x, &n));
1040f78ce678SMark Adams   PetscCheck(n == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Nonconforming matrix and vector");
1041f78ce678SMark Adams   PetscCheck(A->factortype == MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_SUP, "MatGetDiagonal_SeqAIJKokkos not supported on factored matrices");
1042f78ce678SMark Adams 
1043f78ce678SMark Adams   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1044f78ce678SMark Adams   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1045f78ce678SMark Adams 
1046f78ce678SMark Adams   if (A->rmap->n && aijkok->diag_dual.extent(0) == 0) { /* Set the diagonal pointer if not already */
1047f78ce678SMark Adams     PetscCall(MatMarkDiagonal_SeqAIJ(A));
1048f78ce678SMark Adams     aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1049f78ce678SMark Adams     aijkok->SetDiagonal(aijseq->diag);
1050f78ce678SMark Adams   }
1051f78ce678SMark Adams 
1052f78ce678SMark Adams   const auto &Aa    = aijkok->a_dual.view_device();
1053f78ce678SMark Adams   const auto &Ai    = aijkok->i_dual.view_device();
1054f78ce678SMark Adams   const auto &Adiag = aijkok->diag_dual.view_device();
1055f78ce678SMark Adams 
1056f78ce678SMark Adams   PetscCall(VecGetKokkosViewWrite(x, &xv));
10579371c9d4SSatish Balay   Kokkos::parallel_for(
10589371c9d4SSatish Balay     n, KOKKOS_LAMBDA(const PetscInt i) {
1059f78ce678SMark Adams       if (Adiag(i) < Ai(i + 1)) xv(i) = Aa(Adiag(i));
1060f78ce678SMark Adams       else xv(i) = 0;
1061f78ce678SMark Adams     });
1062f78ce678SMark Adams   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
10633ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1064f78ce678SMark Adams }
1065f78ce678SMark Adams 
1066db78de30SJunchao Zhang /* Get a Kokkos View from a mat of type MatSeqAIJKokkos */
1067d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1068d71ae5a4SJacob Faibussowitsch {
1069db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1070db78de30SJunchao Zhang 
1071db78de30SJunchao Zhang   PetscFunctionBegin;
1072db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1073db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1074db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10759566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1076db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1077076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
10783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1079db78de30SJunchao Zhang }
1080db78de30SJunchao Zhang 
1081d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1082d71ae5a4SJacob Faibussowitsch {
1083db78de30SJunchao Zhang   PetscFunctionBegin;
1084db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1085db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1086db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10873ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1088db78de30SJunchao Zhang }
1089db78de30SJunchao Zhang 
1090d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, MatScalarKokkosView *kv)
1091d71ae5a4SJacob Faibussowitsch {
1092db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1093db78de30SJunchao Zhang 
1094db78de30SJunchao Zhang   PetscFunctionBegin;
1095db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1096db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1097db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10989566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1099db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1100076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
11013ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1102db78de30SJunchao Zhang }
1103db78de30SJunchao Zhang 
1104d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, MatScalarKokkosView *kv)
1105d71ae5a4SJacob Faibussowitsch {
1106db78de30SJunchao Zhang   PetscFunctionBegin;
1107db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1108db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1109db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11109566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
11113ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1112db78de30SJunchao Zhang }
1113db78de30SJunchao Zhang 
1114d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1115d71ae5a4SJacob Faibussowitsch {
1116db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1117db78de30SJunchao Zhang 
1118db78de30SJunchao Zhang   PetscFunctionBegin;
1119db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1120db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1121db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1122db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1123076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
11243ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1125db78de30SJunchao Zhang }
1126db78de30SJunchao Zhang 
1127d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1128d71ae5a4SJacob Faibussowitsch {
1129db78de30SJunchao Zhang   PetscFunctionBegin;
1130db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1131db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1132db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11339566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
11343ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1135db78de30SJunchao Zhang }
1136db78de30SJunchao Zhang 
1137c17cf699SJunchao Zhang /* Computes Y += alpha X */
1138d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAXPY_SeqAIJKokkos(Mat Y, PetscScalar alpha, Mat X, MatStructure pattern)
1139d71ae5a4SJacob Faibussowitsch {
1140a587d139SMark   Mat_SeqAIJ              *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
1141c17cf699SJunchao Zhang   Mat_SeqAIJKokkos        *xkok, *ykok, *zkok;
1142c17cf699SJunchao Zhang   ConstMatScalarKokkosView Xa;
1143c17cf699SJunchao Zhang   MatScalarKokkosView      Ya;
1144a587d139SMark 
1145a587d139SMark   PetscFunctionBegin;
1146c17cf699SJunchao Zhang   PetscCheckTypeName(Y, MATSEQAIJKOKKOS);
1147c17cf699SJunchao Zhang   PetscCheckTypeName(X, MATSEQAIJKOKKOS);
11489566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(Y));
11499566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(X));
11509566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1151db78de30SJunchao Zhang 
1152c17cf699SJunchao Zhang   if (pattern != SAME_NONZERO_PATTERN && x->nz == y->nz) {
1153a587d139SMark     PetscBool e;
11549566063dSJacob Faibussowitsch     PetscCall(PetscArraycmp(x->i, y->i, Y->rmap->n + 1, &e));
1155a587d139SMark     if (e) {
11569566063dSJacob Faibussowitsch       PetscCall(PetscArraycmp(x->j, y->j, y->nz, &e));
1157c17cf699SJunchao Zhang       if (e) pattern = SAME_NONZERO_PATTERN;
1158a587d139SMark     }
1159a587d139SMark   }
1160db78de30SJunchao Zhang 
1161c17cf699SJunchao Zhang   /* cusparseDcsrgeam2() computes C = alpha A + beta B. If one knew sparsity pattern of C, one can skip
1162c17cf699SJunchao Zhang     cusparseScsrgeam2_bufferSizeExt() / cusparseXcsrgeam2Nnz(), and directly call cusparseScsrgeam2().
1163c17cf699SJunchao Zhang     If X is SUBSET_NONZERO_PATTERN of Y, we could take advantage of this cusparse feature. However,
1164c17cf699SJunchao Zhang     KokkosSparse::spadd(alpha,A,beta,B,C) has symbolic and numeric phases, MatAXPY does not.
1165c17cf699SJunchao Zhang   */
1166c17cf699SJunchao Zhang   ykok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1167c17cf699SJunchao Zhang   xkok = static_cast<Mat_SeqAIJKokkos *>(X->spptr);
1168c17cf699SJunchao Zhang   Xa   = xkok->a_dual.view_device();
1169c17cf699SJunchao Zhang   Ya   = ykok->a_dual.view_device();
1170c17cf699SJunchao Zhang 
1171c17cf699SJunchao Zhang   if (pattern == SAME_NONZERO_PATTERN) {
1172c17cf699SJunchao Zhang     KokkosBlas::axpy(alpha, Xa, Ya);
11739566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1174c17cf699SJunchao Zhang   } else if (pattern == SUBSET_NONZERO_PATTERN) {
1175c17cf699SJunchao Zhang     MatRowMapKokkosView Xi = xkok->i_dual.view_device(), Yi = ykok->i_dual.view_device();
1176c17cf699SJunchao Zhang     MatColIdxKokkosView Xj = xkok->j_dual.view_device(), Yj = ykok->j_dual.view_device();
1177c17cf699SJunchao Zhang 
11789371c9d4SSatish Balay     Kokkos::parallel_for(
11799371c9d4SSatish Balay       Kokkos::TeamPolicy<>(Y->rmap->n, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1180*0e3ece09SJunchao Zhang         PetscInt i = t.league_rank(); // row i
1181*0e3ece09SJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() {
1182*0e3ece09SJunchao Zhang           // Only one thread works in a team
1183c17cf699SJunchao Zhang           PetscInt p, q = Yi(i);
1184*0e3ece09SJunchao Zhang           for (p = Xi(i); p < Xi(i + 1); p++) {          // For each nonzero on row i of X,
1185*0e3ece09SJunchao Zhang             while (Xj(p) != Yj(q) && q < Yi(i + 1)) q++; // find the matching nonzero on row i of Y.
1186*0e3ece09SJunchao Zhang             if (Xj(p) == Yj(q)) {                        // Found it
1187c17cf699SJunchao Zhang               Ya(q) += alpha * Xa(p);
1188c17cf699SJunchao Zhang               q++;
1189a587d139SMark             } else {
1190*0e3ece09SJunchao Zhang             // If not found, it indicates the input is wrong (X is not a SUBSET_NONZERO_PATTERN of Y).
1191*0e3ece09SJunchao Zhang             // Just insert a NaN at the beginning of row i if it is not empty, to make the result wrong.
1192*0e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_VERSION_GE(3, 7, 0)
1193*0e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::ArithTraits<PetscScalar>::nan();
11948b8b16f9SJunchao Zhang #else
1195*0e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::Experimental::nan("1");
11968b8b16f9SJunchao Zhang #endif
1197a587d139SMark             }
1198c17cf699SJunchao Zhang           }
1199c17cf699SJunchao Zhang         });
1200c17cf699SJunchao Zhang       });
12019566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1202*0e3ece09SJunchao Zhang   } else { // different nonzero patterns
1203c17cf699SJunchao Zhang     Mat             Z;
1204c17cf699SJunchao Zhang     KokkosCsrMatrix zcsr;
1205c17cf699SJunchao Zhang     KernelHandle    kh;
1206*0e3ece09SJunchao Zhang     kh.create_spadd_handle(true); // X, Y are sorted
1207c17cf699SJunchao Zhang     KokkosSparse::spadd_symbolic(&kh, xkok->csrmat, ykok->csrmat, zcsr);
1208c17cf699SJunchao Zhang     KokkosSparse::spadd_numeric(&kh, alpha, xkok->csrmat, (PetscScalar)1.0, ykok->csrmat, zcsr);
1209c17cf699SJunchao Zhang     zkok = new Mat_SeqAIJKokkos(zcsr);
12109566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, zkok, &Z));
12119566063dSJacob Faibussowitsch     PetscCall(MatHeaderReplace(Y, &Z));
1212c17cf699SJunchao Zhang     kh.destroy_spadd_handle();
1213c17cf699SJunchao Zhang   }
12149566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
1215*0e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(xkok->a_dual.extent(0) * 2)); // Because we scaled X and then added it to Y
12163ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1217a587d139SMark }
1218a587d139SMark 
1219d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_SeqAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1220d71ae5a4SJacob Faibussowitsch {
122142550becSJunchao Zhang   Mat_SeqAIJKokkos *akok;
122242550becSJunchao Zhang   Mat_SeqAIJ       *aseq;
122342550becSJunchao Zhang 
122442550becSJunchao Zhang   PetscFunctionBegin;
12259566063dSJacob Faibussowitsch   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, coo_i, coo_j));
1226394ed5ebSJunchao Zhang   aseq = static_cast<Mat_SeqAIJ *>(mat->data);
122742550becSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(mat->spptr);
1228cbc6b225SStefano Zampini   delete akok;
1229cbc6b225SStefano Zampini   mat->spptr = akok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, aseq->nz, aseq->i, aseq->j, aseq->a, mat->nonzerostate + 1, PETSC_FALSE);
12309566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries_SeqAIJKokkos(mat));
1231394ed5ebSJunchao Zhang   akok->SetUpCOO(aseq);
12323ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
123342550becSJunchao Zhang }
123442550becSJunchao Zhang 
1235d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_SeqAIJKokkos(Mat A, const PetscScalar v[], InsertMode imode)
1236d71ae5a4SJacob Faibussowitsch {
123742550becSJunchao Zhang   Mat_SeqAIJ                 *aseq = static_cast<Mat_SeqAIJ *>(A->data);
123842550becSJunchao Zhang   Mat_SeqAIJKokkos           *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1239394ed5ebSJunchao Zhang   PetscCount                  Annz = aseq->nz;
1240394ed5ebSJunchao Zhang   const PetscCountKokkosView &jmap = akok->jmap_d;
1241394ed5ebSJunchao Zhang   const PetscCountKokkosView &perm = akok->perm_d;
124242550becSJunchao Zhang   MatScalarKokkosView         Aa;
124342550becSJunchao Zhang   ConstMatScalarKokkosView    kv;
124442550becSJunchao Zhang   PetscMemType                memtype;
124542550becSJunchao Zhang 
124642550becSJunchao Zhang   PetscFunctionBegin;
12479566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype));
124842550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
1249394ed5ebSJunchao Zhang     kv = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatScalarKokkosViewHost(v, aseq->coo_n));
125042550becSJunchao Zhang   } else {
1251394ed5ebSJunchao Zhang     kv = ConstMatScalarKokkosView(v, aseq->coo_n); /* Directly use v[]'s memory */
125242550becSJunchao Zhang   }
125342550becSJunchao Zhang 
1254c7b718f4SJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1255c7b718f4SJunchao Zhang   else PetscCall(MatSeqAIJGetKokkosView(A, &Aa));                             /* read & write matrix values */
125642550becSJunchao Zhang 
12579371c9d4SSatish Balay   Kokkos::parallel_for(
12589371c9d4SSatish Balay     Annz, KOKKOS_LAMBDA(const PetscCount i) {
1259c7b718f4SJunchao Zhang       PetscScalar sum = 0.0;
1260c7b718f4SJunchao Zhang       for (PetscCount k = jmap(i); k < jmap(i + 1); k++) sum += kv(perm(k));
1261c7b718f4SJunchao Zhang       Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1262c7b718f4SJunchao Zhang     });
1263394ed5ebSJunchao Zhang 
12649566063dSJacob Faibussowitsch   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
12659566063dSJacob Faibussowitsch   else PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
12663ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
126742550becSJunchao Zhang }
126842550becSJunchao Zhang 
1269d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJMoveDiagonalValuesFront_SeqAIJKokkos(Mat A, const PetscInt *diag)
1270d71ae5a4SJacob Faibussowitsch {
12715fbaff96SJunchao Zhang   Mat_SeqAIJKokkos          *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
12725fbaff96SJunchao Zhang   MatScalarKokkosView        Aa;
12735fbaff96SJunchao Zhang   const MatRowMapKokkosView &Ai = akok->i_dual.view_device();
12745fbaff96SJunchao Zhang   PetscInt                   m  = A->rmap->n;
12755fbaff96SJunchao Zhang   ConstMatRowMapKokkosView   Adiag(diag, m); /* diag is a device pointer */
12765fbaff96SJunchao Zhang 
12775fbaff96SJunchao Zhang   PetscFunctionBegin;
12785fbaff96SJunchao Zhang   PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa));
12799371c9d4SSatish Balay   Kokkos::parallel_for(
12809371c9d4SSatish Balay     m, KOKKOS_LAMBDA(const PetscInt i) {
12815fbaff96SJunchao Zhang       PetscScalar tmp;
12825fbaff96SJunchao Zhang       if (Adiag(i) >= Ai(i) && Adiag(i) < Ai(i + 1)) { /* The diagonal element exists */
12835fbaff96SJunchao Zhang         tmp          = Aa(Ai(i));
12845fbaff96SJunchao Zhang         Aa(Ai(i))    = Aa(Adiag(i));
12855fbaff96SJunchao Zhang         Aa(Adiag(i)) = tmp;
12865fbaff96SJunchao Zhang       }
12875fbaff96SJunchao Zhang     });
12885fbaff96SJunchao Zhang   PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
12893ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
12905fbaff96SJunchao Zhang }
12915fbaff96SJunchao Zhang 
1292d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1293d71ae5a4SJacob Faibussowitsch {
12948f7e8f9dSMark Adams   PetscFunctionBegin;
12959566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncHost(A));
12969566063dSJacob Faibussowitsch   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
12978f7e8f9dSMark Adams   B->offloadmask = PETSC_OFFLOAD_CPU;
12983ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
12998f7e8f9dSMark Adams }
13008f7e8f9dSMark Adams 
1301d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat A)
1302d71ae5a4SJacob Faibussowitsch {
1303076ba34aSJunchao Zhang   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
1304076ba34aSJunchao Zhang 
13058c3ff71bSJunchao Zhang   PetscFunctionBegin;
1306076ba34aSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_KOKKOS; /* We do not really use this flag */
13076f3d89d0SStefano Zampini   A->boundtocpu  = PETSC_FALSE;
13086f3d89d0SStefano Zampini 
13098c3ff71bSJunchao Zhang   A->ops->assemblyend               = MatAssemblyEnd_SeqAIJKokkos;
13108c3ff71bSJunchao Zhang   A->ops->destroy                   = MatDestroy_SeqAIJKokkos;
13118c3ff71bSJunchao Zhang   A->ops->duplicate                 = MatDuplicate_SeqAIJKokkos;
1312a587d139SMark   A->ops->axpy                      = MatAXPY_SeqAIJKokkos;
1313f0cf5187SStefano Zampini   A->ops->scale                     = MatScale_SeqAIJKokkos;
1314a587d139SMark   A->ops->zeroentries               = MatZeroEntries_SeqAIJKokkos;
1315076ba34aSJunchao Zhang   A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJKokkos;
13168c3ff71bSJunchao Zhang   A->ops->mult                      = MatMult_SeqAIJKokkos;
13178c3ff71bSJunchao Zhang   A->ops->multadd                   = MatMultAdd_SeqAIJKokkos;
13188c3ff71bSJunchao Zhang   A->ops->multtranspose             = MatMultTranspose_SeqAIJKokkos;
13198c3ff71bSJunchao Zhang   A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJKokkos;
13208c3ff71bSJunchao Zhang   A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJKokkos;
13218c3ff71bSJunchao Zhang   A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJKokkos;
1322076ba34aSJunchao Zhang   A->ops->productnumeric            = MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos;
13230ecb592aSJunchao Zhang   A->ops->transpose                 = MatTranspose_SeqAIJKokkos;
1324152b3e56SJunchao Zhang   A->ops->setoption                 = MatSetOption_SeqAIJKokkos;
1325f78ce678SMark Adams   A->ops->getdiagonal               = MatGetDiagonal_SeqAIJKokkos;
1326076ba34aSJunchao Zhang   a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJKokkos;
1327076ba34aSJunchao Zhang   a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJKokkos;
1328076ba34aSJunchao Zhang   a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJKokkos;
1329076ba34aSJunchao Zhang   a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJKokkos;
1330076ba34aSJunchao Zhang   a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJKokkos;
1331076ba34aSJunchao Zhang   a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJKokkos;
13327ee59b9bSJunchao Zhang   a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJKokkos;
133342550becSJunchao Zhang 
13349566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJKokkos));
13359566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJKokkos));
13363ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1337076ba34aSJunchao Zhang }
1338076ba34aSJunchao Zhang 
1339d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSetSeqAIJKokkosWithCSRMatrix(Mat A, Mat_SeqAIJKokkos *akok)
1340d71ae5a4SJacob Faibussowitsch {
1341076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
1342076ba34aSJunchao Zhang   PetscInt    i, m, n;
1343076ba34aSJunchao Zhang 
1344076ba34aSJunchao Zhang   PetscFunctionBegin;
13455f80ce2aSJacob Faibussowitsch   PetscCheck(!A->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A->spptr is supposed to be empty");
1346076ba34aSJunchao Zhang 
1347076ba34aSJunchao Zhang   m = akok->nrows();
1348076ba34aSJunchao Zhang   n = akok->ncols();
13499566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(A, m, n, m, n));
13509566063dSJacob Faibussowitsch   PetscCall(MatSetType(A, MATSEQAIJKOKKOS));
1351076ba34aSJunchao Zhang 
1352076ba34aSJunchao Zhang   /* Set up data structures of A as a MATSEQAIJ */
13539566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(A, MAT_SKIP_ALLOCATION, NULL));
1354076ba34aSJunchao Zhang   aseq = (Mat_SeqAIJ *)(A)->data;
1355076ba34aSJunchao Zhang 
1356076ba34aSJunchao Zhang   akok->i_dual.sync_host(); /* We always need sync'ed i, j on host */
1357076ba34aSJunchao Zhang   akok->j_dual.sync_host();
1358076ba34aSJunchao Zhang 
1359076ba34aSJunchao Zhang   aseq->i            = akok->i_host_data();
1360076ba34aSJunchao Zhang   aseq->j            = akok->j_host_data();
1361076ba34aSJunchao Zhang   aseq->a            = akok->a_host_data();
1362076ba34aSJunchao Zhang   aseq->nonew        = -1; /*this indicates that inserting a new value in the matrix that generates a new nonzero is an error*/
1363076ba34aSJunchao Zhang   aseq->singlemalloc = PETSC_FALSE;
1364076ba34aSJunchao Zhang   aseq->free_a       = PETSC_FALSE;
1365076ba34aSJunchao Zhang   aseq->free_ij      = PETSC_FALSE;
1366076ba34aSJunchao Zhang   aseq->nz           = akok->nnz();
1367076ba34aSJunchao Zhang   aseq->maxnz        = aseq->nz;
1368076ba34aSJunchao Zhang 
13699566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->imax));
13709566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->ilen));
1371ad540459SPierre Jolivet   for (i = 0; i < m; i++) aseq->ilen[i] = aseq->imax[i] = aseq->i[i + 1] - aseq->i[i];
1372076ba34aSJunchao Zhang 
1373076ba34aSJunchao Zhang   /* It is critical to set the nonzerostate, as we use it to check if sparsity pattern (hence data) has changed on host in MatAssemblyEnd */
1374076ba34aSJunchao Zhang   akok->nonzerostate = A->nonzerostate;
1375ff751488SJunchao Zhang   A->spptr           = akok; /* Set A->spptr before MatAssembly so that A->spptr won't be allocated again there */
13769566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
13779566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
13783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1379076ba34aSJunchao Zhang }
1380076ba34aSJunchao Zhang 
1381*0e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetKokkosCsrMatrix(Mat A, KokkosCsrMatrix *csr)
1382*0e3ece09SJunchao Zhang {
1383*0e3ece09SJunchao Zhang   PetscFunctionBegin;
1384*0e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1385*0e3ece09SJunchao Zhang   *csr = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
1386*0e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1387*0e3ece09SJunchao Zhang }
1388*0e3ece09SJunchao Zhang 
1389*0e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithKokkosCsrMatrix(MPI_Comm comm, KokkosCsrMatrix csr, Mat *A)
1390*0e3ece09SJunchao Zhang {
1391*0e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok;
1392*0e3ece09SJunchao Zhang   PetscFunctionBegin;
1393*0e3ece09SJunchao Zhang   PetscCallCXX(akok = new Mat_SeqAIJKokkos(csr));
1394*0e3ece09SJunchao Zhang   PetscCall(MatCreate(comm, A));
1395*0e3ece09SJunchao Zhang   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
1396*0e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1397*0e3ece09SJunchao Zhang }
1398*0e3ece09SJunchao Zhang 
1399076ba34aSJunchao Zhang /* Crete a SEQAIJKOKKOS matrix with a Mat_SeqAIJKokkos data structure
1400076ba34aSJunchao Zhang 
1401076ba34aSJunchao Zhang    Note we have names like MatSeqAIJSetPreallocationCSR, so I use capitalized CSR
1402076ba34aSJunchao Zhang  */
1403d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithCSRMatrix(MPI_Comm comm, Mat_SeqAIJKokkos *akok, Mat *A)
1404d71ae5a4SJacob Faibussowitsch {
1405076ba34aSJunchao Zhang   PetscFunctionBegin;
14069566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
14079566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
14083ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
14098c3ff71bSJunchao Zhang }
14108c3ff71bSJunchao Zhang 
1411152b3e56SJunchao Zhang /*@C
141211a5261eSBarry Smith    MatCreateSeqAIJKokkos - Creates a sparse matrix in `MATSEQAIJKOKKOS` (compressed row) format
14138c3ff71bSJunchao Zhang    (the default parallel PETSc format). This matrix will ultimately be handled by
14148c3ff71bSJunchao Zhang    Kokkos for calculations. For good matrix
14158c3ff71bSJunchao Zhang    assembly performance the user should preallocate the matrix storage by setting
14168c3ff71bSJunchao Zhang    the parameter nz (or the array nnz).  By setting these parameters accurately,
14178c3ff71bSJunchao Zhang    performance during matrix assembly can be increased by more than a factor of 50.
14188c3ff71bSJunchao Zhang 
14198c3ff71bSJunchao Zhang    Collective
14208c3ff71bSJunchao Zhang 
14218c3ff71bSJunchao Zhang    Input Parameters:
142211a5261eSBarry Smith +  comm - MPI communicator, set to `PETSC_COMM_SELF`
14238c3ff71bSJunchao Zhang .  m - number of rows
14248c3ff71bSJunchao Zhang .  n - number of columns
14258c3ff71bSJunchao Zhang .  nz - number of nonzeros per row (same for all rows)
14268c3ff71bSJunchao Zhang -  nnz - array containing the number of nonzeros in the various rows
14272ef1f0ffSBarry Smith          (possibly different for each row) or `NULL`
14288c3ff71bSJunchao Zhang 
14298c3ff71bSJunchao Zhang    Output Parameter:
14308c3ff71bSJunchao Zhang .  A - the matrix
14318c3ff71bSJunchao Zhang 
14322ef1f0ffSBarry Smith    Level: intermediate
14332ef1f0ffSBarry Smith 
14342ef1f0ffSBarry Smith    Notes:
143511a5261eSBarry Smith    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
14368c3ff71bSJunchao Zhang    MatXXXXSetPreallocation() paradgm instead of this routine directly.
143711a5261eSBarry Smith    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
14388c3ff71bSJunchao Zhang 
14392ef1f0ffSBarry Smith    If `nnz` is given then `nz` is ignored
14408c3ff71bSJunchao Zhang 
144111a5261eSBarry Smith    The AIJ format, also called
14422ef1f0ffSBarry Smith    compressed row storage, is fully compatible with standard Fortran
14438c3ff71bSJunchao Zhang    storage.  That is, the stored row and column indices can begin at
14448c3ff71bSJunchao Zhang    either one (as in Fortran) or zero.  See the users' manual for details.
14458c3ff71bSJunchao Zhang 
14462ef1f0ffSBarry Smith    Specify the preallocated storage with either `nz` or `nnz` (not both).
14472ef1f0ffSBarry Smith    Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
14482ef1f0ffSBarry Smith    allocation.
14498c3ff71bSJunchao Zhang 
14508c3ff71bSJunchao Zhang    By default, this format uses inodes (identical nodes) when possible, to
14518c3ff71bSJunchao Zhang    improve numerical efficiency of matrix-vector products and solves. We
14528c3ff71bSJunchao Zhang    search for consecutive rows with the same nonzero structure, thereby
14538c3ff71bSJunchao Zhang    reusing matrix information to achieve increased efficiency.
14548c3ff71bSJunchao Zhang 
14552ef1f0ffSBarry Smith .seealso: [](chapter_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`
14568c3ff71bSJunchao Zhang @*/
1457d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateSeqAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
1458d71ae5a4SJacob Faibussowitsch {
14598c3ff71bSJunchao Zhang   PetscFunctionBegin;
14609566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
14619566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
14629566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, m, n));
14639566063dSJacob Faibussowitsch   PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
14649566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
14653ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
14668c3ff71bSJunchao Zhang }
1467930e68a5SMark Adams 
14688f7e8f9dSMark Adams typedef Kokkos::TeamPolicy<>::member_type team_member;
14698f7e8f9dSMark Adams //
147046804e07SMark Adams // This factorization exploits block diagonal matrices with "Nf" (not used).
14718f7e8f9dSMark Adams // Use -pc_factor_mat_ordering_type rcm to order decouple blocks of size N/Nf for this optimization
14728f7e8f9dSMark Adams //
1473d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorNumeric_SeqAIJKOKKOSDEVICE(Mat B, Mat A, const MatFactorInfo *info)
1474d71ae5a4SJacob Faibussowitsch {
14758f7e8f9dSMark Adams   Mat_SeqAIJ       *b      = (Mat_SeqAIJ *)B->data;
14768f7e8f9dSMark Adams   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *baijkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
14778f7e8f9dSMark Adams   IS                isrow = b->row, isicol = b->icol;
14788f7e8f9dSMark Adams   const PetscInt   *r_h, *ic_h;
1479300d22a6SJunchao Zhang   const PetscInt n = A->rmap->n, *ai_d = aijkok->i_dual.view_device().data(), *aj_d = aijkok->j_dual.view_device().data(), *bi_d = baijkok->i_dual.view_device().data(), *bj_d = baijkok->j_dual.view_device().data(), *bdiag_d = baijkok->diag_d.data();
1480076ba34aSJunchao Zhang   const PetscScalar *aa_d = aijkok->a_dual.view_device().data();
1481076ba34aSJunchao Zhang   PetscScalar       *ba_d = baijkok->a_dual.view_device().data();
14828f7e8f9dSMark Adams   PetscBool          row_identity, col_identity;
148346804e07SMark Adams   PetscInt           nc, Nf = 1, nVec = 32; // should be a parameter, Nf is batch size - not used
1484930e68a5SMark Adams 
1485930e68a5SMark Adams   PetscFunctionBegin;
14862c71b3e2SJacob Faibussowitsch   PetscCheck(A->rmap->n == n, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "square matrices only supported %" PetscInt_FMT " %" PetscInt_FMT, A->rmap->n, n);
1487b94d7dedSBarry Smith   PetscCall(MatIsStructurallySymmetric(A, &row_identity));
14882c71b3e2SJacob Faibussowitsch   PetscCheck(row_identity, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "structurally symmetric matrices only supported");
14899566063dSJacob Faibussowitsch   PetscCall(ISGetIndices(isrow, &r_h));
14909566063dSJacob Faibussowitsch   PetscCall(ISGetIndices(isicol, &ic_h));
14919566063dSJacob Faibussowitsch   PetscCall(ISGetSize(isicol, &nc));
14929566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
14939566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
14948f7e8f9dSMark Adams   {
14958f7e8f9dSMark Adams #define KOKKOS_SHARED_LEVEL 1
14968f7e8f9dSMark Adams     using scr_mem_t    = Kokkos::DefaultExecutionSpace::scratch_memory_space;
14978f7e8f9dSMark Adams     using sizet_scr_t  = Kokkos::View<size_t, scr_mem_t>;
14988f7e8f9dSMark Adams     using scalar_scr_t = Kokkos::View<PetscScalar, scr_mem_t>;
14998f7e8f9dSMark Adams     const Kokkos::View<const PetscInt *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_r_k(r_h, n);
15008f7e8f9dSMark Adams     Kokkos::View<PetscInt *, Kokkos::LayoutLeft>                                                                         d_r_k("r", n);
15018f7e8f9dSMark Adams     const Kokkos::View<const PetscInt *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_ic_k(ic_h, nc);
15028f7e8f9dSMark Adams     Kokkos::View<PetscInt *, Kokkos::LayoutLeft>                                                                         d_ic_k("ic", nc);
15038f7e8f9dSMark Adams     size_t                                                                                                               flops_h = 0.0;
15048f7e8f9dSMark Adams     Kokkos::View<size_t, Kokkos::HostSpace>                                                                              h_flops_k(&flops_h);
15058f7e8f9dSMark Adams     Kokkos::View<size_t>                                                                                                 d_flops_k("flops");
15068f7e8f9dSMark Adams     const int                                                                                                            conc = Kokkos::DefaultExecutionSpace().concurrency(), team_size = conc > 1 ? 16 : 1; // 8*32 = 256
1507da81f932SPierre Jolivet     const int                                                                                                            nloc = n / Nf, Ni = (conc > 8) ? 1 /* some intelligent number of SMs -- but need league_barrier */ : 1;
15088f7e8f9dSMark Adams     Kokkos::deep_copy(d_flops_k, h_flops_k);
15098f7e8f9dSMark Adams     Kokkos::deep_copy(d_r_k, h_r_k);
15108f7e8f9dSMark Adams     Kokkos::deep_copy(d_ic_k, h_ic_k);
15118f7e8f9dSMark Adams     // Fill A --> fact
15129371c9d4SSatish Balay     Kokkos::parallel_for(
15139371c9d4SSatish Balay       Kokkos::TeamPolicy<>(Nf * Ni, team_size, nVec), KOKKOS_LAMBDA(const team_member team) {
1514042217e8SBarry Smith         const PetscInt  field = team.league_rank() / Ni, field_block = team.league_rank() % Ni; // use grid.x/y in CUDA
15158f7e8f9dSMark Adams         const PetscInt  nloc_i = (nloc / Ni + !!(nloc % Ni)), start_i = field * nloc + field_block * nloc_i, end_i = (start_i + nloc_i) > (field + 1) * nloc ? (field + 1) * nloc : (start_i + nloc_i);
15168f7e8f9dSMark Adams         const PetscInt *ic = d_ic_k.data(), *r = d_r_k.data();
15178f7e8f9dSMark Adams         // zero rows of B
15188f7e8f9dSMark Adams         Kokkos::parallel_for(Kokkos::TeamVectorRange(team, start_i, end_i), [=](const int &rowb) {
15198f7e8f9dSMark Adams           PetscInt     nzbL = bi_d[rowb + 1] - bi_d[rowb], nzbU = bdiag_d[rowb] - bdiag_d[rowb + 1]; // with diag
15208f7e8f9dSMark Adams           PetscScalar *baL = ba_d + bi_d[rowb];
15218f7e8f9dSMark Adams           PetscScalar *baU = ba_d + bdiag_d[rowb + 1] + 1;
15228f7e8f9dSMark Adams           /* zero (unfactored row) */
15238f7e8f9dSMark Adams           for (int j = 0; j < nzbL; j++) baL[j] = 0;
15248f7e8f9dSMark Adams           for (int j = 0; j < nzbU; j++) baU[j] = 0;
15258f7e8f9dSMark Adams         });
15268f7e8f9dSMark Adams         // copy A into B
15278f7e8f9dSMark Adams         Kokkos::parallel_for(Kokkos::TeamVectorRange(team, start_i, end_i), [=](const int &rowb) {
15288f7e8f9dSMark Adams           PetscInt           rowa = r[rowb], nza = ai_d[rowa + 1] - ai_d[rowa];
15298f7e8f9dSMark Adams           const PetscScalar *av    = aa_d + ai_d[rowa];
15308f7e8f9dSMark Adams           const PetscInt    *ajtmp = aj_d + ai_d[rowa];
15318f7e8f9dSMark Adams           /* load in initial (unfactored row) */
15328f7e8f9dSMark Adams           for (int j = 0; j < nza; j++) {
15338f7e8f9dSMark Adams             PetscInt    colb = ic[ajtmp[j]];
15348f7e8f9dSMark Adams             PetscScalar vala = av[j];
15358f7e8f9dSMark Adams             if (colb == rowb) {
15368f7e8f9dSMark Adams               *(ba_d + bdiag_d[rowb]) = vala;
15378f7e8f9dSMark Adams             } else {
15388f7e8f9dSMark Adams               const PetscInt *pbj = bj_d + ((colb > rowb) ? bdiag_d[rowb + 1] + 1 : bi_d[rowb]);
15398f7e8f9dSMark Adams               PetscScalar    *pba = ba_d + ((colb > rowb) ? bdiag_d[rowb + 1] + 1 : bi_d[rowb]);
15408f7e8f9dSMark Adams               PetscInt        nz = (colb > rowb) ? bdiag_d[rowb] - (bdiag_d[rowb + 1] + 1) : bi_d[rowb + 1] - bi_d[rowb], set = 0;
15418f7e8f9dSMark Adams               for (int j = 0; j < nz; j++) {
15428f7e8f9dSMark Adams                 if (pbj[j] == colb) {
15438f7e8f9dSMark Adams                   pba[j] = vala;
15448f7e8f9dSMark Adams                   set++;
15458f7e8f9dSMark Adams                   break;
15468f7e8f9dSMark Adams                 }
15478f7e8f9dSMark Adams               }
15488f1da0b2SJunchao Zhang #if !defined(PETSC_HAVE_SYCL)
15498f7e8f9dSMark Adams               if (set != 1) printf("\t\t\t ERROR DID NOT SET ?????\n");
15508f1da0b2SJunchao Zhang #endif
15518f7e8f9dSMark Adams             }
15528f7e8f9dSMark Adams           }
15538f7e8f9dSMark Adams         });
15548f7e8f9dSMark Adams       });
15558f7e8f9dSMark Adams     Kokkos::fence();
1556930e68a5SMark Adams 
15579371c9d4SSatish Balay     Kokkos::parallel_for(
15589371c9d4SSatish Balay       Kokkos::TeamPolicy<>(Nf * Ni, team_size, nVec).set_scratch_size(KOKKOS_SHARED_LEVEL, Kokkos::PerThread(sizet_scr_t::shmem_size() + scalar_scr_t::shmem_size()), Kokkos::PerTeam(sizet_scr_t::shmem_size())), KOKKOS_LAMBDA(const team_member team) {
15598f7e8f9dSMark Adams         sizet_scr_t    colkIdx(team.thread_scratch(KOKKOS_SHARED_LEVEL));
15608f7e8f9dSMark Adams         scalar_scr_t   L_ki(team.thread_scratch(KOKKOS_SHARED_LEVEL));
15618f7e8f9dSMark Adams         sizet_scr_t    flops(team.team_scratch(KOKKOS_SHARED_LEVEL));
1562042217e8SBarry Smith         const PetscInt field = team.league_rank() / Ni, field_block_idx = team.league_rank() % Ni; // use grid.x/y in CUDA
15638f7e8f9dSMark Adams         const PetscInt start = field * nloc, end = start + nloc;
15648f7e8f9dSMark Adams         Kokkos::single(Kokkos::PerTeam(team), [=]() { flops() = 0; });
15658f7e8f9dSMark Adams         // A22 panel update for each row A(1,:) and col A(:,1)
15668f7e8f9dSMark Adams         for (int ii = start; ii < end - 1; ii++) {
15678f7e8f9dSMark Adams           const PetscInt    *bjUi = bj_d + bdiag_d[ii + 1] + 1, nzUi = bdiag_d[ii] - (bdiag_d[ii + 1] + 1); // vector, and vector size, of column indices of U(i,(i+1):end)
15688f7e8f9dSMark Adams           const PetscScalar *baUi    = ba_d + bdiag_d[ii + 1] + 1;                                          // vector of data  U(i,i+1:end)
15698f7e8f9dSMark Adams           const PetscInt     nUi_its = nzUi / Ni + !!(nzUi % Ni);
15708f7e8f9dSMark Adams           const PetscScalar  Bii     = *(ba_d + bdiag_d[ii]); // diagonal in its special place
15718f7e8f9dSMark Adams           Kokkos::parallel_for(Kokkos::TeamThreadRange(team, nUi_its), [=](const int j) {
15728f7e8f9dSMark Adams             PetscInt kIdx = j * Ni + field_block_idx;
15739371c9d4SSatish Balay             if (kIdx >= nzUi) /* void */
15749371c9d4SSatish Balay               ;
15758f7e8f9dSMark Adams             else {
15768f7e8f9dSMark Adams               const PetscInt  myk = bjUi[kIdx];                // assume symmetric structure, need a transposed meta-data here in general
15778f7e8f9dSMark Adams               const PetscInt *pjL = bj_d + bi_d[myk];          // look for L(myk,ii) in start of row
15788f7e8f9dSMark Adams               const PetscInt  nzL = bi_d[myk + 1] - bi_d[myk]; // size of L_k(:)
15798f7e8f9dSMark Adams               size_t          st_idx;
15808f7e8f9dSMark Adams               // find and do L(k,i) = A(:k,i) / A(i,i)
15818f7e8f9dSMark Adams               Kokkos::single(Kokkos::PerThread(team), [&]() { colkIdx() = PETSC_MAX_INT; });
15828f7e8f9dSMark Adams               // get column, there has got to be a better way
15839371c9d4SSatish Balay               Kokkos::parallel_reduce(
15849371c9d4SSatish Balay                 Kokkos::ThreadVectorRange(team, nzL),
15859371c9d4SSatish Balay                 [&](const int &j, size_t &idx) {
15868f7e8f9dSMark Adams                   if (pjL[j] == ii) {
15878f7e8f9dSMark Adams                     PetscScalar *pLki = ba_d + bi_d[myk] + j;
15888f7e8f9dSMark Adams                     idx               = j;           // output
15898f7e8f9dSMark Adams                     *pLki             = *pLki / Bii; // column scaling:  L(k,i) = A(:k,i) / A(i,i)
15908f7e8f9dSMark Adams                   }
15919371c9d4SSatish Balay                 },
15929371c9d4SSatish Balay                 st_idx);
15939371c9d4SSatish Balay               Kokkos::single(Kokkos::PerThread(team), [=]() {
15949371c9d4SSatish Balay                 colkIdx() = st_idx;
15959371c9d4SSatish Balay                 L_ki()    = *(ba_d + bi_d[myk] + st_idx);
15969371c9d4SSatish Balay               });
15978f1da0b2SJunchao Zhang #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
159899551766SMark Adams               if (colkIdx() == PETSC_MAX_INT) printf("\t\t\t\t\t\t\tERROR: failed to find L_ki(%d,%d)\n", (int)myk, ii); // uses a register
159999551766SMark Adams #endif
160099551766SMark Adams               // active row k, do  A_kj -= Lki * U_ij; j \in U(i,:) j != i
16018f7e8f9dSMark Adams               // U(i+1,:end)
16028f7e8f9dSMark Adams               Kokkos::parallel_for(Kokkos::ThreadVectorRange(team, nzUi), [=](const int &uiIdx) { // index into i (U)
16038f7e8f9dSMark Adams                 PetscScalar Uij = baUi[uiIdx];
16048f7e8f9dSMark Adams                 PetscInt    col = bjUi[uiIdx];
16058f7e8f9dSMark Adams                 if (col == myk) {
16068f7e8f9dSMark Adams                   // A_kk = A_kk - L_ki * U_ij(k)
16078f7e8f9dSMark Adams                   PetscScalar *Akkv = (ba_d + bdiag_d[myk]); // diagonal in its special place
16088f7e8f9dSMark Adams                   *Akkv             = *Akkv - L_ki() * Uij;  // UiK
16098f7e8f9dSMark Adams                 } else {
16108f7e8f9dSMark Adams                   PetscScalar    *start, *end, *pAkjv = NULL;
16118f7e8f9dSMark Adams                   PetscInt        high, low;
16128f7e8f9dSMark Adams                   const PetscInt *startj;
16138f7e8f9dSMark Adams                   if (col < myk) { // L
16148f7e8f9dSMark Adams                     PetscScalar *pLki = ba_d + bi_d[myk] + colkIdx();
16158f7e8f9dSMark Adams                     PetscInt     idx  = (pLki + 1) - (ba_d + bi_d[myk]); // index into row
16168f7e8f9dSMark Adams                     start             = pLki + 1;                        // start at pLki+1, A22(myk,1)
16178f7e8f9dSMark Adams                     startj            = bj_d + bi_d[myk] + idx;
16188f7e8f9dSMark Adams                     end               = ba_d + bi_d[myk + 1];
16198f7e8f9dSMark Adams                   } else {
16208f7e8f9dSMark Adams                     PetscInt idx = bdiag_d[myk + 1] + 1;
16218f7e8f9dSMark Adams                     start        = ba_d + idx;
16228f7e8f9dSMark Adams                     startj       = bj_d + idx;
16238f7e8f9dSMark Adams                     end          = ba_d + bdiag_d[myk];
16248f7e8f9dSMark Adams                   }
16258f7e8f9dSMark Adams                   // search for 'col', use bisection search - TODO
16268f7e8f9dSMark Adams                   low  = 0;
16278f7e8f9dSMark Adams                   high = (PetscInt)(end - start);
16288f7e8f9dSMark Adams                   while (high - low > 5) {
16298f7e8f9dSMark Adams                     int t = (low + high) / 2;
16308f7e8f9dSMark Adams                     if (startj[t] > col) high = t;
16318f7e8f9dSMark Adams                     else low = t;
16328f7e8f9dSMark Adams                   }
16338f7e8f9dSMark Adams                   for (pAkjv = start + low; pAkjv < start + high; pAkjv++) {
16348f7e8f9dSMark Adams                     if (startj[pAkjv - start] == col) break;
16358f7e8f9dSMark Adams                   }
16368f1da0b2SJunchao Zhang #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
163799551766SMark Adams                   if (pAkjv == start + high) printf("\t\t\t\t\t\t\t\t\t\t\tERROR: *** failed to find Akj(%d,%d)\n", (int)myk, (int)col); // uses a register
163899551766SMark Adams #endif
16398f7e8f9dSMark Adams                   *pAkjv = *pAkjv - L_ki() * Uij; // A_kj = A_kj - L_ki * U_ij
16408f7e8f9dSMark Adams                 }
16418f7e8f9dSMark Adams               });
16428f7e8f9dSMark Adams             }
16438f7e8f9dSMark Adams           });
16448f7e8f9dSMark Adams           team.team_barrier(); // this needs to be a league barrier to use more that one SM per block
16458f7e8f9dSMark Adams           if (field_block_idx == 0) Kokkos::single(Kokkos::PerTeam(team), [&]() { Kokkos::atomic_add(flops.data(), (size_t)(2 * (nzUi * nzUi) + 2)); });
16468f7e8f9dSMark Adams         } /* endof for (i=0; i<n; i++) { */
16479371c9d4SSatish Balay         Kokkos::single(Kokkos::PerTeam(team), [=]() {
16489371c9d4SSatish Balay           Kokkos::atomic_add(&d_flops_k(), flops());
16499371c9d4SSatish Balay           flops() = 0;
16509371c9d4SSatish Balay         });
16518f7e8f9dSMark Adams       });
16528f7e8f9dSMark Adams     Kokkos::fence();
16538f7e8f9dSMark Adams     Kokkos::deep_copy(h_flops_k, d_flops_k);
16549566063dSJacob Faibussowitsch     PetscCall(PetscLogGpuFlops((PetscLogDouble)h_flops_k()));
16559371c9d4SSatish Balay     Kokkos::parallel_for(
16569371c9d4SSatish Balay       Kokkos::TeamPolicy<>(Nf * Ni, 1, 256), KOKKOS_LAMBDA(const team_member team) {
16578f7e8f9dSMark Adams         const PetscInt lg_rank = team.league_rank(), field = lg_rank / Ni;                            //, field_offset = lg_rank%Ni;
16588f7e8f9dSMark Adams         const PetscInt start = field * nloc, end = start + nloc, n_its = (nloc / Ni + !!(nloc % Ni)); // 1/Ni iters
16598f7e8f9dSMark Adams         /* Invert diagonal for simpler triangular solves */
16608f7e8f9dSMark Adams         Kokkos::parallel_for(Kokkos::TeamVectorRange(team, n_its), [=](int outer_index) {
16618f7e8f9dSMark Adams           int i = start + outer_index * Ni + lg_rank % Ni;
16628f7e8f9dSMark Adams           if (i < end) {
16638f7e8f9dSMark Adams             PetscScalar *pv = ba_d + bdiag_d[i];
16648f7e8f9dSMark Adams             *pv             = 1.0 / (*pv);
16658f7e8f9dSMark Adams           }
16668f7e8f9dSMark Adams         });
16678f7e8f9dSMark Adams       });
16688f7e8f9dSMark Adams   }
16699566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
16709566063dSJacob Faibussowitsch   PetscCall(ISRestoreIndices(isicol, &ic_h));
16719566063dSJacob Faibussowitsch   PetscCall(ISRestoreIndices(isrow, &r_h));
16728f7e8f9dSMark Adams 
16739566063dSJacob Faibussowitsch   PetscCall(ISIdentity(isrow, &row_identity));
16749566063dSJacob Faibussowitsch   PetscCall(ISIdentity(isicol, &col_identity));
16758f7e8f9dSMark Adams   if (b->inode.size) {
16768f7e8f9dSMark Adams     B->ops->solve = MatSolve_SeqAIJ_Inode;
16778f7e8f9dSMark Adams   } else if (row_identity && col_identity) {
16788f7e8f9dSMark Adams     B->ops->solve = MatSolve_SeqAIJ_NaturalOrdering;
16798f7e8f9dSMark Adams   } else {
16808f7e8f9dSMark Adams     B->ops->solve = MatSolve_SeqAIJ; // at least this needs to be in Kokkos
16818f7e8f9dSMark Adams   }
16828f7e8f9dSMark Adams   B->offloadmask = PETSC_OFFLOAD_GPU;
16839566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncHost(B));          // solve on CPU
16848f7e8f9dSMark Adams   B->ops->solveadd          = MatSolveAdd_SeqAIJ; // and this
16858f7e8f9dSMark Adams   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJ;
16868f7e8f9dSMark Adams   B->ops->solvetransposeadd = MatSolveTransposeAdd_SeqAIJ;
16878f7e8f9dSMark Adams   B->ops->matsolve          = MatMatSolve_SeqAIJ;
16888f7e8f9dSMark Adams   B->assembled              = PETSC_TRUE;
16898f7e8f9dSMark Adams   B->preallocated           = PETSC_TRUE;
16908f7e8f9dSMark Adams 
16913ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1692930e68a5SMark Adams }
1693930e68a5SMark Adams 
1694d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1695d71ae5a4SJacob Faibussowitsch {
1696930e68a5SMark Adams   PetscFunctionBegin;
16979566063dSJacob Faibussowitsch   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
169886a27549SJunchao Zhang   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
16993ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
170086a27549SJunchao Zhang }
170186a27549SJunchao Zhang 
1702d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSymbolicSolveCheck(Mat A)
1703d71ae5a4SJacob Faibussowitsch {
170486a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
170586a27549SJunchao Zhang 
170686a27549SJunchao Zhang   PetscFunctionBegin;
170786a27549SJunchao Zhang   if (!factors->sptrsv_symbolic_completed) {
170886a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d);
170986a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d);
171086a27549SJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_TRUE;
171186a27549SJunchao Zhang   }
17123ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
171386a27549SJunchao Zhang }
171486a27549SJunchao Zhang 
171586a27549SJunchao Zhang /* Check if we need to update factors etc for transpose solve */
1716d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosTransposeSolveCheck(Mat A)
1717d71ae5a4SJacob Faibussowitsch {
171886a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1719076ba34aSJunchao Zhang   MatColIdxType               n       = A->rmap->n;
172086a27549SJunchao Zhang 
172186a27549SJunchao Zhang   PetscFunctionBegin;
172286a27549SJunchao Zhang   if (!factors->transpose_updated) { /* TODO: KK needs to provide functions to do numeric transpose only */
172386a27549SJunchao Zhang     /* Update L^T and do sptrsv symbolic */
1724076ba34aSJunchao Zhang     factors->iLt_d = MatRowMapKokkosView("factors->iLt_d", n + 1);
172586a27549SJunchao Zhang     Kokkos::deep_copy(factors->iLt_d, 0); /* KK requires 0 */
1726076ba34aSJunchao Zhang     factors->jLt_d = MatColIdxKokkosView("factors->jLt_d", factors->jL_d.extent(0));
1727076ba34aSJunchao Zhang     factors->aLt_d = MatScalarKokkosView("factors->aLt_d", factors->aL_d.extent(0));
172886a27549SJunchao Zhang 
17299371c9d4SSatish Balay     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iL_d, factors->jL_d, factors->aL_d,
173086a27549SJunchao Zhang                                                                                                                                                                                                               factors->iLt_d, factors->jLt_d, factors->aLt_d);
173186a27549SJunchao Zhang 
173286a27549SJunchao Zhang     /* TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices.
173386a27549SJunchao Zhang       We have to sort the indices, until KK provides finer control options.
173486a27549SJunchao Zhang     */
17359371c9d4SSatish Balay     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iLt_d, factors->jLt_d, factors->aLt_d);
173686a27549SJunchao Zhang 
173786a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d);
173886a27549SJunchao Zhang 
173986a27549SJunchao Zhang     /* Update U^T and do sptrsv symbolic */
1740076ba34aSJunchao Zhang     factors->iUt_d = MatRowMapKokkosView("factors->iUt_d", n + 1);
174186a27549SJunchao Zhang     Kokkos::deep_copy(factors->iUt_d, 0); /* KK requires 0 */
1742076ba34aSJunchao Zhang     factors->jUt_d = MatColIdxKokkosView("factors->jUt_d", factors->jU_d.extent(0));
1743076ba34aSJunchao Zhang     factors->aUt_d = MatScalarKokkosView("factors->aUt_d", factors->aU_d.extent(0));
174486a27549SJunchao Zhang 
17459371c9d4SSatish Balay     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iU_d, factors->jU_d, factors->aU_d,
174686a27549SJunchao Zhang                                                                                                                                                                                                               factors->iUt_d, factors->jUt_d, factors->aUt_d);
174786a27549SJunchao Zhang 
174886a27549SJunchao Zhang     /* Sort indices. See comments above */
17499371c9d4SSatish Balay     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iUt_d, factors->jUt_d, factors->aUt_d);
175086a27549SJunchao Zhang 
175186a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d);
175286a27549SJunchao Zhang     factors->transpose_updated = PETSC_TRUE;
175386a27549SJunchao Zhang   }
17543ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
175586a27549SJunchao Zhang }
175686a27549SJunchao Zhang 
175786a27549SJunchao Zhang /* Solve Ax = b, with A = LU */
1758d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSolve_SeqAIJKokkos(Mat A, Vec b, Vec x)
1759d71ae5a4SJacob Faibussowitsch {
176086a27549SJunchao Zhang   ConstPetscScalarKokkosView  bv;
176186a27549SJunchao Zhang   PetscScalarKokkosView       xv;
176286a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
176386a27549SJunchao Zhang 
176486a27549SJunchao Zhang   PetscFunctionBegin;
17659566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
17669566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSymbolicSolveCheck(A));
17679566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(b, &bv));
17689566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(x, &xv));
176986a27549SJunchao Zhang   /* Solve L tmpv = b */
17709566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d, bv, factors->workVector));
177186a27549SJunchao Zhang   /* Solve Ux = tmpv */
17729566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, factors->workVector, xv));
17739566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(b, &bv));
17749566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
17759566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
17763ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
177786a27549SJunchao Zhang }
177886a27549SJunchao Zhang 
1779076ba34aSJunchao Zhang /* Solve A^T x = b, where A^T = U^T L^T */
1780d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSolveTranspose_SeqAIJKokkos(Mat A, Vec b, Vec x)
1781d71ae5a4SJacob Faibussowitsch {
178286a27549SJunchao Zhang   ConstPetscScalarKokkosView  bv;
178386a27549SJunchao Zhang   PetscScalarKokkosView       xv;
178486a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
178586a27549SJunchao Zhang 
178686a27549SJunchao Zhang   PetscFunctionBegin;
17879566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
17889566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A));
17899566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(b, &bv));
17909566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(x, &xv));
179186a27549SJunchao Zhang   /* Solve U^T tmpv = b */
179286a27549SJunchao Zhang   KokkosSparse::Experimental::sptrsv_solve(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, bv, factors->workVector);
179386a27549SJunchao Zhang 
179486a27549SJunchao Zhang   /* Solve L^T x = tmpv */
179586a27549SJunchao Zhang   KokkosSparse::Experimental::sptrsv_solve(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d, factors->workVector, xv);
17969566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(b, &bv));
17979566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
17989566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
17993ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
180086a27549SJunchao Zhang }
180186a27549SJunchao Zhang 
1802d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatILUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1803d71ae5a4SJacob Faibussowitsch {
180486a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
180586a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
180686a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
180786a27549SJunchao Zhang 
180886a27549SJunchao Zhang   PetscFunctionBegin;
18099566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
18109566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1811076ba34aSJunchao Zhang 
1812076ba34aSJunchao Zhang   auto a_d = aijkok->a_dual.view_device();
1813076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
1814076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
1815076ba34aSJunchao Zhang 
1816076ba34aSJunchao Zhang   KokkosSparse::Experimental::spiluk_numeric(&factors->kh, fill_lev, i_d, j_d, a_d, factors->iL_d, factors->jL_d, factors->aL_d, factors->iU_d, factors->jU_d, factors->aU_d);
181786a27549SJunchao Zhang 
181886a27549SJunchao Zhang   B->assembled              = PETSC_TRUE;
181986a27549SJunchao Zhang   B->preallocated           = PETSC_TRUE;
182086a27549SJunchao Zhang   B->ops->solve             = MatSolve_SeqAIJKokkos;
182186a27549SJunchao Zhang   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJKokkos;
182286a27549SJunchao Zhang   B->ops->matsolve          = NULL;
182386a27549SJunchao Zhang   B->ops->matsolvetranspose = NULL;
182486a27549SJunchao Zhang   B->offloadmask            = PETSC_OFFLOAD_GPU;
182586a27549SJunchao Zhang 
182686a27549SJunchao Zhang   /* Once the factors' value changed, we need to update their transpose and sptrsv handle */
182786a27549SJunchao Zhang   factors->transpose_updated         = PETSC_FALSE;
182886a27549SJunchao Zhang   factors->sptrsv_symbolic_completed = PETSC_FALSE;
1829eeadb341SJunchao Zhang   /* TODO: log flops, but how to know that? */
18309566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
18313ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
183286a27549SJunchao Zhang }
183386a27549SJunchao Zhang 
1834d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1835d71ae5a4SJacob Faibussowitsch {
183686a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok;
183786a27549SJunchao Zhang   Mat_SeqAIJ                 *b;
183886a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
183986a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
184086a27549SJunchao Zhang   PetscInt                    nnzA     = ((Mat_SeqAIJ *)A->data)->nz, nnzL, nnzU;
184186a27549SJunchao Zhang   PetscInt                    n        = A->rmap->n;
184286a27549SJunchao Zhang 
184386a27549SJunchao Zhang   PetscFunctionBegin;
18449566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
184586a27549SJunchao Zhang   /* Rebuild factors */
18469371c9d4SSatish Balay   if (factors) {
18479371c9d4SSatish Balay     factors->Destroy();
18489371c9d4SSatish Balay   } /* Destroy the old if it exists */
18499371c9d4SSatish Balay   else {
18509371c9d4SSatish Balay     B->spptr = factors = new Mat_SeqAIJKokkosTriFactors(n);
18519371c9d4SSatish Balay   }
185286a27549SJunchao Zhang 
185386a27549SJunchao Zhang   /* Create a spiluk handle and then do symbolic factorization */
185486a27549SJunchao Zhang   nnzL = nnzU = PetscRealIntMultTruncate(info->fill, nnzA);
185586a27549SJunchao Zhang   factors->kh.create_spiluk_handle(KokkosSparse::Experimental::SPILUKAlgorithm::SEQLVLSCHD_TP1, n, nnzL, nnzU);
185686a27549SJunchao Zhang 
185786a27549SJunchao Zhang   auto spiluk_handle = factors->kh.get_spiluk_handle();
185886a27549SJunchao Zhang 
185986a27549SJunchao Zhang   Kokkos::realloc(factors->iL_d, n + 1); /* Free old arrays and realloc */
186086a27549SJunchao Zhang   Kokkos::realloc(factors->jL_d, spiluk_handle->get_nnzL());
186186a27549SJunchao Zhang   Kokkos::realloc(factors->iU_d, n + 1);
186286a27549SJunchao Zhang   Kokkos::realloc(factors->jU_d, spiluk_handle->get_nnzU());
186386a27549SJunchao Zhang 
186486a27549SJunchao Zhang   aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
1865076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
1866076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
1867076ba34aSJunchao Zhang   KokkosSparse::Experimental::spiluk_symbolic(&factors->kh, fill_lev, i_d, j_d, factors->iL_d, factors->jL_d, factors->iU_d, factors->jU_d);
186886a27549SJunchao Zhang   /* TODO: if spiluk_symbolic is asynchronous, do we need to sync before calling get_nnzL()? */
186986a27549SJunchao Zhang 
187086a27549SJunchao Zhang   Kokkos::resize(factors->jL_d, spiluk_handle->get_nnzL()); /* Shrink or expand, and retain old value */
187186a27549SJunchao Zhang   Kokkos::resize(factors->jU_d, spiluk_handle->get_nnzU());
187286a27549SJunchao Zhang   Kokkos::realloc(factors->aL_d, spiluk_handle->get_nnzL()); /* No need to retain old value */
187386a27549SJunchao Zhang   Kokkos::realloc(factors->aU_d, spiluk_handle->get_nnzU());
187486a27549SJunchao Zhang 
187586a27549SJunchao Zhang   /* TODO: add options to select sptrsv algorithms */
187686a27549SJunchao Zhang   /* Create sptrsv handles for L, U and their transpose */
187786a27549SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
187886a27549SJunchao Zhang   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SPTRSV_CUSPARSE;
187986a27549SJunchao Zhang #else
188086a27549SJunchao Zhang   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SEQLVLSCHD_TP1;
188186a27549SJunchao Zhang #endif
188286a27549SJunchao Zhang 
188386a27549SJunchao Zhang   factors->khL.create_sptrsv_handle(sptrsv_alg, n, true /* L is lower tri */);
188486a27549SJunchao Zhang   factors->khU.create_sptrsv_handle(sptrsv_alg, n, false /* U is not lower tri */);
188586a27549SJunchao Zhang   factors->khLt.create_sptrsv_handle(sptrsv_alg, n, false /* L^T is not lower tri */);
188686a27549SJunchao Zhang   factors->khUt.create_sptrsv_handle(sptrsv_alg, n, true /* U^T is lower tri */);
188786a27549SJunchao Zhang 
188886a27549SJunchao Zhang   /* Fill fields of the factor matrix B */
18899566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(B, MAT_SKIP_ALLOCATION, NULL));
189086a27549SJunchao Zhang   b     = (Mat_SeqAIJ *)B->data;
189186a27549SJunchao Zhang   b->nz = b->maxnz          = spiluk_handle->get_nnzL() + spiluk_handle->get_nnzU();
189286a27549SJunchao Zhang   B->info.fill_ratio_given  = info->fill;
189386a27549SJunchao Zhang   B->info.fill_ratio_needed = ((PetscReal)b->nz) / ((PetscReal)nnzA);
189486a27549SJunchao Zhang 
189586a27549SJunchao Zhang   B->offloadmask          = PETSC_OFFLOAD_GPU;
189686a27549SJunchao Zhang   B->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJKokkos;
18973ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1898930e68a5SMark Adams }
1899930e68a5SMark Adams 
1900d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorSymbolic_SeqAIJKOKKOSDEVICE(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1901d71ae5a4SJacob Faibussowitsch {
19028f7e8f9dSMark Adams   Mat_SeqAIJ    *b     = (Mat_SeqAIJ *)B->data;
19038f7e8f9dSMark Adams   const PetscInt nrows = A->rmap->n;
1904930e68a5SMark Adams 
19058f7e8f9dSMark Adams   PetscFunctionBegin;
19069566063dSJacob Faibussowitsch   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
19078f7e8f9dSMark Adams   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKOKKOSDEVICE;
19088f7e8f9dSMark Adams   // move B data into Kokkos
19099566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B)); // create aijkok
19109566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A)); // create aijkok
19118f7e8f9dSMark Adams   {
19128f7e8f9dSMark Adams     Mat_SeqAIJKokkos *baijkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
1913300d22a6SJunchao Zhang     if (!baijkok->diag_d.extent(0)) {
19148f7e8f9dSMark Adams       const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_diag(b->diag, nrows + 1);
1915300d22a6SJunchao Zhang       baijkok->diag_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_diag));
1916300d22a6SJunchao Zhang       Kokkos::deep_copy(baijkok->diag_d, h_diag);
19178f7e8f9dSMark Adams     }
19188f7e8f9dSMark Adams   }
19193ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
19208f7e8f9dSMark Adams }
19218f7e8f9dSMark Adams 
1922d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatFactorGetSolverType_SeqAIJKokkos(Mat A, MatSolverType *type)
1923d71ae5a4SJacob Faibussowitsch {
1924930e68a5SMark Adams   PetscFunctionBegin;
1925930e68a5SMark Adams   *type = MATSOLVERKOKKOS;
19263ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1927930e68a5SMark Adams }
1928930e68a5SMark Adams 
1929d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatFactorGetSolverType_seqaij_kokkos_device(Mat A, MatSolverType *type)
1930d71ae5a4SJacob Faibussowitsch {
19318f7e8f9dSMark Adams   PetscFunctionBegin;
19328f7e8f9dSMark Adams   *type = MATSOLVERKOKKOSDEVICE;
19333ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
19348f7e8f9dSMark Adams }
19358f7e8f9dSMark Adams 
1936930e68a5SMark Adams /*MC
193786a27549SJunchao Zhang   MATSOLVERKOKKOS = "Kokkos" - A matrix solver type providing triangular solvers for sequential matrices
193811a5261eSBarry Smith   on a single GPU of type, `MATSEQAIJKOKKOS`, `MATAIJKOKKOS`.
1939930e68a5SMark Adams 
1940930e68a5SMark Adams   Level: beginner
1941930e68a5SMark Adams 
19422ef1f0ffSBarry Smith .seealso: [](chapter_matrices), `Mat`, `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJKokkos()`, `MATAIJKOKKOS`, `MatKokkosSetFormat()`, `MatKokkosStorageFormat`, `MatKokkosFormatOperation`
1943930e68a5SMark Adams M*/
194486a27549SJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_SeqAIJKokkos_Kokkos(Mat A, MatFactorType ftype, Mat *B) /* MatGetFactor_<MatType>_<MatSolverType> */
1945930e68a5SMark Adams {
1946930e68a5SMark Adams   PetscInt n = A->rmap->n;
1947930e68a5SMark Adams 
1948930e68a5SMark Adams   PetscFunctionBegin;
19499566063dSJacob Faibussowitsch   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B));
19509566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*B, n, n, n, n));
1951930e68a5SMark Adams   (*B)->factortype = ftype;
19529566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
19539566063dSJacob Faibussowitsch   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
1954930e68a5SMark Adams 
19558f7e8f9dSMark Adams   if (ftype == MAT_FACTOR_LU) {
19569566063dSJacob Faibussowitsch     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
195786a27549SJunchao Zhang     (*B)->canuseordering        = PETSC_TRUE;
195886a27549SJunchao Zhang     (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJKokkos;
195986a27549SJunchao Zhang   } else if (ftype == MAT_FACTOR_ILU) {
19609566063dSJacob Faibussowitsch     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
196186a27549SJunchao Zhang     (*B)->canuseordering         = PETSC_FALSE;
196286a27549SJunchao Zhang     (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJKokkos;
196398921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "MatFactorType %s is not supported by MatType SeqAIJKokkos", MatFactorTypes[ftype]);
1964930e68a5SMark Adams 
19659566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
19669566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)(*B), "MatFactorGetSolverType_C", MatFactorGetSolverType_SeqAIJKokkos));
19673ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1968930e68a5SMark Adams }
19698f7e8f9dSMark Adams 
1970d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatGetFactor_seqaijkokkos_kokkos_device(Mat A, MatFactorType ftype, Mat *B)
1971d71ae5a4SJacob Faibussowitsch {
19728f7e8f9dSMark Adams   PetscInt n = A->rmap->n;
19738f7e8f9dSMark Adams 
19748f7e8f9dSMark Adams   PetscFunctionBegin;
19759566063dSJacob Faibussowitsch   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B));
19769566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*B, n, n, n, n));
19778f7e8f9dSMark Adams   (*B)->factortype     = ftype;
1978f73b0415SBarry Smith   (*B)->canuseordering = PETSC_TRUE;
19799566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
19809566063dSJacob Faibussowitsch   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
19818f7e8f9dSMark Adams 
19828f7e8f9dSMark Adams   if (ftype == MAT_FACTOR_LU) {
19839566063dSJacob Faibussowitsch     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
19848f7e8f9dSMark Adams     (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJKOKKOSDEVICE;
19858f7e8f9dSMark Adams   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Factor type not supported for KOKKOS Matrix Types");
19868f7e8f9dSMark Adams 
19879566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
19889566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)(*B), "MatFactorGetSolverType_C", MatFactorGetSolverType_seqaij_kokkos_device));
19893ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
19908f7e8f9dSMark Adams }
199186a27549SJunchao Zhang 
1992d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatSolverTypeRegister_KOKKOS(void)
1993d71ae5a4SJacob Faibussowitsch {
199486a27549SJunchao Zhang   PetscFunctionBegin;
19959566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_SeqAIJKokkos_Kokkos));
19969566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ILU, MatGetFactor_SeqAIJKokkos_Kokkos));
19979566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOSDEVICE, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_seqaijkokkos_kokkos_device));
19983ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
199986a27549SJunchao Zhang }
200086a27549SJunchao Zhang 
2001076ba34aSJunchao Zhang /* Utility to print out a KokkosCsrMatrix for debugging */
2002d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PrintCsrMatrix(const KokkosCsrMatrix &csrmat)
2003d71ae5a4SJacob Faibussowitsch {
2004076ba34aSJunchao Zhang   const auto        &iv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.row_map);
2005076ba34aSJunchao Zhang   const auto        &jv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.entries);
2006076ba34aSJunchao Zhang   const auto        &av = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.values);
2007076ba34aSJunchao Zhang   const PetscInt    *i  = iv.data();
2008076ba34aSJunchao Zhang   const PetscInt    *j  = jv.data();
2009076ba34aSJunchao Zhang   const PetscScalar *a  = av.data();
2010076ba34aSJunchao Zhang   PetscInt           m = csrmat.numRows(), n = csrmat.numCols(), nnz = csrmat.nnz();
2011076ba34aSJunchao Zhang 
2012076ba34aSJunchao Zhang   PetscFunctionBegin;
20139566063dSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT " x %" PetscInt_FMT " SeqAIJKokkos, with %" PetscInt_FMT " nonzeros\n", m, n, nnz));
2014076ba34aSJunchao Zhang   for (PetscInt k = 0; k < m; k++) {
20159566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT ": ", k));
201648a46eb9SPierre Jolivet     for (PetscInt p = i[k]; p < i[k + 1]; p++) PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT "(%.1f), ", j[p], (double)PetscRealPart(a[p])));
20179566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "\n"));
2018076ba34aSJunchao Zhang   }
20193ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2020076ba34aSJunchao Zhang }
2021