xref: /petsc/src/mat/impls/aij/seq/kokkos/aijkok.kokkos.cxx (revision e36ced11a701f37c2c4188d964659bff56c765a0)
1*e36ced11SJunchao Zhang #include <petsc_kokkos.hpp>
211d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
3076ba34aSJunchao Zhang #include <petscpkg_version.h>
4152b3e56SJunchao Zhang #include <petsc/private/petscimpl.h>
542550becSJunchao Zhang #include <petsc/private/sfimpl.h>
68c3ff71bSJunchao Zhang #include <petscsystypes.h>
78c3ff71bSJunchao Zhang #include <petscerror.h>
88c3ff71bSJunchao Zhang 
98c3ff71bSJunchao Zhang #include <Kokkos_Core.hpp>
10f0cf5187SStefano Zampini #include <KokkosBlas.hpp>
118c3ff71bSJunchao Zhang #include <KokkosSparse_CrsMatrix.hpp>
128c3ff71bSJunchao Zhang #include <KokkosSparse_spmv.hpp>
1386a27549SJunchao Zhang #include <KokkosSparse_spiluk.hpp>
1486a27549SJunchao Zhang #include <KokkosSparse_sptrsv.hpp>
15076ba34aSJunchao Zhang #include <KokkosSparse_spgemm.hpp>
16076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
179d13fa56SJunchao Zhang #include <KokkosBatched_LU_Decl.hpp>
189d13fa56SJunchao Zhang #include <KokkosBatched_InverseLU_Decl.hpp>
1986a27549SJunchao Zhang 
2042550becSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
218c3ff71bSJunchao Zhang 
220e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 7, 0)
23f98996d3SJunchao Zhang   #include <KokkosSparse_Utils.hpp>
24f98996d3SJunchao Zhang using KokkosSparse::sort_crs_matrix;
259371c9d4SSatish Balay using KokkosSparse::Impl::transpose_matrix;
26f98996d3SJunchao Zhang #else
27f98996d3SJunchao Zhang   #include <KokkosKernels_Sorting.hpp>
28f98996d3SJunchao Zhang using KokkosKernels::sort_crs_matrix;
299371c9d4SSatish Balay using KokkosKernels::Impl::transpose_matrix;
30f98996d3SJunchao Zhang #endif
31f98996d3SJunchao Zhang 
328c3ff71bSJunchao Zhang static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat); /* Forward declaration */
338c3ff71bSJunchao Zhang 
34076ba34aSJunchao Zhang /* MatAssemblyEnd_SeqAIJKokkos() happens when we finalized nonzeros of the matrix, either after
35076ba34aSJunchao Zhang    we assembled the matrix on host, or after we directly produced the matrix data on device (ex., through MatMatMult).
36076ba34aSJunchao Zhang    In the latter case, it is important to set a_dual's sync state correctly.
37076ba34aSJunchao Zhang  */
38d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_SeqAIJKokkos(Mat A, MatAssemblyType mode)
39d71ae5a4SJacob Faibussowitsch {
40076ba34aSJunchao Zhang   Mat_SeqAIJ       *aijseq;
41076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
428c3ff71bSJunchao Zhang 
438c3ff71bSJunchao Zhang   PetscFunctionBegin;
443ba16761SJacob Faibussowitsch   if (mode == MAT_FLUSH_ASSEMBLY) PetscFunctionReturn(PETSC_SUCCESS);
459566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
46076ba34aSJunchao Zhang 
47076ba34aSJunchao Zhang   aijseq = static_cast<Mat_SeqAIJ *>(A->data);
48076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
49076ba34aSJunchao Zhang 
50076ba34aSJunchao Zhang   /* If aijkok does not exist, we just copy i, j to device.
51076ba34aSJunchao Zhang      If aijkok already exists, but the device's nonzero pattern does not match with the host's, we assume the latest data is on host.
52076ba34aSJunchao Zhang      In both cases, we build a new aijkok structure.
53076ba34aSJunchao Zhang   */
54076ba34aSJunchao Zhang   if (!aijkok || aijkok->nonzerostate != A->nonzerostate) { /* aijkok might not exist yet or nonzero pattern has changed */
55076ba34aSJunchao Zhang     delete aijkok;
56076ba34aSJunchao Zhang     aijkok   = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aijseq->nz, aijseq->i, aijseq->j, aijseq->a, A->nonzerostate, PETSC_FALSE /*don't copy mat values to device*/);
57076ba34aSJunchao Zhang     A->spptr = aijkok;
58076ba34aSJunchao Zhang   }
59076ba34aSJunchao Zhang 
603ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
618c3ff71bSJunchao Zhang }
628c3ff71bSJunchao Zhang 
6386a27549SJunchao Zhang /* Sync CSR data to device if not yet */
64d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosSyncDevice(Mat A)
65d71ae5a4SJacob Faibussowitsch {
668c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
678c3ff71bSJunchao Zhang 
688c3ff71bSJunchao Zhang   PetscFunctionBegin;
69aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from host to device");
705f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
71076ba34aSJunchao Zhang   if (aijkok->a_dual.need_sync_device()) {
72076ba34aSJunchao Zhang     aijkok->a_dual.sync_device();
73580c7c76SPierre Jolivet     aijkok->transpose_updated = PETSC_FALSE; /* values of the transpose is out-of-date */
7486a27549SJunchao Zhang     aijkok->hermitian_updated = PETSC_FALSE;
758c3ff71bSJunchao Zhang   }
763ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
778c3ff71bSJunchao Zhang }
788c3ff71bSJunchao Zhang 
79076ba34aSJunchao Zhang /* Mark the CSR data on device as modified */
80d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosModifyDevice(Mat A)
81d71ae5a4SJacob Faibussowitsch {
8286a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
8386a27549SJunchao Zhang 
8486a27549SJunchao Zhang   PetscFunctionBegin;
855f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Not supported for factorized matries");
8686a27549SJunchao Zhang   aijkok->a_dual.clear_sync_state();
8786a27549SJunchao Zhang   aijkok->a_dual.modify_device();
8886a27549SJunchao Zhang   aijkok->transpose_updated = PETSC_FALSE;
8986a27549SJunchao Zhang   aijkok->hermitian_updated = PETSC_FALSE;
909566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJInvalidateDiagonal(A));
919566063dSJacob Faibussowitsch   PetscCall(PetscObjectStateIncrease((PetscObject)A));
923ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
9386a27549SJunchao Zhang }
9486a27549SJunchao Zhang 
95d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSyncHost(Mat A)
96d71ae5a4SJacob Faibussowitsch {
97f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
98*e36ced11SJunchao Zhang   auto             &exec   = PetscGetKokkosExecutionSpace();
99f0cf5187SStefano Zampini 
100f0cf5187SStefano Zampini   PetscFunctionBegin;
101f0cf5187SStefano Zampini   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10286a27549SJunchao Zhang   /* We do not expect one needs factors on host  */
103aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from device to host");
1045f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing AIJKOK");
105*e36ced11SJunchao Zhang   PetscCallCXX(aijkok->a_dual.sync_host(exec));
106*e36ced11SJunchao Zhang   PetscCallCXX(exec.fence());
1073ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
108f0cf5187SStefano Zampini }
109f0cf5187SStefano Zampini 
110d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
111d71ae5a4SJacob Faibussowitsch {
112076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
113f0cf5187SStefano Zampini 
114f0cf5187SStefano Zampini   PetscFunctionBegin;
1155519a089SJose E. Roman   /* aijkok contains valid pointers only if the host's nonzerostate matches with the device's.
1165519a089SJose E. Roman     Calling MatSeqAIJSetPreallocation() or MatSetValues() on host, where aijseq->{i,j,a} might be
1175519a089SJose E. Roman     reallocated, will lead to stale {i,j,a}_dual in aijkok. In both operations, the hosts's nonzerostate
1185519a089SJose E. Roman     must have been updated. The stale aijkok will be rebuilt during MatAssemblyEnd.
1195519a089SJose E. Roman   */
1205519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
121*e36ced11SJunchao Zhang     auto &exec = PetscGetKokkosExecutionSpace();
122*e36ced11SJunchao Zhang     PetscCallCXX(aijkok->a_dual.sync_host(exec));
123*e36ced11SJunchao Zhang     PetscCallCXX(exec.fence());
124076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
125076ba34aSJunchao Zhang   } else { /* Happens when calling MatSetValues on a newly created matrix */
126076ba34aSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
127076ba34aSJunchao Zhang   }
1283ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
129076ba34aSJunchao Zhang }
130076ba34aSJunchao Zhang 
131d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
132d71ae5a4SJacob Faibussowitsch {
133076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
134076ba34aSJunchao Zhang 
135076ba34aSJunchao Zhang   PetscFunctionBegin;
1365519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) aijkok->a_dual.modify_host();
1373ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
138076ba34aSJunchao Zhang }
139076ba34aSJunchao Zhang 
140d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
141d71ae5a4SJacob Faibussowitsch {
142076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
143076ba34aSJunchao Zhang 
144076ba34aSJunchao Zhang   PetscFunctionBegin;
1455519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
146*e36ced11SJunchao Zhang     auto &exec = PetscGetKokkosExecutionSpace();
147*e36ced11SJunchao Zhang     PetscCallCXX(aijkok->a_dual.sync_host(exec));
148*e36ced11SJunchao Zhang     PetscCallCXX(exec.fence());
149076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1502328674fSJunchao Zhang   } else {
1512328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1522328674fSJunchao Zhang   }
1533ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
154076ba34aSJunchao Zhang }
155076ba34aSJunchao Zhang 
156d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
157d71ae5a4SJacob Faibussowitsch {
158076ba34aSJunchao Zhang   PetscFunctionBegin;
159076ba34aSJunchao Zhang   *array = NULL;
1603ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
161076ba34aSJunchao Zhang }
162076ba34aSJunchao Zhang 
163d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
164d71ae5a4SJacob Faibussowitsch {
165076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
166076ba34aSJunchao Zhang 
167076ba34aSJunchao Zhang   PetscFunctionBegin;
1685519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
169076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1702328674fSJunchao Zhang   } else { /* Ex. happens with MatZeroEntries on a preallocated but not assembled matrix */
1712328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1722328674fSJunchao Zhang   }
1733ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
174076ba34aSJunchao Zhang }
175076ba34aSJunchao Zhang 
176d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
177d71ae5a4SJacob Faibussowitsch {
178076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
179076ba34aSJunchao Zhang 
180076ba34aSJunchao Zhang   PetscFunctionBegin;
1815519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
182076ba34aSJunchao Zhang     aijkok->a_dual.clear_sync_state();
183076ba34aSJunchao Zhang     aijkok->a_dual.modify_host();
1842328674fSJunchao Zhang   }
1853ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
186f0cf5187SStefano Zampini }
187f0cf5187SStefano Zampini 
188d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJKokkos(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
189d71ae5a4SJacob Faibussowitsch {
1907ee59b9bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1917ee59b9bSJunchao Zhang 
1927ee59b9bSJunchao Zhang   PetscFunctionBegin;
1937ee59b9bSJunchao Zhang   PetscCheck(aijkok != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "aijkok is NULL");
1947ee59b9bSJunchao Zhang 
1957ee59b9bSJunchao Zhang   if (i) *i = aijkok->i_device_data();
1967ee59b9bSJunchao Zhang   if (j) *j = aijkok->j_device_data();
1977ee59b9bSJunchao Zhang   if (a) {
1987ee59b9bSJunchao Zhang     aijkok->a_dual.sync_device();
1997ee59b9bSJunchao Zhang     *a = aijkok->a_device_data();
2007ee59b9bSJunchao Zhang   }
2017ee59b9bSJunchao Zhang   if (mtype) *mtype = PETSC_MEMTYPE_KOKKOS;
2023ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2037ee59b9bSJunchao Zhang }
2047ee59b9bSJunchao Zhang 
2050e3ece09SJunchao Zhang /*
2060e3ece09SJunchao Zhang   Generate the sparsity pattern of a MatSeqAIJKokkos matrix's transpose on device.
2070e3ece09SJunchao Zhang 
2080e3ece09SJunchao Zhang   Input Parameter:
2090e3ece09SJunchao Zhang .  A       - the MATSEQAIJKOKKOS matrix
2100e3ece09SJunchao Zhang 
2110e3ece09SJunchao Zhang   Output Parameters:
2120e3ece09SJunchao Zhang +  perm_d - the permutation array on device, which connects Ta(i) = Aa(perm(i))
213aaa8cc7dSPierre Jolivet -  T_d    - the transpose on device, whose value array is allocated but not initialized
2140e3ece09SJunchao Zhang */
2150e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateTransposeStructure(Mat A, MatRowMapKokkosView &perm_d, KokkosCsrMatrix &T_d)
216d71ae5a4SJacob Faibussowitsch {
2170e3ece09SJunchao Zhang   Mat_SeqAIJ             *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2180e3ece09SJunchao Zhang   PetscInt                nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
2190e3ece09SJunchao Zhang   const PetscInt         *Ai = aseq->i, *Aj = aseq->j;
2207b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost Ti_h(NoInit("Ti"), n + 1);
2210e3ece09SJunchao Zhang   MatRowMapType          *Ti = Ti_h.data();
2227b8d4ba6SJunchao Zhang   MatColIdxKokkosViewHost Tj_h(NoInit("Tj"), nz);
2237b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost perm_h(NoInit("permutation"), nz);
2240e3ece09SJunchao Zhang   PetscInt               *Tj   = Tj_h.data();
2250e3ece09SJunchao Zhang   PetscInt               *perm = perm_h.data();
2260e3ece09SJunchao Zhang   PetscInt               *offset;
227152b3e56SJunchao Zhang 
228152b3e56SJunchao Zhang   PetscFunctionBegin;
2290e3ece09SJunchao Zhang   // Populate Ti
2300e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::deep_copy(Ti_h, 0));
2310e3ece09SJunchao Zhang   Ti++;
2320e3ece09SJunchao Zhang   for (PetscInt i = 0; i < nz; i++) Ti[Aj[i]]++;
2330e3ece09SJunchao Zhang   Ti--;
2340e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) Ti[i + 1] += Ti[i];
2350e3ece09SJunchao Zhang 
2360e3ece09SJunchao Zhang   // Populate Tj and the permutation array
2370e3ece09SJunchao Zhang   PetscCall(PetscCalloc1(n, &offset)); // offset in each T row to fill in its column indices
2380e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
2390e3ece09SJunchao Zhang     for (PetscInt j = Ai[i]; j < Ai[i + 1]; j++) { // A's (i,j) is T's (j,i)
2400e3ece09SJunchao Zhang       PetscInt r    = Aj[j];                       // row r of T
2410e3ece09SJunchao Zhang       PetscInt disp = Ti[r] + offset[r];
2420e3ece09SJunchao Zhang 
2430e3ece09SJunchao Zhang       Tj[disp]   = i; // col i of T
2440e3ece09SJunchao Zhang       perm[disp] = j;
2450e3ece09SJunchao Zhang       offset[r]++;
246076ba34aSJunchao Zhang     }
2470e3ece09SJunchao Zhang   }
2480e3ece09SJunchao Zhang   PetscCall(PetscFree(offset));
2490e3ece09SJunchao Zhang 
2500e3ece09SJunchao Zhang   // Sort each row of T, along with the permutation array
2510e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) PetscCall(PetscSortIntWithArray(Ti[i + 1] - Ti[i], Tj + Ti[i], perm + Ti[i]));
2520e3ece09SJunchao Zhang 
2530e3ece09SJunchao Zhang   // Output perm and T on device
2540e3ece09SJunchao Zhang   auto Ti_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ti_h);
2550e3ece09SJunchao Zhang   auto Tj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Tj_h);
2560e3ece09SJunchao Zhang   PetscCallCXX(T_d = KokkosCsrMatrix("csrmatT", n, m, nz, MatScalarKokkosView("Ta", nz), Ti_d, Tj_d));
2570e3ece09SJunchao Zhang   PetscCallCXX(perm_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), perm_h));
2583ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
259152b3e56SJunchao Zhang }
260152b3e56SJunchao Zhang 
2610e3ece09SJunchao Zhang // Generate the transpose on device and cache it internally
2620e3ece09SJunchao Zhang // Note: KK transpose_matrix() does not have support symbolic/numeric transpose, so we do it on our own
2630e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGenerateTranspose_Private(Mat A, KokkosCsrMatrix *csrmatT)
264d71ae5a4SJacob Faibussowitsch {
2650e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2660e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
2670e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
2680e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatT;
269152b3e56SJunchao Zhang 
270152b3e56SJunchao Zhang   PetscFunctionBegin;
2710e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
272145b44c9SPierre Jolivet   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's values since we are going to access them on device
2730e3ece09SJunchao Zhang 
2740e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
2750e3ece09SJunchao Zhang 
2760e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE) {
2770e3ece09SJunchao Zhang     *csrmatT = akok->csrmat;
2780e3ece09SJunchao Zhang   } else {
2790e3ece09SJunchao Zhang     // See if we already have a cached transpose and its value is up to date
2800e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
2810e3ece09SJunchao Zhang       if (!akok->transpose_updated) {            // if the value is out of date, update the cached version
2820e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
2830e3ece09SJunchao Zhang         auto       &Ta   = T.values;
2840e3ece09SJunchao Zhang 
2850e3ece09SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(
2860e3ece09SJunchao Zhang           nz, KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = Aa(perm(i)); }));
287076ba34aSJunchao Zhang       }
2880e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
2890e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
2900e3ece09SJunchao Zhang 
2910e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
2920e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
2930e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(
2940e3ece09SJunchao Zhang         nz, KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = Aa(perm(i)); }));
2950e3ece09SJunchao Zhang     }
2960e3ece09SJunchao Zhang     akok->transpose_updated = PETSC_TRUE;
2970e3ece09SJunchao Zhang     *csrmatT                = akok->csrmatT;
2980e3ece09SJunchao Zhang   }
2990e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3000e3ece09SJunchao Zhang }
3010e3ece09SJunchao Zhang 
3020e3ece09SJunchao Zhang // Generate the Hermitian on device and cache it internally
3030e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateHermitian_Private(Mat A, KokkosCsrMatrix *csrmatH)
3040e3ece09SJunchao Zhang {
3050e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
3060e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3070e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
3080e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatH;
3090e3ece09SJunchao Zhang 
3100e3ece09SJunchao Zhang   PetscFunctionBegin;
3110e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
3120e3ece09SJunchao Zhang   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's values since we are going to access them on device
3130e3ece09SJunchao Zhang 
3140e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
3150e3ece09SJunchao Zhang 
3160e3ece09SJunchao Zhang   if (A->hermitian == PETSC_BOOL3_TRUE) {
3170e3ece09SJunchao Zhang     *csrmatH = akok->csrmat;
3180e3ece09SJunchao Zhang   } else {
3190e3ece09SJunchao Zhang     // See if we already have a cached hermitian and its value is up to date
3200e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
3210e3ece09SJunchao Zhang       if (!akok->hermitian_updated) {            // if the value is out of date, update the cached version
3220e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
3230e3ece09SJunchao Zhang         auto       &Ta   = T.values;
3240e3ece09SJunchao Zhang 
3250e3ece09SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(
3260e3ece09SJunchao Zhang           nz, KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = PetscConj(Aa(perm(i))); }));
3270e3ece09SJunchao Zhang       }
3280e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
3290e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
3300e3ece09SJunchao Zhang 
3310e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
3320e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
3330e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(
3340e3ece09SJunchao Zhang         nz, KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = PetscConj(Aa(perm(i))); }));
3350e3ece09SJunchao Zhang     }
3360e3ece09SJunchao Zhang     akok->hermitian_updated = PETSC_TRUE;
3370e3ece09SJunchao Zhang     *csrmatH                = akok->csrmatH;
3380e3ece09SJunchao Zhang   }
3393ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
340152b3e56SJunchao Zhang }
341a587d139SMark 
3428c3ff71bSJunchao Zhang /* y = A x */
343d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMult_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
344d71ae5a4SJacob Faibussowitsch {
3458c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
346152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
347152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3488c3ff71bSJunchao Zhang 
3498c3ff71bSJunchao Zhang   PetscFunctionBegin;
3509566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3519566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3529566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3539566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
3548c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3559d52486cSJunchao Zhang   PetscCallCXX(KokkosSparse::spmv("N", 1.0 /*alpha*/, aijkok->csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A x + beta y */
3569566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3579566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
358076ba34aSJunchao Zhang   /* 2.0*nnz - numRows seems more accurate here but assumes there are no zero-rows. So a little sloppy here. */
3599566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
3609566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
3613ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
3628c3ff71bSJunchao Zhang }
3638c3ff71bSJunchao Zhang 
3648c3ff71bSJunchao Zhang /* y = A^T x */
365d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
366d71ae5a4SJacob Faibussowitsch {
3678c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
368152b3e56SJunchao Zhang   const char                *mode;
369152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
370152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3710e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
3728c3ff71bSJunchao Zhang 
3738c3ff71bSJunchao Zhang   PetscFunctionBegin;
3749566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3759566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3769566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3779566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
378152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
3799566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
380152b3e56SJunchao Zhang     mode = "N";
381152b3e56SJunchao Zhang   } else {
382076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3830e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
384152b3e56SJunchao Zhang     mode   = "T";
385152b3e56SJunchao Zhang   }
3860e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^T x + beta y */
3879566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3889566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
3890e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
3909566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
3913ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
3928c3ff71bSJunchao Zhang }
3938c3ff71bSJunchao Zhang 
3948c3ff71bSJunchao Zhang /* y = A^H x */
395d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
396d71ae5a4SJacob Faibussowitsch {
3978c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
398152b3e56SJunchao Zhang   const char                *mode;
399152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
400152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
4010e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4028c3ff71bSJunchao Zhang 
4038c3ff71bSJunchao Zhang   PetscFunctionBegin;
4049566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4059566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4069566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4079566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
408152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4099566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
410152b3e56SJunchao Zhang     mode = "N";
411152b3e56SJunchao Zhang   } else {
412076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4130e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
414152b3e56SJunchao Zhang     mode   = "C";
415152b3e56SJunchao Zhang   }
4160e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^H x + beta y */
4179566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4189566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
4190e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4209566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4213ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4228c3ff71bSJunchao Zhang }
4238c3ff71bSJunchao Zhang 
4248c3ff71bSJunchao Zhang /* z = A x + y */
425d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
426d71ae5a4SJacob Faibussowitsch {
4278c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
428152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
429152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4308c3ff71bSJunchao Zhang 
4318c3ff71bSJunchao Zhang   PetscFunctionBegin;
4329566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4339566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4349566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4359566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
4369566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
4378c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
4388c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4399d52486cSJunchao Zhang   PetscCallCXX(KokkosSparse::spmv("N", 1.0 /*alpha*/, aijkok->csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A x + beta z */
4409566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4419566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
4429566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
4439566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
4449566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4453ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4468c3ff71bSJunchao Zhang }
4478c3ff71bSJunchao Zhang 
4488c3ff71bSJunchao Zhang /* z = A^T x + y */
449d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
450d71ae5a4SJacob Faibussowitsch {
4518c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
452152b3e56SJunchao Zhang   const char                *mode;
453152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
454152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4550e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4568c3ff71bSJunchao Zhang 
4578c3ff71bSJunchao Zhang   PetscFunctionBegin;
4589566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4599566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4609566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4619566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
4629566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
4638c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
464152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4659566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
466152b3e56SJunchao Zhang     mode = "N";
467152b3e56SJunchao Zhang   } else {
468076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4690e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
470152b3e56SJunchao Zhang     mode   = "T";
471152b3e56SJunchao Zhang   }
4720e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^T x + beta z */
4739566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4749566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
4759566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
4760e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4779566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4798c3ff71bSJunchao Zhang }
4808c3ff71bSJunchao Zhang 
4818c3ff71bSJunchao Zhang /* z = A^H x + y */
482d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
483d71ae5a4SJacob Faibussowitsch {
4848c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
485152b3e56SJunchao Zhang   const char                *mode;
486152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
487152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4880e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4898c3ff71bSJunchao Zhang 
4908c3ff71bSJunchao Zhang   PetscFunctionBegin;
4919566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4929566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4939566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4949566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
4959566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
4968c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
497152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4989566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
499152b3e56SJunchao Zhang     mode = "N";
500152b3e56SJunchao Zhang   } else {
501076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
5020e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
503152b3e56SJunchao Zhang     mode   = "C";
504152b3e56SJunchao Zhang   }
5050e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^H x + beta z */
5069566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
5079566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
5089566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
5090e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
5109566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
5113ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
512152b3e56SJunchao Zhang }
513152b3e56SJunchao Zhang 
514d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSetOption_SeqAIJKokkos(Mat A, MatOption op, PetscBool flg)
515d71ae5a4SJacob Faibussowitsch {
516152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
517152b3e56SJunchao Zhang 
518152b3e56SJunchao Zhang   PetscFunctionBegin;
519152b3e56SJunchao Zhang   switch (op) {
520152b3e56SJunchao Zhang   case MAT_FORM_EXPLICIT_TRANSPOSE:
521152b3e56SJunchao Zhang     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
5229566063dSJacob Faibussowitsch     if (A->form_explicit_transpose && !flg && aijkok) PetscCall(aijkok->DestroyMatTranspose());
523152b3e56SJunchao Zhang     A->form_explicit_transpose = flg;
524152b3e56SJunchao Zhang     break;
525d71ae5a4SJacob Faibussowitsch   default:
526d71ae5a4SJacob Faibussowitsch     PetscCall(MatSetOption_SeqAIJ(A, op, flg));
527d71ae5a4SJacob Faibussowitsch     break;
528152b3e56SJunchao Zhang   }
5293ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5308c3ff71bSJunchao Zhang }
5318c3ff71bSJunchao Zhang 
532076ba34aSJunchao Zhang /* Depending on reuse, either build a new mat, or use the existing mat */
533d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
534d71ae5a4SJacob Faibussowitsch {
535076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
5368c3ff71bSJunchao Zhang 
5378c3ff71bSJunchao Zhang   PetscFunctionBegin;
5389566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
539076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {                      /* Build a brand new mat */
5409566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));  /* the returned newmat is a SeqAIJKokkos */
5418c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {                 /* Reuse the mat created before */
5429566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); /* newmat is already a SeqAIJKokkos */
543076ba34aSJunchao Zhang   } else if (reuse == MAT_INPLACE_MATRIX) {               /* newmat is A */
5445f80ce2aSJacob Faibussowitsch     PetscCheck(A == *newmat, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "A != *newmat with MAT_INPLACE_MATRIX");
5459566063dSJacob Faibussowitsch     PetscCall(PetscFree(A->defaultvectype));
5469566063dSJacob Faibussowitsch     PetscCall(PetscStrallocpy(VECKOKKOS, &A->defaultvectype)); /* Allocate and copy the string */
5479566063dSJacob Faibussowitsch     PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSEQAIJKOKKOS));
5489566063dSJacob Faibussowitsch     PetscCall(MatSetOps_SeqAIJKokkos(A));
549076ba34aSJunchao Zhang     aseq = static_cast<Mat_SeqAIJ *>(A->data);
550394ed5ebSJunchao Zhang     if (A->assembled) { /* Copy i, j (but not values) to device for an assembled matrix if not yet */
5515f80ce2aSJacob Faibussowitsch       PetscCheck(!A->spptr, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Expect NULL (Mat_SeqAIJKokkos*)A->spptr");
552076ba34aSJunchao Zhang       A->spptr = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aseq->nz, aseq->i, aseq->j, aseq->a, A->nonzerostate, PETSC_FALSE);
5538c3ff71bSJunchao Zhang     }
554076ba34aSJunchao Zhang   }
5553ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5568c3ff71bSJunchao Zhang }
5578c3ff71bSJunchao Zhang 
558076ba34aSJunchao Zhang /* MatDuplicate always creates a new matrix. MatDuplicate can be called either on an assembled matrix or
559076ba34aSJunchao Zhang    an unassembled matrix, even though MAT_COPY_VALUES is not allowed for unassembled matrix.
560076ba34aSJunchao Zhang  */
561d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDuplicate_SeqAIJKokkos(Mat A, MatDuplicateOption dupOption, Mat *B)
562d71ae5a4SJacob Faibussowitsch {
563076ba34aSJunchao Zhang   Mat_SeqAIJ       *bseq;
564076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok;
565076ba34aSJunchao Zhang   Mat               mat;
5668c3ff71bSJunchao Zhang 
5678c3ff71bSJunchao Zhang   PetscFunctionBegin;
568076ba34aSJunchao Zhang   /* Do not copy values on host as A's latest values might be on device. We don't want to do sync blindly */
5699566063dSJacob Faibussowitsch   PetscCall(MatDuplicate_SeqAIJ(A, MAT_DO_NOT_COPY_VALUES, B));
570076ba34aSJunchao Zhang   mat = *B;
5712c4ab24aSJunchao Zhang   if (A->preallocated) {
572076ba34aSJunchao Zhang     bseq = static_cast<Mat_SeqAIJ *>(mat->data);
573076ba34aSJunchao Zhang     bkok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, bseq->nz, bseq->i, bseq->j, bseq->a, mat->nonzerostate, PETSC_FALSE);
574076ba34aSJunchao Zhang     bkok->a_dual.clear_sync_state(); /* Clear B's sync state as it will be decided below */
575076ba34aSJunchao Zhang     /* Now copy values to B if needed */
576076ba34aSJunchao Zhang     if (dupOption == MAT_COPY_VALUES) {
577076ba34aSJunchao Zhang       if (akok->a_dual.need_sync_device()) {
578076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_host(), akok->a_dual.view_host());
579076ba34aSJunchao Zhang         bkok->a_dual.modify_host();
580076ba34aSJunchao Zhang       } else { /* If device has the latest data, we only copy data on device */
581076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_device(), akok->a_dual.view_device());
582076ba34aSJunchao Zhang         bkok->a_dual.modify_device();
583076ba34aSJunchao Zhang       }
584076ba34aSJunchao Zhang     } else { /* MAT_DO_NOT_COPY_VALUES or MAT_SHARE_NONZERO_PATTERN. B's values should be zeroed */
585076ba34aSJunchao Zhang       /* B's values on host should be already zeroed by MatDuplicate_SeqAIJ() */
586076ba34aSJunchao Zhang       bkok->a_dual.modify_host();
587076ba34aSJunchao Zhang     }
588076ba34aSJunchao Zhang     mat->spptr = bkok;
589076ba34aSJunchao Zhang   }
590076ba34aSJunchao Zhang 
5919566063dSJacob Faibussowitsch   PetscCall(PetscFree(mat->defaultvectype));
5929566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &mat->defaultvectype)); /* Allocate and copy the string */
5939566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)mat, MATSEQAIJKOKKOS));
5949566063dSJacob Faibussowitsch   PetscCall(MatSetOps_SeqAIJKokkos(mat));
5953ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5968c3ff71bSJunchao Zhang }
5978c3ff71bSJunchao Zhang 
598d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatTranspose_SeqAIJKokkos(Mat A, MatReuse reuse, Mat *B)
599d71ae5a4SJacob Faibussowitsch {
6000ecb592aSJunchao Zhang   Mat               At;
6010e3ece09SJunchao Zhang   KokkosCsrMatrix   internT;
6020ecb592aSJunchao Zhang   Mat_SeqAIJKokkos *atkok, *bkok;
6030ecb592aSJunchao Zhang 
6040ecb592aSJunchao Zhang   PetscFunctionBegin;
6057fb60732SBarry Smith   if (reuse == MAT_REUSE_MATRIX) PetscCall(MatTransposeCheckNonzeroState_Private(A, *B));
6069566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &internT)); /* Generate a transpose internally */
6070ecb592aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX || reuse == MAT_INPLACE_MATRIX) {
608ff751488SJunchao Zhang     /* Deep copy internT, as we want to isolate the internal transpose */
6090e3ece09SJunchao Zhang     PetscCallCXX(atkok = new Mat_SeqAIJKokkos(KokkosCsrMatrix("csrmat", internT)));
6109566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PetscObjectComm((PetscObject)A), atkok, &At));
6110ecb592aSJunchao Zhang     if (reuse == MAT_INITIAL_MATRIX) *B = At;
6129566063dSJacob Faibussowitsch     else PetscCall(MatHeaderReplace(A, &At)); /* Replace A with At inplace */
6130ecb592aSJunchao Zhang   } else {                                    /* MAT_REUSE_MATRIX, just need to copy values to B on device */
6140ecb592aSJunchao Zhang     if ((*B)->assembled) {
6150ecb592aSJunchao Zhang       bkok = static_cast<Mat_SeqAIJKokkos *>((*B)->spptr);
6160e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(bkok->a_dual.view_device(), internT.values));
6179566063dSJacob Faibussowitsch       PetscCall(MatSeqAIJKokkosModifyDevice(*B));
6180ecb592aSJunchao Zhang     } else if ((*B)->preallocated) { /* It is ok for B to be only preallocated, as needed in MatTranspose_MPIAIJ */
6190ecb592aSJunchao Zhang       Mat_SeqAIJ             *bseq = static_cast<Mat_SeqAIJ *>((*B)->data);
6200e3ece09SJunchao Zhang       MatScalarKokkosViewHost a_h(bseq->a, internT.nnz()); /* bseq->nz = 0 if unassembled */
6210e3ece09SJunchao Zhang       MatColIdxKokkosViewHost j_h(bseq->j, internT.nnz());
6220e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(a_h, internT.values));
6230e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(j_h, internT.graph.entries));
6240ecb592aSJunchao Zhang     } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "B must be assembled or preallocated");
6250ecb592aSJunchao Zhang   }
6263ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6270ecb592aSJunchao Zhang }
6280ecb592aSJunchao Zhang 
629d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDestroy_SeqAIJKokkos(Mat A)
630d71ae5a4SJacob Faibussowitsch {
63186a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
6328c3ff71bSJunchao Zhang 
6338c3ff71bSJunchao Zhang   PetscFunctionBegin;
63486a27549SJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
63586a27549SJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
6368c3ff71bSJunchao Zhang     delete aijkok;
63786a27549SJunchao Zhang   } else {
63886a27549SJunchao Zhang     delete static_cast<Mat_SeqAIJKokkosTriFactors *>(A->spptr);
63986a27549SJunchao Zhang   }
640cbc6b225SStefano Zampini   A->spptr = NULL;
6419566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
6429566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
6439566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
6449566063dSJacob Faibussowitsch   PetscCall(MatDestroy_SeqAIJ(A));
6453ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6468c3ff71bSJunchao Zhang }
6478c3ff71bSJunchao Zhang 
6483f3ba80aSJunchao Zhang /*MC
6493f3ba80aSJunchao Zhang    MATSEQAIJKOKKOS - MATAIJKOKKOS = "(seq)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
6503f3ba80aSJunchao Zhang 
6513f3ba80aSJunchao Zhang    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
6523f3ba80aSJunchao Zhang 
6532ef1f0ffSBarry Smith    Options Database Key:
65411a5261eSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATSEQAIJKOKKOS` during a call to `MatSetFromOptions()`
6553f3ba80aSJunchao Zhang 
6563f3ba80aSJunchao Zhang   Level: beginner
6573f3ba80aSJunchao Zhang 
6581cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateSeqAIJKokkos()`, `MATMPIAIJKOKKOS`
6593f3ba80aSJunchao Zhang M*/
660d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJKokkos(Mat A)
661d71ae5a4SJacob Faibussowitsch {
66286a27549SJunchao Zhang   PetscFunctionBegin;
6639566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
6649566063dSJacob Faibussowitsch   PetscCall(MatCreate_SeqAIJ(A));
6659566063dSJacob Faibussowitsch   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
6663ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
66786a27549SJunchao Zhang }
66886a27549SJunchao Zhang 
669076ba34aSJunchao Zhang /* Merge A, B into a matrix C. A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n) */
670d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
671d71ae5a4SJacob Faibussowitsch {
672076ba34aSJunchao Zhang   Mat_SeqAIJ         *a, *b;
673076ba34aSJunchao Zhang   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
674076ba34aSJunchao Zhang   MatScalarKokkosView aa, ba, ca;
675076ba34aSJunchao Zhang   MatRowMapKokkosView ai, bi, ci;
676076ba34aSJunchao Zhang   MatColIdxKokkosView aj, bj, cj;
677076ba34aSJunchao Zhang   PetscInt            m, n, nnz, aN;
678a3f881fbSStefano Zampini 
679a3f881fbSStefano Zampini   PetscFunctionBegin;
680076ba34aSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
681076ba34aSJunchao Zhang   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
6824f572ea9SToby Isaac   PetscAssertPointer(C, 4);
683076ba34aSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
684076ba34aSJunchao Zhang   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
6855f80ce2aSJacob Faibussowitsch   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
6865f80ce2aSJacob Faibussowitsch   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
687076ba34aSJunchao Zhang 
6889566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
6899566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
690076ba34aSJunchao Zhang   a    = static_cast<Mat_SeqAIJ *>(A->data);
691076ba34aSJunchao Zhang   b    = static_cast<Mat_SeqAIJ *>(B->data);
692076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
693076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
694076ba34aSJunchao Zhang   aa   = akok->a_dual.view_device();
695076ba34aSJunchao Zhang   ai   = akok->i_dual.view_device();
696076ba34aSJunchao Zhang   ba   = bkok->a_dual.view_device();
697076ba34aSJunchao Zhang   bi   = bkok->i_dual.view_device();
698076ba34aSJunchao Zhang   m    = A->rmap->n; /* M, N and nnz of C */
699076ba34aSJunchao Zhang   n    = A->cmap->n + B->cmap->n;
700076ba34aSJunchao Zhang   nnz  = a->nz + b->nz;
701076ba34aSJunchao Zhang   aN   = A->cmap->n; /* N of A */
702076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
703076ba34aSJunchao Zhang     aj           = akok->j_dual.view_device();
704076ba34aSJunchao Zhang     bj           = bkok->j_dual.view_device();
705076ba34aSJunchao Zhang     auto ca_dual = MatScalarKokkosDualView("a", aa.extent(0) + ba.extent(0));
706076ba34aSJunchao Zhang     auto ci_dual = MatRowMapKokkosDualView("i", ai.extent(0));
707076ba34aSJunchao Zhang     auto cj_dual = MatColIdxKokkosDualView("j", aj.extent(0) + bj.extent(0));
708076ba34aSJunchao Zhang     ca           = ca_dual.view_device();
709076ba34aSJunchao Zhang     ci           = ci_dual.view_device();
710076ba34aSJunchao Zhang     cj           = cj_dual.view_device();
711076ba34aSJunchao Zhang 
712076ba34aSJunchao Zhang     /* Concatenate A and B in parallel using Kokkos hierarchical parallelism */
7139371c9d4SSatish Balay     Kokkos::parallel_for(
7149371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
715076ba34aSJunchao Zhang         PetscInt i       = t.league_rank(); /* row i */
716076ba34aSJunchao Zhang         PetscInt coffset = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
717076ba34aSJunchao Zhang 
718076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* this side effect only happens once per whole team */
719076ba34aSJunchao Zhang                                                    ci(i) = coffset;
720076ba34aSJunchao Zhang                                                    if (i == m - 1) ci(m) = ai(m) + bi(m);
721076ba34aSJunchao Zhang         });
722076ba34aSJunchao Zhang 
723076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
724076ba34aSJunchao Zhang           if (k < alen) {
725076ba34aSJunchao Zhang             ca(coffset + k) = aa(ai(i) + k);
726076ba34aSJunchao Zhang             cj(coffset + k) = aj(ai(i) + k);
727076ba34aSJunchao Zhang           } else {
728076ba34aSJunchao Zhang             ca(coffset + k) = ba(bi(i) + k - alen);
729076ba34aSJunchao Zhang             cj(coffset + k) = bj(bi(i) + k - alen) + aN; /* Entries in B get new column indices in C */
730076ba34aSJunchao Zhang           }
731076ba34aSJunchao Zhang         });
732076ba34aSJunchao Zhang       });
733076ba34aSJunchao Zhang     ca_dual.modify_device();
734076ba34aSJunchao Zhang     ci_dual.modify_device();
735076ba34aSJunchao Zhang     cj_dual.modify_device();
7369566063dSJacob Faibussowitsch     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, n, nnz, ci_dual, cj_dual, ca_dual));
7379566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
738076ba34aSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
739076ba34aSJunchao Zhang     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
740076ba34aSJunchao Zhang     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
741076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
742076ba34aSJunchao Zhang     ca   = ckok->a_dual.view_device();
743076ba34aSJunchao Zhang     ci   = ckok->i_dual.view_device();
744076ba34aSJunchao Zhang 
7459371c9d4SSatish Balay     Kokkos::parallel_for(
7469371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
747076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* row i */
748076ba34aSJunchao Zhang         PetscInt alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
749076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
750076ba34aSJunchao Zhang           if (k < alen) ca(ci(i) + k) = aa(ai(i) + k);
751076ba34aSJunchao Zhang           else ca(ci(i) + k) = ba(bi(i) + k - alen);
752076ba34aSJunchao Zhang         });
753076ba34aSJunchao Zhang       });
7549566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
755076ba34aSJunchao Zhang   }
7563ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
757076ba34aSJunchao Zhang }
758076ba34aSJunchao Zhang 
759d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_SeqAIJKokkos(void *pdata)
760d71ae5a4SJacob Faibussowitsch {
761076ba34aSJunchao Zhang   PetscFunctionBegin;
762076ba34aSJunchao Zhang   delete static_cast<MatProductData_SeqAIJKokkos *>(pdata);
7633ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
764a3f881fbSStefano Zampini }
765a3f881fbSStefano Zampini 
766d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos(Mat C)
767d71ae5a4SJacob Faibussowitsch {
768a3f881fbSStefano Zampini   Mat_Product                 *product = C->product;
769a3f881fbSStefano Zampini   Mat                          A, B;
770076ba34aSJunchao Zhang   bool                         transA, transB; /* use bool, since KK needs this type */
771a3f881fbSStefano Zampini   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
772a3f881fbSStefano Zampini   Mat_SeqAIJ                  *c;
773076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
7740e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB;
775a3f881fbSStefano Zampini 
776a3f881fbSStefano Zampini   PetscFunctionBegin;
777a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
7785f80ce2aSJacob Faibussowitsch   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data empty");
779076ba34aSJunchao Zhang   pdata = static_cast<MatProductData_SeqAIJKokkos *>(C->product->data);
780076ba34aSJunchao Zhang 
7810e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
7820e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
7830e3ece09SJunchao Zhang   // we still do numeric.
7840e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
7850e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
7863ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
787076ba34aSJunchao Zhang   }
788076ba34aSJunchao Zhang 
789076ba34aSJunchao Zhang   switch (product->type) {
7909371c9d4SSatish Balay   case MATPRODUCT_AB:
7919371c9d4SSatish Balay     transA = false;
7929371c9d4SSatish Balay     transB = false;
7939371c9d4SSatish Balay     break;
7949371c9d4SSatish Balay   case MATPRODUCT_AtB:
7959371c9d4SSatish Balay     transA = true;
7969371c9d4SSatish Balay     transB = false;
7979371c9d4SSatish Balay     break;
7989371c9d4SSatish Balay   case MATPRODUCT_ABt:
7999371c9d4SSatish Balay     transA = false;
8009371c9d4SSatish Balay     transB = true;
8019371c9d4SSatish Balay     break;
802d71ae5a4SJacob Faibussowitsch   default:
803d71ae5a4SJacob Faibussowitsch     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
804076ba34aSJunchao Zhang   }
805076ba34aSJunchao Zhang 
806a3f881fbSStefano Zampini   A = product->A;
807a3f881fbSStefano Zampini   B = product->B;
8089566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8099566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
810a3f881fbSStefano Zampini   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
811a3f881fbSStefano Zampini   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
812a3f881fbSStefano Zampini   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
813076ba34aSJunchao Zhang 
8145f80ce2aSJacob Faibussowitsch   PetscCheck(ckok, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Device data structure spptr is empty");
815076ba34aSJunchao Zhang 
8160e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
8170e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
818076ba34aSJunchao Zhang 
819076ba34aSJunchao Zhang   /* TODO: Once KK spgemm implements transpose, we can get rid of the explicit transpose here */
820076ba34aSJunchao Zhang   if (transA) {
8219566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
822076ba34aSJunchao Zhang     transA = false;
823a3f881fbSStefano Zampini   }
824a3f881fbSStefano Zampini 
825076ba34aSJunchao Zhang   if (transB) {
8269566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
827076ba34aSJunchao Zhang     transB = false;
828076ba34aSJunchao Zhang   }
8299566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
8300e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, ckok->csrmat));
8310e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
832866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
833866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(ckok->csrmat)); /* without sort, mat_tests-ex62_14_seqaijkokkos fails */
834e944a159SJunchao Zhang #endif
835866eb059SJunchao Zhang 
8369566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
8379566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(C));
838a3f881fbSStefano Zampini   /* shorter version of MatAssemblyEnd_SeqAIJ */
839a3f881fbSStefano Zampini   c = (Mat_SeqAIJ *)C->data;
8409566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded,%" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
8419566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
8429566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
843a3f881fbSStefano Zampini   c->reallocs         = 0;
844076ba34aSJunchao Zhang   C->info.mallocs     = 0;
845a3f881fbSStefano Zampini   C->info.nz_unneeded = 0;
846a3f881fbSStefano Zampini   C->assembled = C->was_assembled = PETSC_TRUE;
847a3f881fbSStefano Zampini   C->num_ass++;
8483ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
849a3f881fbSStefano Zampini }
850a3f881fbSStefano Zampini 
851d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos(Mat C)
852d71ae5a4SJacob Faibussowitsch {
853076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
854076ba34aSJunchao Zhang   MatProductType               ptype;
855076ba34aSJunchao Zhang   Mat                          A, B;
856076ba34aSJunchao Zhang   bool                         transA, transB;
857076ba34aSJunchao Zhang   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
858076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
859076ba34aSJunchao Zhang   MPI_Comm                     comm;
8600e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB, csrmatC;
861a3f881fbSStefano Zampini 
862a3f881fbSStefano Zampini   PetscFunctionBegin;
863a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
8649566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
8655f80ce2aSJacob Faibussowitsch   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
866a3f881fbSStefano Zampini   A = product->A;
867a3f881fbSStefano Zampini   B = product->B;
8689566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8699566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
870a3f881fbSStefano Zampini   akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
871a3f881fbSStefano Zampini   bkok    = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
8720e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
8730e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
874076ba34aSJunchao Zhang 
875a3f881fbSStefano Zampini   ptype = product->type;
8760e3ece09SJunchao Zhang   // Take advantage of the symmetry if true
8770e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
8780e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
8790e3ece09SJunchao Zhang     product->symbolic_used_the_fact_A_is_symmetric = PETSC_TRUE;
8800e3ece09SJunchao Zhang   }
8810e3ece09SJunchao Zhang   if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
8820e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
8830e3ece09SJunchao Zhang     product->symbolic_used_the_fact_B_is_symmetric = PETSC_TRUE;
8840e3ece09SJunchao Zhang   }
8850e3ece09SJunchao Zhang 
886a3f881fbSStefano Zampini   switch (ptype) {
8879371c9d4SSatish Balay   case MATPRODUCT_AB:
8889371c9d4SSatish Balay     transA = false;
8899371c9d4SSatish Balay     transB = false;
8909371c9d4SSatish Balay     break;
8919371c9d4SSatish Balay   case MATPRODUCT_AtB:
8929371c9d4SSatish Balay     transA = true;
8939371c9d4SSatish Balay     transB = false;
8949371c9d4SSatish Balay     break;
8959371c9d4SSatish Balay   case MATPRODUCT_ABt:
8969371c9d4SSatish Balay     transA = false;
8979371c9d4SSatish Balay     transB = true;
8989371c9d4SSatish Balay     break;
899d71ae5a4SJacob Faibussowitsch   default:
900d71ae5a4SJacob Faibussowitsch     SETERRQ(comm, PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
901a3f881fbSStefano Zampini   }
9020e3ece09SJunchao Zhang   PetscCallCXX(product->data = pdata = new MatProductData_SeqAIJKokkos());
903076ba34aSJunchao Zhang   pdata->reusesym = product->api_user;
904a3f881fbSStefano Zampini 
905076ba34aSJunchao Zhang   /* TODO: add command line options to select spgemm algorithms */
906866eb059SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; /* default alg is TPL if enabled, otherwise KK */
907866eb059SJunchao Zhang 
908866eb059SJunchao Zhang   /* CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 */
909866eb059SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
910866eb059SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
911866eb059SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
912866eb059SJunchao Zhang   #endif
913866eb059SJunchao Zhang #endif
9140e3ece09SJunchao Zhang   PetscCallCXX(pdata->kh.create_spgemm_handle(spgemm_alg));
915076ba34aSJunchao Zhang 
9169566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
917076ba34aSJunchao Zhang   /* TODO: Get rid of the explicit transpose once KK-spgemm implements the transpose option */
918076ba34aSJunchao Zhang   if (transA) {
9199566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
920076ba34aSJunchao Zhang     transA = false;
921076ba34aSJunchao Zhang   }
922076ba34aSJunchao Zhang 
923076ba34aSJunchao Zhang   if (transB) {
9249566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
925076ba34aSJunchao Zhang     transB = false;
926076ba34aSJunchao Zhang   }
927076ba34aSJunchao Zhang 
9280e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
929076ba34aSJunchao Zhang   /* spgemm_symbolic() only populates C's rowmap, but not C's column indices.
930076ba34aSJunchao Zhang     So we have to do a fake spgemm_numeric() here to get csrmatC.j_d setup, before
931076ba34aSJunchao Zhang     calling new Mat_SeqAIJKokkos().
932076ba34aSJunchao Zhang     TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
933076ba34aSJunchao Zhang   */
9340e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
9350e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
936866eb059SJunchao Zhang   /* Query if KK outputs a sorted matrix. If not, we need to sort it */
937866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
938866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(csrmatC)); /* sort_option defaults to -1 in KK!*/
939e944a159SJunchao Zhang #endif
9409566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
941076ba34aSJunchao Zhang 
9429566063dSJacob Faibussowitsch   PetscCallCXX(ckok = new Mat_SeqAIJKokkos(csrmatC));
9439566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(C, ckok));
944076ba34aSJunchao Zhang   C->product->destroy = MatProductDataDestroy_SeqAIJKokkos;
9453ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
946a3f881fbSStefano Zampini }
947a3f881fbSStefano Zampini 
948a3f881fbSStefano Zampini /* handles sparse matrix matrix ops */
949d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSetFromOptions_SeqAIJKokkos(Mat mat)
950d71ae5a4SJacob Faibussowitsch {
951076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
952a3f881fbSStefano Zampini   PetscBool    Biskok = PETSC_FALSE, Ciskok = PETSC_TRUE;
953a3f881fbSStefano Zampini 
954a3f881fbSStefano Zampini   PetscFunctionBegin;
955a3f881fbSStefano Zampini   MatCheckProduct(mat, 1);
9569566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJKOKKOS, &Biskok));
95748a46eb9SPierre Jolivet   if (product->type == MATPRODUCT_ABC) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJKOKKOS, &Ciskok));
958a3f881fbSStefano Zampini   if (Biskok && Ciskok) {
959a3f881fbSStefano Zampini     switch (product->type) {
960a3f881fbSStefano Zampini     case MATPRODUCT_AB:
961a3f881fbSStefano Zampini     case MATPRODUCT_AtB:
962d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABt:
963d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos;
964d71ae5a4SJacob Faibussowitsch       break;
965a3f881fbSStefano Zampini     case MATPRODUCT_PtAP:
966a3f881fbSStefano Zampini     case MATPRODUCT_RARt:
967d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABC:
968d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
969d71ae5a4SJacob Faibussowitsch       break;
970d71ae5a4SJacob Faibussowitsch     default:
971d71ae5a4SJacob Faibussowitsch       break;
972a3f881fbSStefano Zampini     }
973a3f881fbSStefano Zampini   } else { /* fallback for AIJ */
9749566063dSJacob Faibussowitsch     PetscCall(MatProductSetFromOptions_SeqAIJ(mat));
975a3f881fbSStefano Zampini   }
9763ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
977a3f881fbSStefano Zampini }
978a587d139SMark 
979d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatScale_SeqAIJKokkos(Mat A, PetscScalar a)
980d71ae5a4SJacob Faibussowitsch {
981f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok;
982f0cf5187SStefano Zampini 
983f0cf5187SStefano Zampini   PetscFunctionBegin;
9849566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
9859566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
986f0cf5187SStefano Zampini   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
987076ba34aSJunchao Zhang   KokkosBlas::scal(aijkok->a_dual.view_device(), a, aijkok->a_dual.view_device());
9889566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
9899566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(aijkok->a_dual.extent(0)));
9909566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
9913ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
992f0cf5187SStefano Zampini }
993f0cf5187SStefano Zampini 
994d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatZeroEntries_SeqAIJKokkos(Mat A)
995d71ae5a4SJacob Faibussowitsch {
996076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
997a587d139SMark 
998a587d139SMark   PetscFunctionBegin;
999076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
10002328674fSJunchao Zhang   if (aijkok) { /* Only zero the device if data is already there */
1001076ba34aSJunchao Zhang     KokkosBlas::fill(aijkok->a_dual.view_device(), 0.0);
10029566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A));
10032328674fSJunchao Zhang   } else { /* Might be preallocated but not assembled */
10049566063dSJacob Faibussowitsch     PetscCall(MatZeroEntries_SeqAIJ(A));
10052328674fSJunchao Zhang   }
10063ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1007a587d139SMark }
1008a587d139SMark 
1009d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatGetDiagonal_SeqAIJKokkos(Mat A, Vec x)
1010d71ae5a4SJacob Faibussowitsch {
1011f78ce678SMark Adams   Mat_SeqAIJ           *aijseq;
1012f78ce678SMark Adams   Mat_SeqAIJKokkos     *aijkok;
1013f78ce678SMark Adams   PetscInt              n;
1014f78ce678SMark Adams   PetscScalarKokkosView xv;
1015f78ce678SMark Adams 
1016f78ce678SMark Adams   PetscFunctionBegin;
1017f78ce678SMark Adams   PetscCall(VecGetLocalSize(x, &n));
1018f78ce678SMark Adams   PetscCheck(n == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Nonconforming matrix and vector");
1019f78ce678SMark Adams   PetscCheck(A->factortype == MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_SUP, "MatGetDiagonal_SeqAIJKokkos not supported on factored matrices");
1020f78ce678SMark Adams 
1021f78ce678SMark Adams   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1022f78ce678SMark Adams   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1023f78ce678SMark Adams 
1024f78ce678SMark Adams   if (A->rmap->n && aijkok->diag_dual.extent(0) == 0) { /* Set the diagonal pointer if not already */
1025f78ce678SMark Adams     PetscCall(MatMarkDiagonal_SeqAIJ(A));
1026f78ce678SMark Adams     aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1027f78ce678SMark Adams     aijkok->SetDiagonal(aijseq->diag);
1028f78ce678SMark Adams   }
1029f78ce678SMark Adams 
1030f78ce678SMark Adams   const auto &Aa    = aijkok->a_dual.view_device();
1031f78ce678SMark Adams   const auto &Ai    = aijkok->i_dual.view_device();
1032f78ce678SMark Adams   const auto &Adiag = aijkok->diag_dual.view_device();
1033f78ce678SMark Adams 
1034f78ce678SMark Adams   PetscCall(VecGetKokkosViewWrite(x, &xv));
10359371c9d4SSatish Balay   Kokkos::parallel_for(
10369371c9d4SSatish Balay     n, KOKKOS_LAMBDA(const PetscInt i) {
1037f78ce678SMark Adams       if (Adiag(i) < Ai(i + 1)) xv(i) = Aa(Adiag(i));
1038f78ce678SMark Adams       else xv(i) = 0;
1039f78ce678SMark Adams     });
1040f78ce678SMark Adams   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
10413ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1042f78ce678SMark Adams }
1043f78ce678SMark Adams 
1044db78de30SJunchao Zhang /* Get a Kokkos View from a mat of type MatSeqAIJKokkos */
1045d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1046d71ae5a4SJacob Faibussowitsch {
1047db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1048db78de30SJunchao Zhang 
1049db78de30SJunchao Zhang   PetscFunctionBegin;
1050db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
10514f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1052db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10539566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1054db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1055076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
10563ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1057db78de30SJunchao Zhang }
1058db78de30SJunchao Zhang 
1059d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1060d71ae5a4SJacob Faibussowitsch {
1061db78de30SJunchao Zhang   PetscFunctionBegin;
1062db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
10634f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1064db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10653ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1066db78de30SJunchao Zhang }
1067db78de30SJunchao Zhang 
1068d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, MatScalarKokkosView *kv)
1069d71ae5a4SJacob Faibussowitsch {
1070db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1071db78de30SJunchao Zhang 
1072db78de30SJunchao Zhang   PetscFunctionBegin;
1073db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
10744f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1075db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10769566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1077db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1078076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
10793ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1080db78de30SJunchao Zhang }
1081db78de30SJunchao Zhang 
1082d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, MatScalarKokkosView *kv)
1083d71ae5a4SJacob Faibussowitsch {
1084db78de30SJunchao Zhang   PetscFunctionBegin;
1085db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
10864f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1087db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10889566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
10893ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1090db78de30SJunchao Zhang }
1091db78de30SJunchao Zhang 
1092d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1093d71ae5a4SJacob Faibussowitsch {
1094db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1095db78de30SJunchao Zhang 
1096db78de30SJunchao Zhang   PetscFunctionBegin;
1097db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
10984f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1099db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1100db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1101076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
11023ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1103db78de30SJunchao Zhang }
1104db78de30SJunchao Zhang 
1105d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1106d71ae5a4SJacob Faibussowitsch {
1107db78de30SJunchao Zhang   PetscFunctionBegin;
1108db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11094f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1110db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11119566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
11123ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1113db78de30SJunchao Zhang }
1114db78de30SJunchao Zhang 
1115c17cf699SJunchao Zhang /* Computes Y += alpha X */
1116d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAXPY_SeqAIJKokkos(Mat Y, PetscScalar alpha, Mat X, MatStructure pattern)
1117d71ae5a4SJacob Faibussowitsch {
1118a587d139SMark   Mat_SeqAIJ              *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
1119c17cf699SJunchao Zhang   Mat_SeqAIJKokkos        *xkok, *ykok, *zkok;
1120c17cf699SJunchao Zhang   ConstMatScalarKokkosView Xa;
1121c17cf699SJunchao Zhang   MatScalarKokkosView      Ya;
1122a587d139SMark 
1123a587d139SMark   PetscFunctionBegin;
1124c17cf699SJunchao Zhang   PetscCheckTypeName(Y, MATSEQAIJKOKKOS);
1125c17cf699SJunchao Zhang   PetscCheckTypeName(X, MATSEQAIJKOKKOS);
11269566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(Y));
11279566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(X));
11289566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1129db78de30SJunchao Zhang 
1130c17cf699SJunchao Zhang   if (pattern != SAME_NONZERO_PATTERN && x->nz == y->nz) {
1131a587d139SMark     PetscBool e;
11329566063dSJacob Faibussowitsch     PetscCall(PetscArraycmp(x->i, y->i, Y->rmap->n + 1, &e));
1133a587d139SMark     if (e) {
11349566063dSJacob Faibussowitsch       PetscCall(PetscArraycmp(x->j, y->j, y->nz, &e));
1135c17cf699SJunchao Zhang       if (e) pattern = SAME_NONZERO_PATTERN;
1136a587d139SMark     }
1137a587d139SMark   }
1138db78de30SJunchao Zhang 
1139c17cf699SJunchao Zhang   /* cusparseDcsrgeam2() computes C = alpha A + beta B. If one knew sparsity pattern of C, one can skip
1140c17cf699SJunchao Zhang     cusparseScsrgeam2_bufferSizeExt() / cusparseXcsrgeam2Nnz(), and directly call cusparseScsrgeam2().
1141c17cf699SJunchao Zhang     If X is SUBSET_NONZERO_PATTERN of Y, we could take advantage of this cusparse feature. However,
1142c17cf699SJunchao Zhang     KokkosSparse::spadd(alpha,A,beta,B,C) has symbolic and numeric phases, MatAXPY does not.
1143c17cf699SJunchao Zhang   */
1144c17cf699SJunchao Zhang   ykok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1145c17cf699SJunchao Zhang   xkok = static_cast<Mat_SeqAIJKokkos *>(X->spptr);
1146c17cf699SJunchao Zhang   Xa   = xkok->a_dual.view_device();
1147c17cf699SJunchao Zhang   Ya   = ykok->a_dual.view_device();
1148c17cf699SJunchao Zhang 
1149c17cf699SJunchao Zhang   if (pattern == SAME_NONZERO_PATTERN) {
1150c17cf699SJunchao Zhang     KokkosBlas::axpy(alpha, Xa, Ya);
11519566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1152c17cf699SJunchao Zhang   } else if (pattern == SUBSET_NONZERO_PATTERN) {
1153c17cf699SJunchao Zhang     MatRowMapKokkosView Xi = xkok->i_dual.view_device(), Yi = ykok->i_dual.view_device();
1154c17cf699SJunchao Zhang     MatColIdxKokkosView Xj = xkok->j_dual.view_device(), Yj = ykok->j_dual.view_device();
1155c17cf699SJunchao Zhang 
11569371c9d4SSatish Balay     Kokkos::parallel_for(
11579371c9d4SSatish Balay       Kokkos::TeamPolicy<>(Y->rmap->n, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
11580e3ece09SJunchao Zhang         PetscInt i = t.league_rank(); // row i
11590e3ece09SJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() {
11600e3ece09SJunchao Zhang           // Only one thread works in a team
1161c17cf699SJunchao Zhang           PetscInt p, q = Yi(i);
11620e3ece09SJunchao Zhang           for (p = Xi(i); p < Xi(i + 1); p++) {          // For each nonzero on row i of X,
11630e3ece09SJunchao Zhang             while (Xj(p) != Yj(q) && q < Yi(i + 1)) q++; // find the matching nonzero on row i of Y.
11640e3ece09SJunchao Zhang             if (Xj(p) == Yj(q)) {                        // Found it
1165c17cf699SJunchao Zhang               Ya(q) += alpha * Xa(p);
1166c17cf699SJunchao Zhang               q++;
1167a587d139SMark             } else {
11680e3ece09SJunchao Zhang             // If not found, it indicates the input is wrong (X is not a SUBSET_NONZERO_PATTERN of Y).
11690e3ece09SJunchao Zhang             // Just insert a NaN at the beginning of row i if it is not empty, to make the result wrong.
11700e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_VERSION_GE(3, 7, 0)
11710e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::ArithTraits<PetscScalar>::nan();
11728b8b16f9SJunchao Zhang #else
11730e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::Experimental::nan("1");
11748b8b16f9SJunchao Zhang #endif
1175a587d139SMark             }
1176c17cf699SJunchao Zhang           }
1177c17cf699SJunchao Zhang         });
1178c17cf699SJunchao Zhang       });
11799566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
11800e3ece09SJunchao Zhang   } else { // different nonzero patterns
1181c17cf699SJunchao Zhang     Mat             Z;
1182c17cf699SJunchao Zhang     KokkosCsrMatrix zcsr;
1183c17cf699SJunchao Zhang     KernelHandle    kh;
11840e3ece09SJunchao Zhang     kh.create_spadd_handle(true); // X, Y are sorted
1185c17cf699SJunchao Zhang     KokkosSparse::spadd_symbolic(&kh, xkok->csrmat, ykok->csrmat, zcsr);
1186c17cf699SJunchao Zhang     KokkosSparse::spadd_numeric(&kh, alpha, xkok->csrmat, (PetscScalar)1.0, ykok->csrmat, zcsr);
1187c17cf699SJunchao Zhang     zkok = new Mat_SeqAIJKokkos(zcsr);
11889566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, zkok, &Z));
11899566063dSJacob Faibussowitsch     PetscCall(MatHeaderReplace(Y, &Z));
1190c17cf699SJunchao Zhang     kh.destroy_spadd_handle();
1191c17cf699SJunchao Zhang   }
11929566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
11930e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(xkok->a_dual.extent(0) * 2)); // Because we scaled X and then added it to Y
11943ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1195a587d139SMark }
1196a587d139SMark 
11972c4ab24aSJunchao Zhang struct MatCOOStruct_SeqAIJKokkos {
11982c4ab24aSJunchao Zhang   PetscCount           n;
11992c4ab24aSJunchao Zhang   PetscCount           Atot;
12002c4ab24aSJunchao Zhang   PetscInt             nz;
12012c4ab24aSJunchao Zhang   PetscCountKokkosView jmap;
12022c4ab24aSJunchao Zhang   PetscCountKokkosView perm;
12032c4ab24aSJunchao Zhang 
12042c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos(const MatCOOStruct_SeqAIJ *coo_h)
12052c4ab24aSJunchao Zhang   {
12062c4ab24aSJunchao Zhang     nz   = coo_h->nz;
12072c4ab24aSJunchao Zhang     n    = coo_h->n;
12082c4ab24aSJunchao Zhang     Atot = coo_h->Atot;
12092c4ab24aSJunchao Zhang     jmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->jmap, nz + 1));
12102c4ab24aSJunchao Zhang     perm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->perm, Atot));
12112c4ab24aSJunchao Zhang   }
12122c4ab24aSJunchao Zhang };
12132c4ab24aSJunchao Zhang 
12142c4ab24aSJunchao Zhang static PetscErrorCode MatCOOStructDestroy_SeqAIJKokkos(void *data)
12152c4ab24aSJunchao Zhang {
12162c4ab24aSJunchao Zhang   PetscFunctionBegin;
12172c4ab24aSJunchao Zhang   PetscCallCXX(delete static_cast<MatCOOStruct_SeqAIJKokkos *>(data));
12182c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
12192c4ab24aSJunchao Zhang }
12202c4ab24aSJunchao Zhang 
1221d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_SeqAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1222d71ae5a4SJacob Faibussowitsch {
122342550becSJunchao Zhang   Mat_SeqAIJKokkos          *akok;
122442550becSJunchao Zhang   Mat_SeqAIJ                *aseq;
12252c4ab24aSJunchao Zhang   PetscContainer             container_h, container_d;
12262c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJ       *coo_h;
12272c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo_d;
122842550becSJunchao Zhang 
122942550becSJunchao Zhang   PetscFunctionBegin;
12309566063dSJacob Faibussowitsch   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, coo_i, coo_j));
1231394ed5ebSJunchao Zhang   aseq = static_cast<Mat_SeqAIJ *>(mat->data);
123242550becSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(mat->spptr);
1233cbc6b225SStefano Zampini   delete akok;
1234cbc6b225SStefano Zampini   mat->spptr = akok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, aseq->nz, aseq->i, aseq->j, aseq->a, mat->nonzerostate + 1, PETSC_FALSE);
12359566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries_SeqAIJKokkos(mat));
12362c4ab24aSJunchao Zhang 
12372c4ab24aSJunchao Zhang   // Copy the COO struct to device
12382c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
12392c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
12402c4ab24aSJunchao Zhang   PetscCallCXX(coo_d = new MatCOOStruct_SeqAIJKokkos(coo_h));
12412c4ab24aSJunchao Zhang 
12422c4ab24aSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
12432c4ab24aSJunchao Zhang   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
12442c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetPointer(container_d, coo_d));
12452c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_SeqAIJKokkos));
12462c4ab24aSJunchao Zhang   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
12472c4ab24aSJunchao Zhang   PetscCall(PetscContainerDestroy(&container_d));
12483ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
124942550becSJunchao Zhang }
125042550becSJunchao Zhang 
1251d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_SeqAIJKokkos(Mat A, const PetscScalar v[], InsertMode imode)
1252d71ae5a4SJacob Faibussowitsch {
125342550becSJunchao Zhang   MatScalarKokkosView        Aa;
125442550becSJunchao Zhang   ConstMatScalarKokkosView   kv;
125542550becSJunchao Zhang   PetscMemType               memtype;
12562c4ab24aSJunchao Zhang   PetscContainer             container;
12572c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo;
125842550becSJunchao Zhang 
125942550becSJunchao Zhang   PetscFunctionBegin;
12602c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)A, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
12612c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
12622c4ab24aSJunchao Zhang 
12632c4ab24aSJunchao Zhang   const auto &n    = coo->n;
12642c4ab24aSJunchao Zhang   const auto &Annz = coo->nz;
12652c4ab24aSJunchao Zhang   const auto &jmap = coo->jmap;
12662c4ab24aSJunchao Zhang   const auto &perm = coo->perm;
12672c4ab24aSJunchao Zhang 
12689566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype));
126942550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
12702c4ab24aSJunchao Zhang     kv = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatScalarKokkosViewHost(v, n));
127142550becSJunchao Zhang   } else {
12722c4ab24aSJunchao Zhang     kv = ConstMatScalarKokkosView(v, n); /* Directly use v[]'s memory */
127342550becSJunchao Zhang   }
127442550becSJunchao Zhang 
1275c7b718f4SJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1276c7b718f4SJunchao Zhang   else PetscCall(MatSeqAIJGetKokkosView(A, &Aa));                             /* read & write matrix values */
127742550becSJunchao Zhang 
127808bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
12799371c9d4SSatish Balay   Kokkos::parallel_for(
12809371c9d4SSatish Balay     Annz, KOKKOS_LAMBDA(const PetscCount i) {
1281c7b718f4SJunchao Zhang       PetscScalar sum = 0.0;
1282c7b718f4SJunchao Zhang       for (PetscCount k = jmap(i); k < jmap(i + 1); k++) sum += kv(perm(k));
1283c7b718f4SJunchao Zhang       Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1284c7b718f4SJunchao Zhang     });
128508bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1286394ed5ebSJunchao Zhang 
12879566063dSJacob Faibussowitsch   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
12889566063dSJacob Faibussowitsch   else PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
12893ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
129042550becSJunchao Zhang }
129142550becSJunchao Zhang 
1292d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1293d71ae5a4SJacob Faibussowitsch {
12948f7e8f9dSMark Adams   PetscFunctionBegin;
12959566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncHost(A));
12969566063dSJacob Faibussowitsch   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
12978f7e8f9dSMark Adams   B->offloadmask = PETSC_OFFLOAD_CPU;
12983ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
12998f7e8f9dSMark Adams }
13008f7e8f9dSMark Adams 
1301d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat A)
1302d71ae5a4SJacob Faibussowitsch {
1303076ba34aSJunchao Zhang   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
1304076ba34aSJunchao Zhang 
13058c3ff71bSJunchao Zhang   PetscFunctionBegin;
1306076ba34aSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_KOKKOS; /* We do not really use this flag */
13076f3d89d0SStefano Zampini   A->boundtocpu  = PETSC_FALSE;
13086f3d89d0SStefano Zampini 
13098c3ff71bSJunchao Zhang   A->ops->assemblyend               = MatAssemblyEnd_SeqAIJKokkos;
13108c3ff71bSJunchao Zhang   A->ops->destroy                   = MatDestroy_SeqAIJKokkos;
13118c3ff71bSJunchao Zhang   A->ops->duplicate                 = MatDuplicate_SeqAIJKokkos;
1312a587d139SMark   A->ops->axpy                      = MatAXPY_SeqAIJKokkos;
1313f0cf5187SStefano Zampini   A->ops->scale                     = MatScale_SeqAIJKokkos;
1314a587d139SMark   A->ops->zeroentries               = MatZeroEntries_SeqAIJKokkos;
1315076ba34aSJunchao Zhang   A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJKokkos;
13168c3ff71bSJunchao Zhang   A->ops->mult                      = MatMult_SeqAIJKokkos;
13178c3ff71bSJunchao Zhang   A->ops->multadd                   = MatMultAdd_SeqAIJKokkos;
13188c3ff71bSJunchao Zhang   A->ops->multtranspose             = MatMultTranspose_SeqAIJKokkos;
13198c3ff71bSJunchao Zhang   A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJKokkos;
13208c3ff71bSJunchao Zhang   A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJKokkos;
13218c3ff71bSJunchao Zhang   A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJKokkos;
1322076ba34aSJunchao Zhang   A->ops->productnumeric            = MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos;
13230ecb592aSJunchao Zhang   A->ops->transpose                 = MatTranspose_SeqAIJKokkos;
1324152b3e56SJunchao Zhang   A->ops->setoption                 = MatSetOption_SeqAIJKokkos;
1325f78ce678SMark Adams   A->ops->getdiagonal               = MatGetDiagonal_SeqAIJKokkos;
1326076ba34aSJunchao Zhang   a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJKokkos;
1327076ba34aSJunchao Zhang   a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJKokkos;
1328076ba34aSJunchao Zhang   a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJKokkos;
1329076ba34aSJunchao Zhang   a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJKokkos;
1330076ba34aSJunchao Zhang   a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJKokkos;
1331076ba34aSJunchao Zhang   a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJKokkos;
13327ee59b9bSJunchao Zhang   a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJKokkos;
133342550becSJunchao Zhang 
13349566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJKokkos));
13359566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJKokkos));
13363ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1337076ba34aSJunchao Zhang }
1338076ba34aSJunchao Zhang 
13399d13fa56SJunchao Zhang /*
13409d13fa56SJunchao Zhang    Extract the (prescribled) diagonal blocks of the matrix and then invert them
13419d13fa56SJunchao Zhang 
13429d13fa56SJunchao Zhang   Input Parameters:
13439d13fa56SJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
13449d13fa56SJunchao Zhang .  bs      - block sizes in 'csr' format, i.e., the i-th block has size bs(i+1) - bs(i)
13459d13fa56SJunchao Zhang .  bs2     - square of block sizes in 'csr' format, i.e., the i-th block should be stored at offset bs2(i) in diagVal[]
13469d13fa56SJunchao Zhang .  blkMap  - map row ids to block ids, i.e., row i belongs to the block blkMap(i)
13479d13fa56SJunchao Zhang -  work    - a pre-allocated work buffer (as big as diagVal) for use by this routine
13489d13fa56SJunchao Zhang 
13499d13fa56SJunchao Zhang   Output Parameter:
13509d13fa56SJunchao Zhang .  diagVal - the (pre-allocated) buffer to store the inverted blocks (each block is stored in column-major order)
13519d13fa56SJunchao Zhang */
13529d13fa56SJunchao Zhang PETSC_INTERN PetscErrorCode MatInvertVariableBlockDiagonal_SeqAIJKokkos(Mat A, const PetscIntKokkosView &bs, const PetscIntKokkosView &bs2, const PetscIntKokkosView &blkMap, PetscScalarKokkosView &work, PetscScalarKokkosView &diagVal)
13539d13fa56SJunchao Zhang {
13549d13fa56SJunchao Zhang   Mat_SeqAIJKokkos *akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
13559d13fa56SJunchao Zhang   PetscInt          N       = A->rmap->n;
13569d13fa56SJunchao Zhang   PetscInt          nblocks = bs.extent(0) - 1;
13579d13fa56SJunchao Zhang 
13589d13fa56SJunchao Zhang   PetscFunctionBegin;
13599d13fa56SJunchao Zhang   // Set the diagonal pointer on device if not already
13609d13fa56SJunchao Zhang   if (N && akok->diag_dual.extent(0) == 0) {
13619d13fa56SJunchao Zhang     PetscCall(MatMarkDiagonal_SeqAIJ(A));
13629d13fa56SJunchao Zhang     akok->SetDiagonal(static_cast<Mat_SeqAIJ *>(A->data)->diag);
13639d13fa56SJunchao Zhang   }
13649d13fa56SJunchao Zhang 
13659d13fa56SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A)); // Since we'll access A's value on device
13669d13fa56SJunchao Zhang 
13679d13fa56SJunchao Zhang   // Pull out the diagonal blocks of the matrix and then invert the blocks
13689d13fa56SJunchao Zhang   auto Aa    = akok->a_dual.view_device();
13699d13fa56SJunchao Zhang   auto Ai    = akok->i_dual.view_device();
13709d13fa56SJunchao Zhang   auto Aj    = akok->j_dual.view_device();
13719d13fa56SJunchao Zhang   auto Adiag = akok->diag_dual.view_device();
13729d13fa56SJunchao Zhang   // TODO: how to tune the team size?
13739d13fa56SJunchao Zhang #if defined(KOKKOS_ENABLE_DEFAULT_DEVICE_TYPE_HOST)
13749d13fa56SJunchao Zhang   auto ts = Kokkos::AUTO();
13759d13fa56SJunchao Zhang #else
13769d13fa56SJunchao Zhang   auto ts         = 16; // improved performance 30% over Kokkos::AUTO() with CUDA, but failed with "Kokkos::abort: Requested Team Size is too large!" on CPUs
13779d13fa56SJunchao Zhang #endif
13789d13fa56SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
13799d13fa56SJunchao Zhang     Kokkos::TeamPolicy<>(nblocks, ts), KOKKOS_LAMBDA(const KokkosTeamMemberType &teamMember) {
13809d13fa56SJunchao Zhang       const PetscInt bid    = teamMember.league_rank();                                                   // block id
13819d13fa56SJunchao Zhang       const PetscInt rstart = bs(bid);                                                                    // this block starts from this row
13829d13fa56SJunchao Zhang       const PetscInt m      = bs(bid + 1) - bs(bid);                                                      // size of this block
13839d13fa56SJunchao Zhang       const auto    &B      = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft>(&diagVal(bs2(bid)), m, m); // column-major order
13849d13fa56SJunchao Zhang       const auto    &W      = PetscScalarKokkosView(&work(bs2(bid)), m * m);
13859d13fa56SJunchao Zhang 
13869d13fa56SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(teamMember, m), [=](const PetscInt &r) { // r-th row in B
13879d13fa56SJunchao Zhang         PetscInt i = rstart + r;                                                            // i-th row in A
13889d13fa56SJunchao Zhang 
13899d13fa56SJunchao Zhang         if (Ai(i) <= Adiag(i) && Adiag(i) < Ai(i + 1)) { // if the diagonal exists (common case)
13909d13fa56SJunchao Zhang           PetscInt first = Adiag(i) - r;                 // we start to check nonzeros from here along this row
13919d13fa56SJunchao Zhang 
13929d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) {                   // walk n steps to see what column indices we will meet
13939d13fa56SJunchao Zhang             if (first + c < Ai(i) || first + c >= Ai(i + 1)) { // this entry (first+c) is out of range of this row, in other words, its value is zero
13949d13fa56SJunchao Zhang               B(r, c) = 0.0;
13959d13fa56SJunchao Zhang             } else if (Aj(first + c) == rstart + c) { // this entry is right on the (rstart+c) column
13969d13fa56SJunchao Zhang               B(r, c) = Aa(first + c);
13979d13fa56SJunchao Zhang             } else { // this entry does not show up in the CSR
13989d13fa56SJunchao Zhang               B(r, c) = 0.0;
13999d13fa56SJunchao Zhang             }
14009d13fa56SJunchao Zhang           }
14019d13fa56SJunchao Zhang         } else { // rare case that the diagonal does not exist
14029d13fa56SJunchao Zhang           const PetscInt begin = Ai(i);
14039d13fa56SJunchao Zhang           const PetscInt end   = Ai(i + 1);
14049d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) B(r, c) = 0.0;
14059d13fa56SJunchao Zhang           for (PetscInt j = begin; j < end; j++) { // scan the whole row; could use binary search but this is a rare case so we did not.
14069d13fa56SJunchao Zhang             if (rstart <= Aj(j) && Aj(j) < rstart + m) B(r, Aj(j) - rstart) = Aa(j);
14079d13fa56SJunchao Zhang             else if (Aj(j) >= rstart + m) break;
14089d13fa56SJunchao Zhang           }
14099d13fa56SJunchao Zhang         }
14109d13fa56SJunchao Zhang       });
14119d13fa56SJunchao Zhang 
14129d13fa56SJunchao Zhang       // LU-decompose B (w/o pivoting) and then invert B
14139d13fa56SJunchao Zhang       KokkosBatched::TeamLU<KokkosTeamMemberType, KokkosBatched::Algo::LU::Unblocked>::invoke(teamMember, B, 0.0);
14149d13fa56SJunchao Zhang       KokkosBatched::TeamInverseLU<KokkosTeamMemberType, KokkosBatched::Algo::InverseLU::Unblocked>::invoke(teamMember, B, W);
14159d13fa56SJunchao Zhang     }));
14169d13fa56SJunchao Zhang   // PetscLogGpuFlops() is done in the caller PCSetUp_VPBJacobi_Kokkos as we don't want to compute the flops in kernels
14179d13fa56SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
14189d13fa56SJunchao Zhang }
14199d13fa56SJunchao Zhang 
1420d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSetSeqAIJKokkosWithCSRMatrix(Mat A, Mat_SeqAIJKokkos *akok)
1421d71ae5a4SJacob Faibussowitsch {
1422076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
1423076ba34aSJunchao Zhang   PetscInt    i, m, n;
1424*e36ced11SJunchao Zhang   auto       &exec = PetscGetKokkosExecutionSpace();
1425076ba34aSJunchao Zhang 
1426076ba34aSJunchao Zhang   PetscFunctionBegin;
14275f80ce2aSJacob Faibussowitsch   PetscCheck(!A->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A->spptr is supposed to be empty");
1428076ba34aSJunchao Zhang 
1429076ba34aSJunchao Zhang   m = akok->nrows();
1430076ba34aSJunchao Zhang   n = akok->ncols();
14319566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(A, m, n, m, n));
14329566063dSJacob Faibussowitsch   PetscCall(MatSetType(A, MATSEQAIJKOKKOS));
1433076ba34aSJunchao Zhang 
1434076ba34aSJunchao Zhang   /* Set up data structures of A as a MATSEQAIJ */
14359566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(A, MAT_SKIP_ALLOCATION, NULL));
1436076ba34aSJunchao Zhang   aseq = (Mat_SeqAIJ *)(A)->data;
1437076ba34aSJunchao Zhang 
1438*e36ced11SJunchao Zhang   PetscCallCXX(akok->i_dual.sync_host(exec)); /* We always need sync'ed i, j on host */
1439*e36ced11SJunchao Zhang   PetscCallCXX(akok->j_dual.sync_host(exec));
1440*e36ced11SJunchao Zhang   PetscCallCXX(exec.fence());
1441076ba34aSJunchao Zhang 
1442076ba34aSJunchao Zhang   aseq->i            = akok->i_host_data();
1443076ba34aSJunchao Zhang   aseq->j            = akok->j_host_data();
1444076ba34aSJunchao Zhang   aseq->a            = akok->a_host_data();
1445076ba34aSJunchao Zhang   aseq->nonew        = -1; /*this indicates that inserting a new value in the matrix that generates a new nonzero is an error*/
1446076ba34aSJunchao Zhang   aseq->singlemalloc = PETSC_FALSE;
1447076ba34aSJunchao Zhang   aseq->free_a       = PETSC_FALSE;
1448076ba34aSJunchao Zhang   aseq->free_ij      = PETSC_FALSE;
1449076ba34aSJunchao Zhang   aseq->nz           = akok->nnz();
1450076ba34aSJunchao Zhang   aseq->maxnz        = aseq->nz;
1451076ba34aSJunchao Zhang 
14529566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->imax));
14539566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->ilen));
1454ad540459SPierre Jolivet   for (i = 0; i < m; i++) aseq->ilen[i] = aseq->imax[i] = aseq->i[i + 1] - aseq->i[i];
1455076ba34aSJunchao Zhang 
1456076ba34aSJunchao Zhang   /* It is critical to set the nonzerostate, as we use it to check if sparsity pattern (hence data) has changed on host in MatAssemblyEnd */
1457076ba34aSJunchao Zhang   akok->nonzerostate = A->nonzerostate;
1458ff751488SJunchao Zhang   A->spptr           = akok; /* Set A->spptr before MatAssembly so that A->spptr won't be allocated again there */
14599566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
14609566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
14613ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1462076ba34aSJunchao Zhang }
1463076ba34aSJunchao Zhang 
14640e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetKokkosCsrMatrix(Mat A, KokkosCsrMatrix *csr)
14650e3ece09SJunchao Zhang {
14660e3ece09SJunchao Zhang   PetscFunctionBegin;
14670e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
14680e3ece09SJunchao Zhang   *csr = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
14690e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
14700e3ece09SJunchao Zhang }
14710e3ece09SJunchao Zhang 
14720e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithKokkosCsrMatrix(MPI_Comm comm, KokkosCsrMatrix csr, Mat *A)
14730e3ece09SJunchao Zhang {
14740e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok;
14750e3ece09SJunchao Zhang   PetscFunctionBegin;
14760e3ece09SJunchao Zhang   PetscCallCXX(akok = new Mat_SeqAIJKokkos(csr));
14770e3ece09SJunchao Zhang   PetscCall(MatCreate(comm, A));
14780e3ece09SJunchao Zhang   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
14790e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
14800e3ece09SJunchao Zhang }
14810e3ece09SJunchao Zhang 
1482076ba34aSJunchao Zhang /* Crete a SEQAIJKOKKOS matrix with a Mat_SeqAIJKokkos data structure
1483076ba34aSJunchao Zhang 
1484076ba34aSJunchao Zhang    Note we have names like MatSeqAIJSetPreallocationCSR, so I use capitalized CSR
1485076ba34aSJunchao Zhang  */
1486d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithCSRMatrix(MPI_Comm comm, Mat_SeqAIJKokkos *akok, Mat *A)
1487d71ae5a4SJacob Faibussowitsch {
1488076ba34aSJunchao Zhang   PetscFunctionBegin;
14899566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
14909566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
14913ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
14928c3ff71bSJunchao Zhang }
14938c3ff71bSJunchao Zhang 
1494152b3e56SJunchao Zhang /*@C
149511a5261eSBarry Smith   MatCreateSeqAIJKokkos - Creates a sparse matrix in `MATSEQAIJKOKKOS` (compressed row) format
14968c3ff71bSJunchao Zhang   (the default parallel PETSc format). This matrix will ultimately be handled by
149720f4b53cSBarry Smith   Kokkos for calculations.
14988c3ff71bSJunchao Zhang 
14998c3ff71bSJunchao Zhang   Collective
15008c3ff71bSJunchao Zhang 
15018c3ff71bSJunchao Zhang   Input Parameters:
150211a5261eSBarry Smith + comm - MPI communicator, set to `PETSC_COMM_SELF`
15038c3ff71bSJunchao Zhang . m    - number of rows
15048c3ff71bSJunchao Zhang . n    - number of columns
150520f4b53cSBarry Smith . nz   - number of nonzeros per row (same for all rows), ignored if `nnz` is provided
150620f4b53cSBarry Smith - nnz  - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL`
15078c3ff71bSJunchao Zhang 
15088c3ff71bSJunchao Zhang   Output Parameter:
15098c3ff71bSJunchao Zhang . A - the matrix
15108c3ff71bSJunchao Zhang 
15112ef1f0ffSBarry Smith   Level: intermediate
15122ef1f0ffSBarry Smith 
15132ef1f0ffSBarry Smith   Notes:
151411a5261eSBarry Smith   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
15158c3ff71bSJunchao Zhang   MatXXXXSetPreallocation() paradgm instead of this routine directly.
151611a5261eSBarry Smith   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
15178c3ff71bSJunchao Zhang 
151811a5261eSBarry Smith   The AIJ format, also called
15192ef1f0ffSBarry Smith   compressed row storage, is fully compatible with standard Fortran
15208c3ff71bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
152120f4b53cSBarry Smith   either one (as in Fortran) or zero.
15228c3ff71bSJunchao Zhang 
15232ef1f0ffSBarry Smith   Specify the preallocated storage with either `nz` or `nnz` (not both).
15242ef1f0ffSBarry Smith   Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
15252ef1f0ffSBarry Smith   allocation.
15268c3ff71bSJunchao Zhang 
1527fe59aa6dSJacob Faibussowitsch .seealso: [](ch_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
15288c3ff71bSJunchao Zhang @*/
1529d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateSeqAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
1530d71ae5a4SJacob Faibussowitsch {
15318c3ff71bSJunchao Zhang   PetscFunctionBegin;
15329566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
15339566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
15349566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, m, n));
15359566063dSJacob Faibussowitsch   PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
15369566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
15373ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
15388c3ff71bSJunchao Zhang }
1539930e68a5SMark Adams 
1540d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1541d71ae5a4SJacob Faibussowitsch {
1542930e68a5SMark Adams   PetscFunctionBegin;
15439566063dSJacob Faibussowitsch   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
154486a27549SJunchao Zhang   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
15453ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
154686a27549SJunchao Zhang }
154786a27549SJunchao Zhang 
1548d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSymbolicSolveCheck(Mat A)
1549d71ae5a4SJacob Faibussowitsch {
155086a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
155186a27549SJunchao Zhang 
155286a27549SJunchao Zhang   PetscFunctionBegin;
155386a27549SJunchao Zhang   if (!factors->sptrsv_symbolic_completed) {
155486a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d);
155586a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d);
155686a27549SJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_TRUE;
155786a27549SJunchao Zhang   }
15583ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
155986a27549SJunchao Zhang }
156086a27549SJunchao Zhang 
156186a27549SJunchao Zhang /* Check if we need to update factors etc for transpose solve */
1562d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosTransposeSolveCheck(Mat A)
1563d71ae5a4SJacob Faibussowitsch {
156486a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1565076ba34aSJunchao Zhang   MatColIdxType               n       = A->rmap->n;
156686a27549SJunchao Zhang 
156786a27549SJunchao Zhang   PetscFunctionBegin;
156886a27549SJunchao Zhang   if (!factors->transpose_updated) { /* TODO: KK needs to provide functions to do numeric transpose only */
156986a27549SJunchao Zhang     /* Update L^T and do sptrsv symbolic */
15707b8d4ba6SJunchao Zhang     factors->iLt_d = MatRowMapKokkosView("factors->iLt_d", n + 1); // KK requires 0
15717b8d4ba6SJunchao Zhang     factors->jLt_d = MatColIdxKokkosView(NoInit("factors->jLt_d"), factors->jL_d.extent(0));
15727b8d4ba6SJunchao Zhang     factors->aLt_d = MatScalarKokkosView(NoInit("factors->aLt_d"), factors->aL_d.extent(0));
157386a27549SJunchao Zhang 
15749371c9d4SSatish Balay     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iL_d, factors->jL_d, factors->aL_d,
157586a27549SJunchao Zhang                                                                                                                                                                                                               factors->iLt_d, factors->jLt_d, factors->aLt_d);
157686a27549SJunchao Zhang 
157786a27549SJunchao Zhang     /* TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices.
157886a27549SJunchao Zhang       We have to sort the indices, until KK provides finer control options.
157986a27549SJunchao Zhang     */
15809371c9d4SSatish Balay     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iLt_d, factors->jLt_d, factors->aLt_d);
158186a27549SJunchao Zhang 
158286a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d);
158386a27549SJunchao Zhang 
158486a27549SJunchao Zhang     /* Update U^T and do sptrsv symbolic */
15857b8d4ba6SJunchao Zhang     factors->iUt_d = MatRowMapKokkosView("factors->iUt_d", n + 1); // KK requires 0
15867b8d4ba6SJunchao Zhang     factors->jUt_d = MatColIdxKokkosView(NoInit("factors->jUt_d"), factors->jU_d.extent(0));
15877b8d4ba6SJunchao Zhang     factors->aUt_d = MatScalarKokkosView(NoInit("factors->aUt_d"), factors->aU_d.extent(0));
158886a27549SJunchao Zhang 
15899371c9d4SSatish Balay     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iU_d, factors->jU_d, factors->aU_d,
159086a27549SJunchao Zhang                                                                                                                                                                                                               factors->iUt_d, factors->jUt_d, factors->aUt_d);
159186a27549SJunchao Zhang 
159286a27549SJunchao Zhang     /* Sort indices. See comments above */
15939371c9d4SSatish Balay     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iUt_d, factors->jUt_d, factors->aUt_d);
159486a27549SJunchao Zhang 
159586a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d);
159686a27549SJunchao Zhang     factors->transpose_updated = PETSC_TRUE;
159786a27549SJunchao Zhang   }
15983ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
159986a27549SJunchao Zhang }
160086a27549SJunchao Zhang 
160186a27549SJunchao Zhang /* Solve Ax = b, with A = LU */
1602d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSolve_SeqAIJKokkos(Mat A, Vec b, Vec x)
1603d71ae5a4SJacob Faibussowitsch {
160486a27549SJunchao Zhang   ConstPetscScalarKokkosView  bv;
160586a27549SJunchao Zhang   PetscScalarKokkosView       xv;
160686a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
160786a27549SJunchao Zhang 
160886a27549SJunchao Zhang   PetscFunctionBegin;
16099566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
16109566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSymbolicSolveCheck(A));
16119566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(b, &bv));
16129566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(x, &xv));
161386a27549SJunchao Zhang   /* Solve L tmpv = b */
16149566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d, bv, factors->workVector));
161586a27549SJunchao Zhang   /* Solve Ux = tmpv */
16169566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, factors->workVector, xv));
16179566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(b, &bv));
16189566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
16199566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
16203ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
162186a27549SJunchao Zhang }
162286a27549SJunchao Zhang 
1623076ba34aSJunchao Zhang /* Solve A^T x = b, where A^T = U^T L^T */
1624d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSolveTranspose_SeqAIJKokkos(Mat A, Vec b, Vec x)
1625d71ae5a4SJacob Faibussowitsch {
162686a27549SJunchao Zhang   ConstPetscScalarKokkosView  bv;
162786a27549SJunchao Zhang   PetscScalarKokkosView       xv;
162886a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
162986a27549SJunchao Zhang 
163086a27549SJunchao Zhang   PetscFunctionBegin;
16319566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
16329566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A));
16339566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(b, &bv));
16349566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(x, &xv));
163586a27549SJunchao Zhang   /* Solve U^T tmpv = b */
163686a27549SJunchao Zhang   KokkosSparse::Experimental::sptrsv_solve(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, bv, factors->workVector);
163786a27549SJunchao Zhang 
163886a27549SJunchao Zhang   /* Solve L^T x = tmpv */
163986a27549SJunchao Zhang   KokkosSparse::Experimental::sptrsv_solve(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d, factors->workVector, xv);
16409566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(b, &bv));
16419566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
16429566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
16433ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
164486a27549SJunchao Zhang }
164586a27549SJunchao Zhang 
1646d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatILUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1647d71ae5a4SJacob Faibussowitsch {
164886a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
164986a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
165086a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
165186a27549SJunchao Zhang 
165286a27549SJunchao Zhang   PetscFunctionBegin;
16539566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
16549566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1655076ba34aSJunchao Zhang 
1656076ba34aSJunchao Zhang   auto a_d = aijkok->a_dual.view_device();
1657076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
1658076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
1659076ba34aSJunchao Zhang 
1660076ba34aSJunchao Zhang   KokkosSparse::Experimental::spiluk_numeric(&factors->kh, fill_lev, i_d, j_d, a_d, factors->iL_d, factors->jL_d, factors->aL_d, factors->iU_d, factors->jU_d, factors->aU_d);
166186a27549SJunchao Zhang 
166286a27549SJunchao Zhang   B->assembled              = PETSC_TRUE;
166386a27549SJunchao Zhang   B->preallocated           = PETSC_TRUE;
166486a27549SJunchao Zhang   B->ops->solve             = MatSolve_SeqAIJKokkos;
166586a27549SJunchao Zhang   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJKokkos;
166686a27549SJunchao Zhang   B->ops->matsolve          = NULL;
166786a27549SJunchao Zhang   B->ops->matsolvetranspose = NULL;
166886a27549SJunchao Zhang   B->offloadmask            = PETSC_OFFLOAD_GPU;
166986a27549SJunchao Zhang 
167086a27549SJunchao Zhang   /* Once the factors' value changed, we need to update their transpose and sptrsv handle */
167186a27549SJunchao Zhang   factors->transpose_updated         = PETSC_FALSE;
167286a27549SJunchao Zhang   factors->sptrsv_symbolic_completed = PETSC_FALSE;
1673eeadb341SJunchao Zhang   /* TODO: log flops, but how to know that? */
16749566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
16753ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
167686a27549SJunchao Zhang }
167786a27549SJunchao Zhang 
1678d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1679d71ae5a4SJacob Faibussowitsch {
168086a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok;
168186a27549SJunchao Zhang   Mat_SeqAIJ                 *b;
168286a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
168386a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
168486a27549SJunchao Zhang   PetscInt                    nnzA     = ((Mat_SeqAIJ *)A->data)->nz, nnzL, nnzU;
168586a27549SJunchao Zhang   PetscInt                    n        = A->rmap->n;
168686a27549SJunchao Zhang 
168786a27549SJunchao Zhang   PetscFunctionBegin;
16889566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
168986a27549SJunchao Zhang   /* Rebuild factors */
16909371c9d4SSatish Balay   if (factors) {
16919371c9d4SSatish Balay     factors->Destroy();
16929371c9d4SSatish Balay   } /* Destroy the old if it exists */
16939371c9d4SSatish Balay   else {
16949371c9d4SSatish Balay     B->spptr = factors = new Mat_SeqAIJKokkosTriFactors(n);
16959371c9d4SSatish Balay   }
169686a27549SJunchao Zhang 
169786a27549SJunchao Zhang   /* Create a spiluk handle and then do symbolic factorization */
169886a27549SJunchao Zhang   nnzL = nnzU = PetscRealIntMultTruncate(info->fill, nnzA);
169986a27549SJunchao Zhang   factors->kh.create_spiluk_handle(KokkosSparse::Experimental::SPILUKAlgorithm::SEQLVLSCHD_TP1, n, nnzL, nnzU);
170086a27549SJunchao Zhang 
170186a27549SJunchao Zhang   auto spiluk_handle = factors->kh.get_spiluk_handle();
170286a27549SJunchao Zhang 
170386a27549SJunchao Zhang   Kokkos::realloc(factors->iL_d, n + 1); /* Free old arrays and realloc */
170486a27549SJunchao Zhang   Kokkos::realloc(factors->jL_d, spiluk_handle->get_nnzL());
170586a27549SJunchao Zhang   Kokkos::realloc(factors->iU_d, n + 1);
170686a27549SJunchao Zhang   Kokkos::realloc(factors->jU_d, spiluk_handle->get_nnzU());
170786a27549SJunchao Zhang 
170886a27549SJunchao Zhang   aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
1709076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
1710076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
1711076ba34aSJunchao Zhang   KokkosSparse::Experimental::spiluk_symbolic(&factors->kh, fill_lev, i_d, j_d, factors->iL_d, factors->jL_d, factors->iU_d, factors->jU_d);
171286a27549SJunchao Zhang   /* TODO: if spiluk_symbolic is asynchronous, do we need to sync before calling get_nnzL()? */
171386a27549SJunchao Zhang 
171486a27549SJunchao Zhang   Kokkos::resize(factors->jL_d, spiluk_handle->get_nnzL()); /* Shrink or expand, and retain old value */
171586a27549SJunchao Zhang   Kokkos::resize(factors->jU_d, spiluk_handle->get_nnzU());
171686a27549SJunchao Zhang   Kokkos::realloc(factors->aL_d, spiluk_handle->get_nnzL()); /* No need to retain old value */
171786a27549SJunchao Zhang   Kokkos::realloc(factors->aU_d, spiluk_handle->get_nnzU());
171886a27549SJunchao Zhang 
171986a27549SJunchao Zhang   /* TODO: add options to select sptrsv algorithms */
172086a27549SJunchao Zhang   /* Create sptrsv handles for L, U and their transpose */
172186a27549SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
172286a27549SJunchao Zhang   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SPTRSV_CUSPARSE;
172386a27549SJunchao Zhang #else
172486a27549SJunchao Zhang   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SEQLVLSCHD_TP1;
172586a27549SJunchao Zhang #endif
172686a27549SJunchao Zhang 
172786a27549SJunchao Zhang   factors->khL.create_sptrsv_handle(sptrsv_alg, n, true /* L is lower tri */);
172886a27549SJunchao Zhang   factors->khU.create_sptrsv_handle(sptrsv_alg, n, false /* U is not lower tri */);
172986a27549SJunchao Zhang   factors->khLt.create_sptrsv_handle(sptrsv_alg, n, false /* L^T is not lower tri */);
173086a27549SJunchao Zhang   factors->khUt.create_sptrsv_handle(sptrsv_alg, n, true /* U^T is lower tri */);
173186a27549SJunchao Zhang 
173286a27549SJunchao Zhang   /* Fill fields of the factor matrix B */
17339566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(B, MAT_SKIP_ALLOCATION, NULL));
173486a27549SJunchao Zhang   b     = (Mat_SeqAIJ *)B->data;
173586a27549SJunchao Zhang   b->nz = b->maxnz          = spiluk_handle->get_nnzL() + spiluk_handle->get_nnzU();
173686a27549SJunchao Zhang   B->info.fill_ratio_given  = info->fill;
1737a1e4e190SStefano Zampini   B->info.fill_ratio_needed = nnzA > 0 ? ((PetscReal)b->nz) / ((PetscReal)nnzA) : 1.0;
173886a27549SJunchao Zhang 
173986a27549SJunchao Zhang   B->offloadmask          = PETSC_OFFLOAD_GPU;
174086a27549SJunchao Zhang   B->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJKokkos;
17413ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1742930e68a5SMark Adams }
1743930e68a5SMark Adams 
1744d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatFactorGetSolverType_SeqAIJKokkos(Mat A, MatSolverType *type)
1745d71ae5a4SJacob Faibussowitsch {
1746930e68a5SMark Adams   PetscFunctionBegin;
1747930e68a5SMark Adams   *type = MATSOLVERKOKKOS;
17483ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1749930e68a5SMark Adams }
1750930e68a5SMark Adams 
1751930e68a5SMark Adams /*MC
175286a27549SJunchao Zhang   MATSOLVERKOKKOS = "Kokkos" - A matrix solver type providing triangular solvers for sequential matrices
175311a5261eSBarry Smith   on a single GPU of type, `MATSEQAIJKOKKOS`, `MATAIJKOKKOS`.
1754930e68a5SMark Adams 
1755930e68a5SMark Adams   Level: beginner
1756930e68a5SMark Adams 
17571cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJKokkos()`, `MATAIJKOKKOS`, `MatKokkosSetFormat()`, `MatKokkosStorageFormat`, `MatKokkosFormatOperation`
1758930e68a5SMark Adams M*/
175986a27549SJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_SeqAIJKokkos_Kokkos(Mat A, MatFactorType ftype, Mat *B) /* MatGetFactor_<MatType>_<MatSolverType> */
1760930e68a5SMark Adams {
1761930e68a5SMark Adams   PetscInt n = A->rmap->n;
1762930e68a5SMark Adams 
1763930e68a5SMark Adams   PetscFunctionBegin;
17649566063dSJacob Faibussowitsch   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B));
17659566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*B, n, n, n, n));
1766930e68a5SMark Adams   (*B)->factortype = ftype;
17679566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
17689566063dSJacob Faibussowitsch   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
1769930e68a5SMark Adams 
17708f7e8f9dSMark Adams   if (ftype == MAT_FACTOR_LU) {
17719566063dSJacob Faibussowitsch     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
177286a27549SJunchao Zhang     (*B)->canuseordering        = PETSC_TRUE;
177386a27549SJunchao Zhang     (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJKokkos;
177486a27549SJunchao Zhang   } else if (ftype == MAT_FACTOR_ILU) {
17759566063dSJacob Faibussowitsch     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
177686a27549SJunchao Zhang     (*B)->canuseordering         = PETSC_FALSE;
177786a27549SJunchao Zhang     (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJKokkos;
177898921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "MatFactorType %s is not supported by MatType SeqAIJKokkos", MatFactorTypes[ftype]);
1779930e68a5SMark Adams 
17809566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
17819566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)(*B), "MatFactorGetSolverType_C", MatFactorGetSolverType_SeqAIJKokkos));
17823ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1783930e68a5SMark Adams }
17848f7e8f9dSMark Adams 
1785d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatSolverTypeRegister_KOKKOS(void)
1786d71ae5a4SJacob Faibussowitsch {
178786a27549SJunchao Zhang   PetscFunctionBegin;
17889566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_SeqAIJKokkos_Kokkos));
17899566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ILU, MatGetFactor_SeqAIJKokkos_Kokkos));
17903ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
179186a27549SJunchao Zhang }
179286a27549SJunchao Zhang 
1793076ba34aSJunchao Zhang /* Utility to print out a KokkosCsrMatrix for debugging */
1794d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PrintCsrMatrix(const KokkosCsrMatrix &csrmat)
1795d71ae5a4SJacob Faibussowitsch {
1796076ba34aSJunchao Zhang   const auto        &iv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.row_map);
1797076ba34aSJunchao Zhang   const auto        &jv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.entries);
1798076ba34aSJunchao Zhang   const auto        &av = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.values);
1799076ba34aSJunchao Zhang   const PetscInt    *i  = iv.data();
1800076ba34aSJunchao Zhang   const PetscInt    *j  = jv.data();
1801076ba34aSJunchao Zhang   const PetscScalar *a  = av.data();
1802076ba34aSJunchao Zhang   PetscInt           m = csrmat.numRows(), n = csrmat.numCols(), nnz = csrmat.nnz();
1803076ba34aSJunchao Zhang 
1804076ba34aSJunchao Zhang   PetscFunctionBegin;
18059566063dSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT " x %" PetscInt_FMT " SeqAIJKokkos, with %" PetscInt_FMT " nonzeros\n", m, n, nnz));
1806076ba34aSJunchao Zhang   for (PetscInt k = 0; k < m; k++) {
18079566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT ": ", k));
180848a46eb9SPierre Jolivet     for (PetscInt p = i[k]; p < i[k + 1]; p++) PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT "(%.1f), ", j[p], (double)PetscRealPart(a[p])));
18099566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "\n"));
1810076ba34aSJunchao Zhang   }
18113ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1812076ba34aSJunchao Zhang }
1813