xref: /petsc/src/mat/impls/aij/seq/kokkos/aijkok.kokkos.cxx (revision f4f49eeac7efa77fffa46b7ff95a3ed169f659ed)
1e36ced11SJunchao Zhang #include <petsc_kokkos.hpp>
211d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
3076ba34aSJunchao Zhang #include <petscpkg_version.h>
4152b3e56SJunchao Zhang #include <petsc/private/petscimpl.h>
542550becSJunchao Zhang #include <petsc/private/sfimpl.h>
68c3ff71bSJunchao Zhang #include <petscsystypes.h>
78c3ff71bSJunchao Zhang #include <petscerror.h>
88c3ff71bSJunchao Zhang 
98c3ff71bSJunchao Zhang #include <Kokkos_Core.hpp>
10f0cf5187SStefano Zampini #include <KokkosBlas.hpp>
118c3ff71bSJunchao Zhang #include <KokkosSparse_CrsMatrix.hpp>
128c3ff71bSJunchao Zhang #include <KokkosSparse_spmv.hpp>
1386a27549SJunchao Zhang #include <KokkosSparse_spiluk.hpp>
1486a27549SJunchao Zhang #include <KokkosSparse_sptrsv.hpp>
15076ba34aSJunchao Zhang #include <KokkosSparse_spgemm.hpp>
16076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
179d13fa56SJunchao Zhang #include <KokkosBatched_LU_Decl.hpp>
189d13fa56SJunchao Zhang #include <KokkosBatched_InverseLU_Decl.hpp>
1986a27549SJunchao Zhang 
2042550becSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
218c3ff71bSJunchao Zhang 
220e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 7, 0)
23f98996d3SJunchao Zhang   #include <KokkosSparse_Utils.hpp>
24f98996d3SJunchao Zhang using KokkosSparse::sort_crs_matrix;
259371c9d4SSatish Balay using KokkosSparse::Impl::transpose_matrix;
26f98996d3SJunchao Zhang #else
27f98996d3SJunchao Zhang   #include <KokkosKernels_Sorting.hpp>
28f98996d3SJunchao Zhang using KokkosKernels::sort_crs_matrix;
299371c9d4SSatish Balay using KokkosKernels::Impl::transpose_matrix;
30f98996d3SJunchao Zhang #endif
31f98996d3SJunchao Zhang 
328c3ff71bSJunchao Zhang static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat); /* Forward declaration */
338c3ff71bSJunchao Zhang 
34076ba34aSJunchao Zhang /* MatAssemblyEnd_SeqAIJKokkos() happens when we finalized nonzeros of the matrix, either after
35076ba34aSJunchao Zhang    we assembled the matrix on host, or after we directly produced the matrix data on device (ex., through MatMatMult).
36076ba34aSJunchao Zhang    In the latter case, it is important to set a_dual's sync state correctly.
37076ba34aSJunchao Zhang  */
38d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_SeqAIJKokkos(Mat A, MatAssemblyType mode)
39d71ae5a4SJacob Faibussowitsch {
40076ba34aSJunchao Zhang   Mat_SeqAIJ       *aijseq;
41076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
428c3ff71bSJunchao Zhang 
438c3ff71bSJunchao Zhang   PetscFunctionBegin;
443ba16761SJacob Faibussowitsch   if (mode == MAT_FLUSH_ASSEMBLY) PetscFunctionReturn(PETSC_SUCCESS);
459566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
46076ba34aSJunchao Zhang 
47076ba34aSJunchao Zhang   aijseq = static_cast<Mat_SeqAIJ *>(A->data);
48076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
49076ba34aSJunchao Zhang 
50076ba34aSJunchao Zhang   /* If aijkok does not exist, we just copy i, j to device.
51076ba34aSJunchao Zhang      If aijkok already exists, but the device's nonzero pattern does not match with the host's, we assume the latest data is on host.
52076ba34aSJunchao Zhang      In both cases, we build a new aijkok structure.
53076ba34aSJunchao Zhang   */
54076ba34aSJunchao Zhang   if (!aijkok || aijkok->nonzerostate != A->nonzerostate) { /* aijkok might not exist yet or nonzero pattern has changed */
55076ba34aSJunchao Zhang     delete aijkok;
56f4747e26SJunchao Zhang     aijkok   = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aijseq, A->nonzerostate, PETSC_FALSE /*don't copy mat values to device*/);
57076ba34aSJunchao Zhang     A->spptr = aijkok;
58f4747e26SJunchao Zhang   } else if (A->rmap->n && aijkok->diag_dual.extent(0) == 0) { // MatProduct might directly produce AIJ on device, but not the diag.
59f4747e26SJunchao Zhang     MatRowMapKokkosViewHost diag_h(aijseq->diag, A->rmap->n);
60f4747e26SJunchao Zhang     auto                    diag_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), diag_h);
61f4747e26SJunchao Zhang     aijkok->diag_dual              = MatRowMapKokkosDualView(diag_d, diag_h);
62076ba34aSJunchao Zhang   }
633ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
648c3ff71bSJunchao Zhang }
658c3ff71bSJunchao Zhang 
6686a27549SJunchao Zhang /* Sync CSR data to device if not yet */
67d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosSyncDevice(Mat A)
68d71ae5a4SJacob Faibussowitsch {
698c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
708c3ff71bSJunchao Zhang 
718c3ff71bSJunchao Zhang   PetscFunctionBegin;
72aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from host to device");
735f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
74076ba34aSJunchao Zhang   if (aijkok->a_dual.need_sync_device()) {
75076ba34aSJunchao Zhang     aijkok->a_dual.sync_device();
76580c7c76SPierre Jolivet     aijkok->transpose_updated = PETSC_FALSE; /* values of the transpose is out-of-date */
7786a27549SJunchao Zhang     aijkok->hermitian_updated = PETSC_FALSE;
788c3ff71bSJunchao Zhang   }
793ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
808c3ff71bSJunchao Zhang }
818c3ff71bSJunchao Zhang 
82076ba34aSJunchao Zhang /* Mark the CSR data on device as modified */
83d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosModifyDevice(Mat A)
84d71ae5a4SJacob Faibussowitsch {
8586a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
8686a27549SJunchao Zhang 
8786a27549SJunchao Zhang   PetscFunctionBegin;
885f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Not supported for factorized matries");
8986a27549SJunchao Zhang   aijkok->a_dual.clear_sync_state();
9086a27549SJunchao Zhang   aijkok->a_dual.modify_device();
9186a27549SJunchao Zhang   aijkok->transpose_updated = PETSC_FALSE;
9286a27549SJunchao Zhang   aijkok->hermitian_updated = PETSC_FALSE;
939566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJInvalidateDiagonal(A));
949566063dSJacob Faibussowitsch   PetscCall(PetscObjectStateIncrease((PetscObject)A));
953ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
9686a27549SJunchao Zhang }
9786a27549SJunchao Zhang 
98d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSyncHost(Mat A)
99d71ae5a4SJacob Faibussowitsch {
100f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
101e36ced11SJunchao Zhang   auto             &exec   = PetscGetKokkosExecutionSpace();
102f0cf5187SStefano Zampini 
103f0cf5187SStefano Zampini   PetscFunctionBegin;
104f0cf5187SStefano Zampini   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10586a27549SJunchao Zhang   /* We do not expect one needs factors on host  */
106aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from device to host");
1075f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing AIJKOK");
108e36ced11SJunchao Zhang   PetscCallCXX(aijkok->a_dual.sync_host(exec));
109e36ced11SJunchao Zhang   PetscCallCXX(exec.fence());
1103ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
111f0cf5187SStefano Zampini }
112f0cf5187SStefano Zampini 
113d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
114d71ae5a4SJacob Faibussowitsch {
115076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
116f0cf5187SStefano Zampini 
117f0cf5187SStefano Zampini   PetscFunctionBegin;
1185519a089SJose E. Roman   /* aijkok contains valid pointers only if the host's nonzerostate matches with the device's.
1195519a089SJose E. Roman     Calling MatSeqAIJSetPreallocation() or MatSetValues() on host, where aijseq->{i,j,a} might be
1205519a089SJose E. Roman     reallocated, will lead to stale {i,j,a}_dual in aijkok. In both operations, the hosts's nonzerostate
1215519a089SJose E. Roman     must have been updated. The stale aijkok will be rebuilt during MatAssemblyEnd.
1225519a089SJose E. Roman   */
1235519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
124e36ced11SJunchao Zhang     auto &exec = PetscGetKokkosExecutionSpace();
125e36ced11SJunchao Zhang     PetscCallCXX(aijkok->a_dual.sync_host(exec));
126e36ced11SJunchao Zhang     PetscCallCXX(exec.fence());
127076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
128076ba34aSJunchao Zhang   } else { /* Happens when calling MatSetValues on a newly created matrix */
129076ba34aSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
130076ba34aSJunchao Zhang   }
1313ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
132076ba34aSJunchao Zhang }
133076ba34aSJunchao Zhang 
134d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
135d71ae5a4SJacob Faibussowitsch {
136076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
137076ba34aSJunchao Zhang 
138076ba34aSJunchao Zhang   PetscFunctionBegin;
1395519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) aijkok->a_dual.modify_host();
1403ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
141076ba34aSJunchao Zhang }
142076ba34aSJunchao Zhang 
143d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
144d71ae5a4SJacob Faibussowitsch {
145076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
146076ba34aSJunchao Zhang 
147076ba34aSJunchao Zhang   PetscFunctionBegin;
1485519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
149e36ced11SJunchao Zhang     auto &exec = PetscGetKokkosExecutionSpace();
150e36ced11SJunchao Zhang     PetscCallCXX(aijkok->a_dual.sync_host(exec));
151e36ced11SJunchao Zhang     PetscCallCXX(exec.fence());
152076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1532328674fSJunchao Zhang   } else {
1542328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1552328674fSJunchao Zhang   }
1563ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
157076ba34aSJunchao Zhang }
158076ba34aSJunchao Zhang 
159d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
160d71ae5a4SJacob Faibussowitsch {
161076ba34aSJunchao Zhang   PetscFunctionBegin;
162076ba34aSJunchao Zhang   *array = NULL;
1633ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
164076ba34aSJunchao Zhang }
165076ba34aSJunchao Zhang 
166d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
167d71ae5a4SJacob Faibussowitsch {
168076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
169076ba34aSJunchao Zhang 
170076ba34aSJunchao Zhang   PetscFunctionBegin;
1715519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
172076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1732328674fSJunchao Zhang   } else { /* Ex. happens with MatZeroEntries on a preallocated but not assembled matrix */
1742328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1752328674fSJunchao Zhang   }
1763ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
177076ba34aSJunchao Zhang }
178076ba34aSJunchao Zhang 
179d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
180d71ae5a4SJacob Faibussowitsch {
181076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
182076ba34aSJunchao Zhang 
183076ba34aSJunchao Zhang   PetscFunctionBegin;
1845519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
185076ba34aSJunchao Zhang     aijkok->a_dual.clear_sync_state();
186076ba34aSJunchao Zhang     aijkok->a_dual.modify_host();
1872328674fSJunchao Zhang   }
1883ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
189f0cf5187SStefano Zampini }
190f0cf5187SStefano Zampini 
191d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJKokkos(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
192d71ae5a4SJacob Faibussowitsch {
1937ee59b9bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1947ee59b9bSJunchao Zhang 
1957ee59b9bSJunchao Zhang   PetscFunctionBegin;
1967ee59b9bSJunchao Zhang   PetscCheck(aijkok != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "aijkok is NULL");
1977ee59b9bSJunchao Zhang 
1987ee59b9bSJunchao Zhang   if (i) *i = aijkok->i_device_data();
1997ee59b9bSJunchao Zhang   if (j) *j = aijkok->j_device_data();
2007ee59b9bSJunchao Zhang   if (a) {
2017ee59b9bSJunchao Zhang     aijkok->a_dual.sync_device();
2027ee59b9bSJunchao Zhang     *a = aijkok->a_device_data();
2037ee59b9bSJunchao Zhang   }
2047ee59b9bSJunchao Zhang   if (mtype) *mtype = PETSC_MEMTYPE_KOKKOS;
2053ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2067ee59b9bSJunchao Zhang }
2077ee59b9bSJunchao Zhang 
2080e3ece09SJunchao Zhang /*
2090e3ece09SJunchao Zhang   Generate the sparsity pattern of a MatSeqAIJKokkos matrix's transpose on device.
2100e3ece09SJunchao Zhang 
2110e3ece09SJunchao Zhang   Input Parameter:
2120e3ece09SJunchao Zhang .  A       - the MATSEQAIJKOKKOS matrix
2130e3ece09SJunchao Zhang 
2140e3ece09SJunchao Zhang   Output Parameters:
2150e3ece09SJunchao Zhang +  perm_d - the permutation array on device, which connects Ta(i) = Aa(perm(i))
216aaa8cc7dSPierre Jolivet -  T_d    - the transpose on device, whose value array is allocated but not initialized
2170e3ece09SJunchao Zhang */
2180e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateTransposeStructure(Mat A, MatRowMapKokkosView &perm_d, KokkosCsrMatrix &T_d)
219d71ae5a4SJacob Faibussowitsch {
2200e3ece09SJunchao Zhang   Mat_SeqAIJ             *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2210e3ece09SJunchao Zhang   PetscInt                nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
2220e3ece09SJunchao Zhang   const PetscInt         *Ai = aseq->i, *Aj = aseq->j;
2237b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost Ti_h(NoInit("Ti"), n + 1);
2240e3ece09SJunchao Zhang   MatRowMapType          *Ti = Ti_h.data();
2257b8d4ba6SJunchao Zhang   MatColIdxKokkosViewHost Tj_h(NoInit("Tj"), nz);
2267b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost perm_h(NoInit("permutation"), nz);
2270e3ece09SJunchao Zhang   PetscInt               *Tj   = Tj_h.data();
2280e3ece09SJunchao Zhang   PetscInt               *perm = perm_h.data();
2290e3ece09SJunchao Zhang   PetscInt               *offset;
230152b3e56SJunchao Zhang 
231152b3e56SJunchao Zhang   PetscFunctionBegin;
2320e3ece09SJunchao Zhang   // Populate Ti
2330e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::deep_copy(Ti_h, 0));
2340e3ece09SJunchao Zhang   Ti++;
2350e3ece09SJunchao Zhang   for (PetscInt i = 0; i < nz; i++) Ti[Aj[i]]++;
2360e3ece09SJunchao Zhang   Ti--;
2370e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) Ti[i + 1] += Ti[i];
2380e3ece09SJunchao Zhang 
2390e3ece09SJunchao Zhang   // Populate Tj and the permutation array
2400e3ece09SJunchao Zhang   PetscCall(PetscCalloc1(n, &offset)); // offset in each T row to fill in its column indices
2410e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
2420e3ece09SJunchao Zhang     for (PetscInt j = Ai[i]; j < Ai[i + 1]; j++) { // A's (i,j) is T's (j,i)
2430e3ece09SJunchao Zhang       PetscInt r    = Aj[j];                       // row r of T
2440e3ece09SJunchao Zhang       PetscInt disp = Ti[r] + offset[r];
2450e3ece09SJunchao Zhang 
2460e3ece09SJunchao Zhang       Tj[disp]   = i; // col i of T
2470e3ece09SJunchao Zhang       perm[disp] = j;
2480e3ece09SJunchao Zhang       offset[r]++;
249076ba34aSJunchao Zhang     }
2500e3ece09SJunchao Zhang   }
2510e3ece09SJunchao Zhang   PetscCall(PetscFree(offset));
2520e3ece09SJunchao Zhang 
2530e3ece09SJunchao Zhang   // Sort each row of T, along with the permutation array
2540e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) PetscCall(PetscSortIntWithArray(Ti[i + 1] - Ti[i], Tj + Ti[i], perm + Ti[i]));
2550e3ece09SJunchao Zhang 
2560e3ece09SJunchao Zhang   // Output perm and T on device
2570e3ece09SJunchao Zhang   auto Ti_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ti_h);
2580e3ece09SJunchao Zhang   auto Tj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Tj_h);
2590e3ece09SJunchao Zhang   PetscCallCXX(T_d = KokkosCsrMatrix("csrmatT", n, m, nz, MatScalarKokkosView("Ta", nz), Ti_d, Tj_d));
2600e3ece09SJunchao Zhang   PetscCallCXX(perm_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), perm_h));
2613ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
262152b3e56SJunchao Zhang }
263152b3e56SJunchao Zhang 
2640e3ece09SJunchao Zhang // Generate the transpose on device and cache it internally
2650e3ece09SJunchao Zhang // Note: KK transpose_matrix() does not have support symbolic/numeric transpose, so we do it on our own
2660e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGenerateTranspose_Private(Mat A, KokkosCsrMatrix *csrmatT)
267d71ae5a4SJacob Faibussowitsch {
2680e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2690e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
2700e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
2710e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatT;
272152b3e56SJunchao Zhang 
273152b3e56SJunchao Zhang   PetscFunctionBegin;
2740e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
275145b44c9SPierre Jolivet   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's values since we are going to access them on device
2760e3ece09SJunchao Zhang 
2770e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
2780e3ece09SJunchao Zhang 
2790e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE) {
2800e3ece09SJunchao Zhang     *csrmatT = akok->csrmat;
2810e3ece09SJunchao Zhang   } else {
2820e3ece09SJunchao Zhang     // See if we already have a cached transpose and its value is up to date
2830e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
2840e3ece09SJunchao Zhang       if (!akok->transpose_updated) {            // if the value is out of date, update the cached version
2850e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
2860e3ece09SJunchao Zhang         auto       &Ta   = T.values;
2870e3ece09SJunchao Zhang 
2880e3ece09SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(
2890e3ece09SJunchao Zhang           nz, KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = Aa(perm(i)); }));
290076ba34aSJunchao Zhang       }
2910e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
2920e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
2930e3ece09SJunchao Zhang 
2940e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
2950e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
2960e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(
2970e3ece09SJunchao Zhang         nz, KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = Aa(perm(i)); }));
2980e3ece09SJunchao Zhang     }
2990e3ece09SJunchao Zhang     akok->transpose_updated = PETSC_TRUE;
3000e3ece09SJunchao Zhang     *csrmatT                = akok->csrmatT;
3010e3ece09SJunchao Zhang   }
3020e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3030e3ece09SJunchao Zhang }
3040e3ece09SJunchao Zhang 
3050e3ece09SJunchao Zhang // Generate the Hermitian on device and cache it internally
3060e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateHermitian_Private(Mat A, KokkosCsrMatrix *csrmatH)
3070e3ece09SJunchao Zhang {
3080e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
3090e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3100e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
3110e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatH;
3120e3ece09SJunchao Zhang 
3130e3ece09SJunchao Zhang   PetscFunctionBegin;
3140e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
3150e3ece09SJunchao Zhang   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's values since we are going to access them on device
3160e3ece09SJunchao Zhang 
3170e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
3180e3ece09SJunchao Zhang 
3190e3ece09SJunchao Zhang   if (A->hermitian == PETSC_BOOL3_TRUE) {
3200e3ece09SJunchao Zhang     *csrmatH = akok->csrmat;
3210e3ece09SJunchao Zhang   } else {
3220e3ece09SJunchao Zhang     // See if we already have a cached hermitian and its value is up to date
3230e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
3240e3ece09SJunchao Zhang       if (!akok->hermitian_updated) {            // if the value is out of date, update the cached version
3250e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
3260e3ece09SJunchao Zhang         auto       &Ta   = T.values;
3270e3ece09SJunchao Zhang 
3280e3ece09SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(
3290e3ece09SJunchao Zhang           nz, KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = PetscConj(Aa(perm(i))); }));
3300e3ece09SJunchao Zhang       }
3310e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
3320e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
3330e3ece09SJunchao Zhang 
3340e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
3350e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
3360e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(
3370e3ece09SJunchao Zhang         nz, KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = PetscConj(Aa(perm(i))); }));
3380e3ece09SJunchao Zhang     }
3390e3ece09SJunchao Zhang     akok->hermitian_updated = PETSC_TRUE;
3400e3ece09SJunchao Zhang     *csrmatH                = akok->csrmatH;
3410e3ece09SJunchao Zhang   }
3423ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
343152b3e56SJunchao Zhang }
344a587d139SMark 
3458c3ff71bSJunchao Zhang /* y = A x */
346d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMult_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
347d71ae5a4SJacob Faibussowitsch {
3488c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
349152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
350152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3518c3ff71bSJunchao Zhang 
3528c3ff71bSJunchao Zhang   PetscFunctionBegin;
3539566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3549566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3559566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3569566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
3578c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3589d52486cSJunchao Zhang   PetscCallCXX(KokkosSparse::spmv("N", 1.0 /*alpha*/, aijkok->csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A x + beta y */
3599566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3609566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
361076ba34aSJunchao Zhang   /* 2.0*nnz - numRows seems more accurate here but assumes there are no zero-rows. So a little sloppy here. */
3629566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
3639566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
3643ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
3658c3ff71bSJunchao Zhang }
3668c3ff71bSJunchao Zhang 
3678c3ff71bSJunchao Zhang /* y = A^T x */
368d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
369d71ae5a4SJacob Faibussowitsch {
3708c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
371152b3e56SJunchao Zhang   const char                *mode;
372152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
373152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3740e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
3758c3ff71bSJunchao Zhang 
3768c3ff71bSJunchao Zhang   PetscFunctionBegin;
3779566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3789566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3799566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3809566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
381152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
3829566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
383152b3e56SJunchao Zhang     mode = "N";
384152b3e56SJunchao Zhang   } else {
385076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3860e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
387152b3e56SJunchao Zhang     mode   = "T";
388152b3e56SJunchao Zhang   }
3890e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^T x + beta y */
3909566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3919566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
3920e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
3939566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
3943ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
3958c3ff71bSJunchao Zhang }
3968c3ff71bSJunchao Zhang 
3978c3ff71bSJunchao Zhang /* y = A^H x */
398d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
399d71ae5a4SJacob Faibussowitsch {
4008c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
401152b3e56SJunchao Zhang   const char                *mode;
402152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
403152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
4040e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4058c3ff71bSJunchao Zhang 
4068c3ff71bSJunchao Zhang   PetscFunctionBegin;
4079566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4089566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4099566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4109566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
411152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4129566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
413152b3e56SJunchao Zhang     mode = "N";
414152b3e56SJunchao Zhang   } else {
415076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4160e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
417152b3e56SJunchao Zhang     mode   = "C";
418152b3e56SJunchao Zhang   }
4190e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^H x + beta y */
4209566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4219566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
4220e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4239566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4243ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4258c3ff71bSJunchao Zhang }
4268c3ff71bSJunchao Zhang 
4278c3ff71bSJunchao Zhang /* z = A x + y */
428d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
429d71ae5a4SJacob Faibussowitsch {
4308c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
431152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
432152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4338c3ff71bSJunchao Zhang 
4348c3ff71bSJunchao Zhang   PetscFunctionBegin;
4359566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4369566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4379566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4389566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
4399566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
4408c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
4418c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4429d52486cSJunchao Zhang   PetscCallCXX(KokkosSparse::spmv("N", 1.0 /*alpha*/, aijkok->csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A x + beta z */
4439566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4449566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
4459566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
4469566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
4479566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4483ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4498c3ff71bSJunchao Zhang }
4508c3ff71bSJunchao Zhang 
4518c3ff71bSJunchao Zhang /* z = A^T x + y */
452d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
453d71ae5a4SJacob Faibussowitsch {
4548c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
455152b3e56SJunchao Zhang   const char                *mode;
456152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
457152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4580e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4598c3ff71bSJunchao Zhang 
4608c3ff71bSJunchao Zhang   PetscFunctionBegin;
4619566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4629566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4639566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4649566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
4659566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
4668c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
467152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4689566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
469152b3e56SJunchao Zhang     mode = "N";
470152b3e56SJunchao Zhang   } else {
471076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4720e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
473152b3e56SJunchao Zhang     mode   = "T";
474152b3e56SJunchao Zhang   }
4750e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^T x + beta z */
4769566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4779566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
4789566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
4790e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4809566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4813ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4828c3ff71bSJunchao Zhang }
4838c3ff71bSJunchao Zhang 
4848c3ff71bSJunchao Zhang /* z = A^H x + y */
485d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
486d71ae5a4SJacob Faibussowitsch {
4878c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
488152b3e56SJunchao Zhang   const char                *mode;
489152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
490152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4910e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4928c3ff71bSJunchao Zhang 
4938c3ff71bSJunchao Zhang   PetscFunctionBegin;
4949566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4959566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4969566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4979566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
4989566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
4998c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
500152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
5019566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
502152b3e56SJunchao Zhang     mode = "N";
503152b3e56SJunchao Zhang   } else {
504076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
5050e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
506152b3e56SJunchao Zhang     mode   = "C";
507152b3e56SJunchao Zhang   }
5080e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^H x + beta z */
5099566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
5109566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
5119566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
5120e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
5139566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
5143ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
515152b3e56SJunchao Zhang }
516152b3e56SJunchao Zhang 
51766976f2fSJacob Faibussowitsch static PetscErrorCode MatSetOption_SeqAIJKokkos(Mat A, MatOption op, PetscBool flg)
518d71ae5a4SJacob Faibussowitsch {
519152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
520152b3e56SJunchao Zhang 
521152b3e56SJunchao Zhang   PetscFunctionBegin;
522152b3e56SJunchao Zhang   switch (op) {
523152b3e56SJunchao Zhang   case MAT_FORM_EXPLICIT_TRANSPOSE:
524152b3e56SJunchao Zhang     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
5259566063dSJacob Faibussowitsch     if (A->form_explicit_transpose && !flg && aijkok) PetscCall(aijkok->DestroyMatTranspose());
526152b3e56SJunchao Zhang     A->form_explicit_transpose = flg;
527152b3e56SJunchao Zhang     break;
528d71ae5a4SJacob Faibussowitsch   default:
529d71ae5a4SJacob Faibussowitsch     PetscCall(MatSetOption_SeqAIJ(A, op, flg));
530d71ae5a4SJacob Faibussowitsch     break;
531152b3e56SJunchao Zhang   }
5323ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5338c3ff71bSJunchao Zhang }
5348c3ff71bSJunchao Zhang 
535076ba34aSJunchao Zhang /* Depending on reuse, either build a new mat, or use the existing mat */
536d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
537d71ae5a4SJacob Faibussowitsch {
538076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
5398c3ff71bSJunchao Zhang 
5408c3ff71bSJunchao Zhang   PetscFunctionBegin;
5419566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
542076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {                      /* Build a brand new mat */
5439566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));  /* the returned newmat is a SeqAIJKokkos */
5448c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {                 /* Reuse the mat created before */
5459566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); /* newmat is already a SeqAIJKokkos */
546076ba34aSJunchao Zhang   } else if (reuse == MAT_INPLACE_MATRIX) {               /* newmat is A */
5475f80ce2aSJacob Faibussowitsch     PetscCheck(A == *newmat, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "A != *newmat with MAT_INPLACE_MATRIX");
5489566063dSJacob Faibussowitsch     PetscCall(PetscFree(A->defaultvectype));
5499566063dSJacob Faibussowitsch     PetscCall(PetscStrallocpy(VECKOKKOS, &A->defaultvectype)); /* Allocate and copy the string */
5509566063dSJacob Faibussowitsch     PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSEQAIJKOKKOS));
5519566063dSJacob Faibussowitsch     PetscCall(MatSetOps_SeqAIJKokkos(A));
552076ba34aSJunchao Zhang     aseq = static_cast<Mat_SeqAIJ *>(A->data);
553394ed5ebSJunchao Zhang     if (A->assembled) { /* Copy i, j (but not values) to device for an assembled matrix if not yet */
5545f80ce2aSJacob Faibussowitsch       PetscCheck(!A->spptr, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Expect NULL (Mat_SeqAIJKokkos*)A->spptr");
555f4747e26SJunchao Zhang       A->spptr = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aseq, A->nonzerostate, PETSC_FALSE);
5568c3ff71bSJunchao Zhang     }
557076ba34aSJunchao Zhang   }
5583ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5598c3ff71bSJunchao Zhang }
5608c3ff71bSJunchao Zhang 
561076ba34aSJunchao Zhang /* MatDuplicate always creates a new matrix. MatDuplicate can be called either on an assembled matrix or
562076ba34aSJunchao Zhang    an unassembled matrix, even though MAT_COPY_VALUES is not allowed for unassembled matrix.
563076ba34aSJunchao Zhang  */
564d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDuplicate_SeqAIJKokkos(Mat A, MatDuplicateOption dupOption, Mat *B)
565d71ae5a4SJacob Faibussowitsch {
566076ba34aSJunchao Zhang   Mat_SeqAIJ       *bseq;
567076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok;
568076ba34aSJunchao Zhang   Mat               mat;
5698c3ff71bSJunchao Zhang 
5708c3ff71bSJunchao Zhang   PetscFunctionBegin;
571076ba34aSJunchao Zhang   /* Do not copy values on host as A's latest values might be on device. We don't want to do sync blindly */
5729566063dSJacob Faibussowitsch   PetscCall(MatDuplicate_SeqAIJ(A, MAT_DO_NOT_COPY_VALUES, B));
573076ba34aSJunchao Zhang   mat = *B;
574f4747e26SJunchao Zhang   if (A->assembled) {
575076ba34aSJunchao Zhang     bseq = static_cast<Mat_SeqAIJ *>(mat->data);
576f4747e26SJunchao Zhang     bkok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, bseq, mat->nonzerostate, PETSC_FALSE);
577076ba34aSJunchao Zhang     bkok->a_dual.clear_sync_state(); /* Clear B's sync state as it will be decided below */
578076ba34aSJunchao Zhang     /* Now copy values to B if needed */
579076ba34aSJunchao Zhang     if (dupOption == MAT_COPY_VALUES) {
580076ba34aSJunchao Zhang       if (akok->a_dual.need_sync_device()) {
581076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_host(), akok->a_dual.view_host());
582076ba34aSJunchao Zhang         bkok->a_dual.modify_host();
583076ba34aSJunchao Zhang       } else { /* If device has the latest data, we only copy data on device */
584076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_device(), akok->a_dual.view_device());
585076ba34aSJunchao Zhang         bkok->a_dual.modify_device();
586076ba34aSJunchao Zhang       }
587076ba34aSJunchao Zhang     } else { /* MAT_DO_NOT_COPY_VALUES or MAT_SHARE_NONZERO_PATTERN. B's values should be zeroed */
588076ba34aSJunchao Zhang       /* B's values on host should be already zeroed by MatDuplicate_SeqAIJ() */
589076ba34aSJunchao Zhang       bkok->a_dual.modify_host();
590076ba34aSJunchao Zhang     }
591076ba34aSJunchao Zhang     mat->spptr = bkok;
592076ba34aSJunchao Zhang   }
593076ba34aSJunchao Zhang 
5949566063dSJacob Faibussowitsch   PetscCall(PetscFree(mat->defaultvectype));
5959566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &mat->defaultvectype)); /* Allocate and copy the string */
5969566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)mat, MATSEQAIJKOKKOS));
5979566063dSJacob Faibussowitsch   PetscCall(MatSetOps_SeqAIJKokkos(mat));
5983ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5998c3ff71bSJunchao Zhang }
6008c3ff71bSJunchao Zhang 
601d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatTranspose_SeqAIJKokkos(Mat A, MatReuse reuse, Mat *B)
602d71ae5a4SJacob Faibussowitsch {
6030ecb592aSJunchao Zhang   Mat               At;
6040e3ece09SJunchao Zhang   KokkosCsrMatrix   internT;
6050ecb592aSJunchao Zhang   Mat_SeqAIJKokkos *atkok, *bkok;
6060ecb592aSJunchao Zhang 
6070ecb592aSJunchao Zhang   PetscFunctionBegin;
6087fb60732SBarry Smith   if (reuse == MAT_REUSE_MATRIX) PetscCall(MatTransposeCheckNonzeroState_Private(A, *B));
6099566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &internT)); /* Generate a transpose internally */
6100ecb592aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX || reuse == MAT_INPLACE_MATRIX) {
611ff751488SJunchao Zhang     /* Deep copy internT, as we want to isolate the internal transpose */
6120e3ece09SJunchao Zhang     PetscCallCXX(atkok = new Mat_SeqAIJKokkos(KokkosCsrMatrix("csrmat", internT)));
6139566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PetscObjectComm((PetscObject)A), atkok, &At));
6140ecb592aSJunchao Zhang     if (reuse == MAT_INITIAL_MATRIX) *B = At;
6159566063dSJacob Faibussowitsch     else PetscCall(MatHeaderReplace(A, &At)); /* Replace A with At inplace */
6160ecb592aSJunchao Zhang   } else {                                    /* MAT_REUSE_MATRIX, just need to copy values to B on device */
6170ecb592aSJunchao Zhang     if ((*B)->assembled) {
6180ecb592aSJunchao Zhang       bkok = static_cast<Mat_SeqAIJKokkos *>((*B)->spptr);
6190e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(bkok->a_dual.view_device(), internT.values));
6209566063dSJacob Faibussowitsch       PetscCall(MatSeqAIJKokkosModifyDevice(*B));
6210ecb592aSJunchao Zhang     } else if ((*B)->preallocated) { /* It is ok for B to be only preallocated, as needed in MatTranspose_MPIAIJ */
6220ecb592aSJunchao Zhang       Mat_SeqAIJ             *bseq = static_cast<Mat_SeqAIJ *>((*B)->data);
6230e3ece09SJunchao Zhang       MatScalarKokkosViewHost a_h(bseq->a, internT.nnz()); /* bseq->nz = 0 if unassembled */
6240e3ece09SJunchao Zhang       MatColIdxKokkosViewHost j_h(bseq->j, internT.nnz());
6250e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(a_h, internT.values));
6260e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(j_h, internT.graph.entries));
6270ecb592aSJunchao Zhang     } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "B must be assembled or preallocated");
6280ecb592aSJunchao Zhang   }
6293ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6300ecb592aSJunchao Zhang }
6310ecb592aSJunchao Zhang 
632d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDestroy_SeqAIJKokkos(Mat A)
633d71ae5a4SJacob Faibussowitsch {
63486a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
6358c3ff71bSJunchao Zhang 
6368c3ff71bSJunchao Zhang   PetscFunctionBegin;
63786a27549SJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
63886a27549SJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
6398c3ff71bSJunchao Zhang     delete aijkok;
64086a27549SJunchao Zhang   } else {
64186a27549SJunchao Zhang     delete static_cast<Mat_SeqAIJKokkosTriFactors *>(A->spptr);
64286a27549SJunchao Zhang   }
643cbc6b225SStefano Zampini   A->spptr = NULL;
6449566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
6459566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
6469566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
6479566063dSJacob Faibussowitsch   PetscCall(MatDestroy_SeqAIJ(A));
6483ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6498c3ff71bSJunchao Zhang }
6508c3ff71bSJunchao Zhang 
6513f3ba80aSJunchao Zhang /*MC
6523f3ba80aSJunchao Zhang    MATSEQAIJKOKKOS - MATAIJKOKKOS = "(seq)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
6533f3ba80aSJunchao Zhang 
65415229ffcSPierre Jolivet    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
6553f3ba80aSJunchao Zhang 
6562ef1f0ffSBarry Smith    Options Database Key:
65711a5261eSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATSEQAIJKOKKOS` during a call to `MatSetFromOptions()`
6583f3ba80aSJunchao Zhang 
6593f3ba80aSJunchao Zhang   Level: beginner
6603f3ba80aSJunchao Zhang 
6611cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateSeqAIJKokkos()`, `MATMPIAIJKOKKOS`
6623f3ba80aSJunchao Zhang M*/
663d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJKokkos(Mat A)
664d71ae5a4SJacob Faibussowitsch {
66586a27549SJunchao Zhang   PetscFunctionBegin;
6669566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
6679566063dSJacob Faibussowitsch   PetscCall(MatCreate_SeqAIJ(A));
6689566063dSJacob Faibussowitsch   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
6693ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
67086a27549SJunchao Zhang }
67186a27549SJunchao Zhang 
672076ba34aSJunchao Zhang /* Merge A, B into a matrix C. A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n) */
673d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
674d71ae5a4SJacob Faibussowitsch {
675076ba34aSJunchao Zhang   Mat_SeqAIJ         *a, *b;
676076ba34aSJunchao Zhang   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
677076ba34aSJunchao Zhang   MatScalarKokkosView aa, ba, ca;
678076ba34aSJunchao Zhang   MatRowMapKokkosView ai, bi, ci;
679076ba34aSJunchao Zhang   MatColIdxKokkosView aj, bj, cj;
680076ba34aSJunchao Zhang   PetscInt            m, n, nnz, aN;
681a3f881fbSStefano Zampini 
682a3f881fbSStefano Zampini   PetscFunctionBegin;
683076ba34aSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
684076ba34aSJunchao Zhang   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
6854f572ea9SToby Isaac   PetscAssertPointer(C, 4);
686076ba34aSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
687076ba34aSJunchao Zhang   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
6885f80ce2aSJacob Faibussowitsch   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
6895f80ce2aSJacob Faibussowitsch   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
690076ba34aSJunchao Zhang 
6919566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
6929566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
693076ba34aSJunchao Zhang   a    = static_cast<Mat_SeqAIJ *>(A->data);
694076ba34aSJunchao Zhang   b    = static_cast<Mat_SeqAIJ *>(B->data);
695076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
696076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
697076ba34aSJunchao Zhang   aa   = akok->a_dual.view_device();
698076ba34aSJunchao Zhang   ai   = akok->i_dual.view_device();
699076ba34aSJunchao Zhang   ba   = bkok->a_dual.view_device();
700076ba34aSJunchao Zhang   bi   = bkok->i_dual.view_device();
701076ba34aSJunchao Zhang   m    = A->rmap->n; /* M, N and nnz of C */
702076ba34aSJunchao Zhang   n    = A->cmap->n + B->cmap->n;
703076ba34aSJunchao Zhang   nnz  = a->nz + b->nz;
704076ba34aSJunchao Zhang   aN   = A->cmap->n; /* N of A */
705076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
706076ba34aSJunchao Zhang     aj           = akok->j_dual.view_device();
707076ba34aSJunchao Zhang     bj           = bkok->j_dual.view_device();
708076ba34aSJunchao Zhang     auto ca_dual = MatScalarKokkosDualView("a", aa.extent(0) + ba.extent(0));
709076ba34aSJunchao Zhang     auto ci_dual = MatRowMapKokkosDualView("i", ai.extent(0));
710076ba34aSJunchao Zhang     auto cj_dual = MatColIdxKokkosDualView("j", aj.extent(0) + bj.extent(0));
711076ba34aSJunchao Zhang     ca           = ca_dual.view_device();
712076ba34aSJunchao Zhang     ci           = ci_dual.view_device();
713076ba34aSJunchao Zhang     cj           = cj_dual.view_device();
714076ba34aSJunchao Zhang 
715076ba34aSJunchao Zhang     /* Concatenate A and B in parallel using Kokkos hierarchical parallelism */
7169371c9d4SSatish Balay     Kokkos::parallel_for(
7179371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
718076ba34aSJunchao Zhang         PetscInt i       = t.league_rank(); /* row i */
719076ba34aSJunchao Zhang         PetscInt coffset = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
720076ba34aSJunchao Zhang 
721076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* this side effect only happens once per whole team */
722076ba34aSJunchao Zhang                                                    ci(i) = coffset;
723076ba34aSJunchao Zhang                                                    if (i == m - 1) ci(m) = ai(m) + bi(m);
724076ba34aSJunchao Zhang         });
725076ba34aSJunchao Zhang 
726076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
727076ba34aSJunchao Zhang           if (k < alen) {
728076ba34aSJunchao Zhang             ca(coffset + k) = aa(ai(i) + k);
729076ba34aSJunchao Zhang             cj(coffset + k) = aj(ai(i) + k);
730076ba34aSJunchao Zhang           } else {
731076ba34aSJunchao Zhang             ca(coffset + k) = ba(bi(i) + k - alen);
732076ba34aSJunchao Zhang             cj(coffset + k) = bj(bi(i) + k - alen) + aN; /* Entries in B get new column indices in C */
733076ba34aSJunchao Zhang           }
734076ba34aSJunchao Zhang         });
735076ba34aSJunchao Zhang       });
736076ba34aSJunchao Zhang     ca_dual.modify_device();
737076ba34aSJunchao Zhang     ci_dual.modify_device();
738076ba34aSJunchao Zhang     cj_dual.modify_device();
7399566063dSJacob Faibussowitsch     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, n, nnz, ci_dual, cj_dual, ca_dual));
7409566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
741076ba34aSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
742076ba34aSJunchao Zhang     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
743076ba34aSJunchao Zhang     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
744076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
745076ba34aSJunchao Zhang     ca   = ckok->a_dual.view_device();
746076ba34aSJunchao Zhang     ci   = ckok->i_dual.view_device();
747076ba34aSJunchao Zhang 
7489371c9d4SSatish Balay     Kokkos::parallel_for(
7499371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
750076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* row i */
751076ba34aSJunchao Zhang         PetscInt alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
752076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
753076ba34aSJunchao Zhang           if (k < alen) ca(ci(i) + k) = aa(ai(i) + k);
754076ba34aSJunchao Zhang           else ca(ci(i) + k) = ba(bi(i) + k - alen);
755076ba34aSJunchao Zhang         });
756076ba34aSJunchao Zhang       });
7579566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
758076ba34aSJunchao Zhang   }
7593ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
760076ba34aSJunchao Zhang }
761076ba34aSJunchao Zhang 
762d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_SeqAIJKokkos(void *pdata)
763d71ae5a4SJacob Faibussowitsch {
764076ba34aSJunchao Zhang   PetscFunctionBegin;
765076ba34aSJunchao Zhang   delete static_cast<MatProductData_SeqAIJKokkos *>(pdata);
7663ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
767a3f881fbSStefano Zampini }
768a3f881fbSStefano Zampini 
769d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos(Mat C)
770d71ae5a4SJacob Faibussowitsch {
771a3f881fbSStefano Zampini   Mat_Product                 *product = C->product;
772a3f881fbSStefano Zampini   Mat                          A, B;
773076ba34aSJunchao Zhang   bool                         transA, transB; /* use bool, since KK needs this type */
774a3f881fbSStefano Zampini   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
775a3f881fbSStefano Zampini   Mat_SeqAIJ                  *c;
776076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
7770e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB;
778a3f881fbSStefano Zampini 
779a3f881fbSStefano Zampini   PetscFunctionBegin;
780a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
7815f80ce2aSJacob Faibussowitsch   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data empty");
782076ba34aSJunchao Zhang   pdata = static_cast<MatProductData_SeqAIJKokkos *>(C->product->data);
783076ba34aSJunchao Zhang 
7840e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
7850e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
7860e3ece09SJunchao Zhang   // we still do numeric.
7870e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
7880e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
7893ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
790076ba34aSJunchao Zhang   }
791076ba34aSJunchao Zhang 
792076ba34aSJunchao Zhang   switch (product->type) {
7939371c9d4SSatish Balay   case MATPRODUCT_AB:
7949371c9d4SSatish Balay     transA = false;
7959371c9d4SSatish Balay     transB = false;
7969371c9d4SSatish Balay     break;
7979371c9d4SSatish Balay   case MATPRODUCT_AtB:
7989371c9d4SSatish Balay     transA = true;
7999371c9d4SSatish Balay     transB = false;
8009371c9d4SSatish Balay     break;
8019371c9d4SSatish Balay   case MATPRODUCT_ABt:
8029371c9d4SSatish Balay     transA = false;
8039371c9d4SSatish Balay     transB = true;
8049371c9d4SSatish Balay     break;
805d71ae5a4SJacob Faibussowitsch   default:
806d71ae5a4SJacob Faibussowitsch     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
807076ba34aSJunchao Zhang   }
808076ba34aSJunchao Zhang 
809a3f881fbSStefano Zampini   A = product->A;
810a3f881fbSStefano Zampini   B = product->B;
8119566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8129566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
813a3f881fbSStefano Zampini   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
814a3f881fbSStefano Zampini   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
815a3f881fbSStefano Zampini   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
816076ba34aSJunchao Zhang 
8175f80ce2aSJacob Faibussowitsch   PetscCheck(ckok, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Device data structure spptr is empty");
818076ba34aSJunchao Zhang 
8190e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
8200e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
821076ba34aSJunchao Zhang 
822076ba34aSJunchao Zhang   /* TODO: Once KK spgemm implements transpose, we can get rid of the explicit transpose here */
823076ba34aSJunchao Zhang   if (transA) {
8249566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
825076ba34aSJunchao Zhang     transA = false;
826a3f881fbSStefano Zampini   }
827a3f881fbSStefano Zampini 
828076ba34aSJunchao Zhang   if (transB) {
8299566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
830076ba34aSJunchao Zhang     transB = false;
831076ba34aSJunchao Zhang   }
8329566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
8330e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, ckok->csrmat));
8340e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
835866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
836866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(ckok->csrmat)); /* without sort, mat_tests-ex62_14_seqaijkokkos fails */
837e944a159SJunchao Zhang #endif
838866eb059SJunchao Zhang 
8399566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
8409566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(C));
841a3f881fbSStefano Zampini   /* shorter version of MatAssemblyEnd_SeqAIJ */
842a3f881fbSStefano Zampini   c = (Mat_SeqAIJ *)C->data;
8439566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded,%" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
8449566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
8459566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
846a3f881fbSStefano Zampini   c->reallocs         = 0;
847076ba34aSJunchao Zhang   C->info.mallocs     = 0;
848a3f881fbSStefano Zampini   C->info.nz_unneeded = 0;
849a3f881fbSStefano Zampini   C->assembled = C->was_assembled = PETSC_TRUE;
850a3f881fbSStefano Zampini   C->num_ass++;
8513ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
852a3f881fbSStefano Zampini }
853a3f881fbSStefano Zampini 
854d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos(Mat C)
855d71ae5a4SJacob Faibussowitsch {
856076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
857076ba34aSJunchao Zhang   MatProductType               ptype;
858076ba34aSJunchao Zhang   Mat                          A, B;
859076ba34aSJunchao Zhang   bool                         transA, transB;
860076ba34aSJunchao Zhang   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
861076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
862076ba34aSJunchao Zhang   MPI_Comm                     comm;
8630e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB, csrmatC;
864a3f881fbSStefano Zampini 
865a3f881fbSStefano Zampini   PetscFunctionBegin;
866a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
8679566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
8685f80ce2aSJacob Faibussowitsch   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
869a3f881fbSStefano Zampini   A = product->A;
870a3f881fbSStefano Zampini   B = product->B;
8719566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8729566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
873a3f881fbSStefano Zampini   akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
874a3f881fbSStefano Zampini   bkok    = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
8750e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
8760e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
877076ba34aSJunchao Zhang 
878a3f881fbSStefano Zampini   ptype = product->type;
8790e3ece09SJunchao Zhang   // Take advantage of the symmetry if true
8800e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
8810e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
8820e3ece09SJunchao Zhang     product->symbolic_used_the_fact_A_is_symmetric = PETSC_TRUE;
8830e3ece09SJunchao Zhang   }
8840e3ece09SJunchao Zhang   if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
8850e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
8860e3ece09SJunchao Zhang     product->symbolic_used_the_fact_B_is_symmetric = PETSC_TRUE;
8870e3ece09SJunchao Zhang   }
8880e3ece09SJunchao Zhang 
889a3f881fbSStefano Zampini   switch (ptype) {
8909371c9d4SSatish Balay   case MATPRODUCT_AB:
8919371c9d4SSatish Balay     transA = false;
8929371c9d4SSatish Balay     transB = false;
8939371c9d4SSatish Balay     break;
8949371c9d4SSatish Balay   case MATPRODUCT_AtB:
8959371c9d4SSatish Balay     transA = true;
8969371c9d4SSatish Balay     transB = false;
8979371c9d4SSatish Balay     break;
8989371c9d4SSatish Balay   case MATPRODUCT_ABt:
8999371c9d4SSatish Balay     transA = false;
9009371c9d4SSatish Balay     transB = true;
9019371c9d4SSatish Balay     break;
902d71ae5a4SJacob Faibussowitsch   default:
903d71ae5a4SJacob Faibussowitsch     SETERRQ(comm, PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
904a3f881fbSStefano Zampini   }
9050e3ece09SJunchao Zhang   PetscCallCXX(product->data = pdata = new MatProductData_SeqAIJKokkos());
906076ba34aSJunchao Zhang   pdata->reusesym = product->api_user;
907a3f881fbSStefano Zampini 
908076ba34aSJunchao Zhang   /* TODO: add command line options to select spgemm algorithms */
909866eb059SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; /* default alg is TPL if enabled, otherwise KK */
910866eb059SJunchao Zhang 
911866eb059SJunchao Zhang   /* CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 */
912866eb059SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
913866eb059SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
914866eb059SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
915866eb059SJunchao Zhang   #endif
916866eb059SJunchao Zhang #endif
9170e3ece09SJunchao Zhang   PetscCallCXX(pdata->kh.create_spgemm_handle(spgemm_alg));
918076ba34aSJunchao Zhang 
9199566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
920076ba34aSJunchao Zhang   /* TODO: Get rid of the explicit transpose once KK-spgemm implements the transpose option */
921076ba34aSJunchao Zhang   if (transA) {
9229566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
923076ba34aSJunchao Zhang     transA = false;
924076ba34aSJunchao Zhang   }
925076ba34aSJunchao Zhang 
926076ba34aSJunchao Zhang   if (transB) {
9279566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
928076ba34aSJunchao Zhang     transB = false;
929076ba34aSJunchao Zhang   }
930076ba34aSJunchao Zhang 
9310e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
932076ba34aSJunchao Zhang   /* spgemm_symbolic() only populates C's rowmap, but not C's column indices.
933076ba34aSJunchao Zhang     So we have to do a fake spgemm_numeric() here to get csrmatC.j_d setup, before
934076ba34aSJunchao Zhang     calling new Mat_SeqAIJKokkos().
935076ba34aSJunchao Zhang     TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
936076ba34aSJunchao Zhang   */
9370e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
9380e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
939866eb059SJunchao Zhang   /* Query if KK outputs a sorted matrix. If not, we need to sort it */
940866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
941866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(csrmatC)); /* sort_option defaults to -1 in KK!*/
942e944a159SJunchao Zhang #endif
9439566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
944076ba34aSJunchao Zhang 
9459566063dSJacob Faibussowitsch   PetscCallCXX(ckok = new Mat_SeqAIJKokkos(csrmatC));
9469566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(C, ckok));
947076ba34aSJunchao Zhang   C->product->destroy = MatProductDataDestroy_SeqAIJKokkos;
9483ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
949a3f881fbSStefano Zampini }
950a3f881fbSStefano Zampini 
951a3f881fbSStefano Zampini /* handles sparse matrix matrix ops */
952d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSetFromOptions_SeqAIJKokkos(Mat mat)
953d71ae5a4SJacob Faibussowitsch {
954076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
955a3f881fbSStefano Zampini   PetscBool    Biskok = PETSC_FALSE, Ciskok = PETSC_TRUE;
956a3f881fbSStefano Zampini 
957a3f881fbSStefano Zampini   PetscFunctionBegin;
958a3f881fbSStefano Zampini   MatCheckProduct(mat, 1);
9599566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJKOKKOS, &Biskok));
96048a46eb9SPierre Jolivet   if (product->type == MATPRODUCT_ABC) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJKOKKOS, &Ciskok));
961a3f881fbSStefano Zampini   if (Biskok && Ciskok) {
962a3f881fbSStefano Zampini     switch (product->type) {
963a3f881fbSStefano Zampini     case MATPRODUCT_AB:
964a3f881fbSStefano Zampini     case MATPRODUCT_AtB:
965d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABt:
966d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos;
967d71ae5a4SJacob Faibussowitsch       break;
968a3f881fbSStefano Zampini     case MATPRODUCT_PtAP:
969a3f881fbSStefano Zampini     case MATPRODUCT_RARt:
970d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABC:
971d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
972d71ae5a4SJacob Faibussowitsch       break;
973d71ae5a4SJacob Faibussowitsch     default:
974d71ae5a4SJacob Faibussowitsch       break;
975a3f881fbSStefano Zampini     }
976a3f881fbSStefano Zampini   } else { /* fallback for AIJ */
9779566063dSJacob Faibussowitsch     PetscCall(MatProductSetFromOptions_SeqAIJ(mat));
978a3f881fbSStefano Zampini   }
9793ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
980a3f881fbSStefano Zampini }
981a587d139SMark 
982d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatScale_SeqAIJKokkos(Mat A, PetscScalar a)
983d71ae5a4SJacob Faibussowitsch {
984f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok;
985f0cf5187SStefano Zampini 
986f0cf5187SStefano Zampini   PetscFunctionBegin;
9879566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
9889566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
989f0cf5187SStefano Zampini   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
990076ba34aSJunchao Zhang   KokkosBlas::scal(aijkok->a_dual.view_device(), a, aijkok->a_dual.view_device());
9919566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
9929566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(aijkok->a_dual.extent(0)));
9939566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
9943ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
995f0cf5187SStefano Zampini }
996f0cf5187SStefano Zampini 
997f4747e26SJunchao Zhang // add a to A's diagonal (if A is square) or main diagonal (if A is rectangular)
998f4747e26SJunchao Zhang static PetscErrorCode MatShift_SeqAIJKokkos(Mat A, PetscScalar a)
999f4747e26SJunchao Zhang {
1000f4747e26SJunchao Zhang   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1001f4747e26SJunchao Zhang 
1002f4747e26SJunchao Zhang   PetscFunctionBegin;
1003f4747e26SJunchao Zhang   if (A->assembled && aijseq->diagonaldense) { // no missing diagonals
1004f4747e26SJunchao Zhang     PetscInt n = PetscMin(A->rmap->n, A->cmap->n);
1005f4747e26SJunchao Zhang 
1006f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
1007f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosSyncDevice(A));
1008f4747e26SJunchao Zhang     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1009f4747e26SJunchao Zhang     const auto &Aa     = aijkok->a_dual.view_device();
1010f4747e26SJunchao Zhang     const auto &Adiag  = aijkok->diag_dual.view_device();
1011f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(
1012f4747e26SJunchao Zhang       n, KOKKOS_LAMBDA(const PetscInt i) { Aa(Adiag(i)) += a; }));
1013f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosModifyDevice(A));
1014f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(n));
1015f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
1016f4747e26SJunchao Zhang   } else { // need reassembly, very slow!
1017f4747e26SJunchao Zhang     PetscCall(MatShift_Basic(A, a));
1018f4747e26SJunchao Zhang   }
1019f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1020f4747e26SJunchao Zhang }
1021f4747e26SJunchao Zhang 
1022f4747e26SJunchao Zhang static PetscErrorCode MatDiagonalSet_SeqAIJKokkos(Mat Y, Vec D, InsertMode is)
1023f4747e26SJunchao Zhang {
1024f4747e26SJunchao Zhang   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(Y->data);
1025f4747e26SJunchao Zhang 
1026f4747e26SJunchao Zhang   PetscFunctionBegin;
1027f4747e26SJunchao Zhang   if (Y->assembled && aijseq->diagonaldense) { // no missing diagonals
1028f4747e26SJunchao Zhang     ConstPetscScalarKokkosView dv;
1029f4747e26SJunchao Zhang     PetscInt                   n, nv;
1030f4747e26SJunchao Zhang 
1031f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
1032f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosSyncDevice(Y));
1033f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(D, &dv));
1034f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(D, &nv));
1035f4747e26SJunchao Zhang     n = PetscMin(Y->rmap->n, Y->cmap->n);
1036f4747e26SJunchao Zhang     PetscCheck(n == nv, PetscObjectComm((PetscObject)Y), PETSC_ERR_ARG_SIZ, "Matrix size and vector size do not match");
1037f4747e26SJunchao Zhang 
1038f4747e26SJunchao Zhang     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1039f4747e26SJunchao Zhang     const auto &Aa     = aijkok->a_dual.view_device();
1040f4747e26SJunchao Zhang     const auto &Adiag  = aijkok->diag_dual.view_device();
1041f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(
1042f4747e26SJunchao Zhang       n, KOKKOS_LAMBDA(const PetscInt i) {
1043f4747e26SJunchao Zhang         if (is == INSERT_VALUES) Aa(Adiag(i)) = dv(i);
1044f4747e26SJunchao Zhang         else Aa(Adiag(i)) += dv(i);
1045f4747e26SJunchao Zhang       }));
1046f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(D, &dv));
1047f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1048f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(n));
1049f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
1050f4747e26SJunchao Zhang   } else { // need reassembly, very slow!
1051f4747e26SJunchao Zhang     PetscCall(MatDiagonalSet_Default(Y, D, is));
1052f4747e26SJunchao Zhang   }
1053f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1054f4747e26SJunchao Zhang }
1055f4747e26SJunchao Zhang 
1056f4747e26SJunchao Zhang static PetscErrorCode MatDiagonalScale_SeqAIJKokkos(Mat A, Vec ll, Vec rr)
1057f4747e26SJunchao Zhang {
1058f4747e26SJunchao Zhang   Mat_SeqAIJ                *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1059f4747e26SJunchao Zhang   PetscInt                   m = A->rmap->n, n = A->cmap->n, nz = aijseq->nz;
1060f4747e26SJunchao Zhang   ConstPetscScalarKokkosView lv, rv;
1061f4747e26SJunchao Zhang 
1062f4747e26SJunchao Zhang   PetscFunctionBegin;
1063f4747e26SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1064f4747e26SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1065f4747e26SJunchao Zhang   const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1066f4747e26SJunchao Zhang   const auto &Aa     = aijkok->a_dual.view_device();
1067f4747e26SJunchao Zhang   const auto &Ai     = aijkok->i_dual.view_device();
1068f4747e26SJunchao Zhang   const auto &Aj     = aijkok->j_dual.view_device();
1069f4747e26SJunchao Zhang   if (ll) {
1070f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(ll, &m));
1071f4747e26SJunchao Zhang     PetscCheck(m == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Left scaling vector wrong length");
1072f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(ll, &lv));
1073f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for( // for each row
1074f4747e26SJunchao Zhang       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1075f4747e26SJunchao Zhang         PetscInt i   = t.league_rank(); // row i
1076f4747e26SJunchao Zhang         PetscInt len = Ai(i + 1) - Ai(i);
1077f4747e26SJunchao Zhang         // scale entries on the row
1078f4747e26SJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt j) { Aa(Ai(i) + j) *= lv(i); });
1079f4747e26SJunchao Zhang       }));
1080f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(ll, &lv));
1081f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(nz));
1082f4747e26SJunchao Zhang   }
1083f4747e26SJunchao Zhang   if (rr) {
1084f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(rr, &n));
1085f4747e26SJunchao Zhang     PetscCheck(n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Right scaling vector wrong length");
1086f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(rr, &rv));
1087f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for( // for each nonzero
1088f4747e26SJunchao Zhang       nz, KOKKOS_LAMBDA(const PetscInt k) { Aa(k) *= rv(Aj(k)); }));
1089f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(rr, &lv));
1090f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(nz));
1091f4747e26SJunchao Zhang   }
1092f4747e26SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(A));
1093f4747e26SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1094f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1095f4747e26SJunchao Zhang }
1096f4747e26SJunchao Zhang 
1097d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatZeroEntries_SeqAIJKokkos(Mat A)
1098d71ae5a4SJacob Faibussowitsch {
1099076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1100a587d139SMark 
1101a587d139SMark   PetscFunctionBegin;
1102076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
11032328674fSJunchao Zhang   if (aijkok) { /* Only zero the device if data is already there */
1104076ba34aSJunchao Zhang     KokkosBlas::fill(aijkok->a_dual.view_device(), 0.0);
11059566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A));
11062328674fSJunchao Zhang   } else { /* Might be preallocated but not assembled */
11079566063dSJacob Faibussowitsch     PetscCall(MatZeroEntries_SeqAIJ(A));
11082328674fSJunchao Zhang   }
11093ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1110a587d139SMark }
1111a587d139SMark 
1112d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatGetDiagonal_SeqAIJKokkos(Mat A, Vec x)
1113d71ae5a4SJacob Faibussowitsch {
1114f78ce678SMark Adams   Mat_SeqAIJKokkos     *aijkok;
1115f78ce678SMark Adams   PetscInt              n;
1116f78ce678SMark Adams   PetscScalarKokkosView xv;
1117f78ce678SMark Adams 
1118f78ce678SMark Adams   PetscFunctionBegin;
1119f78ce678SMark Adams   PetscCall(VecGetLocalSize(x, &n));
1120f78ce678SMark Adams   PetscCheck(n == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Nonconforming matrix and vector");
1121f78ce678SMark Adams   PetscCheck(A->factortype == MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_SUP, "MatGetDiagonal_SeqAIJKokkos not supported on factored matrices");
1122f78ce678SMark Adams 
1123f78ce678SMark Adams   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1124f78ce678SMark Adams   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1125f78ce678SMark Adams 
1126f78ce678SMark Adams   const auto &Aa    = aijkok->a_dual.view_device();
1127f78ce678SMark Adams   const auto &Ai    = aijkok->i_dual.view_device();
1128f78ce678SMark Adams   const auto &Adiag = aijkok->diag_dual.view_device();
1129f78ce678SMark Adams 
1130f78ce678SMark Adams   PetscCall(VecGetKokkosViewWrite(x, &xv));
11319371c9d4SSatish Balay   Kokkos::parallel_for(
11329371c9d4SSatish Balay     n, KOKKOS_LAMBDA(const PetscInt i) {
1133f78ce678SMark Adams       if (Adiag(i) < Ai(i + 1)) xv(i) = Aa(Adiag(i));
1134f78ce678SMark Adams       else xv(i) = 0;
1135f78ce678SMark Adams     });
1136f78ce678SMark Adams   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
11373ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1138f78ce678SMark Adams }
1139f78ce678SMark Adams 
1140db78de30SJunchao Zhang /* Get a Kokkos View from a mat of type MatSeqAIJKokkos */
1141d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1142d71ae5a4SJacob Faibussowitsch {
1143db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1144db78de30SJunchao Zhang 
1145db78de30SJunchao Zhang   PetscFunctionBegin;
1146db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11474f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1148db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11499566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1150db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1151076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
11523ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1153db78de30SJunchao Zhang }
1154db78de30SJunchao Zhang 
1155d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1156d71ae5a4SJacob Faibussowitsch {
1157db78de30SJunchao Zhang   PetscFunctionBegin;
1158db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11594f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1160db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11613ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1162db78de30SJunchao Zhang }
1163db78de30SJunchao Zhang 
1164d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, MatScalarKokkosView *kv)
1165d71ae5a4SJacob Faibussowitsch {
1166db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1167db78de30SJunchao Zhang 
1168db78de30SJunchao Zhang   PetscFunctionBegin;
1169db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11704f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1171db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11729566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1173db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1174076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
11753ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1176db78de30SJunchao Zhang }
1177db78de30SJunchao Zhang 
1178d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, MatScalarKokkosView *kv)
1179d71ae5a4SJacob Faibussowitsch {
1180db78de30SJunchao Zhang   PetscFunctionBegin;
1181db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11824f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1183db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11849566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
11853ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1186db78de30SJunchao Zhang }
1187db78de30SJunchao Zhang 
1188d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1189d71ae5a4SJacob Faibussowitsch {
1190db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1191db78de30SJunchao Zhang 
1192db78de30SJunchao Zhang   PetscFunctionBegin;
1193db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11944f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1195db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1196db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1197076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
11983ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1199db78de30SJunchao Zhang }
1200db78de30SJunchao Zhang 
1201d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1202d71ae5a4SJacob Faibussowitsch {
1203db78de30SJunchao Zhang   PetscFunctionBegin;
1204db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12054f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1206db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12079566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
12083ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1209db78de30SJunchao Zhang }
1210db78de30SJunchao Zhang 
1211c17cf699SJunchao Zhang /* Computes Y += alpha X */
1212d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAXPY_SeqAIJKokkos(Mat Y, PetscScalar alpha, Mat X, MatStructure pattern)
1213d71ae5a4SJacob Faibussowitsch {
1214a587d139SMark   Mat_SeqAIJ              *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
1215c17cf699SJunchao Zhang   Mat_SeqAIJKokkos        *xkok, *ykok, *zkok;
1216c17cf699SJunchao Zhang   ConstMatScalarKokkosView Xa;
1217c17cf699SJunchao Zhang   MatScalarKokkosView      Ya;
1218a587d139SMark 
1219a587d139SMark   PetscFunctionBegin;
1220c17cf699SJunchao Zhang   PetscCheckTypeName(Y, MATSEQAIJKOKKOS);
1221c17cf699SJunchao Zhang   PetscCheckTypeName(X, MATSEQAIJKOKKOS);
12229566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(Y));
12239566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(X));
12249566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1225db78de30SJunchao Zhang 
1226c17cf699SJunchao Zhang   if (pattern != SAME_NONZERO_PATTERN && x->nz == y->nz) {
1227a587d139SMark     PetscBool e;
12289566063dSJacob Faibussowitsch     PetscCall(PetscArraycmp(x->i, y->i, Y->rmap->n + 1, &e));
1229a587d139SMark     if (e) {
12309566063dSJacob Faibussowitsch       PetscCall(PetscArraycmp(x->j, y->j, y->nz, &e));
1231c17cf699SJunchao Zhang       if (e) pattern = SAME_NONZERO_PATTERN;
1232a587d139SMark     }
1233a587d139SMark   }
1234db78de30SJunchao Zhang 
1235c17cf699SJunchao Zhang   /* cusparseDcsrgeam2() computes C = alpha A + beta B. If one knew sparsity pattern of C, one can skip
1236c17cf699SJunchao Zhang     cusparseScsrgeam2_bufferSizeExt() / cusparseXcsrgeam2Nnz(), and directly call cusparseScsrgeam2().
1237c17cf699SJunchao Zhang     If X is SUBSET_NONZERO_PATTERN of Y, we could take advantage of this cusparse feature. However,
1238c17cf699SJunchao Zhang     KokkosSparse::spadd(alpha,A,beta,B,C) has symbolic and numeric phases, MatAXPY does not.
1239c17cf699SJunchao Zhang   */
1240c17cf699SJunchao Zhang   ykok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1241c17cf699SJunchao Zhang   xkok = static_cast<Mat_SeqAIJKokkos *>(X->spptr);
1242c17cf699SJunchao Zhang   Xa   = xkok->a_dual.view_device();
1243c17cf699SJunchao Zhang   Ya   = ykok->a_dual.view_device();
1244c17cf699SJunchao Zhang 
1245c17cf699SJunchao Zhang   if (pattern == SAME_NONZERO_PATTERN) {
1246c17cf699SJunchao Zhang     KokkosBlas::axpy(alpha, Xa, Ya);
12479566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1248c17cf699SJunchao Zhang   } else if (pattern == SUBSET_NONZERO_PATTERN) {
1249c17cf699SJunchao Zhang     MatRowMapKokkosView Xi = xkok->i_dual.view_device(), Yi = ykok->i_dual.view_device();
1250c17cf699SJunchao Zhang     MatColIdxKokkosView Xj = xkok->j_dual.view_device(), Yj = ykok->j_dual.view_device();
1251c17cf699SJunchao Zhang 
12529371c9d4SSatish Balay     Kokkos::parallel_for(
12539371c9d4SSatish Balay       Kokkos::TeamPolicy<>(Y->rmap->n, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
12540e3ece09SJunchao Zhang         PetscInt i = t.league_rank(); // row i
12550e3ece09SJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() {
12560e3ece09SJunchao Zhang           // Only one thread works in a team
1257c17cf699SJunchao Zhang           PetscInt p, q = Yi(i);
12580e3ece09SJunchao Zhang           for (p = Xi(i); p < Xi(i + 1); p++) {          // For each nonzero on row i of X,
12590e3ece09SJunchao Zhang             while (Xj(p) != Yj(q) && q < Yi(i + 1)) q++; // find the matching nonzero on row i of Y.
12600e3ece09SJunchao Zhang             if (Xj(p) == Yj(q)) {                        // Found it
1261c17cf699SJunchao Zhang               Ya(q) += alpha * Xa(p);
1262c17cf699SJunchao Zhang               q++;
1263a587d139SMark             } else {
12640e3ece09SJunchao Zhang             // If not found, it indicates the input is wrong (X is not a SUBSET_NONZERO_PATTERN of Y).
12650e3ece09SJunchao Zhang             // Just insert a NaN at the beginning of row i if it is not empty, to make the result wrong.
12660e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_VERSION_GE(3, 7, 0)
12670e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::ArithTraits<PetscScalar>::nan();
12688b8b16f9SJunchao Zhang #else
12690e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::Experimental::nan("1");
12708b8b16f9SJunchao Zhang #endif
1271a587d139SMark             }
1272c17cf699SJunchao Zhang           }
1273c17cf699SJunchao Zhang         });
1274c17cf699SJunchao Zhang       });
12759566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
12760e3ece09SJunchao Zhang   } else { // different nonzero patterns
1277c17cf699SJunchao Zhang     Mat             Z;
1278c17cf699SJunchao Zhang     KokkosCsrMatrix zcsr;
1279c17cf699SJunchao Zhang     KernelHandle    kh;
12800e3ece09SJunchao Zhang     kh.create_spadd_handle(true); // X, Y are sorted
1281c17cf699SJunchao Zhang     KokkosSparse::spadd_symbolic(&kh, xkok->csrmat, ykok->csrmat, zcsr);
1282c17cf699SJunchao Zhang     KokkosSparse::spadd_numeric(&kh, alpha, xkok->csrmat, (PetscScalar)1.0, ykok->csrmat, zcsr);
1283c17cf699SJunchao Zhang     zkok = new Mat_SeqAIJKokkos(zcsr);
12849566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, zkok, &Z));
12859566063dSJacob Faibussowitsch     PetscCall(MatHeaderReplace(Y, &Z));
1286c17cf699SJunchao Zhang     kh.destroy_spadd_handle();
1287c17cf699SJunchao Zhang   }
12889566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
12890e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(xkok->a_dual.extent(0) * 2)); // Because we scaled X and then added it to Y
12903ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1291a587d139SMark }
1292a587d139SMark 
12932c4ab24aSJunchao Zhang struct MatCOOStruct_SeqAIJKokkos {
12942c4ab24aSJunchao Zhang   PetscCount           n;
12952c4ab24aSJunchao Zhang   PetscCount           Atot;
12962c4ab24aSJunchao Zhang   PetscInt             nz;
12972c4ab24aSJunchao Zhang   PetscCountKokkosView jmap;
12982c4ab24aSJunchao Zhang   PetscCountKokkosView perm;
12992c4ab24aSJunchao Zhang 
13002c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos(const MatCOOStruct_SeqAIJ *coo_h)
13012c4ab24aSJunchao Zhang   {
13022c4ab24aSJunchao Zhang     nz   = coo_h->nz;
13032c4ab24aSJunchao Zhang     n    = coo_h->n;
13042c4ab24aSJunchao Zhang     Atot = coo_h->Atot;
13052c4ab24aSJunchao Zhang     jmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->jmap, nz + 1));
13062c4ab24aSJunchao Zhang     perm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->perm, Atot));
13072c4ab24aSJunchao Zhang   }
13082c4ab24aSJunchao Zhang };
13092c4ab24aSJunchao Zhang 
13102c4ab24aSJunchao Zhang static PetscErrorCode MatCOOStructDestroy_SeqAIJKokkos(void *data)
13112c4ab24aSJunchao Zhang {
13122c4ab24aSJunchao Zhang   PetscFunctionBegin;
13132c4ab24aSJunchao Zhang   PetscCallCXX(delete static_cast<MatCOOStruct_SeqAIJKokkos *>(data));
13142c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
13152c4ab24aSJunchao Zhang }
13162c4ab24aSJunchao Zhang 
1317d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_SeqAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1318d71ae5a4SJacob Faibussowitsch {
131942550becSJunchao Zhang   Mat_SeqAIJKokkos          *akok;
132042550becSJunchao Zhang   Mat_SeqAIJ                *aseq;
13212c4ab24aSJunchao Zhang   PetscContainer             container_h, container_d;
13222c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJ       *coo_h;
13232c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo_d;
132442550becSJunchao Zhang 
132542550becSJunchao Zhang   PetscFunctionBegin;
13269566063dSJacob Faibussowitsch   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, coo_i, coo_j));
1327394ed5ebSJunchao Zhang   aseq = static_cast<Mat_SeqAIJ *>(mat->data);
132842550becSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(mat->spptr);
1329cbc6b225SStefano Zampini   delete akok;
1330f4747e26SJunchao Zhang   mat->spptr = akok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, aseq, mat->nonzerostate + 1, PETSC_FALSE);
13319566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries_SeqAIJKokkos(mat));
13322c4ab24aSJunchao Zhang 
13332c4ab24aSJunchao Zhang   // Copy the COO struct to device
13342c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
13352c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
13362c4ab24aSJunchao Zhang   PetscCallCXX(coo_d = new MatCOOStruct_SeqAIJKokkos(coo_h));
13372c4ab24aSJunchao Zhang 
13382c4ab24aSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
13392c4ab24aSJunchao Zhang   PetscCall(PetscContainerCreate(PETSC_COMM_SELF, &container_d));
13402c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetPointer(container_d, coo_d));
13412c4ab24aSJunchao Zhang   PetscCall(PetscContainerSetUserDestroy(container_d, MatCOOStructDestroy_SeqAIJKokkos));
13422c4ab24aSJunchao Zhang   PetscCall(PetscObjectCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject)container_d));
13432c4ab24aSJunchao Zhang   PetscCall(PetscContainerDestroy(&container_d));
13443ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
134542550becSJunchao Zhang }
134642550becSJunchao Zhang 
1347d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_SeqAIJKokkos(Mat A, const PetscScalar v[], InsertMode imode)
1348d71ae5a4SJacob Faibussowitsch {
134942550becSJunchao Zhang   MatScalarKokkosView        Aa;
135042550becSJunchao Zhang   ConstMatScalarKokkosView   kv;
135142550becSJunchao Zhang   PetscMemType               memtype;
13522c4ab24aSJunchao Zhang   PetscContainer             container;
13532c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo;
135442550becSJunchao Zhang 
135542550becSJunchao Zhang   PetscFunctionBegin;
13562c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)A, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
13572c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
13582c4ab24aSJunchao Zhang 
13592c4ab24aSJunchao Zhang   const auto &n    = coo->n;
13602c4ab24aSJunchao Zhang   const auto &Annz = coo->nz;
13612c4ab24aSJunchao Zhang   const auto &jmap = coo->jmap;
13622c4ab24aSJunchao Zhang   const auto &perm = coo->perm;
13632c4ab24aSJunchao Zhang 
13649566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype));
136542550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
13662c4ab24aSJunchao Zhang     kv = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatScalarKokkosViewHost(v, n));
136742550becSJunchao Zhang   } else {
13682c4ab24aSJunchao Zhang     kv = ConstMatScalarKokkosView(v, n); /* Directly use v[]'s memory */
136942550becSJunchao Zhang   }
137042550becSJunchao Zhang 
1371c7b718f4SJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1372c7b718f4SJunchao Zhang   else PetscCall(MatSeqAIJGetKokkosView(A, &Aa));                             /* read & write matrix values */
137342550becSJunchao Zhang 
137408bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
13759371c9d4SSatish Balay   Kokkos::parallel_for(
13769371c9d4SSatish Balay     Annz, KOKKOS_LAMBDA(const PetscCount i) {
1377c7b718f4SJunchao Zhang       PetscScalar sum = 0.0;
1378c7b718f4SJunchao Zhang       for (PetscCount k = jmap(i); k < jmap(i + 1); k++) sum += kv(perm(k));
1379c7b718f4SJunchao Zhang       Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1380c7b718f4SJunchao Zhang     });
138108bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1382394ed5ebSJunchao Zhang 
13839566063dSJacob Faibussowitsch   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
13849566063dSJacob Faibussowitsch   else PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
13853ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
138642550becSJunchao Zhang }
138742550becSJunchao Zhang 
1388d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1389d71ae5a4SJacob Faibussowitsch {
13908f7e8f9dSMark Adams   PetscFunctionBegin;
13919566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncHost(A));
13929566063dSJacob Faibussowitsch   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
13938f7e8f9dSMark Adams   B->offloadmask = PETSC_OFFLOAD_CPU;
13943ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
13958f7e8f9dSMark Adams }
13968f7e8f9dSMark Adams 
1397d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat A)
1398d71ae5a4SJacob Faibussowitsch {
1399076ba34aSJunchao Zhang   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
1400076ba34aSJunchao Zhang 
14018c3ff71bSJunchao Zhang   PetscFunctionBegin;
1402076ba34aSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_KOKKOS; /* We do not really use this flag */
14036f3d89d0SStefano Zampini   A->boundtocpu  = PETSC_FALSE;
14046f3d89d0SStefano Zampini 
14058c3ff71bSJunchao Zhang   A->ops->assemblyend               = MatAssemblyEnd_SeqAIJKokkos;
14068c3ff71bSJunchao Zhang   A->ops->destroy                   = MatDestroy_SeqAIJKokkos;
14078c3ff71bSJunchao Zhang   A->ops->duplicate                 = MatDuplicate_SeqAIJKokkos;
1408a587d139SMark   A->ops->axpy                      = MatAXPY_SeqAIJKokkos;
1409f0cf5187SStefano Zampini   A->ops->scale                     = MatScale_SeqAIJKokkos;
1410a587d139SMark   A->ops->zeroentries               = MatZeroEntries_SeqAIJKokkos;
1411076ba34aSJunchao Zhang   A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJKokkos;
14128c3ff71bSJunchao Zhang   A->ops->mult                      = MatMult_SeqAIJKokkos;
14138c3ff71bSJunchao Zhang   A->ops->multadd                   = MatMultAdd_SeqAIJKokkos;
14148c3ff71bSJunchao Zhang   A->ops->multtranspose             = MatMultTranspose_SeqAIJKokkos;
14158c3ff71bSJunchao Zhang   A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJKokkos;
14168c3ff71bSJunchao Zhang   A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJKokkos;
14178c3ff71bSJunchao Zhang   A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJKokkos;
1418076ba34aSJunchao Zhang   A->ops->productnumeric            = MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos;
14190ecb592aSJunchao Zhang   A->ops->transpose                 = MatTranspose_SeqAIJKokkos;
1420152b3e56SJunchao Zhang   A->ops->setoption                 = MatSetOption_SeqAIJKokkos;
1421f78ce678SMark Adams   A->ops->getdiagonal               = MatGetDiagonal_SeqAIJKokkos;
1422f4747e26SJunchao Zhang   A->ops->shift                     = MatShift_SeqAIJKokkos;
1423f4747e26SJunchao Zhang   A->ops->diagonalset               = MatDiagonalSet_SeqAIJKokkos;
1424f4747e26SJunchao Zhang   A->ops->diagonalscale             = MatDiagonalScale_SeqAIJKokkos;
1425076ba34aSJunchao Zhang   a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJKokkos;
1426076ba34aSJunchao Zhang   a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJKokkos;
1427076ba34aSJunchao Zhang   a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJKokkos;
1428076ba34aSJunchao Zhang   a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJKokkos;
1429076ba34aSJunchao Zhang   a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJKokkos;
1430076ba34aSJunchao Zhang   a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJKokkos;
14317ee59b9bSJunchao Zhang   a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJKokkos;
143242550becSJunchao Zhang 
14339566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJKokkos));
14349566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJKokkos));
14353ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1436076ba34aSJunchao Zhang }
1437076ba34aSJunchao Zhang 
14389d13fa56SJunchao Zhang /*
14399d13fa56SJunchao Zhang    Extract the (prescribled) diagonal blocks of the matrix and then invert them
14409d13fa56SJunchao Zhang 
14419d13fa56SJunchao Zhang   Input Parameters:
14429d13fa56SJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
14439d13fa56SJunchao Zhang .  bs      - block sizes in 'csr' format, i.e., the i-th block has size bs(i+1) - bs(i)
14449d13fa56SJunchao Zhang .  bs2     - square of block sizes in 'csr' format, i.e., the i-th block should be stored at offset bs2(i) in diagVal[]
14459d13fa56SJunchao Zhang .  blkMap  - map row ids to block ids, i.e., row i belongs to the block blkMap(i)
14469d13fa56SJunchao Zhang -  work    - a pre-allocated work buffer (as big as diagVal) for use by this routine
14479d13fa56SJunchao Zhang 
14489d13fa56SJunchao Zhang   Output Parameter:
14499d13fa56SJunchao Zhang .  diagVal - the (pre-allocated) buffer to store the inverted blocks (each block is stored in column-major order)
14509d13fa56SJunchao Zhang */
14519d13fa56SJunchao Zhang PETSC_INTERN PetscErrorCode MatInvertVariableBlockDiagonal_SeqAIJKokkos(Mat A, const PetscIntKokkosView &bs, const PetscIntKokkosView &bs2, const PetscIntKokkosView &blkMap, PetscScalarKokkosView &work, PetscScalarKokkosView &diagVal)
14529d13fa56SJunchao Zhang {
14539d13fa56SJunchao Zhang   Mat_SeqAIJKokkos *akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
14549d13fa56SJunchao Zhang   PetscInt          nblocks = bs.extent(0) - 1;
14559d13fa56SJunchao Zhang 
14569d13fa56SJunchao Zhang   PetscFunctionBegin;
14579d13fa56SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A)); // Since we'll access A's value on device
14589d13fa56SJunchao Zhang 
14599d13fa56SJunchao Zhang   // Pull out the diagonal blocks of the matrix and then invert the blocks
14609d13fa56SJunchao Zhang   auto Aa    = akok->a_dual.view_device();
14619d13fa56SJunchao Zhang   auto Ai    = akok->i_dual.view_device();
14629d13fa56SJunchao Zhang   auto Aj    = akok->j_dual.view_device();
14639d13fa56SJunchao Zhang   auto Adiag = akok->diag_dual.view_device();
14649d13fa56SJunchao Zhang   // TODO: how to tune the team size?
14659d13fa56SJunchao Zhang #if defined(KOKKOS_ENABLE_DEFAULT_DEVICE_TYPE_HOST)
14669d13fa56SJunchao Zhang   auto ts = Kokkos::AUTO();
14679d13fa56SJunchao Zhang #else
14689d13fa56SJunchao Zhang   auto ts         = 16; // improved performance 30% over Kokkos::AUTO() with CUDA, but failed with "Kokkos::abort: Requested Team Size is too large!" on CPUs
14699d13fa56SJunchao Zhang #endif
14709d13fa56SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
14719d13fa56SJunchao Zhang     Kokkos::TeamPolicy<>(nblocks, ts), KOKKOS_LAMBDA(const KokkosTeamMemberType &teamMember) {
14729d13fa56SJunchao Zhang       const PetscInt bid    = teamMember.league_rank();                                                   // block id
14739d13fa56SJunchao Zhang       const PetscInt rstart = bs(bid);                                                                    // this block starts from this row
14749d13fa56SJunchao Zhang       const PetscInt m      = bs(bid + 1) - bs(bid);                                                      // size of this block
14759d13fa56SJunchao Zhang       const auto    &B      = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft>(&diagVal(bs2(bid)), m, m); // column-major order
14769d13fa56SJunchao Zhang       const auto    &W      = PetscScalarKokkosView(&work(bs2(bid)), m * m);
14779d13fa56SJunchao Zhang 
14789d13fa56SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(teamMember, m), [=](const PetscInt &r) { // r-th row in B
14799d13fa56SJunchao Zhang         PetscInt i = rstart + r;                                                            // i-th row in A
14809d13fa56SJunchao Zhang 
14819d13fa56SJunchao Zhang         if (Ai(i) <= Adiag(i) && Adiag(i) < Ai(i + 1)) { // if the diagonal exists (common case)
14829d13fa56SJunchao Zhang           PetscInt first = Adiag(i) - r;                 // we start to check nonzeros from here along this row
14839d13fa56SJunchao Zhang 
14849d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) {                   // walk n steps to see what column indices we will meet
14859d13fa56SJunchao Zhang             if (first + c < Ai(i) || first + c >= Ai(i + 1)) { // this entry (first+c) is out of range of this row, in other words, its value is zero
14869d13fa56SJunchao Zhang               B(r, c) = 0.0;
14879d13fa56SJunchao Zhang             } else if (Aj(first + c) == rstart + c) { // this entry is right on the (rstart+c) column
14889d13fa56SJunchao Zhang               B(r, c) = Aa(first + c);
14899d13fa56SJunchao Zhang             } else { // this entry does not show up in the CSR
14909d13fa56SJunchao Zhang               B(r, c) = 0.0;
14919d13fa56SJunchao Zhang             }
14929d13fa56SJunchao Zhang           }
14939d13fa56SJunchao Zhang         } else { // rare case that the diagonal does not exist
14949d13fa56SJunchao Zhang           const PetscInt begin = Ai(i);
14959d13fa56SJunchao Zhang           const PetscInt end   = Ai(i + 1);
14969d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) B(r, c) = 0.0;
14979d13fa56SJunchao Zhang           for (PetscInt j = begin; j < end; j++) { // scan the whole row; could use binary search but this is a rare case so we did not.
14989d13fa56SJunchao Zhang             if (rstart <= Aj(j) && Aj(j) < rstart + m) B(r, Aj(j) - rstart) = Aa(j);
14999d13fa56SJunchao Zhang             else if (Aj(j) >= rstart + m) break;
15009d13fa56SJunchao Zhang           }
15019d13fa56SJunchao Zhang         }
15029d13fa56SJunchao Zhang       });
15039d13fa56SJunchao Zhang 
15049d13fa56SJunchao Zhang       // LU-decompose B (w/o pivoting) and then invert B
15059d13fa56SJunchao Zhang       KokkosBatched::TeamLU<KokkosTeamMemberType, KokkosBatched::Algo::LU::Unblocked>::invoke(teamMember, B, 0.0);
15069d13fa56SJunchao Zhang       KokkosBatched::TeamInverseLU<KokkosTeamMemberType, KokkosBatched::Algo::InverseLU::Unblocked>::invoke(teamMember, B, W);
15079d13fa56SJunchao Zhang     }));
15089d13fa56SJunchao Zhang   // PetscLogGpuFlops() is done in the caller PCSetUp_VPBJacobi_Kokkos as we don't want to compute the flops in kernels
15099d13fa56SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15109d13fa56SJunchao Zhang }
15119d13fa56SJunchao Zhang 
1512d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSetSeqAIJKokkosWithCSRMatrix(Mat A, Mat_SeqAIJKokkos *akok)
1513d71ae5a4SJacob Faibussowitsch {
1514076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
1515076ba34aSJunchao Zhang   PetscInt    i, m, n;
1516e36ced11SJunchao Zhang   auto       &exec = PetscGetKokkosExecutionSpace();
1517076ba34aSJunchao Zhang 
1518076ba34aSJunchao Zhang   PetscFunctionBegin;
15195f80ce2aSJacob Faibussowitsch   PetscCheck(!A->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A->spptr is supposed to be empty");
1520076ba34aSJunchao Zhang 
1521076ba34aSJunchao Zhang   m = akok->nrows();
1522076ba34aSJunchao Zhang   n = akok->ncols();
15239566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(A, m, n, m, n));
15249566063dSJacob Faibussowitsch   PetscCall(MatSetType(A, MATSEQAIJKOKKOS));
1525076ba34aSJunchao Zhang 
1526076ba34aSJunchao Zhang   /* Set up data structures of A as a MATSEQAIJ */
15279566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(A, MAT_SKIP_ALLOCATION, NULL));
1528076ba34aSJunchao Zhang   aseq = (Mat_SeqAIJ *)(A)->data;
1529076ba34aSJunchao Zhang 
1530e36ced11SJunchao Zhang   PetscCallCXX(akok->i_dual.sync_host(exec)); /* We always need sync'ed i, j on host */
1531e36ced11SJunchao Zhang   PetscCallCXX(akok->j_dual.sync_host(exec));
1532e36ced11SJunchao Zhang   PetscCallCXX(exec.fence());
1533076ba34aSJunchao Zhang 
1534076ba34aSJunchao Zhang   aseq->i            = akok->i_host_data();
1535076ba34aSJunchao Zhang   aseq->j            = akok->j_host_data();
1536076ba34aSJunchao Zhang   aseq->a            = akok->a_host_data();
1537076ba34aSJunchao Zhang   aseq->nonew        = -1; /*this indicates that inserting a new value in the matrix that generates a new nonzero is an error*/
1538076ba34aSJunchao Zhang   aseq->singlemalloc = PETSC_FALSE;
1539076ba34aSJunchao Zhang   aseq->free_a       = PETSC_FALSE;
1540076ba34aSJunchao Zhang   aseq->free_ij      = PETSC_FALSE;
1541076ba34aSJunchao Zhang   aseq->nz           = akok->nnz();
1542076ba34aSJunchao Zhang   aseq->maxnz        = aseq->nz;
1543076ba34aSJunchao Zhang 
15449566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->imax));
15459566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->ilen));
1546ad540459SPierre Jolivet   for (i = 0; i < m; i++) aseq->ilen[i] = aseq->imax[i] = aseq->i[i + 1] - aseq->i[i];
1547076ba34aSJunchao Zhang 
1548076ba34aSJunchao Zhang   /* It is critical to set the nonzerostate, as we use it to check if sparsity pattern (hence data) has changed on host in MatAssemblyEnd */
1549076ba34aSJunchao Zhang   akok->nonzerostate = A->nonzerostate;
1550ff751488SJunchao Zhang   A->spptr           = akok; /* Set A->spptr before MatAssembly so that A->spptr won't be allocated again there */
15519566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
15529566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
15533ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1554076ba34aSJunchao Zhang }
1555076ba34aSJunchao Zhang 
15560e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetKokkosCsrMatrix(Mat A, KokkosCsrMatrix *csr)
15570e3ece09SJunchao Zhang {
15580e3ece09SJunchao Zhang   PetscFunctionBegin;
15590e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
15600e3ece09SJunchao Zhang   *csr = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
15610e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15620e3ece09SJunchao Zhang }
15630e3ece09SJunchao Zhang 
15640e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithKokkosCsrMatrix(MPI_Comm comm, KokkosCsrMatrix csr, Mat *A)
15650e3ece09SJunchao Zhang {
15660e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok;
15674d86920dSPierre Jolivet 
15680e3ece09SJunchao Zhang   PetscFunctionBegin;
15690e3ece09SJunchao Zhang   PetscCallCXX(akok = new Mat_SeqAIJKokkos(csr));
15700e3ece09SJunchao Zhang   PetscCall(MatCreate(comm, A));
15710e3ece09SJunchao Zhang   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
15720e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15730e3ece09SJunchao Zhang }
15740e3ece09SJunchao Zhang 
1575076ba34aSJunchao Zhang /* Crete a SEQAIJKOKKOS matrix with a Mat_SeqAIJKokkos data structure
1576076ba34aSJunchao Zhang 
1577076ba34aSJunchao Zhang    Note we have names like MatSeqAIJSetPreallocationCSR, so I use capitalized CSR
1578076ba34aSJunchao Zhang  */
1579d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithCSRMatrix(MPI_Comm comm, Mat_SeqAIJKokkos *akok, Mat *A)
1580d71ae5a4SJacob Faibussowitsch {
1581076ba34aSJunchao Zhang   PetscFunctionBegin;
15829566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
15839566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
15843ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
15858c3ff71bSJunchao Zhang }
15868c3ff71bSJunchao Zhang 
1587152b3e56SJunchao Zhang /*@C
158811a5261eSBarry Smith   MatCreateSeqAIJKokkos - Creates a sparse matrix in `MATSEQAIJKOKKOS` (compressed row) format
15898c3ff71bSJunchao Zhang   (the default parallel PETSc format). This matrix will ultimately be handled by
159020f4b53cSBarry Smith   Kokkos for calculations.
15918c3ff71bSJunchao Zhang 
15928c3ff71bSJunchao Zhang   Collective
15938c3ff71bSJunchao Zhang 
15948c3ff71bSJunchao Zhang   Input Parameters:
159511a5261eSBarry Smith + comm - MPI communicator, set to `PETSC_COMM_SELF`
15968c3ff71bSJunchao Zhang . m    - number of rows
15978c3ff71bSJunchao Zhang . n    - number of columns
159820f4b53cSBarry Smith . nz   - number of nonzeros per row (same for all rows), ignored if `nnz` is provided
159920f4b53cSBarry Smith - nnz  - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL`
16008c3ff71bSJunchao Zhang 
16018c3ff71bSJunchao Zhang   Output Parameter:
16028c3ff71bSJunchao Zhang . A - the matrix
16038c3ff71bSJunchao Zhang 
16042ef1f0ffSBarry Smith   Level: intermediate
16052ef1f0ffSBarry Smith 
16062ef1f0ffSBarry Smith   Notes:
160711a5261eSBarry Smith   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
16088c3ff71bSJunchao Zhang   MatXXXXSetPreallocation() paradgm instead of this routine directly.
160911a5261eSBarry Smith   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
16108c3ff71bSJunchao Zhang 
161111a5261eSBarry Smith   The AIJ format, also called
16122ef1f0ffSBarry Smith   compressed row storage, is fully compatible with standard Fortran
16138c3ff71bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
161420f4b53cSBarry Smith   either one (as in Fortran) or zero.
16158c3ff71bSJunchao Zhang 
16162ef1f0ffSBarry Smith   Specify the preallocated storage with either `nz` or `nnz` (not both).
16172ef1f0ffSBarry Smith   Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
16182ef1f0ffSBarry Smith   allocation.
16198c3ff71bSJunchao Zhang 
1620fe59aa6dSJacob Faibussowitsch .seealso: [](ch_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
16218c3ff71bSJunchao Zhang @*/
1622d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateSeqAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
1623d71ae5a4SJacob Faibussowitsch {
16248c3ff71bSJunchao Zhang   PetscFunctionBegin;
16259566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
16269566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
16279566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, m, n));
16289566063dSJacob Faibussowitsch   PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
16299566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
16303ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16318c3ff71bSJunchao Zhang }
1632930e68a5SMark Adams 
1633d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1634d71ae5a4SJacob Faibussowitsch {
1635930e68a5SMark Adams   PetscFunctionBegin;
16369566063dSJacob Faibussowitsch   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
163786a27549SJunchao Zhang   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
16383ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
163986a27549SJunchao Zhang }
164086a27549SJunchao Zhang 
1641d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSymbolicSolveCheck(Mat A)
1642d71ae5a4SJacob Faibussowitsch {
164386a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
164486a27549SJunchao Zhang 
164586a27549SJunchao Zhang   PetscFunctionBegin;
164686a27549SJunchao Zhang   if (!factors->sptrsv_symbolic_completed) {
164786a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d);
164886a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d);
164986a27549SJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_TRUE;
165086a27549SJunchao Zhang   }
16513ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
165286a27549SJunchao Zhang }
165386a27549SJunchao Zhang 
165486a27549SJunchao Zhang /* Check if we need to update factors etc for transpose solve */
1655d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosTransposeSolveCheck(Mat A)
1656d71ae5a4SJacob Faibussowitsch {
165786a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1658076ba34aSJunchao Zhang   MatColIdxType               n       = A->rmap->n;
165986a27549SJunchao Zhang 
166086a27549SJunchao Zhang   PetscFunctionBegin;
166186a27549SJunchao Zhang   if (!factors->transpose_updated) { /* TODO: KK needs to provide functions to do numeric transpose only */
166286a27549SJunchao Zhang     /* Update L^T and do sptrsv symbolic */
16637b8d4ba6SJunchao Zhang     factors->iLt_d = MatRowMapKokkosView("factors->iLt_d", n + 1); // KK requires 0
16647b8d4ba6SJunchao Zhang     factors->jLt_d = MatColIdxKokkosView(NoInit("factors->jLt_d"), factors->jL_d.extent(0));
16657b8d4ba6SJunchao Zhang     factors->aLt_d = MatScalarKokkosView(NoInit("factors->aLt_d"), factors->aL_d.extent(0));
166686a27549SJunchao Zhang 
16679371c9d4SSatish Balay     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iL_d, factors->jL_d, factors->aL_d,
166886a27549SJunchao Zhang                                                                                                                                                                                                               factors->iLt_d, factors->jLt_d, factors->aLt_d);
166986a27549SJunchao Zhang 
167086a27549SJunchao Zhang     /* TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices.
167186a27549SJunchao Zhang       We have to sort the indices, until KK provides finer control options.
167286a27549SJunchao Zhang     */
16739371c9d4SSatish Balay     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iLt_d, factors->jLt_d, factors->aLt_d);
167486a27549SJunchao Zhang 
167586a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d);
167686a27549SJunchao Zhang 
167786a27549SJunchao Zhang     /* Update U^T and do sptrsv symbolic */
16787b8d4ba6SJunchao Zhang     factors->iUt_d = MatRowMapKokkosView("factors->iUt_d", n + 1); // KK requires 0
16797b8d4ba6SJunchao Zhang     factors->jUt_d = MatColIdxKokkosView(NoInit("factors->jUt_d"), factors->jU_d.extent(0));
16807b8d4ba6SJunchao Zhang     factors->aUt_d = MatScalarKokkosView(NoInit("factors->aUt_d"), factors->aU_d.extent(0));
168186a27549SJunchao Zhang 
16829371c9d4SSatish Balay     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iU_d, factors->jU_d, factors->aU_d,
168386a27549SJunchao Zhang                                                                                                                                                                                                               factors->iUt_d, factors->jUt_d, factors->aUt_d);
168486a27549SJunchao Zhang 
168586a27549SJunchao Zhang     /* Sort indices. See comments above */
16869371c9d4SSatish Balay     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iUt_d, factors->jUt_d, factors->aUt_d);
168786a27549SJunchao Zhang 
168886a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d);
168986a27549SJunchao Zhang     factors->transpose_updated = PETSC_TRUE;
169086a27549SJunchao Zhang   }
16913ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
169286a27549SJunchao Zhang }
169386a27549SJunchao Zhang 
169486a27549SJunchao Zhang /* Solve Ax = b, with A = LU */
1695d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSolve_SeqAIJKokkos(Mat A, Vec b, Vec x)
1696d71ae5a4SJacob Faibussowitsch {
169786a27549SJunchao Zhang   ConstPetscScalarKokkosView  bv;
169886a27549SJunchao Zhang   PetscScalarKokkosView       xv;
169986a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
170086a27549SJunchao Zhang 
170186a27549SJunchao Zhang   PetscFunctionBegin;
17029566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
17039566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSymbolicSolveCheck(A));
17049566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(b, &bv));
17059566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(x, &xv));
170686a27549SJunchao Zhang   /* Solve L tmpv = b */
17079566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d, bv, factors->workVector));
170886a27549SJunchao Zhang   /* Solve Ux = tmpv */
17099566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, factors->workVector, xv));
17109566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(b, &bv));
17119566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
17129566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
17133ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
171486a27549SJunchao Zhang }
171586a27549SJunchao Zhang 
1716076ba34aSJunchao Zhang /* Solve A^T x = b, where A^T = U^T L^T */
1717d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSolveTranspose_SeqAIJKokkos(Mat A, Vec b, Vec x)
1718d71ae5a4SJacob Faibussowitsch {
171986a27549SJunchao Zhang   ConstPetscScalarKokkosView  bv;
172086a27549SJunchao Zhang   PetscScalarKokkosView       xv;
172186a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
172286a27549SJunchao Zhang 
172386a27549SJunchao Zhang   PetscFunctionBegin;
17249566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
17259566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A));
17269566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(b, &bv));
17279566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(x, &xv));
172886a27549SJunchao Zhang   /* Solve U^T tmpv = b */
172986a27549SJunchao Zhang   KokkosSparse::Experimental::sptrsv_solve(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, bv, factors->workVector);
173086a27549SJunchao Zhang 
173186a27549SJunchao Zhang   /* Solve L^T x = tmpv */
173286a27549SJunchao Zhang   KokkosSparse::Experimental::sptrsv_solve(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d, factors->workVector, xv);
17339566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(b, &bv));
17349566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
17359566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
17363ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
173786a27549SJunchao Zhang }
173886a27549SJunchao Zhang 
1739d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatILUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1740d71ae5a4SJacob Faibussowitsch {
174186a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
174286a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
174386a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
174486a27549SJunchao Zhang 
174586a27549SJunchao Zhang   PetscFunctionBegin;
17469566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
17479566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1748076ba34aSJunchao Zhang 
1749076ba34aSJunchao Zhang   auto a_d = aijkok->a_dual.view_device();
1750076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
1751076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
1752076ba34aSJunchao Zhang 
1753076ba34aSJunchao Zhang   KokkosSparse::Experimental::spiluk_numeric(&factors->kh, fill_lev, i_d, j_d, a_d, factors->iL_d, factors->jL_d, factors->aL_d, factors->iU_d, factors->jU_d, factors->aU_d);
175486a27549SJunchao Zhang 
175586a27549SJunchao Zhang   B->assembled              = PETSC_TRUE;
175686a27549SJunchao Zhang   B->preallocated           = PETSC_TRUE;
175786a27549SJunchao Zhang   B->ops->solve             = MatSolve_SeqAIJKokkos;
175886a27549SJunchao Zhang   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJKokkos;
175986a27549SJunchao Zhang   B->ops->matsolve          = NULL;
176086a27549SJunchao Zhang   B->ops->matsolvetranspose = NULL;
176186a27549SJunchao Zhang   B->offloadmask            = PETSC_OFFLOAD_GPU;
176286a27549SJunchao Zhang 
176386a27549SJunchao Zhang   /* Once the factors' value changed, we need to update their transpose and sptrsv handle */
176486a27549SJunchao Zhang   factors->transpose_updated         = PETSC_FALSE;
176586a27549SJunchao Zhang   factors->sptrsv_symbolic_completed = PETSC_FALSE;
1766eeadb341SJunchao Zhang   /* TODO: log flops, but how to know that? */
17679566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
17683ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
176986a27549SJunchao Zhang }
177086a27549SJunchao Zhang 
1771d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1772d71ae5a4SJacob Faibussowitsch {
177386a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok;
177486a27549SJunchao Zhang   Mat_SeqAIJ                 *b;
177586a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
177686a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
177786a27549SJunchao Zhang   PetscInt                    nnzA     = ((Mat_SeqAIJ *)A->data)->nz, nnzL, nnzU;
177886a27549SJunchao Zhang   PetscInt                    n        = A->rmap->n;
177986a27549SJunchao Zhang 
178086a27549SJunchao Zhang   PetscFunctionBegin;
17819566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
178286a27549SJunchao Zhang   /* Rebuild factors */
17839371c9d4SSatish Balay   if (factors) {
17849371c9d4SSatish Balay     factors->Destroy();
17859371c9d4SSatish Balay   } /* Destroy the old if it exists */
17869371c9d4SSatish Balay   else {
17879371c9d4SSatish Balay     B->spptr = factors = new Mat_SeqAIJKokkosTriFactors(n);
17889371c9d4SSatish Balay   }
178986a27549SJunchao Zhang 
179086a27549SJunchao Zhang   /* Create a spiluk handle and then do symbolic factorization */
179186a27549SJunchao Zhang   nnzL = nnzU = PetscRealIntMultTruncate(info->fill, nnzA);
179286a27549SJunchao Zhang   factors->kh.create_spiluk_handle(KokkosSparse::Experimental::SPILUKAlgorithm::SEQLVLSCHD_TP1, n, nnzL, nnzU);
179386a27549SJunchao Zhang 
179486a27549SJunchao Zhang   auto spiluk_handle = factors->kh.get_spiluk_handle();
179586a27549SJunchao Zhang 
179686a27549SJunchao Zhang   Kokkos::realloc(factors->iL_d, n + 1); /* Free old arrays and realloc */
179786a27549SJunchao Zhang   Kokkos::realloc(factors->jL_d, spiluk_handle->get_nnzL());
179886a27549SJunchao Zhang   Kokkos::realloc(factors->iU_d, n + 1);
179986a27549SJunchao Zhang   Kokkos::realloc(factors->jU_d, spiluk_handle->get_nnzU());
180086a27549SJunchao Zhang 
180186a27549SJunchao Zhang   aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
1802076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
1803076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
1804076ba34aSJunchao Zhang   KokkosSparse::Experimental::spiluk_symbolic(&factors->kh, fill_lev, i_d, j_d, factors->iL_d, factors->jL_d, factors->iU_d, factors->jU_d);
180586a27549SJunchao Zhang   /* TODO: if spiluk_symbolic is asynchronous, do we need to sync before calling get_nnzL()? */
180686a27549SJunchao Zhang 
180786a27549SJunchao Zhang   Kokkos::resize(factors->jL_d, spiluk_handle->get_nnzL()); /* Shrink or expand, and retain old value */
180886a27549SJunchao Zhang   Kokkos::resize(factors->jU_d, spiluk_handle->get_nnzU());
180986a27549SJunchao Zhang   Kokkos::realloc(factors->aL_d, spiluk_handle->get_nnzL()); /* No need to retain old value */
181086a27549SJunchao Zhang   Kokkos::realloc(factors->aU_d, spiluk_handle->get_nnzU());
181186a27549SJunchao Zhang 
181286a27549SJunchao Zhang   /* TODO: add options to select sptrsv algorithms */
181386a27549SJunchao Zhang   /* Create sptrsv handles for L, U and their transpose */
181486a27549SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
181586a27549SJunchao Zhang   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SPTRSV_CUSPARSE;
181686a27549SJunchao Zhang #else
181786a27549SJunchao Zhang   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SEQLVLSCHD_TP1;
181886a27549SJunchao Zhang #endif
181986a27549SJunchao Zhang 
182086a27549SJunchao Zhang   factors->khL.create_sptrsv_handle(sptrsv_alg, n, true /* L is lower tri */);
182186a27549SJunchao Zhang   factors->khU.create_sptrsv_handle(sptrsv_alg, n, false /* U is not lower tri */);
182286a27549SJunchao Zhang   factors->khLt.create_sptrsv_handle(sptrsv_alg, n, false /* L^T is not lower tri */);
182386a27549SJunchao Zhang   factors->khUt.create_sptrsv_handle(sptrsv_alg, n, true /* U^T is lower tri */);
182486a27549SJunchao Zhang 
182586a27549SJunchao Zhang   /* Fill fields of the factor matrix B */
18269566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(B, MAT_SKIP_ALLOCATION, NULL));
182786a27549SJunchao Zhang   b     = (Mat_SeqAIJ *)B->data;
182886a27549SJunchao Zhang   b->nz = b->maxnz          = spiluk_handle->get_nnzL() + spiluk_handle->get_nnzU();
182986a27549SJunchao Zhang   B->info.fill_ratio_given  = info->fill;
1830a1e4e190SStefano Zampini   B->info.fill_ratio_needed = nnzA > 0 ? ((PetscReal)b->nz) / ((PetscReal)nnzA) : 1.0;
183186a27549SJunchao Zhang 
183286a27549SJunchao Zhang   B->offloadmask          = PETSC_OFFLOAD_GPU;
183386a27549SJunchao Zhang   B->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJKokkos;
18343ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1835930e68a5SMark Adams }
1836930e68a5SMark Adams 
1837d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatFactorGetSolverType_SeqAIJKokkos(Mat A, MatSolverType *type)
1838d71ae5a4SJacob Faibussowitsch {
1839930e68a5SMark Adams   PetscFunctionBegin;
1840930e68a5SMark Adams   *type = MATSOLVERKOKKOS;
18413ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1842930e68a5SMark Adams }
1843930e68a5SMark Adams 
1844930e68a5SMark Adams /*MC
184586a27549SJunchao Zhang   MATSOLVERKOKKOS = "Kokkos" - A matrix solver type providing triangular solvers for sequential matrices
184611a5261eSBarry Smith   on a single GPU of type, `MATSEQAIJKOKKOS`, `MATAIJKOKKOS`.
1847930e68a5SMark Adams 
1848930e68a5SMark Adams   Level: beginner
1849930e68a5SMark Adams 
18501cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJKokkos()`, `MATAIJKOKKOS`, `MatKokkosSetFormat()`, `MatKokkosStorageFormat`, `MatKokkosFormatOperation`
1851930e68a5SMark Adams M*/
185286a27549SJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_SeqAIJKokkos_Kokkos(Mat A, MatFactorType ftype, Mat *B) /* MatGetFactor_<MatType>_<MatSolverType> */
1853930e68a5SMark Adams {
1854930e68a5SMark Adams   PetscInt n = A->rmap->n;
1855930e68a5SMark Adams 
1856930e68a5SMark Adams   PetscFunctionBegin;
18579566063dSJacob Faibussowitsch   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B));
18589566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*B, n, n, n, n));
1859930e68a5SMark Adams   (*B)->factortype = ftype;
18609566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
18619566063dSJacob Faibussowitsch   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
1862930e68a5SMark Adams 
18638f7e8f9dSMark Adams   if (ftype == MAT_FACTOR_LU) {
18649566063dSJacob Faibussowitsch     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
186586a27549SJunchao Zhang     (*B)->canuseordering        = PETSC_TRUE;
186686a27549SJunchao Zhang     (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJKokkos;
186786a27549SJunchao Zhang   } else if (ftype == MAT_FACTOR_ILU) {
18689566063dSJacob Faibussowitsch     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
186986a27549SJunchao Zhang     (*B)->canuseordering         = PETSC_FALSE;
187086a27549SJunchao Zhang     (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJKokkos;
187198921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "MatFactorType %s is not supported by MatType SeqAIJKokkos", MatFactorTypes[ftype]);
1872930e68a5SMark Adams 
18739566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
1874*f4f49eeaSPierre Jolivet   PetscCall(PetscObjectComposeFunction((PetscObject)*B, "MatFactorGetSolverType_C", MatFactorGetSolverType_SeqAIJKokkos));
18753ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1876930e68a5SMark Adams }
18778f7e8f9dSMark Adams 
1878d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatSolverTypeRegister_KOKKOS(void)
1879d71ae5a4SJacob Faibussowitsch {
188086a27549SJunchao Zhang   PetscFunctionBegin;
18819566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_SeqAIJKokkos_Kokkos));
18829566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ILU, MatGetFactor_SeqAIJKokkos_Kokkos));
18833ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
188486a27549SJunchao Zhang }
188586a27549SJunchao Zhang 
1886076ba34aSJunchao Zhang /* Utility to print out a KokkosCsrMatrix for debugging */
1887d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PrintCsrMatrix(const KokkosCsrMatrix &csrmat)
1888d71ae5a4SJacob Faibussowitsch {
1889076ba34aSJunchao Zhang   const auto        &iv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.row_map);
1890076ba34aSJunchao Zhang   const auto        &jv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.entries);
1891076ba34aSJunchao Zhang   const auto        &av = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.values);
1892076ba34aSJunchao Zhang   const PetscInt    *i  = iv.data();
1893076ba34aSJunchao Zhang   const PetscInt    *j  = jv.data();
1894076ba34aSJunchao Zhang   const PetscScalar *a  = av.data();
1895076ba34aSJunchao Zhang   PetscInt           m = csrmat.numRows(), n = csrmat.numCols(), nnz = csrmat.nnz();
1896076ba34aSJunchao Zhang 
1897076ba34aSJunchao Zhang   PetscFunctionBegin;
18989566063dSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT " x %" PetscInt_FMT " SeqAIJKokkos, with %" PetscInt_FMT " nonzeros\n", m, n, nnz));
1899076ba34aSJunchao Zhang   for (PetscInt k = 0; k < m; k++) {
19009566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT ": ", k));
190148a46eb9SPierre Jolivet     for (PetscInt p = i[k]; p < i[k + 1]; p++) PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT "(%.1f), ", j[p], (double)PetscRealPart(a[p])));
19029566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "\n"));
1903076ba34aSJunchao Zhang   }
19043ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1905076ba34aSJunchao Zhang }
1906