xref: /petsc/src/mat/impls/aij/seq/kokkos/aijkok.kokkos.cxx (revision fe59aa6d68c880d4014a5813129926ee5b21e858)
111d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
2076ba34aSJunchao Zhang #include <petscpkg_version.h>
3152b3e56SJunchao Zhang #include <petsc/private/petscimpl.h>
442550becSJunchao Zhang #include <petsc/private/sfimpl.h>
58c3ff71bSJunchao Zhang #include <petscsystypes.h>
68c3ff71bSJunchao Zhang #include <petscerror.h>
78c3ff71bSJunchao Zhang 
88c3ff71bSJunchao Zhang #include <Kokkos_Core.hpp>
9f0cf5187SStefano Zampini #include <KokkosBlas.hpp>
108c3ff71bSJunchao Zhang #include <KokkosSparse_CrsMatrix.hpp>
118c3ff71bSJunchao Zhang #include <KokkosSparse_spmv.hpp>
1286a27549SJunchao Zhang #include <KokkosSparse_spiluk.hpp>
1386a27549SJunchao Zhang #include <KokkosSparse_sptrsv.hpp>
14076ba34aSJunchao Zhang #include <KokkosSparse_spgemm.hpp>
15076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
169d13fa56SJunchao Zhang #include <KokkosBatched_LU_Decl.hpp>
179d13fa56SJunchao Zhang #include <KokkosBatched_InverseLU_Decl.hpp>
1886a27549SJunchao Zhang 
1942550becSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
208c3ff71bSJunchao Zhang 
210e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 7, 0)
22f98996d3SJunchao Zhang   #include <KokkosSparse_Utils.hpp>
23f98996d3SJunchao Zhang using KokkosSparse::sort_crs_matrix;
249371c9d4SSatish Balay using KokkosSparse::Impl::transpose_matrix;
25f98996d3SJunchao Zhang #else
26f98996d3SJunchao Zhang   #include <KokkosKernels_Sorting.hpp>
27f98996d3SJunchao Zhang using KokkosKernels::sort_crs_matrix;
289371c9d4SSatish Balay using KokkosKernels::Impl::transpose_matrix;
29f98996d3SJunchao Zhang #endif
30f98996d3SJunchao Zhang 
318c3ff71bSJunchao Zhang static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat); /* Forward declaration */
328c3ff71bSJunchao Zhang 
33076ba34aSJunchao Zhang /* MatAssemblyEnd_SeqAIJKokkos() happens when we finalized nonzeros of the matrix, either after
34076ba34aSJunchao Zhang    we assembled the matrix on host, or after we directly produced the matrix data on device (ex., through MatMatMult).
35076ba34aSJunchao Zhang    In the latter case, it is important to set a_dual's sync state correctly.
36076ba34aSJunchao Zhang  */
37d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_SeqAIJKokkos(Mat A, MatAssemblyType mode)
38d71ae5a4SJacob Faibussowitsch {
39076ba34aSJunchao Zhang   Mat_SeqAIJ       *aijseq;
40076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
418c3ff71bSJunchao Zhang 
428c3ff71bSJunchao Zhang   PetscFunctionBegin;
433ba16761SJacob Faibussowitsch   if (mode == MAT_FLUSH_ASSEMBLY) PetscFunctionReturn(PETSC_SUCCESS);
449566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
45076ba34aSJunchao Zhang 
46076ba34aSJunchao Zhang   aijseq = static_cast<Mat_SeqAIJ *>(A->data);
47076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
48076ba34aSJunchao Zhang 
49076ba34aSJunchao Zhang   /* If aijkok does not exist, we just copy i, j to device.
50076ba34aSJunchao Zhang      If aijkok already exists, but the device's nonzero pattern does not match with the host's, we assume the latest data is on host.
51076ba34aSJunchao Zhang      In both cases, we build a new aijkok structure.
52076ba34aSJunchao Zhang   */
53076ba34aSJunchao Zhang   if (!aijkok || aijkok->nonzerostate != A->nonzerostate) { /* aijkok might not exist yet or nonzero pattern has changed */
54076ba34aSJunchao Zhang     delete aijkok;
55076ba34aSJunchao Zhang     aijkok   = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aijseq->nz, aijseq->i, aijseq->j, aijseq->a, A->nonzerostate, PETSC_FALSE /*don't copy mat values to device*/);
56076ba34aSJunchao Zhang     A->spptr = aijkok;
57076ba34aSJunchao Zhang   }
58076ba34aSJunchao Zhang 
593ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
608c3ff71bSJunchao Zhang }
618c3ff71bSJunchao Zhang 
6286a27549SJunchao Zhang /* Sync CSR data to device if not yet */
63d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosSyncDevice(Mat A)
64d71ae5a4SJacob Faibussowitsch {
658c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
668c3ff71bSJunchao Zhang 
678c3ff71bSJunchao Zhang   PetscFunctionBegin;
68aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from host to device");
695f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
70076ba34aSJunchao Zhang   if (aijkok->a_dual.need_sync_device()) {
71076ba34aSJunchao Zhang     aijkok->a_dual.sync_device();
72580c7c76SPierre Jolivet     aijkok->transpose_updated = PETSC_FALSE; /* values of the transpose is out-of-date */
7386a27549SJunchao Zhang     aijkok->hermitian_updated = PETSC_FALSE;
748c3ff71bSJunchao Zhang   }
753ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
768c3ff71bSJunchao Zhang }
778c3ff71bSJunchao Zhang 
78076ba34aSJunchao Zhang /* Mark the CSR data on device as modified */
79d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosModifyDevice(Mat A)
80d71ae5a4SJacob Faibussowitsch {
8186a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
8286a27549SJunchao Zhang 
8386a27549SJunchao Zhang   PetscFunctionBegin;
845f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Not supported for factorized matries");
8586a27549SJunchao Zhang   aijkok->a_dual.clear_sync_state();
8686a27549SJunchao Zhang   aijkok->a_dual.modify_device();
8786a27549SJunchao Zhang   aijkok->transpose_updated = PETSC_FALSE;
8886a27549SJunchao Zhang   aijkok->hermitian_updated = PETSC_FALSE;
899566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJInvalidateDiagonal(A));
909566063dSJacob Faibussowitsch   PetscCall(PetscObjectStateIncrease((PetscObject)A));
913ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
9286a27549SJunchao Zhang }
9386a27549SJunchao Zhang 
94d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSyncHost(Mat A)
95d71ae5a4SJacob Faibussowitsch {
96f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
97f0cf5187SStefano Zampini 
98f0cf5187SStefano Zampini   PetscFunctionBegin;
99f0cf5187SStefano Zampini   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10086a27549SJunchao Zhang   /* We do not expect one needs factors on host  */
101aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from device to host");
1025f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing AIJKOK");
103076ba34aSJunchao Zhang   aijkok->a_dual.sync_host();
1043ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
105f0cf5187SStefano Zampini }
106f0cf5187SStefano Zampini 
107d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
108d71ae5a4SJacob Faibussowitsch {
109076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
110f0cf5187SStefano Zampini 
111f0cf5187SStefano Zampini   PetscFunctionBegin;
1125519a089SJose E. Roman   /* aijkok contains valid pointers only if the host's nonzerostate matches with the device's.
1135519a089SJose E. Roman     Calling MatSeqAIJSetPreallocation() or MatSetValues() on host, where aijseq->{i,j,a} might be
1145519a089SJose E. Roman     reallocated, will lead to stale {i,j,a}_dual in aijkok. In both operations, the hosts's nonzerostate
1155519a089SJose E. Roman     must have been updated. The stale aijkok will be rebuilt during MatAssemblyEnd.
1165519a089SJose E. Roman   */
1175519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
118076ba34aSJunchao Zhang     aijkok->a_dual.sync_host();
119076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
120076ba34aSJunchao Zhang   } else { /* Happens when calling MatSetValues on a newly created matrix */
121076ba34aSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
122076ba34aSJunchao Zhang   }
1233ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
124076ba34aSJunchao Zhang }
125076ba34aSJunchao Zhang 
126d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
127d71ae5a4SJacob Faibussowitsch {
128076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
129076ba34aSJunchao Zhang 
130076ba34aSJunchao Zhang   PetscFunctionBegin;
1315519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) aijkok->a_dual.modify_host();
1323ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
133076ba34aSJunchao Zhang }
134076ba34aSJunchao Zhang 
135d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
136d71ae5a4SJacob Faibussowitsch {
137076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
138076ba34aSJunchao Zhang 
139076ba34aSJunchao Zhang   PetscFunctionBegin;
1405519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
141076ba34aSJunchao Zhang     aijkok->a_dual.sync_host();
142076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1432328674fSJunchao Zhang   } else {
1442328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1452328674fSJunchao Zhang   }
1463ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
147076ba34aSJunchao Zhang }
148076ba34aSJunchao Zhang 
149d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
150d71ae5a4SJacob Faibussowitsch {
151076ba34aSJunchao Zhang   PetscFunctionBegin;
152076ba34aSJunchao Zhang   *array = NULL;
1533ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
154076ba34aSJunchao Zhang }
155076ba34aSJunchao Zhang 
156d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
157d71ae5a4SJacob Faibussowitsch {
158076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
159076ba34aSJunchao Zhang 
160076ba34aSJunchao Zhang   PetscFunctionBegin;
1615519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
162076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1632328674fSJunchao Zhang   } else { /* Ex. happens with MatZeroEntries on a preallocated but not assembled matrix */
1642328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1652328674fSJunchao Zhang   }
1663ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
167076ba34aSJunchao Zhang }
168076ba34aSJunchao Zhang 
169d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
170d71ae5a4SJacob Faibussowitsch {
171076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
172076ba34aSJunchao Zhang 
173076ba34aSJunchao Zhang   PetscFunctionBegin;
1745519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
175076ba34aSJunchao Zhang     aijkok->a_dual.clear_sync_state();
176076ba34aSJunchao Zhang     aijkok->a_dual.modify_host();
1772328674fSJunchao Zhang   }
1783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
179f0cf5187SStefano Zampini }
180f0cf5187SStefano Zampini 
181d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJKokkos(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
182d71ae5a4SJacob Faibussowitsch {
1837ee59b9bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1847ee59b9bSJunchao Zhang 
1857ee59b9bSJunchao Zhang   PetscFunctionBegin;
1867ee59b9bSJunchao Zhang   PetscCheck(aijkok != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "aijkok is NULL");
1877ee59b9bSJunchao Zhang 
1887ee59b9bSJunchao Zhang   if (i) *i = aijkok->i_device_data();
1897ee59b9bSJunchao Zhang   if (j) *j = aijkok->j_device_data();
1907ee59b9bSJunchao Zhang   if (a) {
1917ee59b9bSJunchao Zhang     aijkok->a_dual.sync_device();
1927ee59b9bSJunchao Zhang     *a = aijkok->a_device_data();
1937ee59b9bSJunchao Zhang   }
1947ee59b9bSJunchao Zhang   if (mtype) *mtype = PETSC_MEMTYPE_KOKKOS;
1953ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1967ee59b9bSJunchao Zhang }
1977ee59b9bSJunchao Zhang 
1980e3ece09SJunchao Zhang /*
1990e3ece09SJunchao Zhang   Generate the sparsity pattern of a MatSeqAIJKokkos matrix's transpose on device.
2000e3ece09SJunchao Zhang 
2010e3ece09SJunchao Zhang   Input Parameter:
2020e3ece09SJunchao Zhang .  A       - the MATSEQAIJKOKKOS matrix
2030e3ece09SJunchao Zhang 
2040e3ece09SJunchao Zhang   Output Parameters:
2050e3ece09SJunchao Zhang +  perm_d - the permutation array on device, which connects Ta(i) = Aa(perm(i))
206aaa8cc7dSPierre Jolivet -  T_d    - the transpose on device, whose value array is allocated but not initialized
2070e3ece09SJunchao Zhang */
2080e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateTransposeStructure(Mat A, MatRowMapKokkosView &perm_d, KokkosCsrMatrix &T_d)
209d71ae5a4SJacob Faibussowitsch {
2100e3ece09SJunchao Zhang   Mat_SeqAIJ             *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2110e3ece09SJunchao Zhang   PetscInt                nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
2120e3ece09SJunchao Zhang   const PetscInt         *Ai = aseq->i, *Aj = aseq->j;
2137b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost Ti_h(NoInit("Ti"), n + 1);
2140e3ece09SJunchao Zhang   MatRowMapType          *Ti = Ti_h.data();
2157b8d4ba6SJunchao Zhang   MatColIdxKokkosViewHost Tj_h(NoInit("Tj"), nz);
2167b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost perm_h(NoInit("permutation"), nz);
2170e3ece09SJunchao Zhang   PetscInt               *Tj   = Tj_h.data();
2180e3ece09SJunchao Zhang   PetscInt               *perm = perm_h.data();
2190e3ece09SJunchao Zhang   PetscInt               *offset;
220152b3e56SJunchao Zhang 
221152b3e56SJunchao Zhang   PetscFunctionBegin;
2220e3ece09SJunchao Zhang   // Populate Ti
2230e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::deep_copy(Ti_h, 0));
2240e3ece09SJunchao Zhang   Ti++;
2250e3ece09SJunchao Zhang   for (PetscInt i = 0; i < nz; i++) Ti[Aj[i]]++;
2260e3ece09SJunchao Zhang   Ti--;
2270e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) Ti[i + 1] += Ti[i];
2280e3ece09SJunchao Zhang 
2290e3ece09SJunchao Zhang   // Populate Tj and the permutation array
2300e3ece09SJunchao Zhang   PetscCall(PetscCalloc1(n, &offset)); // offset in each T row to fill in its column indices
2310e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
2320e3ece09SJunchao Zhang     for (PetscInt j = Ai[i]; j < Ai[i + 1]; j++) { // A's (i,j) is T's (j,i)
2330e3ece09SJunchao Zhang       PetscInt r    = Aj[j];                       // row r of T
2340e3ece09SJunchao Zhang       PetscInt disp = Ti[r] + offset[r];
2350e3ece09SJunchao Zhang 
2360e3ece09SJunchao Zhang       Tj[disp]   = i; // col i of T
2370e3ece09SJunchao Zhang       perm[disp] = j;
2380e3ece09SJunchao Zhang       offset[r]++;
239076ba34aSJunchao Zhang     }
2400e3ece09SJunchao Zhang   }
2410e3ece09SJunchao Zhang   PetscCall(PetscFree(offset));
2420e3ece09SJunchao Zhang 
2430e3ece09SJunchao Zhang   // Sort each row of T, along with the permutation array
2440e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) PetscCall(PetscSortIntWithArray(Ti[i + 1] - Ti[i], Tj + Ti[i], perm + Ti[i]));
2450e3ece09SJunchao Zhang 
2460e3ece09SJunchao Zhang   // Output perm and T on device
2470e3ece09SJunchao Zhang   auto Ti_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ti_h);
2480e3ece09SJunchao Zhang   auto Tj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Tj_h);
2490e3ece09SJunchao Zhang   PetscCallCXX(T_d = KokkosCsrMatrix("csrmatT", n, m, nz, MatScalarKokkosView("Ta", nz), Ti_d, Tj_d));
2500e3ece09SJunchao Zhang   PetscCallCXX(perm_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), perm_h));
2513ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
252152b3e56SJunchao Zhang }
253152b3e56SJunchao Zhang 
2540e3ece09SJunchao Zhang // Generate the transpose on device and cache it internally
2550e3ece09SJunchao Zhang // Note: KK transpose_matrix() does not have support symbolic/numeric transpose, so we do it on our own
2560e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGenerateTranspose_Private(Mat A, KokkosCsrMatrix *csrmatT)
257d71ae5a4SJacob Faibussowitsch {
2580e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2590e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
2600e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
2610e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatT;
262152b3e56SJunchao Zhang 
263152b3e56SJunchao Zhang   PetscFunctionBegin;
2640e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
2650e3ece09SJunchao Zhang   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's valeus since we are going to access them on device
2660e3ece09SJunchao Zhang 
2670e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
2680e3ece09SJunchao Zhang 
2690e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE) {
2700e3ece09SJunchao Zhang     *csrmatT = akok->csrmat;
2710e3ece09SJunchao Zhang   } else {
2720e3ece09SJunchao Zhang     // See if we already have a cached transpose and its value is up to date
2730e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
2740e3ece09SJunchao Zhang       if (!akok->transpose_updated) {            // if the value is out of date, update the cached version
2750e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
2760e3ece09SJunchao Zhang         auto       &Ta   = T.values;
2770e3ece09SJunchao Zhang 
2780e3ece09SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(
2790e3ece09SJunchao Zhang           nz, KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = Aa(perm(i)); }));
280076ba34aSJunchao Zhang       }
2810e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
2820e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
2830e3ece09SJunchao Zhang 
2840e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
2850e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
2860e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(
2870e3ece09SJunchao Zhang         nz, KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = Aa(perm(i)); }));
2880e3ece09SJunchao Zhang     }
2890e3ece09SJunchao Zhang     akok->transpose_updated = PETSC_TRUE;
2900e3ece09SJunchao Zhang     *csrmatT                = akok->csrmatT;
2910e3ece09SJunchao Zhang   }
2920e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2930e3ece09SJunchao Zhang }
2940e3ece09SJunchao Zhang 
2950e3ece09SJunchao Zhang // Generate the Hermitian on device and cache it internally
2960e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateHermitian_Private(Mat A, KokkosCsrMatrix *csrmatH)
2970e3ece09SJunchao Zhang {
2980e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2990e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3000e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
3010e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatH;
3020e3ece09SJunchao Zhang 
3030e3ece09SJunchao Zhang   PetscFunctionBegin;
3040e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
3050e3ece09SJunchao Zhang   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's values since we are going to access them on device
3060e3ece09SJunchao Zhang 
3070e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
3080e3ece09SJunchao Zhang 
3090e3ece09SJunchao Zhang   if (A->hermitian == PETSC_BOOL3_TRUE) {
3100e3ece09SJunchao Zhang     *csrmatH = akok->csrmat;
3110e3ece09SJunchao Zhang   } else {
3120e3ece09SJunchao Zhang     // See if we already have a cached hermitian and its value is up to date
3130e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
3140e3ece09SJunchao Zhang       if (!akok->hermitian_updated) {            // if the value is out of date, update the cached version
3150e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
3160e3ece09SJunchao Zhang         auto       &Ta   = T.values;
3170e3ece09SJunchao Zhang 
3180e3ece09SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(
3190e3ece09SJunchao Zhang           nz, KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = PetscConj(Aa(perm(i))); }));
3200e3ece09SJunchao Zhang       }
3210e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
3220e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
3230e3ece09SJunchao Zhang 
3240e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
3250e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
3260e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(
3270e3ece09SJunchao Zhang         nz, KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = PetscConj(Aa(perm(i))); }));
3280e3ece09SJunchao Zhang     }
3290e3ece09SJunchao Zhang     akok->hermitian_updated = PETSC_TRUE;
3300e3ece09SJunchao Zhang     *csrmatH                = akok->csrmatH;
3310e3ece09SJunchao Zhang   }
3323ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
333152b3e56SJunchao Zhang }
334a587d139SMark 
3358c3ff71bSJunchao Zhang /* y = A x */
336d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMult_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
337d71ae5a4SJacob Faibussowitsch {
3388c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
339152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
340152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3418c3ff71bSJunchao Zhang 
3428c3ff71bSJunchao Zhang   PetscFunctionBegin;
3439566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3449566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3459566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3469566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
3478c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3489d52486cSJunchao Zhang   PetscCallCXX(KokkosSparse::spmv("N", 1.0 /*alpha*/, aijkok->csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A x + beta y */
3499566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3509566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
351076ba34aSJunchao Zhang   /* 2.0*nnz - numRows seems more accurate here but assumes there are no zero-rows. So a little sloppy here. */
3529566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
3539566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
3543ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
3558c3ff71bSJunchao Zhang }
3568c3ff71bSJunchao Zhang 
3578c3ff71bSJunchao Zhang /* y = A^T x */
358d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
359d71ae5a4SJacob Faibussowitsch {
3608c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
361152b3e56SJunchao Zhang   const char                *mode;
362152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
363152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3640e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
3658c3ff71bSJunchao Zhang 
3668c3ff71bSJunchao Zhang   PetscFunctionBegin;
3679566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3689566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3699566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3709566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
371152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
3729566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
373152b3e56SJunchao Zhang     mode = "N";
374152b3e56SJunchao Zhang   } else {
375076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3760e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
377152b3e56SJunchao Zhang     mode   = "T";
378152b3e56SJunchao Zhang   }
3790e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^T x + beta y */
3809566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3819566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
3820e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
3839566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
3843ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
3858c3ff71bSJunchao Zhang }
3868c3ff71bSJunchao Zhang 
3878c3ff71bSJunchao Zhang /* y = A^H x */
388d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
389d71ae5a4SJacob Faibussowitsch {
3908c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
391152b3e56SJunchao Zhang   const char                *mode;
392152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
393152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3940e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
3958c3ff71bSJunchao Zhang 
3968c3ff71bSJunchao Zhang   PetscFunctionBegin;
3979566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3989566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3999566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4009566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
401152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4029566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
403152b3e56SJunchao Zhang     mode = "N";
404152b3e56SJunchao Zhang   } else {
405076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4060e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
407152b3e56SJunchao Zhang     mode   = "C";
408152b3e56SJunchao Zhang   }
4090e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^H x + beta y */
4109566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4119566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
4120e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4139566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4143ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4158c3ff71bSJunchao Zhang }
4168c3ff71bSJunchao Zhang 
4178c3ff71bSJunchao Zhang /* z = A x + y */
418d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
419d71ae5a4SJacob Faibussowitsch {
4208c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
421152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
422152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4238c3ff71bSJunchao Zhang 
4248c3ff71bSJunchao Zhang   PetscFunctionBegin;
4259566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4269566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4279566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4289566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
4299566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
4308c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
4318c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4329d52486cSJunchao Zhang   PetscCallCXX(KokkosSparse::spmv("N", 1.0 /*alpha*/, aijkok->csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A x + beta z */
4339566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4349566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
4359566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
4369566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
4379566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4383ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4398c3ff71bSJunchao Zhang }
4408c3ff71bSJunchao Zhang 
4418c3ff71bSJunchao Zhang /* z = A^T x + y */
442d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
443d71ae5a4SJacob Faibussowitsch {
4448c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
445152b3e56SJunchao Zhang   const char                *mode;
446152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
447152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4480e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4498c3ff71bSJunchao Zhang 
4508c3ff71bSJunchao Zhang   PetscFunctionBegin;
4519566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4529566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4539566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4549566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
4559566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
4568c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
457152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4589566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
459152b3e56SJunchao Zhang     mode = "N";
460152b3e56SJunchao Zhang   } else {
461076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4620e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
463152b3e56SJunchao Zhang     mode   = "T";
464152b3e56SJunchao Zhang   }
4650e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^T x + beta z */
4669566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4679566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
4689566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
4690e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4709566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4713ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4728c3ff71bSJunchao Zhang }
4738c3ff71bSJunchao Zhang 
4748c3ff71bSJunchao Zhang /* z = A^H x + y */
475d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
476d71ae5a4SJacob Faibussowitsch {
4778c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
478152b3e56SJunchao Zhang   const char                *mode;
479152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
480152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4810e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4828c3ff71bSJunchao Zhang 
4838c3ff71bSJunchao Zhang   PetscFunctionBegin;
4849566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4859566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4869566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4879566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
4889566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
4898c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
490152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4919566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
492152b3e56SJunchao Zhang     mode = "N";
493152b3e56SJunchao Zhang   } else {
494076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4950e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
496152b3e56SJunchao Zhang     mode   = "C";
497152b3e56SJunchao Zhang   }
4980e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^H x + beta z */
4999566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
5009566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
5019566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
5020e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
5039566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
5043ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
505152b3e56SJunchao Zhang }
506152b3e56SJunchao Zhang 
507d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSetOption_SeqAIJKokkos(Mat A, MatOption op, PetscBool flg)
508d71ae5a4SJacob Faibussowitsch {
509152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
510152b3e56SJunchao Zhang 
511152b3e56SJunchao Zhang   PetscFunctionBegin;
512152b3e56SJunchao Zhang   switch (op) {
513152b3e56SJunchao Zhang   case MAT_FORM_EXPLICIT_TRANSPOSE:
514152b3e56SJunchao Zhang     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
5159566063dSJacob Faibussowitsch     if (A->form_explicit_transpose && !flg && aijkok) PetscCall(aijkok->DestroyMatTranspose());
516152b3e56SJunchao Zhang     A->form_explicit_transpose = flg;
517152b3e56SJunchao Zhang     break;
518d71ae5a4SJacob Faibussowitsch   default:
519d71ae5a4SJacob Faibussowitsch     PetscCall(MatSetOption_SeqAIJ(A, op, flg));
520d71ae5a4SJacob Faibussowitsch     break;
521152b3e56SJunchao Zhang   }
5223ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5238c3ff71bSJunchao Zhang }
5248c3ff71bSJunchao Zhang 
525076ba34aSJunchao Zhang /* Depending on reuse, either build a new mat, or use the existing mat */
526d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
527d71ae5a4SJacob Faibussowitsch {
528076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
5298c3ff71bSJunchao Zhang 
5308c3ff71bSJunchao Zhang   PetscFunctionBegin;
5319566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
532076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {                      /* Build a brand new mat */
5339566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));  /* the returned newmat is a SeqAIJKokkos */
5348c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {                 /* Reuse the mat created before */
5359566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); /* newmat is already a SeqAIJKokkos */
536076ba34aSJunchao Zhang   } else if (reuse == MAT_INPLACE_MATRIX) {               /* newmat is A */
5375f80ce2aSJacob Faibussowitsch     PetscCheck(A == *newmat, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "A != *newmat with MAT_INPLACE_MATRIX");
5389566063dSJacob Faibussowitsch     PetscCall(PetscFree(A->defaultvectype));
5399566063dSJacob Faibussowitsch     PetscCall(PetscStrallocpy(VECKOKKOS, &A->defaultvectype)); /* Allocate and copy the string */
5409566063dSJacob Faibussowitsch     PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSEQAIJKOKKOS));
5419566063dSJacob Faibussowitsch     PetscCall(MatSetOps_SeqAIJKokkos(A));
542076ba34aSJunchao Zhang     aseq = static_cast<Mat_SeqAIJ *>(A->data);
543394ed5ebSJunchao Zhang     if (A->assembled) { /* Copy i, j (but not values) to device for an assembled matrix if not yet */
5445f80ce2aSJacob Faibussowitsch       PetscCheck(!A->spptr, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Expect NULL (Mat_SeqAIJKokkos*)A->spptr");
545076ba34aSJunchao Zhang       A->spptr = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aseq->nz, aseq->i, aseq->j, aseq->a, A->nonzerostate, PETSC_FALSE);
5468c3ff71bSJunchao Zhang     }
547076ba34aSJunchao Zhang   }
5483ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5498c3ff71bSJunchao Zhang }
5508c3ff71bSJunchao Zhang 
551076ba34aSJunchao Zhang /* MatDuplicate always creates a new matrix. MatDuplicate can be called either on an assembled matrix or
552076ba34aSJunchao Zhang    an unassembled matrix, even though MAT_COPY_VALUES is not allowed for unassembled matrix.
553076ba34aSJunchao Zhang  */
554d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDuplicate_SeqAIJKokkos(Mat A, MatDuplicateOption dupOption, Mat *B)
555d71ae5a4SJacob Faibussowitsch {
556076ba34aSJunchao Zhang   Mat_SeqAIJ       *bseq;
557076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok;
558076ba34aSJunchao Zhang   Mat               mat;
5598c3ff71bSJunchao Zhang 
5608c3ff71bSJunchao Zhang   PetscFunctionBegin;
561076ba34aSJunchao Zhang   /* Do not copy values on host as A's latest values might be on device. We don't want to do sync blindly */
5629566063dSJacob Faibussowitsch   PetscCall(MatDuplicate_SeqAIJ(A, MAT_DO_NOT_COPY_VALUES, B));
563076ba34aSJunchao Zhang   mat = *B;
5642c4ab24aSJunchao Zhang   if (A->preallocated) {
565076ba34aSJunchao Zhang     bseq = static_cast<Mat_SeqAIJ *>(mat->data);
566076ba34aSJunchao Zhang     bkok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, bseq->nz, bseq->i, bseq->j, bseq->a, mat->nonzerostate, PETSC_FALSE);
567076ba34aSJunchao Zhang     bkok->a_dual.clear_sync_state(); /* Clear B's sync state as it will be decided below */
568076ba34aSJunchao Zhang     /* Now copy values to B if needed */
569076ba34aSJunchao Zhang     if (dupOption == MAT_COPY_VALUES) {
570076ba34aSJunchao Zhang       if (akok->a_dual.need_sync_device()) {
571076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_host(), akok->a_dual.view_host());
572076ba34aSJunchao Zhang         bkok->a_dual.modify_host();
573076ba34aSJunchao Zhang       } else { /* If device has the latest data, we only copy data on device */
574076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_device(), akok->a_dual.view_device());
575076ba34aSJunchao Zhang         bkok->a_dual.modify_device();
576076ba34aSJunchao Zhang       }
577076ba34aSJunchao Zhang     } else { /* MAT_DO_NOT_COPY_VALUES or MAT_SHARE_NONZERO_PATTERN. B's values should be zeroed */
578076ba34aSJunchao Zhang       /* B's values on host should be already zeroed by MatDuplicate_SeqAIJ() */
579076ba34aSJunchao Zhang       bkok->a_dual.modify_host();
580076ba34aSJunchao Zhang     }
581076ba34aSJunchao Zhang     mat->spptr = bkok;
582076ba34aSJunchao Zhang   }
583076ba34aSJunchao Zhang 
5849566063dSJacob Faibussowitsch   PetscCall(PetscFree(mat->defaultvectype));
5859566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &mat->defaultvectype)); /* Allocate and copy the string */
5869566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)mat, MATSEQAIJKOKKOS));
5879566063dSJacob Faibussowitsch   PetscCall(MatSetOps_SeqAIJKokkos(mat));
5883ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5898c3ff71bSJunchao Zhang }
5908c3ff71bSJunchao Zhang 
591d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatTranspose_SeqAIJKokkos(Mat A, MatReuse reuse, Mat *B)
592d71ae5a4SJacob Faibussowitsch {
5930ecb592aSJunchao Zhang   Mat               At;
5940e3ece09SJunchao Zhang   KokkosCsrMatrix   internT;
5950ecb592aSJunchao Zhang   Mat_SeqAIJKokkos *atkok, *bkok;
5960ecb592aSJunchao Zhang 
5970ecb592aSJunchao Zhang   PetscFunctionBegin;
5987fb60732SBarry Smith   if (reuse == MAT_REUSE_MATRIX) PetscCall(MatTransposeCheckNonzeroState_Private(A, *B));
5999566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &internT)); /* Generate a transpose internally */
6000ecb592aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX || reuse == MAT_INPLACE_MATRIX) {
601ff751488SJunchao Zhang     /* Deep copy internT, as we want to isolate the internal transpose */
6020e3ece09SJunchao Zhang     PetscCallCXX(atkok = new Mat_SeqAIJKokkos(KokkosCsrMatrix("csrmat", internT)));
6039566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PetscObjectComm((PetscObject)A), atkok, &At));
6040ecb592aSJunchao Zhang     if (reuse == MAT_INITIAL_MATRIX) *B = At;
6059566063dSJacob Faibussowitsch     else PetscCall(MatHeaderReplace(A, &At)); /* Replace A with At inplace */
6060ecb592aSJunchao Zhang   } else {                                    /* MAT_REUSE_MATRIX, just need to copy values to B on device */
6070ecb592aSJunchao Zhang     if ((*B)->assembled) {
6080ecb592aSJunchao Zhang       bkok = static_cast<Mat_SeqAIJKokkos *>((*B)->spptr);
6090e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(bkok->a_dual.view_device(), internT.values));
6109566063dSJacob Faibussowitsch       PetscCall(MatSeqAIJKokkosModifyDevice(*B));
6110ecb592aSJunchao Zhang     } else if ((*B)->preallocated) { /* It is ok for B to be only preallocated, as needed in MatTranspose_MPIAIJ */
6120ecb592aSJunchao Zhang       Mat_SeqAIJ             *bseq = static_cast<Mat_SeqAIJ *>((*B)->data);
6130e3ece09SJunchao Zhang       MatScalarKokkosViewHost a_h(bseq->a, internT.nnz()); /* bseq->nz = 0 if unassembled */
6140e3ece09SJunchao Zhang       MatColIdxKokkosViewHost j_h(bseq->j, internT.nnz());
6150e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(a_h, internT.values));
6160e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(j_h, internT.graph.entries));
6170ecb592aSJunchao Zhang     } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "B must be assembled or preallocated");
6180ecb592aSJunchao Zhang   }
6193ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6200ecb592aSJunchao Zhang }
6210ecb592aSJunchao Zhang 
622d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDestroy_SeqAIJKokkos(Mat A)
623d71ae5a4SJacob Faibussowitsch {
62486a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
6258c3ff71bSJunchao Zhang 
6268c3ff71bSJunchao Zhang   PetscFunctionBegin;
62786a27549SJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
62886a27549SJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
6298c3ff71bSJunchao Zhang     delete aijkok;
63086a27549SJunchao Zhang   } else {
63186a27549SJunchao Zhang     delete static_cast<Mat_SeqAIJKokkosTriFactors *>(A->spptr);
63286a27549SJunchao Zhang   }
633cbc6b225SStefano Zampini   A->spptr = NULL;
6349566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
6359566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
6369566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
6379566063dSJacob Faibussowitsch   PetscCall(MatDestroy_SeqAIJ(A));
6383ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6398c3ff71bSJunchao Zhang }
6408c3ff71bSJunchao Zhang 
6413f3ba80aSJunchao Zhang /*MC
6423f3ba80aSJunchao Zhang    MATSEQAIJKOKKOS - MATAIJKOKKOS = "(seq)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
6433f3ba80aSJunchao Zhang 
6443f3ba80aSJunchao Zhang    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
6453f3ba80aSJunchao Zhang 
6462ef1f0ffSBarry Smith    Options Database Key:
64711a5261eSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATSEQAIJKOKKOS` during a call to `MatSetFromOptions()`
6483f3ba80aSJunchao Zhang 
6493f3ba80aSJunchao Zhang   Level: beginner
6503f3ba80aSJunchao Zhang 
6511cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateSeqAIJKokkos()`, `MATMPIAIJKOKKOS`
6523f3ba80aSJunchao Zhang M*/
653d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJKokkos(Mat A)
654d71ae5a4SJacob Faibussowitsch {
65586a27549SJunchao Zhang   PetscFunctionBegin;
6569566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
6579566063dSJacob Faibussowitsch   PetscCall(MatCreate_SeqAIJ(A));
6589566063dSJacob Faibussowitsch   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
6593ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
66086a27549SJunchao Zhang }
66186a27549SJunchao Zhang 
662076ba34aSJunchao Zhang /* Merge A, B into a matrix C. A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n) */
663d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
664d71ae5a4SJacob Faibussowitsch {
665076ba34aSJunchao Zhang   Mat_SeqAIJ         *a, *b;
666076ba34aSJunchao Zhang   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
667076ba34aSJunchao Zhang   MatScalarKokkosView aa, ba, ca;
668076ba34aSJunchao Zhang   MatRowMapKokkosView ai, bi, ci;
669076ba34aSJunchao Zhang   MatColIdxKokkosView aj, bj, cj;
670076ba34aSJunchao Zhang   PetscInt            m, n, nnz, aN;
671a3f881fbSStefano Zampini 
672a3f881fbSStefano Zampini   PetscFunctionBegin;
673076ba34aSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
674076ba34aSJunchao Zhang   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
675076ba34aSJunchao Zhang   PetscValidPointer(C, 4);
676076ba34aSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
677076ba34aSJunchao Zhang   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
6785f80ce2aSJacob Faibussowitsch   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
6795f80ce2aSJacob Faibussowitsch   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
680076ba34aSJunchao Zhang 
6819566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
6829566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
683076ba34aSJunchao Zhang   a    = static_cast<Mat_SeqAIJ *>(A->data);
684076ba34aSJunchao Zhang   b    = static_cast<Mat_SeqAIJ *>(B->data);
685076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
686076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
687076ba34aSJunchao Zhang   aa   = akok->a_dual.view_device();
688076ba34aSJunchao Zhang   ai   = akok->i_dual.view_device();
689076ba34aSJunchao Zhang   ba   = bkok->a_dual.view_device();
690076ba34aSJunchao Zhang   bi   = bkok->i_dual.view_device();
691076ba34aSJunchao Zhang   m    = A->rmap->n; /* M, N and nnz of C */
692076ba34aSJunchao Zhang   n    = A->cmap->n + B->cmap->n;
693076ba34aSJunchao Zhang   nnz  = a->nz + b->nz;
694076ba34aSJunchao Zhang   aN   = A->cmap->n; /* N of A */
695076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
696076ba34aSJunchao Zhang     aj           = akok->j_dual.view_device();
697076ba34aSJunchao Zhang     bj           = bkok->j_dual.view_device();
698076ba34aSJunchao Zhang     auto ca_dual = MatScalarKokkosDualView("a", aa.extent(0) + ba.extent(0));
699076ba34aSJunchao Zhang     auto ci_dual = MatRowMapKokkosDualView("i", ai.extent(0));
700076ba34aSJunchao Zhang     auto cj_dual = MatColIdxKokkosDualView("j", aj.extent(0) + bj.extent(0));
701076ba34aSJunchao Zhang     ca           = ca_dual.view_device();
702076ba34aSJunchao Zhang     ci           = ci_dual.view_device();
703076ba34aSJunchao Zhang     cj           = cj_dual.view_device();
704076ba34aSJunchao Zhang 
705076ba34aSJunchao Zhang     /* Concatenate A and B in parallel using Kokkos hierarchical parallelism */
7069371c9d4SSatish Balay     Kokkos::parallel_for(
7079371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
708076ba34aSJunchao Zhang         PetscInt i       = t.league_rank(); /* row i */
709076ba34aSJunchao Zhang         PetscInt coffset = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
710076ba34aSJunchao Zhang 
711076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* this side effect only happens once per whole team */
712076ba34aSJunchao Zhang                                                    ci(i) = coffset;
713076ba34aSJunchao Zhang                                                    if (i == m - 1) ci(m) = ai(m) + bi(m);
714076ba34aSJunchao Zhang         });
715076ba34aSJunchao Zhang 
716076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
717076ba34aSJunchao Zhang           if (k < alen) {
718076ba34aSJunchao Zhang             ca(coffset + k) = aa(ai(i) + k);
719076ba34aSJunchao Zhang             cj(coffset + k) = aj(ai(i) + k);
720076ba34aSJunchao Zhang           } else {
721076ba34aSJunchao Zhang             ca(coffset + k) = ba(bi(i) + k - alen);
722076ba34aSJunchao Zhang             cj(coffset + k) = bj(bi(i) + k - alen) + aN; /* Entries in B get new column indices in C */
723076ba34aSJunchao Zhang           }
724076ba34aSJunchao Zhang         });
725076ba34aSJunchao Zhang       });
726076ba34aSJunchao Zhang     ca_dual.modify_device();
727076ba34aSJunchao Zhang     ci_dual.modify_device();
728076ba34aSJunchao Zhang     cj_dual.modify_device();
7299566063dSJacob Faibussowitsch     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, n, nnz, ci_dual, cj_dual, ca_dual));
7309566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
731076ba34aSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
732076ba34aSJunchao Zhang     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
733076ba34aSJunchao Zhang     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
734076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
735076ba34aSJunchao Zhang     ca   = ckok->a_dual.view_device();
736076ba34aSJunchao Zhang     ci   = ckok->i_dual.view_device();
737076ba34aSJunchao Zhang 
7389371c9d4SSatish Balay     Kokkos::parallel_for(
7399371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
740076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* row i */
741076ba34aSJunchao Zhang         PetscInt alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
742076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
743076ba34aSJunchao Zhang           if (k < alen) ca(ci(i) + k) = aa(ai(i) + k);
744076ba34aSJunchao Zhang           else ca(ci(i) + k) = ba(bi(i) + k - alen);
745076ba34aSJunchao Zhang         });
746076ba34aSJunchao Zhang       });
7479566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
748076ba34aSJunchao Zhang   }
7493ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
750076ba34aSJunchao Zhang }
751076ba34aSJunchao Zhang 
752d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_SeqAIJKokkos(void *pdata)
753d71ae5a4SJacob Faibussowitsch {
754076ba34aSJunchao Zhang   PetscFunctionBegin;
755076ba34aSJunchao Zhang   delete static_cast<MatProductData_SeqAIJKokkos *>(pdata);
7563ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
757a3f881fbSStefano Zampini }
758a3f881fbSStefano Zampini 
759d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos(Mat C)
760d71ae5a4SJacob Faibussowitsch {
761a3f881fbSStefano Zampini   Mat_Product                 *product = C->product;
762a3f881fbSStefano Zampini   Mat                          A, B;
763076ba34aSJunchao Zhang   bool                         transA, transB; /* use bool, since KK needs this type */
764a3f881fbSStefano Zampini   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
765a3f881fbSStefano Zampini   Mat_SeqAIJ                  *c;
766076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
7670e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB;
768a3f881fbSStefano Zampini 
769a3f881fbSStefano Zampini   PetscFunctionBegin;
770a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
7715f80ce2aSJacob Faibussowitsch   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data empty");
772076ba34aSJunchao Zhang   pdata = static_cast<MatProductData_SeqAIJKokkos *>(C->product->data);
773076ba34aSJunchao Zhang 
7740e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
7750e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
7760e3ece09SJunchao Zhang   // we still do numeric.
7770e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
7780e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
7793ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
780076ba34aSJunchao Zhang   }
781076ba34aSJunchao Zhang 
782076ba34aSJunchao Zhang   switch (product->type) {
7839371c9d4SSatish Balay   case MATPRODUCT_AB:
7849371c9d4SSatish Balay     transA = false;
7859371c9d4SSatish Balay     transB = false;
7869371c9d4SSatish Balay     break;
7879371c9d4SSatish Balay   case MATPRODUCT_AtB:
7889371c9d4SSatish Balay     transA = true;
7899371c9d4SSatish Balay     transB = false;
7909371c9d4SSatish Balay     break;
7919371c9d4SSatish Balay   case MATPRODUCT_ABt:
7929371c9d4SSatish Balay     transA = false;
7939371c9d4SSatish Balay     transB = true;
7949371c9d4SSatish Balay     break;
795d71ae5a4SJacob Faibussowitsch   default:
796d71ae5a4SJacob Faibussowitsch     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
797076ba34aSJunchao Zhang   }
798076ba34aSJunchao Zhang 
799a3f881fbSStefano Zampini   A = product->A;
800a3f881fbSStefano Zampini   B = product->B;
8019566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8029566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
803a3f881fbSStefano Zampini   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
804a3f881fbSStefano Zampini   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
805a3f881fbSStefano Zampini   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
806076ba34aSJunchao Zhang 
8075f80ce2aSJacob Faibussowitsch   PetscCheck(ckok, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Device data structure spptr is empty");
808076ba34aSJunchao Zhang 
8090e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
8100e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
811076ba34aSJunchao Zhang 
812076ba34aSJunchao Zhang   /* TODO: Once KK spgemm implements transpose, we can get rid of the explicit transpose here */
813076ba34aSJunchao Zhang   if (transA) {
8149566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
815076ba34aSJunchao Zhang     transA = false;
816a3f881fbSStefano Zampini   }
817a3f881fbSStefano Zampini 
818076ba34aSJunchao Zhang   if (transB) {
8199566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
820076ba34aSJunchao Zhang     transB = false;
821076ba34aSJunchao Zhang   }
8229566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
8230e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, ckok->csrmat));
8240e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
825866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
826866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(ckok->csrmat)); /* without sort, mat_tests-ex62_14_seqaijkokkos fails */
827e944a159SJunchao Zhang #endif
828866eb059SJunchao Zhang 
8299566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
8309566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(C));
831a3f881fbSStefano Zampini   /* shorter version of MatAssemblyEnd_SeqAIJ */
832a3f881fbSStefano Zampini   c = (Mat_SeqAIJ *)C->data;
8339566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded,%" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
8349566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
8359566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
836a3f881fbSStefano Zampini   c->reallocs         = 0;
837076ba34aSJunchao Zhang   C->info.mallocs     = 0;
838a3f881fbSStefano Zampini   C->info.nz_unneeded = 0;
839a3f881fbSStefano Zampini   C->assembled = C->was_assembled = PETSC_TRUE;
840a3f881fbSStefano Zampini   C->num_ass++;
8413ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
842a3f881fbSStefano Zampini }
843a3f881fbSStefano Zampini 
844d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos(Mat C)
845d71ae5a4SJacob Faibussowitsch {
846076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
847076ba34aSJunchao Zhang   MatProductType               ptype;
848076ba34aSJunchao Zhang   Mat                          A, B;
849076ba34aSJunchao Zhang   bool                         transA, transB;
850076ba34aSJunchao Zhang   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
851076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
852076ba34aSJunchao Zhang   MPI_Comm                     comm;
8530e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB, csrmatC;
854a3f881fbSStefano Zampini 
855a3f881fbSStefano Zampini   PetscFunctionBegin;
856a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
8579566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
8585f80ce2aSJacob Faibussowitsch   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
859a3f881fbSStefano Zampini   A = product->A;
860a3f881fbSStefano Zampini   B = product->B;
8619566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8629566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
863a3f881fbSStefano Zampini   akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
864a3f881fbSStefano Zampini   bkok    = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
8650e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
8660e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
867076ba34aSJunchao Zhang 
868a3f881fbSStefano Zampini   ptype = product->type;
8690e3ece09SJunchao Zhang   // Take advantage of the symmetry if true
8700e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
8710e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
8720e3ece09SJunchao Zhang     product->symbolic_used_the_fact_A_is_symmetric = PETSC_TRUE;
8730e3ece09SJunchao Zhang   }
8740e3ece09SJunchao Zhang   if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
8750e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
8760e3ece09SJunchao Zhang     product->symbolic_used_the_fact_B_is_symmetric = PETSC_TRUE;
8770e3ece09SJunchao Zhang   }
8780e3ece09SJunchao Zhang 
879a3f881fbSStefano Zampini   switch (ptype) {
8809371c9d4SSatish Balay   case MATPRODUCT_AB:
8819371c9d4SSatish Balay     transA = false;
8829371c9d4SSatish Balay     transB = false;
8839371c9d4SSatish Balay     break;
8849371c9d4SSatish Balay   case MATPRODUCT_AtB:
8859371c9d4SSatish Balay     transA = true;
8869371c9d4SSatish Balay     transB = false;
8879371c9d4SSatish Balay     break;
8889371c9d4SSatish Balay   case MATPRODUCT_ABt:
8899371c9d4SSatish Balay     transA = false;
8909371c9d4SSatish Balay     transB = true;
8919371c9d4SSatish Balay     break;
892d71ae5a4SJacob Faibussowitsch   default:
893d71ae5a4SJacob Faibussowitsch     SETERRQ(comm, PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
894a3f881fbSStefano Zampini   }
8950e3ece09SJunchao Zhang   PetscCallCXX(product->data = pdata = new MatProductData_SeqAIJKokkos());
896076ba34aSJunchao Zhang   pdata->reusesym = product->api_user;
897a3f881fbSStefano Zampini 
898076ba34aSJunchao Zhang   /* TODO: add command line options to select spgemm algorithms */
899866eb059SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; /* default alg is TPL if enabled, otherwise KK */
900866eb059SJunchao Zhang 
901866eb059SJunchao Zhang   /* CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 */
902866eb059SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
903866eb059SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
904866eb059SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
905866eb059SJunchao Zhang   #endif
906866eb059SJunchao Zhang #endif
9070e3ece09SJunchao Zhang   PetscCallCXX(pdata->kh.create_spgemm_handle(spgemm_alg));
908076ba34aSJunchao Zhang 
9099566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
910076ba34aSJunchao Zhang   /* TODO: Get rid of the explicit transpose once KK-spgemm implements the transpose option */
911076ba34aSJunchao Zhang   if (transA) {
9129566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
913076ba34aSJunchao Zhang     transA = false;
914076ba34aSJunchao Zhang   }
915076ba34aSJunchao Zhang 
916076ba34aSJunchao Zhang   if (transB) {
9179566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
918076ba34aSJunchao Zhang     transB = false;
919076ba34aSJunchao Zhang   }
920076ba34aSJunchao Zhang 
9210e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
922076ba34aSJunchao Zhang   /* spgemm_symbolic() only populates C's rowmap, but not C's column indices.
923076ba34aSJunchao Zhang     So we have to do a fake spgemm_numeric() here to get csrmatC.j_d setup, before
924076ba34aSJunchao Zhang     calling new Mat_SeqAIJKokkos().
925076ba34aSJunchao Zhang     TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
926076ba34aSJunchao Zhang   */
9270e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
9280e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
929866eb059SJunchao Zhang   /* Query if KK outputs a sorted matrix. If not, we need to sort it */
930866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
931866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(csrmatC)); /* sort_option defaults to -1 in KK!*/
932e944a159SJunchao Zhang #endif
9339566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
934076ba34aSJunchao Zhang 
9359566063dSJacob Faibussowitsch   PetscCallCXX(ckok = new Mat_SeqAIJKokkos(csrmatC));
9369566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(C, ckok));
937076ba34aSJunchao Zhang   C->product->destroy = MatProductDataDestroy_SeqAIJKokkos;
9383ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
939a3f881fbSStefano Zampini }
940a3f881fbSStefano Zampini 
941a3f881fbSStefano Zampini /* handles sparse matrix matrix ops */
942d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSetFromOptions_SeqAIJKokkos(Mat mat)
943d71ae5a4SJacob Faibussowitsch {
944076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
945a3f881fbSStefano Zampini   PetscBool    Biskok = PETSC_FALSE, Ciskok = PETSC_TRUE;
946a3f881fbSStefano Zampini 
947a3f881fbSStefano Zampini   PetscFunctionBegin;
948a3f881fbSStefano Zampini   MatCheckProduct(mat, 1);
9499566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJKOKKOS, &Biskok));
95048a46eb9SPierre Jolivet   if (product->type == MATPRODUCT_ABC) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJKOKKOS, &Ciskok));
951a3f881fbSStefano Zampini   if (Biskok && Ciskok) {
952a3f881fbSStefano Zampini     switch (product->type) {
953a3f881fbSStefano Zampini     case MATPRODUCT_AB:
954a3f881fbSStefano Zampini     case MATPRODUCT_AtB:
955d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABt:
956d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos;
957d71ae5a4SJacob Faibussowitsch       break;
958a3f881fbSStefano Zampini     case MATPRODUCT_PtAP:
959a3f881fbSStefano Zampini     case MATPRODUCT_RARt:
960d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABC:
961d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
962d71ae5a4SJacob Faibussowitsch       break;
963d71ae5a4SJacob Faibussowitsch     default:
964d71ae5a4SJacob Faibussowitsch       break;
965a3f881fbSStefano Zampini     }
966a3f881fbSStefano Zampini   } else { /* fallback for AIJ */
9679566063dSJacob Faibussowitsch     PetscCall(MatProductSetFromOptions_SeqAIJ(mat));
968a3f881fbSStefano Zampini   }
9693ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
970a3f881fbSStefano Zampini }
971a587d139SMark 
972d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatScale_SeqAIJKokkos(Mat A, PetscScalar a)
973d71ae5a4SJacob Faibussowitsch {
974f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok;
975f0cf5187SStefano Zampini 
976f0cf5187SStefano Zampini   PetscFunctionBegin;
9779566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
9789566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
979f0cf5187SStefano Zampini   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
980076ba34aSJunchao Zhang   KokkosBlas::scal(aijkok->a_dual.view_device(), a, aijkok->a_dual.view_device());
9819566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
9829566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(aijkok->a_dual.extent(0)));
9839566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
9843ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
985f0cf5187SStefano Zampini }
986f0cf5187SStefano Zampini 
987d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatZeroEntries_SeqAIJKokkos(Mat A)
988d71ae5a4SJacob Faibussowitsch {
989076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
990a587d139SMark 
991a587d139SMark   PetscFunctionBegin;
992076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
9932328674fSJunchao Zhang   if (aijkok) { /* Only zero the device if data is already there */
994076ba34aSJunchao Zhang     KokkosBlas::fill(aijkok->a_dual.view_device(), 0.0);
9959566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A));
9962328674fSJunchao Zhang   } else { /* Might be preallocated but not assembled */
9979566063dSJacob Faibussowitsch     PetscCall(MatZeroEntries_SeqAIJ(A));
9982328674fSJunchao Zhang   }
9993ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1000a587d139SMark }
1001a587d139SMark 
1002d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatGetDiagonal_SeqAIJKokkos(Mat A, Vec x)
1003d71ae5a4SJacob Faibussowitsch {
1004f78ce678SMark Adams   Mat_SeqAIJ           *aijseq;
1005f78ce678SMark Adams   Mat_SeqAIJKokkos     *aijkok;
1006f78ce678SMark Adams   PetscInt              n;
1007f78ce678SMark Adams   PetscScalarKokkosView xv;
1008f78ce678SMark Adams 
1009f78ce678SMark Adams   PetscFunctionBegin;
1010f78ce678SMark Adams   PetscCall(VecGetLocalSize(x, &n));
1011f78ce678SMark Adams   PetscCheck(n == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Nonconforming matrix and vector");
1012f78ce678SMark Adams   PetscCheck(A->factortype == MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_SUP, "MatGetDiagonal_SeqAIJKokkos not supported on factored matrices");
1013f78ce678SMark Adams 
1014f78ce678SMark Adams   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1015f78ce678SMark Adams   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1016f78ce678SMark Adams 
1017f78ce678SMark Adams   if (A->rmap->n && aijkok->diag_dual.extent(0) == 0) { /* Set the diagonal pointer if not already */
1018f78ce678SMark Adams     PetscCall(MatMarkDiagonal_SeqAIJ(A));
1019f78ce678SMark Adams     aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1020f78ce678SMark Adams     aijkok->SetDiagonal(aijseq->diag);
1021f78ce678SMark Adams   }
1022f78ce678SMark Adams 
1023f78ce678SMark Adams   const auto &Aa    = aijkok->a_dual.view_device();
1024f78ce678SMark Adams   const auto &Ai    = aijkok->i_dual.view_device();
1025f78ce678SMark Adams   const auto &Adiag = aijkok->diag_dual.view_device();
1026f78ce678SMark Adams 
1027f78ce678SMark Adams   PetscCall(VecGetKokkosViewWrite(x, &xv));
10289371c9d4SSatish Balay   Kokkos::parallel_for(
10299371c9d4SSatish Balay     n, KOKKOS_LAMBDA(const PetscInt i) {
1030f78ce678SMark Adams       if (Adiag(i) < Ai(i + 1)) xv(i) = Aa(Adiag(i));
1031f78ce678SMark Adams       else xv(i) = 0;
1032f78ce678SMark Adams     });
1033f78ce678SMark Adams   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
10343ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1035f78ce678SMark Adams }
1036f78ce678SMark Adams 
1037db78de30SJunchao Zhang /* Get a Kokkos View from a mat of type MatSeqAIJKokkos */
1038d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1039d71ae5a4SJacob Faibussowitsch {
1040db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1041db78de30SJunchao Zhang 
1042db78de30SJunchao Zhang   PetscFunctionBegin;
1043db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1044db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1045db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10469566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1047db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1048076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
10493ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1050db78de30SJunchao Zhang }
1051db78de30SJunchao Zhang 
1052d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1053d71ae5a4SJacob Faibussowitsch {
1054db78de30SJunchao Zhang   PetscFunctionBegin;
1055db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1056db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1057db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10583ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1059db78de30SJunchao Zhang }
1060db78de30SJunchao Zhang 
1061d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, MatScalarKokkosView *kv)
1062d71ae5a4SJacob Faibussowitsch {
1063db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1064db78de30SJunchao Zhang 
1065db78de30SJunchao Zhang   PetscFunctionBegin;
1066db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1067db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1068db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10699566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1070db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1071076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
10723ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1073db78de30SJunchao Zhang }
1074db78de30SJunchao Zhang 
1075d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, MatScalarKokkosView *kv)
1076d71ae5a4SJacob Faibussowitsch {
1077db78de30SJunchao Zhang   PetscFunctionBegin;
1078db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1079db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1080db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10819566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
10823ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1083db78de30SJunchao Zhang }
1084db78de30SJunchao Zhang 
1085d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1086d71ae5a4SJacob Faibussowitsch {
1087db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1088db78de30SJunchao Zhang 
1089db78de30SJunchao Zhang   PetscFunctionBegin;
1090db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1091db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1092db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1093db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1094076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
10953ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1096db78de30SJunchao Zhang }
1097db78de30SJunchao Zhang 
1098d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1099d71ae5a4SJacob Faibussowitsch {
1100db78de30SJunchao Zhang   PetscFunctionBegin;
1101db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1102db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1103db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11049566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
11053ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1106db78de30SJunchao Zhang }
1107db78de30SJunchao Zhang 
1108c17cf699SJunchao Zhang /* Computes Y += alpha X */
1109d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAXPY_SeqAIJKokkos(Mat Y, PetscScalar alpha, Mat X, MatStructure pattern)
1110d71ae5a4SJacob Faibussowitsch {
1111a587d139SMark   Mat_SeqAIJ              *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
1112c17cf699SJunchao Zhang   Mat_SeqAIJKokkos        *xkok, *ykok, *zkok;
1113c17cf699SJunchao Zhang   ConstMatScalarKokkosView Xa;
1114c17cf699SJunchao Zhang   MatScalarKokkosView      Ya;
1115a587d139SMark 
1116a587d139SMark   PetscFunctionBegin;
1117c17cf699SJunchao Zhang   PetscCheckTypeName(Y, MATSEQAIJKOKKOS);
1118c17cf699SJunchao Zhang   PetscCheckTypeName(X, MATSEQAIJKOKKOS);
11199566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(Y));
11209566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(X));
11219566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1122db78de30SJunchao Zhang 
1123c17cf699SJunchao Zhang   if (pattern != SAME_NONZERO_PATTERN && x->nz == y->nz) {
1124a587d139SMark     PetscBool e;
11259566063dSJacob Faibussowitsch     PetscCall(PetscArraycmp(x->i, y->i, Y->rmap->n + 1, &e));
1126a587d139SMark     if (e) {
11279566063dSJacob Faibussowitsch       PetscCall(PetscArraycmp(x->j, y->j, y->nz, &e));
1128c17cf699SJunchao Zhang       if (e) pattern = SAME_NONZERO_PATTERN;
1129a587d139SMark     }
1130a587d139SMark   }
1131db78de30SJunchao Zhang 
1132c17cf699SJunchao Zhang   /* cusparseDcsrgeam2() computes C = alpha A + beta B. If one knew sparsity pattern of C, one can skip
1133c17cf699SJunchao Zhang     cusparseScsrgeam2_bufferSizeExt() / cusparseXcsrgeam2Nnz(), and directly call cusparseScsrgeam2().
1134c17cf699SJunchao Zhang     If X is SUBSET_NONZERO_PATTERN of Y, we could take advantage of this cusparse feature. However,
1135c17cf699SJunchao Zhang     KokkosSparse::spadd(alpha,A,beta,B,C) has symbolic and numeric phases, MatAXPY does not.
1136c17cf699SJunchao Zhang   */
1137c17cf699SJunchao Zhang   ykok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1138c17cf699SJunchao Zhang   xkok = static_cast<Mat_SeqAIJKokkos *>(X->spptr);
1139c17cf699SJunchao Zhang   Xa   = xkok->a_dual.view_device();
1140c17cf699SJunchao Zhang   Ya   = ykok->a_dual.view_device();
1141c17cf699SJunchao Zhang 
1142c17cf699SJunchao Zhang   if (pattern == SAME_NONZERO_PATTERN) {
1143c17cf699SJunchao Zhang     KokkosBlas::axpy(alpha, Xa, Ya);
11449566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1145c17cf699SJunchao Zhang   } else if (pattern == SUBSET_NONZERO_PATTERN) {
1146c17cf699SJunchao Zhang     MatRowMapKokkosView Xi = xkok->i_dual.view_device(), Yi = ykok->i_dual.view_device();
1147c17cf699SJunchao Zhang     MatColIdxKokkosView Xj = xkok->j_dual.view_device(), Yj = ykok->j_dual.view_device();
1148c17cf699SJunchao Zhang 
11499371c9d4SSatish Balay     Kokkos::parallel_for(
11509371c9d4SSatish Balay       Kokkos::TeamPolicy<>(Y->rmap->n, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
11510e3ece09SJunchao Zhang         PetscInt i = t.league_rank(); // row i
11520e3ece09SJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() {
11530e3ece09SJunchao Zhang           // Only one thread works in a team
1154c17cf699SJunchao Zhang           PetscInt p, q = Yi(i);
11550e3ece09SJunchao Zhang           for (p = Xi(i); p < Xi(i + 1); p++) {          // For each nonzero on row i of X,
11560e3ece09SJunchao Zhang             while (Xj(p) != Yj(q) && q < Yi(i + 1)) q++; // find the matching nonzero on row i of Y.
11570e3ece09SJunchao Zhang             if (Xj(p) == Yj(q)) {                        // Found it
1158c17cf699SJunchao Zhang               Ya(q) += alpha * Xa(p);
1159c17cf699SJunchao Zhang               q++;
1160a587d139SMark             } else {
11610e3ece09SJunchao Zhang             // If not found, it indicates the input is wrong (X is not a SUBSET_NONZERO_PATTERN of Y).
11620e3ece09SJunchao Zhang             // Just insert a NaN at the beginning of row i if it is not empty, to make the result wrong.
11630e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_VERSION_GE(3, 7, 0)
11640e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::ArithTraits<PetscScalar>::nan();
11658b8b16f9SJunchao Zhang #else
11660e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::Experimental::nan("1");
11678b8b16f9SJunchao Zhang #endif
1168a587d139SMark             }
1169c17cf699SJunchao Zhang           }
1170c17cf699SJunchao Zhang         });
1171c17cf699SJunchao Zhang       });
11729566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
11730e3ece09SJunchao Zhang   } else { // different nonzero patterns
1174c17cf699SJunchao Zhang     Mat             Z;
1175c17cf699SJunchao Zhang     KokkosCsrMatrix zcsr;
1176c17cf699SJunchao Zhang     KernelHandle    kh;
11770e3ece09SJunchao Zhang     kh.create_spadd_handle(true); // X, Y are sorted
1178c17cf699SJunchao Zhang     KokkosSparse::spadd_symbolic(&kh, xkok->csrmat, ykok->csrmat, zcsr);
1179c17cf699SJunchao Zhang     KokkosSparse::spadd_numeric(&kh, alpha, xkok->csrmat, (PetscScalar)1.0, ykok->csrmat, zcsr);
1180c17cf699SJunchao Zhang     zkok = new Mat_SeqAIJKokkos(zcsr);
11819566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, zkok, &Z));
11829566063dSJacob Faibussowitsch     PetscCall(MatHeaderReplace(Y, &Z));
1183c17cf699SJunchao Zhang     kh.destroy_spadd_handle();
1184c17cf699SJunchao Zhang   }
11859566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
11860e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(xkok->a_dual.extent(0) * 2)); // Because we scaled X and then added it to Y
11873ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1188a587d139SMark }
1189a587d139SMark 
11902c4ab24aSJunchao Zhang struct MatCOOStruct_SeqAIJKokkos {
11912c4ab24aSJunchao Zhang   PetscCount           n;
11922c4ab24aSJunchao Zhang   PetscCount           Atot;
11932c4ab24aSJunchao Zhang   PetscInt             nz;
11942c4ab24aSJunchao Zhang   PetscCountKokkosView jmap;
11952c4ab24aSJunchao Zhang   PetscCountKokkosView perm;
11962c4ab24aSJunchao Zhang 
11972c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos(const MatCOOStruct_SeqAIJ *coo_h)
11982c4ab24aSJunchao Zhang   {
11992c4ab24aSJunchao Zhang     nz   = coo_h->nz;
12002c4ab24aSJunchao Zhang     n    = coo_h->n;
12012c4ab24aSJunchao Zhang     Atot = coo_h->Atot;
12022c4ab24aSJunchao Zhang     jmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->jmap, nz + 1));
12032c4ab24aSJunchao Zhang     perm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->perm, Atot));
12042c4ab24aSJunchao Zhang   }
12052c4ab24aSJunchao Zhang };
12062c4ab24aSJunchao Zhang 
12072c4ab24aSJunchao Zhang static PetscErrorCode MatCOOStructDestroy_SeqAIJKokkos(void *data)
12082c4ab24aSJunchao Zhang {
12092c4ab24aSJunchao Zhang   PetscFunctionBegin;
12102c4ab24aSJunchao Zhang   PetscCallCXX(delete static_cast<MatCOOStruct_SeqAIJKokkos *>(data));
12112c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
12122c4ab24aSJunchao Zhang }
12132c4ab24aSJunchao Zhang 
1214d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_SeqAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1215d71ae5a4SJacob Faibussowitsch {
121642550becSJunchao Zhang   Mat_SeqAIJKokkos          *akok;
121742550becSJunchao Zhang   Mat_SeqAIJ                *aseq;
12182c4ab24aSJunchao Zhang   PetscContainer             container_h, container_d;
12192c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJ       *coo_h;
12202c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo_d;
122142550becSJunchao Zhang 
122242550becSJunchao Zhang   PetscFunctionBegin;
12239566063dSJacob Faibussowitsch   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, coo_i, coo_j));
1224394ed5ebSJunchao Zhang   aseq = static_cast<Mat_SeqAIJ *>(mat->data);
122542550becSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(mat->spptr);
1226cbc6b225SStefano Zampini   delete akok;
1227cbc6b225SStefano Zampini   mat->spptr = akok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, aseq->nz, aseq->i, aseq->j, aseq->a, mat->nonzerostate + 1, PETSC_FALSE);
12289566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries_SeqAIJKokkos(mat));
12292c4ab24aSJunchao Zhang 
12302c4ab24aSJunchao Zhang   // Copy the COO struct to device
12312c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
12322c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
12332c4ab24aSJunchao Zhang   PetscCallCXX(coo_d = new MatCOOStruct_SeqAIJKokkos(coo_h));
12342c4ab24aSJunchao Zhang 
12352c4ab24aSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
12362c4ab24aSJunchao Zhang   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
12372c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetPointer(container_d, coo_d));
12382c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_SeqAIJKokkos));
12392c4ab24aSJunchao Zhang   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
12402c4ab24aSJunchao Zhang   PetscCall(PetscContainerDestroy(&container_d));
12413ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
124242550becSJunchao Zhang }
124342550becSJunchao Zhang 
1244d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_SeqAIJKokkos(Mat A, const PetscScalar v[], InsertMode imode)
1245d71ae5a4SJacob Faibussowitsch {
124642550becSJunchao Zhang   MatScalarKokkosView        Aa;
124742550becSJunchao Zhang   ConstMatScalarKokkosView   kv;
124842550becSJunchao Zhang   PetscMemType               memtype;
12492c4ab24aSJunchao Zhang   PetscContainer             container;
12502c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo;
125142550becSJunchao Zhang 
125242550becSJunchao Zhang   PetscFunctionBegin;
12532c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)A, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
12542c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
12552c4ab24aSJunchao Zhang 
12562c4ab24aSJunchao Zhang   const auto &n    = coo->n;
12572c4ab24aSJunchao Zhang   const auto &Annz = coo->nz;
12582c4ab24aSJunchao Zhang   const auto &jmap = coo->jmap;
12592c4ab24aSJunchao Zhang   const auto &perm = coo->perm;
12602c4ab24aSJunchao Zhang 
12619566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype));
126242550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
12632c4ab24aSJunchao Zhang     kv = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatScalarKokkosViewHost(v, n));
126442550becSJunchao Zhang   } else {
12652c4ab24aSJunchao Zhang     kv = ConstMatScalarKokkosView(v, n); /* Directly use v[]'s memory */
126642550becSJunchao Zhang   }
126742550becSJunchao Zhang 
1268c7b718f4SJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1269c7b718f4SJunchao Zhang   else PetscCall(MatSeqAIJGetKokkosView(A, &Aa));                             /* read & write matrix values */
127042550becSJunchao Zhang 
12719371c9d4SSatish Balay   Kokkos::parallel_for(
12729371c9d4SSatish Balay     Annz, KOKKOS_LAMBDA(const PetscCount i) {
1273c7b718f4SJunchao Zhang       PetscScalar sum = 0.0;
1274c7b718f4SJunchao Zhang       for (PetscCount k = jmap(i); k < jmap(i + 1); k++) sum += kv(perm(k));
1275c7b718f4SJunchao Zhang       Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1276c7b718f4SJunchao Zhang     });
1277394ed5ebSJunchao Zhang 
12789566063dSJacob Faibussowitsch   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
12799566063dSJacob Faibussowitsch   else PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
12803ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
128142550becSJunchao Zhang }
128242550becSJunchao Zhang 
1283d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1284d71ae5a4SJacob Faibussowitsch {
12858f7e8f9dSMark Adams   PetscFunctionBegin;
12869566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncHost(A));
12879566063dSJacob Faibussowitsch   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
12888f7e8f9dSMark Adams   B->offloadmask = PETSC_OFFLOAD_CPU;
12893ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
12908f7e8f9dSMark Adams }
12918f7e8f9dSMark Adams 
1292d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat A)
1293d71ae5a4SJacob Faibussowitsch {
1294076ba34aSJunchao Zhang   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
1295076ba34aSJunchao Zhang 
12968c3ff71bSJunchao Zhang   PetscFunctionBegin;
1297076ba34aSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_KOKKOS; /* We do not really use this flag */
12986f3d89d0SStefano Zampini   A->boundtocpu  = PETSC_FALSE;
12996f3d89d0SStefano Zampini 
13008c3ff71bSJunchao Zhang   A->ops->assemblyend               = MatAssemblyEnd_SeqAIJKokkos;
13018c3ff71bSJunchao Zhang   A->ops->destroy                   = MatDestroy_SeqAIJKokkos;
13028c3ff71bSJunchao Zhang   A->ops->duplicate                 = MatDuplicate_SeqAIJKokkos;
1303a587d139SMark   A->ops->axpy                      = MatAXPY_SeqAIJKokkos;
1304f0cf5187SStefano Zampini   A->ops->scale                     = MatScale_SeqAIJKokkos;
1305a587d139SMark   A->ops->zeroentries               = MatZeroEntries_SeqAIJKokkos;
1306076ba34aSJunchao Zhang   A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJKokkos;
13078c3ff71bSJunchao Zhang   A->ops->mult                      = MatMult_SeqAIJKokkos;
13088c3ff71bSJunchao Zhang   A->ops->multadd                   = MatMultAdd_SeqAIJKokkos;
13098c3ff71bSJunchao Zhang   A->ops->multtranspose             = MatMultTranspose_SeqAIJKokkos;
13108c3ff71bSJunchao Zhang   A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJKokkos;
13118c3ff71bSJunchao Zhang   A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJKokkos;
13128c3ff71bSJunchao Zhang   A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJKokkos;
1313076ba34aSJunchao Zhang   A->ops->productnumeric            = MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos;
13140ecb592aSJunchao Zhang   A->ops->transpose                 = MatTranspose_SeqAIJKokkos;
1315152b3e56SJunchao Zhang   A->ops->setoption                 = MatSetOption_SeqAIJKokkos;
1316f78ce678SMark Adams   A->ops->getdiagonal               = MatGetDiagonal_SeqAIJKokkos;
1317076ba34aSJunchao Zhang   a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJKokkos;
1318076ba34aSJunchao Zhang   a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJKokkos;
1319076ba34aSJunchao Zhang   a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJKokkos;
1320076ba34aSJunchao Zhang   a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJKokkos;
1321076ba34aSJunchao Zhang   a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJKokkos;
1322076ba34aSJunchao Zhang   a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJKokkos;
13237ee59b9bSJunchao Zhang   a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJKokkos;
132442550becSJunchao Zhang 
13259566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJKokkos));
13269566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJKokkos));
13273ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1328076ba34aSJunchao Zhang }
1329076ba34aSJunchao Zhang 
13309d13fa56SJunchao Zhang /*
13319d13fa56SJunchao Zhang    Extract the (prescribled) diagonal blocks of the matrix and then invert them
13329d13fa56SJunchao Zhang 
13339d13fa56SJunchao Zhang   Input Parameters:
13349d13fa56SJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
13359d13fa56SJunchao Zhang .  bs      - block sizes in 'csr' format, i.e., the i-th block has size bs(i+1) - bs(i)
13369d13fa56SJunchao Zhang .  bs2     - square of block sizes in 'csr' format, i.e., the i-th block should be stored at offset bs2(i) in diagVal[]
13379d13fa56SJunchao Zhang .  blkMap  - map row ids to block ids, i.e., row i belongs to the block blkMap(i)
13389d13fa56SJunchao Zhang -  work    - a pre-allocated work buffer (as big as diagVal) for use by this routine
13399d13fa56SJunchao Zhang 
13409d13fa56SJunchao Zhang   Output Parameter:
13419d13fa56SJunchao Zhang .  diagVal - the (pre-allocated) buffer to store the inverted blocks (each block is stored in column-major order)
13429d13fa56SJunchao Zhang */
13439d13fa56SJunchao Zhang PETSC_INTERN PetscErrorCode MatInvertVariableBlockDiagonal_SeqAIJKokkos(Mat A, const PetscIntKokkosView &bs, const PetscIntKokkosView &bs2, const PetscIntKokkosView &blkMap, PetscScalarKokkosView &work, PetscScalarKokkosView &diagVal)
13449d13fa56SJunchao Zhang {
13459d13fa56SJunchao Zhang   Mat_SeqAIJKokkos *akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
13469d13fa56SJunchao Zhang   PetscInt          N       = A->rmap->n;
13479d13fa56SJunchao Zhang   PetscInt          nblocks = bs.extent(0) - 1;
13489d13fa56SJunchao Zhang 
13499d13fa56SJunchao Zhang   PetscFunctionBegin;
13509d13fa56SJunchao Zhang   // Set the diagonal pointer on device if not already
13519d13fa56SJunchao Zhang   if (N && akok->diag_dual.extent(0) == 0) {
13529d13fa56SJunchao Zhang     PetscCall(MatMarkDiagonal_SeqAIJ(A));
13539d13fa56SJunchao Zhang     akok->SetDiagonal(static_cast<Mat_SeqAIJ *>(A->data)->diag);
13549d13fa56SJunchao Zhang   }
13559d13fa56SJunchao Zhang 
13569d13fa56SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A)); // Since we'll access A's value on device
13579d13fa56SJunchao Zhang 
13589d13fa56SJunchao Zhang   // Pull out the diagonal blocks of the matrix and then invert the blocks
13599d13fa56SJunchao Zhang   auto Aa    = akok->a_dual.view_device();
13609d13fa56SJunchao Zhang   auto Ai    = akok->i_dual.view_device();
13619d13fa56SJunchao Zhang   auto Aj    = akok->j_dual.view_device();
13629d13fa56SJunchao Zhang   auto Adiag = akok->diag_dual.view_device();
13639d13fa56SJunchao Zhang   // TODO: how to tune the team size?
13649d13fa56SJunchao Zhang #if defined(KOKKOS_ENABLE_DEFAULT_DEVICE_TYPE_HOST)
13659d13fa56SJunchao Zhang   auto ts = Kokkos::AUTO();
13669d13fa56SJunchao Zhang #else
13679d13fa56SJunchao Zhang   auto ts         = 16; // improved performance 30% over Kokkos::AUTO() with CUDA, but failed with "Kokkos::abort: Requested Team Size is too large!" on CPUs
13689d13fa56SJunchao Zhang #endif
13699d13fa56SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
13709d13fa56SJunchao Zhang     Kokkos::TeamPolicy<>(nblocks, ts), KOKKOS_LAMBDA(const KokkosTeamMemberType &teamMember) {
13719d13fa56SJunchao Zhang       const PetscInt bid    = teamMember.league_rank();                                                   // block id
13729d13fa56SJunchao Zhang       const PetscInt rstart = bs(bid);                                                                    // this block starts from this row
13739d13fa56SJunchao Zhang       const PetscInt m      = bs(bid + 1) - bs(bid);                                                      // size of this block
13749d13fa56SJunchao Zhang       const auto    &B      = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft>(&diagVal(bs2(bid)), m, m); // column-major order
13759d13fa56SJunchao Zhang       const auto    &W      = PetscScalarKokkosView(&work(bs2(bid)), m * m);
13769d13fa56SJunchao Zhang 
13779d13fa56SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(teamMember, m), [=](const PetscInt &r) { // r-th row in B
13789d13fa56SJunchao Zhang         PetscInt i = rstart + r;                                                            // i-th row in A
13799d13fa56SJunchao Zhang 
13809d13fa56SJunchao Zhang         if (Ai(i) <= Adiag(i) && Adiag(i) < Ai(i + 1)) { // if the diagonal exists (common case)
13819d13fa56SJunchao Zhang           PetscInt first = Adiag(i) - r;                 // we start to check nonzeros from here along this row
13829d13fa56SJunchao Zhang 
13839d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) {                   // walk n steps to see what column indices we will meet
13849d13fa56SJunchao Zhang             if (first + c < Ai(i) || first + c >= Ai(i + 1)) { // this entry (first+c) is out of range of this row, in other words, its value is zero
13859d13fa56SJunchao Zhang               B(r, c) = 0.0;
13869d13fa56SJunchao Zhang             } else if (Aj(first + c) == rstart + c) { // this entry is right on the (rstart+c) column
13879d13fa56SJunchao Zhang               B(r, c) = Aa(first + c);
13889d13fa56SJunchao Zhang             } else { // this entry does not show up in the CSR
13899d13fa56SJunchao Zhang               B(r, c) = 0.0;
13909d13fa56SJunchao Zhang             }
13919d13fa56SJunchao Zhang           }
13929d13fa56SJunchao Zhang         } else { // rare case that the diagonal does not exist
13939d13fa56SJunchao Zhang           const PetscInt begin = Ai(i);
13949d13fa56SJunchao Zhang           const PetscInt end   = Ai(i + 1);
13959d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) B(r, c) = 0.0;
13969d13fa56SJunchao Zhang           for (PetscInt j = begin; j < end; j++) { // scan the whole row; could use binary search but this is a rare case so we did not.
13979d13fa56SJunchao Zhang             if (rstart <= Aj(j) && Aj(j) < rstart + m) B(r, Aj(j) - rstart) = Aa(j);
13989d13fa56SJunchao Zhang             else if (Aj(j) >= rstart + m) break;
13999d13fa56SJunchao Zhang           }
14009d13fa56SJunchao Zhang         }
14019d13fa56SJunchao Zhang       });
14029d13fa56SJunchao Zhang 
14039d13fa56SJunchao Zhang       // LU-decompose B (w/o pivoting) and then invert B
14049d13fa56SJunchao Zhang       KokkosBatched::TeamLU<KokkosTeamMemberType, KokkosBatched::Algo::LU::Unblocked>::invoke(teamMember, B, 0.0);
14059d13fa56SJunchao Zhang       KokkosBatched::TeamInverseLU<KokkosTeamMemberType, KokkosBatched::Algo::InverseLU::Unblocked>::invoke(teamMember, B, W);
14069d13fa56SJunchao Zhang     }));
14079d13fa56SJunchao Zhang   // PetscLogGpuFlops() is done in the caller PCSetUp_VPBJacobi_Kokkos as we don't want to compute the flops in kernels
14089d13fa56SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
14099d13fa56SJunchao Zhang }
14109d13fa56SJunchao Zhang 
1411d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSetSeqAIJKokkosWithCSRMatrix(Mat A, Mat_SeqAIJKokkos *akok)
1412d71ae5a4SJacob Faibussowitsch {
1413076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
1414076ba34aSJunchao Zhang   PetscInt    i, m, n;
1415076ba34aSJunchao Zhang 
1416076ba34aSJunchao Zhang   PetscFunctionBegin;
14175f80ce2aSJacob Faibussowitsch   PetscCheck(!A->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A->spptr is supposed to be empty");
1418076ba34aSJunchao Zhang 
1419076ba34aSJunchao Zhang   m = akok->nrows();
1420076ba34aSJunchao Zhang   n = akok->ncols();
14219566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(A, m, n, m, n));
14229566063dSJacob Faibussowitsch   PetscCall(MatSetType(A, MATSEQAIJKOKKOS));
1423076ba34aSJunchao Zhang 
1424076ba34aSJunchao Zhang   /* Set up data structures of A as a MATSEQAIJ */
14259566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(A, MAT_SKIP_ALLOCATION, NULL));
1426076ba34aSJunchao Zhang   aseq = (Mat_SeqAIJ *)(A)->data;
1427076ba34aSJunchao Zhang 
1428076ba34aSJunchao Zhang   akok->i_dual.sync_host(); /* We always need sync'ed i, j on host */
1429076ba34aSJunchao Zhang   akok->j_dual.sync_host();
1430076ba34aSJunchao Zhang 
1431076ba34aSJunchao Zhang   aseq->i            = akok->i_host_data();
1432076ba34aSJunchao Zhang   aseq->j            = akok->j_host_data();
1433076ba34aSJunchao Zhang   aseq->a            = akok->a_host_data();
1434076ba34aSJunchao Zhang   aseq->nonew        = -1; /*this indicates that inserting a new value in the matrix that generates a new nonzero is an error*/
1435076ba34aSJunchao Zhang   aseq->singlemalloc = PETSC_FALSE;
1436076ba34aSJunchao Zhang   aseq->free_a       = PETSC_FALSE;
1437076ba34aSJunchao Zhang   aseq->free_ij      = PETSC_FALSE;
1438076ba34aSJunchao Zhang   aseq->nz           = akok->nnz();
1439076ba34aSJunchao Zhang   aseq->maxnz        = aseq->nz;
1440076ba34aSJunchao Zhang 
14419566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->imax));
14429566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->ilen));
1443ad540459SPierre Jolivet   for (i = 0; i < m; i++) aseq->ilen[i] = aseq->imax[i] = aseq->i[i + 1] - aseq->i[i];
1444076ba34aSJunchao Zhang 
1445076ba34aSJunchao Zhang   /* It is critical to set the nonzerostate, as we use it to check if sparsity pattern (hence data) has changed on host in MatAssemblyEnd */
1446076ba34aSJunchao Zhang   akok->nonzerostate = A->nonzerostate;
1447ff751488SJunchao Zhang   A->spptr           = akok; /* Set A->spptr before MatAssembly so that A->spptr won't be allocated again there */
14489566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
14499566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
14503ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1451076ba34aSJunchao Zhang }
1452076ba34aSJunchao Zhang 
14530e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetKokkosCsrMatrix(Mat A, KokkosCsrMatrix *csr)
14540e3ece09SJunchao Zhang {
14550e3ece09SJunchao Zhang   PetscFunctionBegin;
14560e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
14570e3ece09SJunchao Zhang   *csr = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
14580e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
14590e3ece09SJunchao Zhang }
14600e3ece09SJunchao Zhang 
14610e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithKokkosCsrMatrix(MPI_Comm comm, KokkosCsrMatrix csr, Mat *A)
14620e3ece09SJunchao Zhang {
14630e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok;
14640e3ece09SJunchao Zhang   PetscFunctionBegin;
14650e3ece09SJunchao Zhang   PetscCallCXX(akok = new Mat_SeqAIJKokkos(csr));
14660e3ece09SJunchao Zhang   PetscCall(MatCreate(comm, A));
14670e3ece09SJunchao Zhang   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
14680e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
14690e3ece09SJunchao Zhang }
14700e3ece09SJunchao Zhang 
1471076ba34aSJunchao Zhang /* Crete a SEQAIJKOKKOS matrix with a Mat_SeqAIJKokkos data structure
1472076ba34aSJunchao Zhang 
1473076ba34aSJunchao Zhang    Note we have names like MatSeqAIJSetPreallocationCSR, so I use capitalized CSR
1474076ba34aSJunchao Zhang  */
1475d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithCSRMatrix(MPI_Comm comm, Mat_SeqAIJKokkos *akok, Mat *A)
1476d71ae5a4SJacob Faibussowitsch {
1477076ba34aSJunchao Zhang   PetscFunctionBegin;
14789566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
14799566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
14803ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
14818c3ff71bSJunchao Zhang }
14828c3ff71bSJunchao Zhang 
1483152b3e56SJunchao Zhang /*@C
148411a5261eSBarry Smith   MatCreateSeqAIJKokkos - Creates a sparse matrix in `MATSEQAIJKOKKOS` (compressed row) format
14858c3ff71bSJunchao Zhang   (the default parallel PETSc format). This matrix will ultimately be handled by
148620f4b53cSBarry Smith   Kokkos for calculations.
14878c3ff71bSJunchao Zhang 
14888c3ff71bSJunchao Zhang   Collective
14898c3ff71bSJunchao Zhang 
14908c3ff71bSJunchao Zhang   Input Parameters:
149111a5261eSBarry Smith + comm - MPI communicator, set to `PETSC_COMM_SELF`
14928c3ff71bSJunchao Zhang . m    - number of rows
14938c3ff71bSJunchao Zhang . n    - number of columns
149420f4b53cSBarry Smith . nz   - number of nonzeros per row (same for all rows), ignored if `nnz` is provided
149520f4b53cSBarry Smith - nnz  - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL`
14968c3ff71bSJunchao Zhang 
14978c3ff71bSJunchao Zhang   Output Parameter:
14988c3ff71bSJunchao Zhang . A - the matrix
14998c3ff71bSJunchao Zhang 
15002ef1f0ffSBarry Smith   Level: intermediate
15012ef1f0ffSBarry Smith 
15022ef1f0ffSBarry Smith   Notes:
150311a5261eSBarry Smith   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
15048c3ff71bSJunchao Zhang   MatXXXXSetPreallocation() paradgm instead of this routine directly.
150511a5261eSBarry Smith   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
15068c3ff71bSJunchao Zhang 
150711a5261eSBarry Smith   The AIJ format, also called
15082ef1f0ffSBarry Smith   compressed row storage, is fully compatible with standard Fortran
15098c3ff71bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
151020f4b53cSBarry Smith   either one (as in Fortran) or zero.
15118c3ff71bSJunchao Zhang 
15122ef1f0ffSBarry Smith   Specify the preallocated storage with either `nz` or `nnz` (not both).
15132ef1f0ffSBarry Smith   Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
15142ef1f0ffSBarry Smith   allocation.
15158c3ff71bSJunchao Zhang 
1516*fe59aa6dSJacob Faibussowitsch .seealso: [](ch_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
15178c3ff71bSJunchao Zhang @*/
1518d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateSeqAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
1519d71ae5a4SJacob Faibussowitsch {
15208c3ff71bSJunchao Zhang   PetscFunctionBegin;
15219566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
15229566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
15239566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, m, n));
15249566063dSJacob Faibussowitsch   PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
15259566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
15263ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
15278c3ff71bSJunchao Zhang }
1528930e68a5SMark Adams 
1529d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1530d71ae5a4SJacob Faibussowitsch {
1531930e68a5SMark Adams   PetscFunctionBegin;
15329566063dSJacob Faibussowitsch   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
153386a27549SJunchao Zhang   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
15343ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
153586a27549SJunchao Zhang }
153686a27549SJunchao Zhang 
1537d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSymbolicSolveCheck(Mat A)
1538d71ae5a4SJacob Faibussowitsch {
153986a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
154086a27549SJunchao Zhang 
154186a27549SJunchao Zhang   PetscFunctionBegin;
154286a27549SJunchao Zhang   if (!factors->sptrsv_symbolic_completed) {
154386a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d);
154486a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d);
154586a27549SJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_TRUE;
154686a27549SJunchao Zhang   }
15473ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
154886a27549SJunchao Zhang }
154986a27549SJunchao Zhang 
155086a27549SJunchao Zhang /* Check if we need to update factors etc for transpose solve */
1551d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosTransposeSolveCheck(Mat A)
1552d71ae5a4SJacob Faibussowitsch {
155386a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1554076ba34aSJunchao Zhang   MatColIdxType               n       = A->rmap->n;
155586a27549SJunchao Zhang 
155686a27549SJunchao Zhang   PetscFunctionBegin;
155786a27549SJunchao Zhang   if (!factors->transpose_updated) { /* TODO: KK needs to provide functions to do numeric transpose only */
155886a27549SJunchao Zhang     /* Update L^T and do sptrsv symbolic */
15597b8d4ba6SJunchao Zhang     factors->iLt_d = MatRowMapKokkosView("factors->iLt_d", n + 1); // KK requires 0
15607b8d4ba6SJunchao Zhang     factors->jLt_d = MatColIdxKokkosView(NoInit("factors->jLt_d"), factors->jL_d.extent(0));
15617b8d4ba6SJunchao Zhang     factors->aLt_d = MatScalarKokkosView(NoInit("factors->aLt_d"), factors->aL_d.extent(0));
156286a27549SJunchao Zhang 
15639371c9d4SSatish Balay     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iL_d, factors->jL_d, factors->aL_d,
156486a27549SJunchao Zhang                                                                                                                                                                                                               factors->iLt_d, factors->jLt_d, factors->aLt_d);
156586a27549SJunchao Zhang 
156686a27549SJunchao Zhang     /* TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices.
156786a27549SJunchao Zhang       We have to sort the indices, until KK provides finer control options.
156886a27549SJunchao Zhang     */
15699371c9d4SSatish Balay     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iLt_d, factors->jLt_d, factors->aLt_d);
157086a27549SJunchao Zhang 
157186a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d);
157286a27549SJunchao Zhang 
157386a27549SJunchao Zhang     /* Update U^T and do sptrsv symbolic */
15747b8d4ba6SJunchao Zhang     factors->iUt_d = MatRowMapKokkosView("factors->iUt_d", n + 1); // KK requires 0
15757b8d4ba6SJunchao Zhang     factors->jUt_d = MatColIdxKokkosView(NoInit("factors->jUt_d"), factors->jU_d.extent(0));
15767b8d4ba6SJunchao Zhang     factors->aUt_d = MatScalarKokkosView(NoInit("factors->aUt_d"), factors->aU_d.extent(0));
157786a27549SJunchao Zhang 
15789371c9d4SSatish Balay     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iU_d, factors->jU_d, factors->aU_d,
157986a27549SJunchao Zhang                                                                                                                                                                                                               factors->iUt_d, factors->jUt_d, factors->aUt_d);
158086a27549SJunchao Zhang 
158186a27549SJunchao Zhang     /* Sort indices. See comments above */
15829371c9d4SSatish Balay     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iUt_d, factors->jUt_d, factors->aUt_d);
158386a27549SJunchao Zhang 
158486a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d);
158586a27549SJunchao Zhang     factors->transpose_updated = PETSC_TRUE;
158686a27549SJunchao Zhang   }
15873ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
158886a27549SJunchao Zhang }
158986a27549SJunchao Zhang 
159086a27549SJunchao Zhang /* Solve Ax = b, with A = LU */
1591d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSolve_SeqAIJKokkos(Mat A, Vec b, Vec x)
1592d71ae5a4SJacob Faibussowitsch {
159386a27549SJunchao Zhang   ConstPetscScalarKokkosView  bv;
159486a27549SJunchao Zhang   PetscScalarKokkosView       xv;
159586a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
159686a27549SJunchao Zhang 
159786a27549SJunchao Zhang   PetscFunctionBegin;
15989566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
15999566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSymbolicSolveCheck(A));
16009566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(b, &bv));
16019566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(x, &xv));
160286a27549SJunchao Zhang   /* Solve L tmpv = b */
16039566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d, bv, factors->workVector));
160486a27549SJunchao Zhang   /* Solve Ux = tmpv */
16059566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, factors->workVector, xv));
16069566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(b, &bv));
16079566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
16089566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
16093ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
161086a27549SJunchao Zhang }
161186a27549SJunchao Zhang 
1612076ba34aSJunchao Zhang /* Solve A^T x = b, where A^T = U^T L^T */
1613d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSolveTranspose_SeqAIJKokkos(Mat A, Vec b, Vec x)
1614d71ae5a4SJacob Faibussowitsch {
161586a27549SJunchao Zhang   ConstPetscScalarKokkosView  bv;
161686a27549SJunchao Zhang   PetscScalarKokkosView       xv;
161786a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
161886a27549SJunchao Zhang 
161986a27549SJunchao Zhang   PetscFunctionBegin;
16209566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
16219566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A));
16229566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(b, &bv));
16239566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(x, &xv));
162486a27549SJunchao Zhang   /* Solve U^T tmpv = b */
162586a27549SJunchao Zhang   KokkosSparse::Experimental::sptrsv_solve(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, bv, factors->workVector);
162686a27549SJunchao Zhang 
162786a27549SJunchao Zhang   /* Solve L^T x = tmpv */
162886a27549SJunchao Zhang   KokkosSparse::Experimental::sptrsv_solve(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d, factors->workVector, xv);
16299566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(b, &bv));
16309566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
16319566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
16323ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
163386a27549SJunchao Zhang }
163486a27549SJunchao Zhang 
1635d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatILUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1636d71ae5a4SJacob Faibussowitsch {
163786a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
163886a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
163986a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
164086a27549SJunchao Zhang 
164186a27549SJunchao Zhang   PetscFunctionBegin;
16429566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
16439566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1644076ba34aSJunchao Zhang 
1645076ba34aSJunchao Zhang   auto a_d = aijkok->a_dual.view_device();
1646076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
1647076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
1648076ba34aSJunchao Zhang 
1649076ba34aSJunchao Zhang   KokkosSparse::Experimental::spiluk_numeric(&factors->kh, fill_lev, i_d, j_d, a_d, factors->iL_d, factors->jL_d, factors->aL_d, factors->iU_d, factors->jU_d, factors->aU_d);
165086a27549SJunchao Zhang 
165186a27549SJunchao Zhang   B->assembled              = PETSC_TRUE;
165286a27549SJunchao Zhang   B->preallocated           = PETSC_TRUE;
165386a27549SJunchao Zhang   B->ops->solve             = MatSolve_SeqAIJKokkos;
165486a27549SJunchao Zhang   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJKokkos;
165586a27549SJunchao Zhang   B->ops->matsolve          = NULL;
165686a27549SJunchao Zhang   B->ops->matsolvetranspose = NULL;
165786a27549SJunchao Zhang   B->offloadmask            = PETSC_OFFLOAD_GPU;
165886a27549SJunchao Zhang 
165986a27549SJunchao Zhang   /* Once the factors' value changed, we need to update their transpose and sptrsv handle */
166086a27549SJunchao Zhang   factors->transpose_updated         = PETSC_FALSE;
166186a27549SJunchao Zhang   factors->sptrsv_symbolic_completed = PETSC_FALSE;
1662eeadb341SJunchao Zhang   /* TODO: log flops, but how to know that? */
16639566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
16643ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
166586a27549SJunchao Zhang }
166686a27549SJunchao Zhang 
1667d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1668d71ae5a4SJacob Faibussowitsch {
166986a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok;
167086a27549SJunchao Zhang   Mat_SeqAIJ                 *b;
167186a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
167286a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
167386a27549SJunchao Zhang   PetscInt                    nnzA     = ((Mat_SeqAIJ *)A->data)->nz, nnzL, nnzU;
167486a27549SJunchao Zhang   PetscInt                    n        = A->rmap->n;
167586a27549SJunchao Zhang 
167686a27549SJunchao Zhang   PetscFunctionBegin;
16779566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
167886a27549SJunchao Zhang   /* Rebuild factors */
16799371c9d4SSatish Balay   if (factors) {
16809371c9d4SSatish Balay     factors->Destroy();
16819371c9d4SSatish Balay   } /* Destroy the old if it exists */
16829371c9d4SSatish Balay   else {
16839371c9d4SSatish Balay     B->spptr = factors = new Mat_SeqAIJKokkosTriFactors(n);
16849371c9d4SSatish Balay   }
168586a27549SJunchao Zhang 
168686a27549SJunchao Zhang   /* Create a spiluk handle and then do symbolic factorization */
168786a27549SJunchao Zhang   nnzL = nnzU = PetscRealIntMultTruncate(info->fill, nnzA);
168886a27549SJunchao Zhang   factors->kh.create_spiluk_handle(KokkosSparse::Experimental::SPILUKAlgorithm::SEQLVLSCHD_TP1, n, nnzL, nnzU);
168986a27549SJunchao Zhang 
169086a27549SJunchao Zhang   auto spiluk_handle = factors->kh.get_spiluk_handle();
169186a27549SJunchao Zhang 
169286a27549SJunchao Zhang   Kokkos::realloc(factors->iL_d, n + 1); /* Free old arrays and realloc */
169386a27549SJunchao Zhang   Kokkos::realloc(factors->jL_d, spiluk_handle->get_nnzL());
169486a27549SJunchao Zhang   Kokkos::realloc(factors->iU_d, n + 1);
169586a27549SJunchao Zhang   Kokkos::realloc(factors->jU_d, spiluk_handle->get_nnzU());
169686a27549SJunchao Zhang 
169786a27549SJunchao Zhang   aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
1698076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
1699076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
1700076ba34aSJunchao Zhang   KokkosSparse::Experimental::spiluk_symbolic(&factors->kh, fill_lev, i_d, j_d, factors->iL_d, factors->jL_d, factors->iU_d, factors->jU_d);
170186a27549SJunchao Zhang   /* TODO: if spiluk_symbolic is asynchronous, do we need to sync before calling get_nnzL()? */
170286a27549SJunchao Zhang 
170386a27549SJunchao Zhang   Kokkos::resize(factors->jL_d, spiluk_handle->get_nnzL()); /* Shrink or expand, and retain old value */
170486a27549SJunchao Zhang   Kokkos::resize(factors->jU_d, spiluk_handle->get_nnzU());
170586a27549SJunchao Zhang   Kokkos::realloc(factors->aL_d, spiluk_handle->get_nnzL()); /* No need to retain old value */
170686a27549SJunchao Zhang   Kokkos::realloc(factors->aU_d, spiluk_handle->get_nnzU());
170786a27549SJunchao Zhang 
170886a27549SJunchao Zhang   /* TODO: add options to select sptrsv algorithms */
170986a27549SJunchao Zhang   /* Create sptrsv handles for L, U and their transpose */
171086a27549SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
171186a27549SJunchao Zhang   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SPTRSV_CUSPARSE;
171286a27549SJunchao Zhang #else
171386a27549SJunchao Zhang   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SEQLVLSCHD_TP1;
171486a27549SJunchao Zhang #endif
171586a27549SJunchao Zhang 
171686a27549SJunchao Zhang   factors->khL.create_sptrsv_handle(sptrsv_alg, n, true /* L is lower tri */);
171786a27549SJunchao Zhang   factors->khU.create_sptrsv_handle(sptrsv_alg, n, false /* U is not lower tri */);
171886a27549SJunchao Zhang   factors->khLt.create_sptrsv_handle(sptrsv_alg, n, false /* L^T is not lower tri */);
171986a27549SJunchao Zhang   factors->khUt.create_sptrsv_handle(sptrsv_alg, n, true /* U^T is lower tri */);
172086a27549SJunchao Zhang 
172186a27549SJunchao Zhang   /* Fill fields of the factor matrix B */
17229566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(B, MAT_SKIP_ALLOCATION, NULL));
172386a27549SJunchao Zhang   b     = (Mat_SeqAIJ *)B->data;
172486a27549SJunchao Zhang   b->nz = b->maxnz          = spiluk_handle->get_nnzL() + spiluk_handle->get_nnzU();
172586a27549SJunchao Zhang   B->info.fill_ratio_given  = info->fill;
1726a1e4e190SStefano Zampini   B->info.fill_ratio_needed = nnzA > 0 ? ((PetscReal)b->nz) / ((PetscReal)nnzA) : 1.0;
172786a27549SJunchao Zhang 
172886a27549SJunchao Zhang   B->offloadmask          = PETSC_OFFLOAD_GPU;
172986a27549SJunchao Zhang   B->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJKokkos;
17303ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1731930e68a5SMark Adams }
1732930e68a5SMark Adams 
1733d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatFactorGetSolverType_SeqAIJKokkos(Mat A, MatSolverType *type)
1734d71ae5a4SJacob Faibussowitsch {
1735930e68a5SMark Adams   PetscFunctionBegin;
1736930e68a5SMark Adams   *type = MATSOLVERKOKKOS;
17373ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1738930e68a5SMark Adams }
1739930e68a5SMark Adams 
1740930e68a5SMark Adams /*MC
174186a27549SJunchao Zhang   MATSOLVERKOKKOS = "Kokkos" - A matrix solver type providing triangular solvers for sequential matrices
174211a5261eSBarry Smith   on a single GPU of type, `MATSEQAIJKOKKOS`, `MATAIJKOKKOS`.
1743930e68a5SMark Adams 
1744930e68a5SMark Adams   Level: beginner
1745930e68a5SMark Adams 
17461cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJKokkos()`, `MATAIJKOKKOS`, `MatKokkosSetFormat()`, `MatKokkosStorageFormat`, `MatKokkosFormatOperation`
1747930e68a5SMark Adams M*/
174886a27549SJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_SeqAIJKokkos_Kokkos(Mat A, MatFactorType ftype, Mat *B) /* MatGetFactor_<MatType>_<MatSolverType> */
1749930e68a5SMark Adams {
1750930e68a5SMark Adams   PetscInt n = A->rmap->n;
1751930e68a5SMark Adams 
1752930e68a5SMark Adams   PetscFunctionBegin;
17539566063dSJacob Faibussowitsch   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B));
17549566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*B, n, n, n, n));
1755930e68a5SMark Adams   (*B)->factortype = ftype;
17569566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
17579566063dSJacob Faibussowitsch   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
1758930e68a5SMark Adams 
17598f7e8f9dSMark Adams   if (ftype == MAT_FACTOR_LU) {
17609566063dSJacob Faibussowitsch     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
176186a27549SJunchao Zhang     (*B)->canuseordering        = PETSC_TRUE;
176286a27549SJunchao Zhang     (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJKokkos;
176386a27549SJunchao Zhang   } else if (ftype == MAT_FACTOR_ILU) {
17649566063dSJacob Faibussowitsch     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
176586a27549SJunchao Zhang     (*B)->canuseordering         = PETSC_FALSE;
176686a27549SJunchao Zhang     (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJKokkos;
176798921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "MatFactorType %s is not supported by MatType SeqAIJKokkos", MatFactorTypes[ftype]);
1768930e68a5SMark Adams 
17699566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
17709566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)(*B), "MatFactorGetSolverType_C", MatFactorGetSolverType_SeqAIJKokkos));
17713ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1772930e68a5SMark Adams }
17738f7e8f9dSMark Adams 
1774d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatSolverTypeRegister_KOKKOS(void)
1775d71ae5a4SJacob Faibussowitsch {
177686a27549SJunchao Zhang   PetscFunctionBegin;
17779566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_SeqAIJKokkos_Kokkos));
17789566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ILU, MatGetFactor_SeqAIJKokkos_Kokkos));
17793ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
178086a27549SJunchao Zhang }
178186a27549SJunchao Zhang 
1782076ba34aSJunchao Zhang /* Utility to print out a KokkosCsrMatrix for debugging */
1783d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PrintCsrMatrix(const KokkosCsrMatrix &csrmat)
1784d71ae5a4SJacob Faibussowitsch {
1785076ba34aSJunchao Zhang   const auto        &iv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.row_map);
1786076ba34aSJunchao Zhang   const auto        &jv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.entries);
1787076ba34aSJunchao Zhang   const auto        &av = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.values);
1788076ba34aSJunchao Zhang   const PetscInt    *i  = iv.data();
1789076ba34aSJunchao Zhang   const PetscInt    *j  = jv.data();
1790076ba34aSJunchao Zhang   const PetscScalar *a  = av.data();
1791076ba34aSJunchao Zhang   PetscInt           m = csrmat.numRows(), n = csrmat.numCols(), nnz = csrmat.nnz();
1792076ba34aSJunchao Zhang 
1793076ba34aSJunchao Zhang   PetscFunctionBegin;
17949566063dSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT " x %" PetscInt_FMT " SeqAIJKokkos, with %" PetscInt_FMT " nonzeros\n", m, n, nnz));
1795076ba34aSJunchao Zhang   for (PetscInt k = 0; k < m; k++) {
17969566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT ": ", k));
179748a46eb9SPierre Jolivet     for (PetscInt p = i[k]; p < i[k + 1]; p++) PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT "(%.1f), ", j[p], (double)PetscRealPart(a[p])));
17989566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "\n"));
1799076ba34aSJunchao Zhang   }
18003ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1801076ba34aSJunchao Zhang }
1802