xref: /petsc/src/mat/impls/aij/seq/kokkos/aijkok.kokkos.cxx (revision bfe80ac4a46d58cb7760074b25f5e81b2f541d8a)
1e36ced11SJunchao Zhang #include <petsc_kokkos.hpp>
211d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
3c0c276a7Ssdargavi #include <petscmat_kokkos.hpp>
4076ba34aSJunchao Zhang #include <petscpkg_version.h>
5152b3e56SJunchao Zhang #include <petsc/private/petscimpl.h>
642550becSJunchao Zhang #include <petsc/private/sfimpl.h>
7aac854edSJunchao Zhang #include <petsc/private/kokkosimpl.hpp>
88c3ff71bSJunchao Zhang #include <petscsystypes.h>
98c3ff71bSJunchao Zhang #include <petscerror.h>
108c3ff71bSJunchao Zhang 
118c3ff71bSJunchao Zhang #include <Kokkos_Core.hpp>
12f0cf5187SStefano Zampini #include <KokkosBlas.hpp>
138c3ff71bSJunchao Zhang #include <KokkosSparse_CrsMatrix.hpp>
14cc6e31f1SJunchao Zhang 
15cc6e31f1SJunchao Zhang // To suppress compiler warnings:
16cc6e31f1SJunchao Zhang // /path/include/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp:434:63:
17cc6e31f1SJunchao Zhang // warning: 'cusparseStatus_t cusparseDbsrmm(cusparseHandle_t, cusparseDirection_t, cusparseOperation_t,
18cc6e31f1SJunchao Zhang // cusparseOperation_t, int, int, int, int, const double*, cusparseMatDescr_t, const double*, const int*, const int*,
19cc6e31f1SJunchao Zhang // int, const double*, int, const double*, double*, int)' is deprecated: please use cusparseSpMM instead [-Wdeprecated-declarations]
20cc6e31f1SJunchao Zhang PETSC_PRAGMA_DIAGNOSTIC_IGNORED_BEGIN("-Wdeprecated-declarations")
218c3ff71bSJunchao Zhang #include <KokkosSparse_spmv.hpp>
22cc6e31f1SJunchao Zhang PETSC_PRAGMA_DIAGNOSTIC_IGNORED_END()
23cc6e31f1SJunchao Zhang 
2486a27549SJunchao Zhang #include <KokkosSparse_spiluk.hpp>
2586a27549SJunchao Zhang #include <KokkosSparse_sptrsv.hpp>
26076ba34aSJunchao Zhang #include <KokkosSparse_spgemm.hpp>
27076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
289d13fa56SJunchao Zhang #include <KokkosBatched_LU_Decl.hpp>
299d13fa56SJunchao Zhang #include <KokkosBatched_InverseLU_Decl.hpp>
3086a27549SJunchao Zhang 
3142550becSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
328c3ff71bSJunchao Zhang 
330e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 7, 0)
34f98996d3SJunchao Zhang   #include <KokkosSparse_Utils.hpp>
35f98996d3SJunchao Zhang using KokkosSparse::sort_crs_matrix;
369371c9d4SSatish Balay using KokkosSparse::Impl::transpose_matrix;
37f98996d3SJunchao Zhang #else
38f98996d3SJunchao Zhang   #include <KokkosKernels_Sorting.hpp>
39f98996d3SJunchao Zhang using KokkosKernels::sort_crs_matrix;
409371c9d4SSatish Balay using KokkosKernels::Impl::transpose_matrix;
41f98996d3SJunchao Zhang #endif
42f98996d3SJunchao Zhang 
43aac854edSJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(4, 6, 0)
44aac854edSJunchao Zhang using KokkosSparse::spiluk_symbolic;
45aac854edSJunchao Zhang using KokkosSparse::spiluk_numeric;
46aac854edSJunchao Zhang using KokkosSparse::sptrsv_symbolic;
47aac854edSJunchao Zhang using KokkosSparse::sptrsv_solve;
48aac854edSJunchao Zhang using KokkosSparse::Experimental::SPTRSVAlgorithm;
49aac854edSJunchao Zhang using KokkosSparse::Experimental::SPILUKAlgorithm;
50aac854edSJunchao Zhang #else
51aac854edSJunchao Zhang using KokkosSparse::Experimental::spiluk_symbolic;
52aac854edSJunchao Zhang using KokkosSparse::Experimental::spiluk_numeric;
53aac854edSJunchao Zhang using KokkosSparse::Experimental::sptrsv_symbolic;
54aac854edSJunchao Zhang using KokkosSparse::Experimental::sptrsv_solve;
55aac854edSJunchao Zhang using KokkosSparse::Experimental::SPTRSVAlgorithm;
56aac854edSJunchao Zhang using KokkosSparse::Experimental::SPILUKAlgorithm;
57aac854edSJunchao Zhang #endif
58aac854edSJunchao Zhang 
598c3ff71bSJunchao Zhang static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat); /* Forward declaration */
608c3ff71bSJunchao Zhang 
61076ba34aSJunchao Zhang /* MatAssemblyEnd_SeqAIJKokkos() happens when we finalized nonzeros of the matrix, either after
62076ba34aSJunchao Zhang    we assembled the matrix on host, or after we directly produced the matrix data on device (ex., through MatMatMult).
63076ba34aSJunchao Zhang    In the latter case, it is important to set a_dual's sync state correctly.
64076ba34aSJunchao Zhang  */
65d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_SeqAIJKokkos(Mat A, MatAssemblyType mode)
66d71ae5a4SJacob Faibussowitsch {
67076ba34aSJunchao Zhang   Mat_SeqAIJ       *aijseq;
68076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
698c3ff71bSJunchao Zhang 
708c3ff71bSJunchao Zhang   PetscFunctionBegin;
713ba16761SJacob Faibussowitsch   if (mode == MAT_FLUSH_ASSEMBLY) PetscFunctionReturn(PETSC_SUCCESS);
729566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
73076ba34aSJunchao Zhang 
74076ba34aSJunchao Zhang   aijseq = static_cast<Mat_SeqAIJ *>(A->data);
75076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
76076ba34aSJunchao Zhang 
77076ba34aSJunchao Zhang   /* If aijkok does not exist, we just copy i, j to device.
78076ba34aSJunchao Zhang      If aijkok already exists, but the device's nonzero pattern does not match with the host's, we assume the latest data is on host.
79076ba34aSJunchao Zhang      In both cases, we build a new aijkok structure.
80076ba34aSJunchao Zhang   */
81076ba34aSJunchao Zhang   if (!aijkok || aijkok->nonzerostate != A->nonzerostate) { /* aijkok might not exist yet or nonzero pattern has changed */
82076ba34aSJunchao Zhang     delete aijkok;
83f4747e26SJunchao Zhang     aijkok   = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aijseq, A->nonzerostate, PETSC_FALSE /*don't copy mat values to device*/);
84076ba34aSJunchao Zhang     A->spptr = aijkok;
85f4747e26SJunchao Zhang   } else if (A->rmap->n && aijkok->diag_dual.extent(0) == 0) { // MatProduct might directly produce AIJ on device, but not the diag.
86f4747e26SJunchao Zhang     MatRowMapKokkosViewHost diag_h(aijseq->diag, A->rmap->n);
87f4747e26SJunchao Zhang     auto                    diag_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), diag_h);
88f4747e26SJunchao Zhang     aijkok->diag_dual              = MatRowMapKokkosDualView(diag_d, diag_h);
89076ba34aSJunchao Zhang   }
903ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
918c3ff71bSJunchao Zhang }
928c3ff71bSJunchao Zhang 
9386a27549SJunchao Zhang /* Sync CSR data to device if not yet */
94d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosSyncDevice(Mat A)
95d71ae5a4SJacob Faibussowitsch {
968c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
978c3ff71bSJunchao Zhang 
988c3ff71bSJunchao Zhang   PetscFunctionBegin;
99aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from host to device");
1005f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
101076ba34aSJunchao Zhang   if (aijkok->a_dual.need_sync_device()) {
102076ba34aSJunchao Zhang     aijkok->a_dual.sync_device();
103580c7c76SPierre Jolivet     aijkok->transpose_updated = PETSC_FALSE; /* values of the transpose is out-of-date */
10486a27549SJunchao Zhang     aijkok->hermitian_updated = PETSC_FALSE;
1058c3ff71bSJunchao Zhang   }
1063ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1078c3ff71bSJunchao Zhang }
1088c3ff71bSJunchao Zhang 
109076ba34aSJunchao Zhang /* Mark the CSR data on device as modified */
110d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosModifyDevice(Mat A)
111d71ae5a4SJacob Faibussowitsch {
11286a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
11386a27549SJunchao Zhang 
11486a27549SJunchao Zhang   PetscFunctionBegin;
1155f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Not supported for factorized matries");
11686a27549SJunchao Zhang   aijkok->a_dual.clear_sync_state();
11786a27549SJunchao Zhang   aijkok->a_dual.modify_device();
11886a27549SJunchao Zhang   aijkok->transpose_updated = PETSC_FALSE;
11986a27549SJunchao Zhang   aijkok->hermitian_updated = PETSC_FALSE;
1209566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJInvalidateDiagonal(A));
1219566063dSJacob Faibussowitsch   PetscCall(PetscObjectStateIncrease((PetscObject)A));
1223ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
12386a27549SJunchao Zhang }
12486a27549SJunchao Zhang 
125d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSyncHost(Mat A)
126d71ae5a4SJacob Faibussowitsch {
127f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1284df4a32cSJunchao Zhang   auto              exec   = PetscGetKokkosExecutionSpace();
129f0cf5187SStefano Zampini 
130f0cf5187SStefano Zampini   PetscFunctionBegin;
131f0cf5187SStefano Zampini   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
13286a27549SJunchao Zhang   /* We do not expect one needs factors on host  */
133aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from device to host");
1345f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing AIJKOK");
135aac854edSJunchao Zhang   PetscCall(KokkosDualViewSync<HostMirrorMemorySpace>(aijkok->a_dual, exec));
1363ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
137f0cf5187SStefano Zampini }
138f0cf5187SStefano Zampini 
139d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
140d71ae5a4SJacob Faibussowitsch {
141076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
142f0cf5187SStefano Zampini 
143f0cf5187SStefano Zampini   PetscFunctionBegin;
1445519a089SJose E. Roman   /* aijkok contains valid pointers only if the host's nonzerostate matches with the device's.
1455519a089SJose E. Roman     Calling MatSeqAIJSetPreallocation() or MatSetValues() on host, where aijseq->{i,j,a} might be
1465519a089SJose E. Roman     reallocated, will lead to stale {i,j,a}_dual in aijkok. In both operations, the hosts's nonzerostate
1475519a089SJose E. Roman     must have been updated. The stale aijkok will be rebuilt during MatAssemblyEnd.
1485519a089SJose E. Roman   */
1495519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
1504df4a32cSJunchao Zhang     auto exec = PetscGetKokkosExecutionSpace();
151e36ced11SJunchao Zhang     PetscCallCXX(aijkok->a_dual.sync_host(exec));
152e36ced11SJunchao Zhang     PetscCallCXX(exec.fence());
153076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
154076ba34aSJunchao Zhang   } else { /* Happens when calling MatSetValues on a newly created matrix */
155076ba34aSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
156076ba34aSJunchao Zhang   }
1573ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
158076ba34aSJunchao Zhang }
159076ba34aSJunchao Zhang 
160d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
161d71ae5a4SJacob Faibussowitsch {
162076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
163076ba34aSJunchao Zhang 
164076ba34aSJunchao Zhang   PetscFunctionBegin;
1655519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) aijkok->a_dual.modify_host();
1663ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
167076ba34aSJunchao Zhang }
168076ba34aSJunchao Zhang 
169d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
170d71ae5a4SJacob Faibussowitsch {
171076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
172076ba34aSJunchao Zhang 
173076ba34aSJunchao Zhang   PetscFunctionBegin;
1745519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
1754df4a32cSJunchao Zhang     auto exec = PetscGetKokkosExecutionSpace();
176e36ced11SJunchao Zhang     PetscCallCXX(aijkok->a_dual.sync_host(exec));
177e36ced11SJunchao Zhang     PetscCallCXX(exec.fence());
178076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1792328674fSJunchao Zhang   } else {
1802328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1812328674fSJunchao Zhang   }
1823ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
183076ba34aSJunchao Zhang }
184076ba34aSJunchao Zhang 
185d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
186d71ae5a4SJacob Faibussowitsch {
187076ba34aSJunchao Zhang   PetscFunctionBegin;
188076ba34aSJunchao Zhang   *array = NULL;
1893ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
190076ba34aSJunchao Zhang }
191076ba34aSJunchao Zhang 
192d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
193d71ae5a4SJacob Faibussowitsch {
194076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
195076ba34aSJunchao Zhang 
196076ba34aSJunchao Zhang   PetscFunctionBegin;
1975519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
198076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1992328674fSJunchao Zhang   } else { /* Ex. happens with MatZeroEntries on a preallocated but not assembled matrix */
2002328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
2012328674fSJunchao Zhang   }
2023ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
203076ba34aSJunchao Zhang }
204076ba34aSJunchao Zhang 
205d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
206d71ae5a4SJacob Faibussowitsch {
207076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
208076ba34aSJunchao Zhang 
209076ba34aSJunchao Zhang   PetscFunctionBegin;
2105519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
211076ba34aSJunchao Zhang     aijkok->a_dual.clear_sync_state();
212076ba34aSJunchao Zhang     aijkok->a_dual.modify_host();
2132328674fSJunchao Zhang   }
2143ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
215f0cf5187SStefano Zampini }
216f0cf5187SStefano Zampini 
217d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJKokkos(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
218d71ae5a4SJacob Faibussowitsch {
2197ee59b9bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
2207ee59b9bSJunchao Zhang 
2217ee59b9bSJunchao Zhang   PetscFunctionBegin;
2227ee59b9bSJunchao Zhang   PetscCheck(aijkok != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "aijkok is NULL");
2237ee59b9bSJunchao Zhang 
2247ee59b9bSJunchao Zhang   if (i) *i = aijkok->i_device_data();
2257ee59b9bSJunchao Zhang   if (j) *j = aijkok->j_device_data();
2267ee59b9bSJunchao Zhang   if (a) {
2277ee59b9bSJunchao Zhang     aijkok->a_dual.sync_device();
2287ee59b9bSJunchao Zhang     *a = aijkok->a_device_data();
2297ee59b9bSJunchao Zhang   }
2307ee59b9bSJunchao Zhang   if (mtype) *mtype = PETSC_MEMTYPE_KOKKOS;
2313ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2327ee59b9bSJunchao Zhang }
2337ee59b9bSJunchao Zhang 
2340e3ece09SJunchao Zhang /*
2350e3ece09SJunchao Zhang   Generate the sparsity pattern of a MatSeqAIJKokkos matrix's transpose on device.
2360e3ece09SJunchao Zhang 
2370e3ece09SJunchao Zhang   Input Parameter:
2380e3ece09SJunchao Zhang .  A       - the MATSEQAIJKOKKOS matrix
2390e3ece09SJunchao Zhang 
2400e3ece09SJunchao Zhang   Output Parameters:
2410e3ece09SJunchao Zhang +  perm_d - the permutation array on device, which connects Ta(i) = Aa(perm(i))
242aaa8cc7dSPierre Jolivet -  T_d    - the transpose on device, whose value array is allocated but not initialized
2430e3ece09SJunchao Zhang */
2440e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateTransposeStructure(Mat A, MatRowMapKokkosView &perm_d, KokkosCsrMatrix &T_d)
245d71ae5a4SJacob Faibussowitsch {
2460e3ece09SJunchao Zhang   Mat_SeqAIJ             *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2470e3ece09SJunchao Zhang   PetscInt                nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
2480e3ece09SJunchao Zhang   const PetscInt         *Ai = aseq->i, *Aj = aseq->j;
2497b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost Ti_h(NoInit("Ti"), n + 1);
2500e3ece09SJunchao Zhang   MatRowMapType          *Ti = Ti_h.data();
2517b8d4ba6SJunchao Zhang   MatColIdxKokkosViewHost Tj_h(NoInit("Tj"), nz);
2527b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost perm_h(NoInit("permutation"), nz);
2530e3ece09SJunchao Zhang   PetscInt               *Tj   = Tj_h.data();
2540e3ece09SJunchao Zhang   PetscInt               *perm = perm_h.data();
2550e3ece09SJunchao Zhang   PetscInt               *offset;
256152b3e56SJunchao Zhang 
257152b3e56SJunchao Zhang   PetscFunctionBegin;
2580e3ece09SJunchao Zhang   // Populate Ti
2590e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::deep_copy(Ti_h, 0));
2600e3ece09SJunchao Zhang   Ti++;
2610e3ece09SJunchao Zhang   for (PetscInt i = 0; i < nz; i++) Ti[Aj[i]]++;
2620e3ece09SJunchao Zhang   Ti--;
2630e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) Ti[i + 1] += Ti[i];
2640e3ece09SJunchao Zhang 
2650e3ece09SJunchao Zhang   // Populate Tj and the permutation array
2660e3ece09SJunchao Zhang   PetscCall(PetscCalloc1(n, &offset)); // offset in each T row to fill in its column indices
2670e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
2680e3ece09SJunchao Zhang     for (PetscInt j = Ai[i]; j < Ai[i + 1]; j++) { // A's (i,j) is T's (j,i)
2690e3ece09SJunchao Zhang       PetscInt r    = Aj[j];                       // row r of T
2700e3ece09SJunchao Zhang       PetscInt disp = Ti[r] + offset[r];
2710e3ece09SJunchao Zhang 
2720e3ece09SJunchao Zhang       Tj[disp]   = i; // col i of T
2730e3ece09SJunchao Zhang       perm[disp] = j;
2740e3ece09SJunchao Zhang       offset[r]++;
275076ba34aSJunchao Zhang     }
2760e3ece09SJunchao Zhang   }
2770e3ece09SJunchao Zhang   PetscCall(PetscFree(offset));
2780e3ece09SJunchao Zhang 
2790e3ece09SJunchao Zhang   // Sort each row of T, along with the permutation array
2800e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) PetscCall(PetscSortIntWithArray(Ti[i + 1] - Ti[i], Tj + Ti[i], perm + Ti[i]));
2810e3ece09SJunchao Zhang 
2820e3ece09SJunchao Zhang   // Output perm and T on device
2830e3ece09SJunchao Zhang   auto Ti_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ti_h);
2840e3ece09SJunchao Zhang   auto Tj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Tj_h);
2850e3ece09SJunchao Zhang   PetscCallCXX(T_d = KokkosCsrMatrix("csrmatT", n, m, nz, MatScalarKokkosView("Ta", nz), Ti_d, Tj_d));
2860e3ece09SJunchao Zhang   PetscCallCXX(perm_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), perm_h));
2873ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
288152b3e56SJunchao Zhang }
289152b3e56SJunchao Zhang 
2900e3ece09SJunchao Zhang // Generate the transpose on device and cache it internally
2910e3ece09SJunchao Zhang // Note: KK transpose_matrix() does not have support symbolic/numeric transpose, so we do it on our own
2920e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGenerateTranspose_Private(Mat A, KokkosCsrMatrix *csrmatT)
293d71ae5a4SJacob Faibussowitsch {
2940e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2950e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
2960e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
2970e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatT;
298152b3e56SJunchao Zhang 
299152b3e56SJunchao Zhang   PetscFunctionBegin;
3000e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
301145b44c9SPierre Jolivet   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's values since we are going to access them on device
3020e3ece09SJunchao Zhang 
3030e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
3040e3ece09SJunchao Zhang 
3050e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE) {
3060e3ece09SJunchao Zhang     *csrmatT = akok->csrmat;
3070e3ece09SJunchao Zhang   } else {
3080e3ece09SJunchao Zhang     // See if we already have a cached transpose and its value is up to date
3090e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
3100e3ece09SJunchao Zhang       if (!akok->transpose_updated) {            // if the value is out of date, update the cached version
3110e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
3120e3ece09SJunchao Zhang         auto       &Ta   = T.values;
3130e3ece09SJunchao Zhang 
314d326c3f1SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = Aa(perm(i)); }));
315076ba34aSJunchao Zhang       }
3160e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
3170e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
3180e3ece09SJunchao Zhang 
3190e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
3200e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
321d326c3f1SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = Aa(perm(i)); }));
3220e3ece09SJunchao Zhang     }
3230e3ece09SJunchao Zhang     akok->transpose_updated = PETSC_TRUE;
3240e3ece09SJunchao Zhang     *csrmatT                = akok->csrmatT;
3250e3ece09SJunchao Zhang   }
3260e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3270e3ece09SJunchao Zhang }
3280e3ece09SJunchao Zhang 
3290e3ece09SJunchao Zhang // Generate the Hermitian on device and cache it internally
3300e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateHermitian_Private(Mat A, KokkosCsrMatrix *csrmatH)
3310e3ece09SJunchao Zhang {
3320e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
3330e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3340e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
3350e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatH;
3360e3ece09SJunchao Zhang 
3370e3ece09SJunchao Zhang   PetscFunctionBegin;
3380e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
3390e3ece09SJunchao Zhang   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's values since we are going to access them on device
3400e3ece09SJunchao Zhang 
3410e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
3420e3ece09SJunchao Zhang 
3430e3ece09SJunchao Zhang   if (A->hermitian == PETSC_BOOL3_TRUE) {
3440e3ece09SJunchao Zhang     *csrmatH = akok->csrmat;
3450e3ece09SJunchao Zhang   } else {
3460e3ece09SJunchao Zhang     // See if we already have a cached hermitian and its value is up to date
3470e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
3480e3ece09SJunchao Zhang       if (!akok->hermitian_updated) {            // if the value is out of date, update the cached version
3490e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
3500e3ece09SJunchao Zhang         auto       &Ta   = T.values;
3510e3ece09SJunchao Zhang 
352d326c3f1SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = PetscConj(Aa(perm(i))); }));
3530e3ece09SJunchao Zhang       }
3540e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
3550e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
3560e3ece09SJunchao Zhang 
3570e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
3580e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
359d326c3f1SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = PetscConj(Aa(perm(i))); }));
3600e3ece09SJunchao Zhang     }
3610e3ece09SJunchao Zhang     akok->hermitian_updated = PETSC_TRUE;
3620e3ece09SJunchao Zhang     *csrmatH                = akok->csrmatH;
3630e3ece09SJunchao Zhang   }
3643ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
365152b3e56SJunchao Zhang }
366a587d139SMark 
3678c3ff71bSJunchao Zhang /* y = A x */
368d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMult_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
369d71ae5a4SJacob Faibussowitsch {
3708c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
371152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
372152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3738c3ff71bSJunchao Zhang 
3748c3ff71bSJunchao Zhang   PetscFunctionBegin;
3759566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3769566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3779566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3789566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
3798c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
380d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), "N", 1.0 /*alpha*/, aijkok->csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A x + beta y */
3819566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3829566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
383076ba34aSJunchao Zhang   /* 2.0*nnz - numRows seems more accurate here but assumes there are no zero-rows. So a little sloppy here. */
3849566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
3859566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
3863ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
3878c3ff71bSJunchao Zhang }
3888c3ff71bSJunchao Zhang 
3898c3ff71bSJunchao Zhang /* y = A^T x */
390d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
391d71ae5a4SJacob Faibussowitsch {
3928c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
393152b3e56SJunchao Zhang   const char                *mode;
394152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
395152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3960e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
3978c3ff71bSJunchao Zhang 
3988c3ff71bSJunchao Zhang   PetscFunctionBegin;
3999566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4009566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4019566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4029566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
403152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4049566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
405152b3e56SJunchao Zhang     mode = "N";
406152b3e56SJunchao Zhang   } else {
407076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4080e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
409152b3e56SJunchao Zhang     mode   = "T";
410152b3e56SJunchao Zhang   }
411d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^T x + beta y */
4129566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4139566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
4140e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4159566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4163ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4178c3ff71bSJunchao Zhang }
4188c3ff71bSJunchao Zhang 
4198c3ff71bSJunchao Zhang /* y = A^H x */
420d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
421d71ae5a4SJacob Faibussowitsch {
4228c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
423152b3e56SJunchao Zhang   const char                *mode;
424152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
425152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
4260e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4278c3ff71bSJunchao Zhang 
4288c3ff71bSJunchao Zhang   PetscFunctionBegin;
4299566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4309566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4319566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4329566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
433152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4349566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
435152b3e56SJunchao Zhang     mode = "N";
436152b3e56SJunchao Zhang   } else {
437076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4380e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
439152b3e56SJunchao Zhang     mode   = "C";
440152b3e56SJunchao Zhang   }
441d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^H x + beta y */
4429566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4439566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
4440e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4459566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4463ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4478c3ff71bSJunchao Zhang }
4488c3ff71bSJunchao Zhang 
4498c3ff71bSJunchao Zhang /* z = A x + y */
450d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
451d71ae5a4SJacob Faibussowitsch {
4528c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
45392896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
454152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4558c3ff71bSJunchao Zhang 
4568c3ff71bSJunchao Zhang   PetscFunctionBegin;
4579566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4589566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
45992896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz)); // depending on yy's sync flags, zz might get its latest data on host
4609566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
46192896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv)); // do after VecCopy(yy, zz) to get the latest data on device
4628c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
463d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), "N", 1.0 /*alpha*/, aijkok->csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A x + beta z */
4649566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
46592896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
4669566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
4679566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4683ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4698c3ff71bSJunchao Zhang }
4708c3ff71bSJunchao Zhang 
4718c3ff71bSJunchao Zhang /* z = A^T x + y */
472d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
473d71ae5a4SJacob Faibussowitsch {
4748c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
475152b3e56SJunchao Zhang   const char                *mode;
47692896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
477152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4780e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4798c3ff71bSJunchao Zhang 
4808c3ff71bSJunchao Zhang   PetscFunctionBegin;
4819566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4829566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
48392896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz));
4849566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
48592896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv));
486152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4879566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
488152b3e56SJunchao Zhang     mode = "N";
489152b3e56SJunchao Zhang   } else {
490076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4910e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
492152b3e56SJunchao Zhang     mode   = "T";
493152b3e56SJunchao Zhang   }
494d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^T x + beta z */
4959566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
49692896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
4970e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4989566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4993ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5008c3ff71bSJunchao Zhang }
5018c3ff71bSJunchao Zhang 
5028c3ff71bSJunchao Zhang /* z = A^H x + y */
503d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
504d71ae5a4SJacob Faibussowitsch {
5058c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
506152b3e56SJunchao Zhang   const char                *mode;
50792896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
508152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
5090e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
5108c3ff71bSJunchao Zhang 
5118c3ff71bSJunchao Zhang   PetscFunctionBegin;
5129566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
5139566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
51492896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz));
5159566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
51692896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv));
517152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
5189566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
519152b3e56SJunchao Zhang     mode = "N";
520152b3e56SJunchao Zhang   } else {
521076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
5220e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
523152b3e56SJunchao Zhang     mode   = "C";
524152b3e56SJunchao Zhang   }
525d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^H x + beta z */
5269566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
52792896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
5280e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
5299566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
5303ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
531152b3e56SJunchao Zhang }
532152b3e56SJunchao Zhang 
53366976f2fSJacob Faibussowitsch static PetscErrorCode MatSetOption_SeqAIJKokkos(Mat A, MatOption op, PetscBool flg)
534d71ae5a4SJacob Faibussowitsch {
535152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
536152b3e56SJunchao Zhang 
537152b3e56SJunchao Zhang   PetscFunctionBegin;
538152b3e56SJunchao Zhang   switch (op) {
539152b3e56SJunchao Zhang   case MAT_FORM_EXPLICIT_TRANSPOSE:
540152b3e56SJunchao Zhang     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
5419566063dSJacob Faibussowitsch     if (A->form_explicit_transpose && !flg && aijkok) PetscCall(aijkok->DestroyMatTranspose());
542152b3e56SJunchao Zhang     A->form_explicit_transpose = flg;
543152b3e56SJunchao Zhang     break;
544d71ae5a4SJacob Faibussowitsch   default:
545d71ae5a4SJacob Faibussowitsch     PetscCall(MatSetOption_SeqAIJ(A, op, flg));
546d71ae5a4SJacob Faibussowitsch     break;
547152b3e56SJunchao Zhang   }
5483ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5498c3ff71bSJunchao Zhang }
5508c3ff71bSJunchao Zhang 
551076ba34aSJunchao Zhang /* Depending on reuse, either build a new mat, or use the existing mat */
552d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
553d71ae5a4SJacob Faibussowitsch {
554076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
5558c3ff71bSJunchao Zhang 
5568c3ff71bSJunchao Zhang   PetscFunctionBegin;
5579566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
558076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {                      /* Build a brand new mat */
5599566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));  /* the returned newmat is a SeqAIJKokkos */
5608c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {                 /* Reuse the mat created before */
5619566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); /* newmat is already a SeqAIJKokkos */
562076ba34aSJunchao Zhang   } else if (reuse == MAT_INPLACE_MATRIX) {               /* newmat is A */
5635f80ce2aSJacob Faibussowitsch     PetscCheck(A == *newmat, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "A != *newmat with MAT_INPLACE_MATRIX");
5649566063dSJacob Faibussowitsch     PetscCall(PetscFree(A->defaultvectype));
5659566063dSJacob Faibussowitsch     PetscCall(PetscStrallocpy(VECKOKKOS, &A->defaultvectype)); /* Allocate and copy the string */
5669566063dSJacob Faibussowitsch     PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSEQAIJKOKKOS));
5679566063dSJacob Faibussowitsch     PetscCall(MatSetOps_SeqAIJKokkos(A));
568076ba34aSJunchao Zhang     aseq = static_cast<Mat_SeqAIJ *>(A->data);
569394ed5ebSJunchao Zhang     if (A->assembled) { /* Copy i, j (but not values) to device for an assembled matrix if not yet */
5705f80ce2aSJacob Faibussowitsch       PetscCheck(!A->spptr, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Expect NULL (Mat_SeqAIJKokkos*)A->spptr");
571f4747e26SJunchao Zhang       A->spptr = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aseq, A->nonzerostate, PETSC_FALSE);
5728c3ff71bSJunchao Zhang     }
573076ba34aSJunchao Zhang   }
5743ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5758c3ff71bSJunchao Zhang }
5768c3ff71bSJunchao Zhang 
577076ba34aSJunchao Zhang /* MatDuplicate always creates a new matrix. MatDuplicate can be called either on an assembled matrix or
578076ba34aSJunchao Zhang    an unassembled matrix, even though MAT_COPY_VALUES is not allowed for unassembled matrix.
579076ba34aSJunchao Zhang  */
580d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDuplicate_SeqAIJKokkos(Mat A, MatDuplicateOption dupOption, Mat *B)
581d71ae5a4SJacob Faibussowitsch {
582076ba34aSJunchao Zhang   Mat_SeqAIJ       *bseq;
583076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok;
584076ba34aSJunchao Zhang   Mat               mat;
5858c3ff71bSJunchao Zhang 
5868c3ff71bSJunchao Zhang   PetscFunctionBegin;
587076ba34aSJunchao Zhang   /* Do not copy values on host as A's latest values might be on device. We don't want to do sync blindly */
5889566063dSJacob Faibussowitsch   PetscCall(MatDuplicate_SeqAIJ(A, MAT_DO_NOT_COPY_VALUES, B));
589076ba34aSJunchao Zhang   mat = *B;
590f4747e26SJunchao Zhang   if (A->assembled) {
591076ba34aSJunchao Zhang     bseq = static_cast<Mat_SeqAIJ *>(mat->data);
592f4747e26SJunchao Zhang     bkok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, bseq, mat->nonzerostate, PETSC_FALSE);
593076ba34aSJunchao Zhang     bkok->a_dual.clear_sync_state(); /* Clear B's sync state as it will be decided below */
594076ba34aSJunchao Zhang     /* Now copy values to B if needed */
595076ba34aSJunchao Zhang     if (dupOption == MAT_COPY_VALUES) {
596076ba34aSJunchao Zhang       if (akok->a_dual.need_sync_device()) {
597076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_host(), akok->a_dual.view_host());
598076ba34aSJunchao Zhang         bkok->a_dual.modify_host();
599076ba34aSJunchao Zhang       } else { /* If device has the latest data, we only copy data on device */
600076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_device(), akok->a_dual.view_device());
601076ba34aSJunchao Zhang         bkok->a_dual.modify_device();
602076ba34aSJunchao Zhang       }
603076ba34aSJunchao Zhang     } else { /* MAT_DO_NOT_COPY_VALUES or MAT_SHARE_NONZERO_PATTERN. B's values should be zeroed */
604076ba34aSJunchao Zhang       /* B's values on host should be already zeroed by MatDuplicate_SeqAIJ() */
605076ba34aSJunchao Zhang       bkok->a_dual.modify_host();
606076ba34aSJunchao Zhang     }
607076ba34aSJunchao Zhang     mat->spptr = bkok;
608076ba34aSJunchao Zhang   }
609076ba34aSJunchao Zhang 
6109566063dSJacob Faibussowitsch   PetscCall(PetscFree(mat->defaultvectype));
6119566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &mat->defaultvectype)); /* Allocate and copy the string */
6129566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)mat, MATSEQAIJKOKKOS));
6139566063dSJacob Faibussowitsch   PetscCall(MatSetOps_SeqAIJKokkos(mat));
6143ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6158c3ff71bSJunchao Zhang }
6168c3ff71bSJunchao Zhang 
617d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatTranspose_SeqAIJKokkos(Mat A, MatReuse reuse, Mat *B)
618d71ae5a4SJacob Faibussowitsch {
6190ecb592aSJunchao Zhang   Mat               At;
6200e3ece09SJunchao Zhang   KokkosCsrMatrix   internT;
6210ecb592aSJunchao Zhang   Mat_SeqAIJKokkos *atkok, *bkok;
6220ecb592aSJunchao Zhang 
6230ecb592aSJunchao Zhang   PetscFunctionBegin;
6247fb60732SBarry Smith   if (reuse == MAT_REUSE_MATRIX) PetscCall(MatTransposeCheckNonzeroState_Private(A, *B));
6259566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &internT)); /* Generate a transpose internally */
6260ecb592aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX || reuse == MAT_INPLACE_MATRIX) {
627ff751488SJunchao Zhang     /* Deep copy internT, as we want to isolate the internal transpose */
6280e3ece09SJunchao Zhang     PetscCallCXX(atkok = new Mat_SeqAIJKokkos(KokkosCsrMatrix("csrmat", internT)));
6299566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PetscObjectComm((PetscObject)A), atkok, &At));
6300ecb592aSJunchao Zhang     if (reuse == MAT_INITIAL_MATRIX) *B = At;
6319566063dSJacob Faibussowitsch     else PetscCall(MatHeaderReplace(A, &At)); /* Replace A with At inplace */
6320ecb592aSJunchao Zhang   } else {                                    /* MAT_REUSE_MATRIX, just need to copy values to B on device */
6330ecb592aSJunchao Zhang     if ((*B)->assembled) {
6340ecb592aSJunchao Zhang       bkok = static_cast<Mat_SeqAIJKokkos *>((*B)->spptr);
6350e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(bkok->a_dual.view_device(), internT.values));
6369566063dSJacob Faibussowitsch       PetscCall(MatSeqAIJKokkosModifyDevice(*B));
6370ecb592aSJunchao Zhang     } else if ((*B)->preallocated) { /* It is ok for B to be only preallocated, as needed in MatTranspose_MPIAIJ */
6380ecb592aSJunchao Zhang       Mat_SeqAIJ             *bseq = static_cast<Mat_SeqAIJ *>((*B)->data);
6390e3ece09SJunchao Zhang       MatScalarKokkosViewHost a_h(bseq->a, internT.nnz()); /* bseq->nz = 0 if unassembled */
6400e3ece09SJunchao Zhang       MatColIdxKokkosViewHost j_h(bseq->j, internT.nnz());
6410e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(a_h, internT.values));
6420e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(j_h, internT.graph.entries));
6430ecb592aSJunchao Zhang     } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "B must be assembled or preallocated");
6440ecb592aSJunchao Zhang   }
6453ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6460ecb592aSJunchao Zhang }
6470ecb592aSJunchao Zhang 
648d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDestroy_SeqAIJKokkos(Mat A)
649d71ae5a4SJacob Faibussowitsch {
65086a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
6518c3ff71bSJunchao Zhang 
6528c3ff71bSJunchao Zhang   PetscFunctionBegin;
65386a27549SJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
65486a27549SJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
6558c3ff71bSJunchao Zhang     delete aijkok;
65686a27549SJunchao Zhang   } else {
65786a27549SJunchao Zhang     delete static_cast<Mat_SeqAIJKokkosTriFactors *>(A->spptr);
65886a27549SJunchao Zhang   }
659cbc6b225SStefano Zampini   A->spptr = NULL;
6609566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
6619566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
6629566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
66357761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
66457761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijkokkos_hypre_C", NULL));
66557761e9aSJunchao Zhang #endif
6669566063dSJacob Faibussowitsch   PetscCall(MatDestroy_SeqAIJ(A));
6673ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6688c3ff71bSJunchao Zhang }
6698c3ff71bSJunchao Zhang 
6703f3ba80aSJunchao Zhang /*MC
6713f3ba80aSJunchao Zhang    MATSEQAIJKOKKOS - MATAIJKOKKOS = "(seq)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
6723f3ba80aSJunchao Zhang 
67315229ffcSPierre Jolivet    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
6743f3ba80aSJunchao Zhang 
6752ef1f0ffSBarry Smith    Options Database Key:
67611a5261eSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATSEQAIJKOKKOS` during a call to `MatSetFromOptions()`
6773f3ba80aSJunchao Zhang 
6783f3ba80aSJunchao Zhang   Level: beginner
6793f3ba80aSJunchao Zhang 
6801cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateSeqAIJKokkos()`, `MATMPIAIJKOKKOS`
6813f3ba80aSJunchao Zhang M*/
682d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJKokkos(Mat A)
683d71ae5a4SJacob Faibussowitsch {
68486a27549SJunchao Zhang   PetscFunctionBegin;
6859566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
6869566063dSJacob Faibussowitsch   PetscCall(MatCreate_SeqAIJ(A));
6879566063dSJacob Faibussowitsch   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
6883ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
68986a27549SJunchao Zhang }
69086a27549SJunchao Zhang 
691076ba34aSJunchao Zhang /* Merge A, B into a matrix C. A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n) */
692d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
693d71ae5a4SJacob Faibussowitsch {
694076ba34aSJunchao Zhang   Mat_SeqAIJ         *a, *b;
695076ba34aSJunchao Zhang   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
696076ba34aSJunchao Zhang   MatScalarKokkosView aa, ba, ca;
697076ba34aSJunchao Zhang   MatRowMapKokkosView ai, bi, ci;
698076ba34aSJunchao Zhang   MatColIdxKokkosView aj, bj, cj;
699076ba34aSJunchao Zhang   PetscInt            m, n, nnz, aN;
700a3f881fbSStefano Zampini 
701a3f881fbSStefano Zampini   PetscFunctionBegin;
702076ba34aSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
703076ba34aSJunchao Zhang   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
7044f572ea9SToby Isaac   PetscAssertPointer(C, 4);
705076ba34aSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
706076ba34aSJunchao Zhang   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
7075f80ce2aSJacob Faibussowitsch   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
7085f80ce2aSJacob Faibussowitsch   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
709076ba34aSJunchao Zhang 
7109566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
7119566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
712076ba34aSJunchao Zhang   a    = static_cast<Mat_SeqAIJ *>(A->data);
713076ba34aSJunchao Zhang   b    = static_cast<Mat_SeqAIJ *>(B->data);
714076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
715076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
716076ba34aSJunchao Zhang   aa   = akok->a_dual.view_device();
717076ba34aSJunchao Zhang   ai   = akok->i_dual.view_device();
718076ba34aSJunchao Zhang   ba   = bkok->a_dual.view_device();
719076ba34aSJunchao Zhang   bi   = bkok->i_dual.view_device();
720076ba34aSJunchao Zhang   m    = A->rmap->n; /* M, N and nnz of C */
721076ba34aSJunchao Zhang   n    = A->cmap->n + B->cmap->n;
722076ba34aSJunchao Zhang   nnz  = a->nz + b->nz;
723076ba34aSJunchao Zhang   aN   = A->cmap->n; /* N of A */
724076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
725076ba34aSJunchao Zhang     aj           = akok->j_dual.view_device();
726076ba34aSJunchao Zhang     bj           = bkok->j_dual.view_device();
727076ba34aSJunchao Zhang     auto ca_dual = MatScalarKokkosDualView("a", aa.extent(0) + ba.extent(0));
728076ba34aSJunchao Zhang     auto ci_dual = MatRowMapKokkosDualView("i", ai.extent(0));
729076ba34aSJunchao Zhang     auto cj_dual = MatColIdxKokkosDualView("j", aj.extent(0) + bj.extent(0));
730076ba34aSJunchao Zhang     ca           = ca_dual.view_device();
731076ba34aSJunchao Zhang     ci           = ci_dual.view_device();
732076ba34aSJunchao Zhang     cj           = cj_dual.view_device();
733076ba34aSJunchao Zhang 
734076ba34aSJunchao Zhang     /* Concatenate A and B in parallel using Kokkos hierarchical parallelism */
7359371c9d4SSatish Balay     Kokkos::parallel_for(
736d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
737076ba34aSJunchao Zhang         PetscInt i       = t.league_rank(); /* row i */
738076ba34aSJunchao Zhang         PetscInt coffset = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
739076ba34aSJunchao Zhang 
740076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* this side effect only happens once per whole team */
741076ba34aSJunchao Zhang                                                    ci(i) = coffset;
742076ba34aSJunchao Zhang                                                    if (i == m - 1) ci(m) = ai(m) + bi(m);
743076ba34aSJunchao Zhang         });
744076ba34aSJunchao Zhang 
745076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
746076ba34aSJunchao Zhang           if (k < alen) {
747076ba34aSJunchao Zhang             ca(coffset + k) = aa(ai(i) + k);
748076ba34aSJunchao Zhang             cj(coffset + k) = aj(ai(i) + k);
749076ba34aSJunchao Zhang           } else {
750076ba34aSJunchao Zhang             ca(coffset + k) = ba(bi(i) + k - alen);
751076ba34aSJunchao Zhang             cj(coffset + k) = bj(bi(i) + k - alen) + aN; /* Entries in B get new column indices in C */
752076ba34aSJunchao Zhang           }
753076ba34aSJunchao Zhang         });
754076ba34aSJunchao Zhang       });
755076ba34aSJunchao Zhang     ca_dual.modify_device();
756076ba34aSJunchao Zhang     ci_dual.modify_device();
757076ba34aSJunchao Zhang     cj_dual.modify_device();
7589566063dSJacob Faibussowitsch     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, n, nnz, ci_dual, cj_dual, ca_dual));
7599566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
760076ba34aSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
761076ba34aSJunchao Zhang     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
762076ba34aSJunchao Zhang     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
763076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
764076ba34aSJunchao Zhang     ca   = ckok->a_dual.view_device();
765076ba34aSJunchao Zhang     ci   = ckok->i_dual.view_device();
766076ba34aSJunchao Zhang 
7679371c9d4SSatish Balay     Kokkos::parallel_for(
768d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
769076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* row i */
770076ba34aSJunchao Zhang         PetscInt alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
771076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
772076ba34aSJunchao Zhang           if (k < alen) ca(ci(i) + k) = aa(ai(i) + k);
773076ba34aSJunchao Zhang           else ca(ci(i) + k) = ba(bi(i) + k - alen);
774076ba34aSJunchao Zhang         });
775076ba34aSJunchao Zhang       });
7769566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
777076ba34aSJunchao Zhang   }
7783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
779076ba34aSJunchao Zhang }
780076ba34aSJunchao Zhang 
781d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_SeqAIJKokkos(void *pdata)
782d71ae5a4SJacob Faibussowitsch {
783076ba34aSJunchao Zhang   PetscFunctionBegin;
784076ba34aSJunchao Zhang   delete static_cast<MatProductData_SeqAIJKokkos *>(pdata);
7853ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
786a3f881fbSStefano Zampini }
787a3f881fbSStefano Zampini 
788d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos(Mat C)
789d71ae5a4SJacob Faibussowitsch {
790a3f881fbSStefano Zampini   Mat_Product                 *product = C->product;
791a3f881fbSStefano Zampini   Mat                          A, B;
792076ba34aSJunchao Zhang   bool                         transA, transB; /* use bool, since KK needs this type */
793a3f881fbSStefano Zampini   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
794a3f881fbSStefano Zampini   Mat_SeqAIJ                  *c;
795076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
7960e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB;
797a3f881fbSStefano Zampini 
798a3f881fbSStefano Zampini   PetscFunctionBegin;
799a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
8005f80ce2aSJacob Faibussowitsch   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data empty");
801076ba34aSJunchao Zhang   pdata = static_cast<MatProductData_SeqAIJKokkos *>(C->product->data);
802076ba34aSJunchao Zhang 
8030e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
8040e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
8050e3ece09SJunchao Zhang   // we still do numeric.
8060e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
8070e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
8083ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
809076ba34aSJunchao Zhang   }
810076ba34aSJunchao Zhang 
811076ba34aSJunchao Zhang   switch (product->type) {
8129371c9d4SSatish Balay   case MATPRODUCT_AB:
8139371c9d4SSatish Balay     transA = false;
8149371c9d4SSatish Balay     transB = false;
8159371c9d4SSatish Balay     break;
8169371c9d4SSatish Balay   case MATPRODUCT_AtB:
8179371c9d4SSatish Balay     transA = true;
8189371c9d4SSatish Balay     transB = false;
8199371c9d4SSatish Balay     break;
8209371c9d4SSatish Balay   case MATPRODUCT_ABt:
8219371c9d4SSatish Balay     transA = false;
8229371c9d4SSatish Balay     transB = true;
8239371c9d4SSatish Balay     break;
824d71ae5a4SJacob Faibussowitsch   default:
825d71ae5a4SJacob Faibussowitsch     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
826076ba34aSJunchao Zhang   }
827076ba34aSJunchao Zhang 
828a3f881fbSStefano Zampini   A = product->A;
829a3f881fbSStefano Zampini   B = product->B;
8309566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8319566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
832a3f881fbSStefano Zampini   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
833a3f881fbSStefano Zampini   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
834a3f881fbSStefano Zampini   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
835076ba34aSJunchao Zhang 
8365f80ce2aSJacob Faibussowitsch   PetscCheck(ckok, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Device data structure spptr is empty");
837076ba34aSJunchao Zhang 
8380e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
8390e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
840076ba34aSJunchao Zhang 
841076ba34aSJunchao Zhang   /* TODO: Once KK spgemm implements transpose, we can get rid of the explicit transpose here */
842076ba34aSJunchao Zhang   if (transA) {
8439566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
844076ba34aSJunchao Zhang     transA = false;
845a3f881fbSStefano Zampini   }
846a3f881fbSStefano Zampini 
847076ba34aSJunchao Zhang   if (transB) {
8489566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
849076ba34aSJunchao Zhang     transB = false;
850076ba34aSJunchao Zhang   }
8519566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
8520e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, ckok->csrmat));
8530e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
854866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
855866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(ckok->csrmat)); /* without sort, mat_tests-ex62_14_seqaijkokkos fails */
856e944a159SJunchao Zhang #endif
857866eb059SJunchao Zhang 
8589566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
8599566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(C));
860a3f881fbSStefano Zampini   /* shorter version of MatAssemblyEnd_SeqAIJ */
861a3f881fbSStefano Zampini   c = (Mat_SeqAIJ *)C->data;
8629566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded,%" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
8639566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
8649566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
865a3f881fbSStefano Zampini   c->reallocs         = 0;
866076ba34aSJunchao Zhang   C->info.mallocs     = 0;
867a3f881fbSStefano Zampini   C->info.nz_unneeded = 0;
868a3f881fbSStefano Zampini   C->assembled = C->was_assembled = PETSC_TRUE;
869a3f881fbSStefano Zampini   C->num_ass++;
8703ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
871a3f881fbSStefano Zampini }
872a3f881fbSStefano Zampini 
873d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos(Mat C)
874d71ae5a4SJacob Faibussowitsch {
875076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
876076ba34aSJunchao Zhang   MatProductType               ptype;
877076ba34aSJunchao Zhang   Mat                          A, B;
878076ba34aSJunchao Zhang   bool                         transA, transB;
879076ba34aSJunchao Zhang   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
880076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
881076ba34aSJunchao Zhang   MPI_Comm                     comm;
8820e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB, csrmatC;
883a3f881fbSStefano Zampini 
884a3f881fbSStefano Zampini   PetscFunctionBegin;
885a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
8869566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
8875f80ce2aSJacob Faibussowitsch   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
888a3f881fbSStefano Zampini   A = product->A;
889a3f881fbSStefano Zampini   B = product->B;
8909566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8919566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
892a3f881fbSStefano Zampini   akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
893a3f881fbSStefano Zampini   bkok    = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
8940e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
8950e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
896076ba34aSJunchao Zhang 
897a3f881fbSStefano Zampini   ptype = product->type;
8980e3ece09SJunchao Zhang   // Take advantage of the symmetry if true
8990e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
9000e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
9010e3ece09SJunchao Zhang     product->symbolic_used_the_fact_A_is_symmetric = PETSC_TRUE;
9020e3ece09SJunchao Zhang   }
9030e3ece09SJunchao Zhang   if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
9040e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
9050e3ece09SJunchao Zhang     product->symbolic_used_the_fact_B_is_symmetric = PETSC_TRUE;
9060e3ece09SJunchao Zhang   }
9070e3ece09SJunchao Zhang 
908a3f881fbSStefano Zampini   switch (ptype) {
9099371c9d4SSatish Balay   case MATPRODUCT_AB:
9109371c9d4SSatish Balay     transA = false;
9119371c9d4SSatish Balay     transB = false;
9120e6a1e94SMark Adams     PetscCall(MatSetBlockSizesFromMats(C, A, B));
9139371c9d4SSatish Balay     break;
9149371c9d4SSatish Balay   case MATPRODUCT_AtB:
9159371c9d4SSatish Balay     transA = true;
9169371c9d4SSatish Balay     transB = false;
9170e6a1e94SMark Adams     if (A->cmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->rmap, A->cmap->bs));
9180e6a1e94SMark Adams     if (B->cmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->cmap, B->cmap->bs));
9199371c9d4SSatish Balay     break;
9209371c9d4SSatish Balay   case MATPRODUCT_ABt:
9219371c9d4SSatish Balay     transA = false;
9229371c9d4SSatish Balay     transB = true;
9230e6a1e94SMark Adams     if (A->rmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->rmap, A->rmap->bs));
9240e6a1e94SMark Adams     if (B->rmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->cmap, B->rmap->bs));
9259371c9d4SSatish Balay     break;
926d71ae5a4SJacob Faibussowitsch   default:
927d71ae5a4SJacob Faibussowitsch     SETERRQ(comm, PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
928a3f881fbSStefano Zampini   }
9290e3ece09SJunchao Zhang   PetscCallCXX(product->data = pdata = new MatProductData_SeqAIJKokkos());
930076ba34aSJunchao Zhang   pdata->reusesym = product->api_user;
931a3f881fbSStefano Zampini 
932076ba34aSJunchao Zhang   /* TODO: add command line options to select spgemm algorithms */
933866eb059SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; /* default alg is TPL if enabled, otherwise KK */
934866eb059SJunchao Zhang 
935866eb059SJunchao Zhang   /* CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 */
936866eb059SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
937866eb059SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
938866eb059SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
939866eb059SJunchao Zhang   #endif
940866eb059SJunchao Zhang #endif
9410e3ece09SJunchao Zhang   PetscCallCXX(pdata->kh.create_spgemm_handle(spgemm_alg));
942076ba34aSJunchao Zhang 
9439566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
944076ba34aSJunchao Zhang   /* TODO: Get rid of the explicit transpose once KK-spgemm implements the transpose option */
945076ba34aSJunchao Zhang   if (transA) {
9469566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
947076ba34aSJunchao Zhang     transA = false;
948076ba34aSJunchao Zhang   }
949076ba34aSJunchao Zhang 
950076ba34aSJunchao Zhang   if (transB) {
9519566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
952076ba34aSJunchao Zhang     transB = false;
953076ba34aSJunchao Zhang   }
954076ba34aSJunchao Zhang 
9550e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
956076ba34aSJunchao Zhang   /* spgemm_symbolic() only populates C's rowmap, but not C's column indices.
957076ba34aSJunchao Zhang     So we have to do a fake spgemm_numeric() here to get csrmatC.j_d setup, before
958076ba34aSJunchao Zhang     calling new Mat_SeqAIJKokkos().
959076ba34aSJunchao Zhang     TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
960076ba34aSJunchao Zhang   */
9610e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
9620e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
963866eb059SJunchao Zhang   /* Query if KK outputs a sorted matrix. If not, we need to sort it */
964866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
965866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(csrmatC)); /* sort_option defaults to -1 in KK!*/
966e944a159SJunchao Zhang #endif
9679566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
968076ba34aSJunchao Zhang 
9699566063dSJacob Faibussowitsch   PetscCallCXX(ckok = new Mat_SeqAIJKokkos(csrmatC));
9709566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(C, ckok));
971076ba34aSJunchao Zhang   C->product->destroy = MatProductDataDestroy_SeqAIJKokkos;
9723ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
973a3f881fbSStefano Zampini }
974a3f881fbSStefano Zampini 
975a3f881fbSStefano Zampini /* handles sparse matrix matrix ops */
976d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSetFromOptions_SeqAIJKokkos(Mat mat)
977d71ae5a4SJacob Faibussowitsch {
978076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
979a3f881fbSStefano Zampini   PetscBool    Biskok = PETSC_FALSE, Ciskok = PETSC_TRUE;
980a3f881fbSStefano Zampini 
981a3f881fbSStefano Zampini   PetscFunctionBegin;
982a3f881fbSStefano Zampini   MatCheckProduct(mat, 1);
9839566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJKOKKOS, &Biskok));
98448a46eb9SPierre Jolivet   if (product->type == MATPRODUCT_ABC) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJKOKKOS, &Ciskok));
985a3f881fbSStefano Zampini   if (Biskok && Ciskok) {
986a3f881fbSStefano Zampini     switch (product->type) {
987a3f881fbSStefano Zampini     case MATPRODUCT_AB:
988a3f881fbSStefano Zampini     case MATPRODUCT_AtB:
989d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABt:
990d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos;
991d71ae5a4SJacob Faibussowitsch       break;
992a3f881fbSStefano Zampini     case MATPRODUCT_PtAP:
993a3f881fbSStefano Zampini     case MATPRODUCT_RARt:
994d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABC:
995d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
996d71ae5a4SJacob Faibussowitsch       break;
997d71ae5a4SJacob Faibussowitsch     default:
998d71ae5a4SJacob Faibussowitsch       break;
999a3f881fbSStefano Zampini     }
1000a3f881fbSStefano Zampini   } else { /* fallback for AIJ */
10019566063dSJacob Faibussowitsch     PetscCall(MatProductSetFromOptions_SeqAIJ(mat));
1002a3f881fbSStefano Zampini   }
10033ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1004a3f881fbSStefano Zampini }
1005a587d139SMark 
1006d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatScale_SeqAIJKokkos(Mat A, PetscScalar a)
1007d71ae5a4SJacob Faibussowitsch {
1008f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok;
1009f0cf5187SStefano Zampini 
1010f0cf5187SStefano Zampini   PetscFunctionBegin;
10119566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
10129566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1013f0cf5187SStefano Zampini   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1014d326c3f1SJunchao Zhang   KokkosBlas::scal(PetscGetKokkosExecutionSpace(), aijkok->a_dual.view_device(), a, aijkok->a_dual.view_device());
10159566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
10169566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(aijkok->a_dual.extent(0)));
10179566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
10183ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1019f0cf5187SStefano Zampini }
1020f0cf5187SStefano Zampini 
1021f4747e26SJunchao Zhang // add a to A's diagonal (if A is square) or main diagonal (if A is rectangular)
1022f4747e26SJunchao Zhang static PetscErrorCode MatShift_SeqAIJKokkos(Mat A, PetscScalar a)
1023f4747e26SJunchao Zhang {
1024f4747e26SJunchao Zhang   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1025f4747e26SJunchao Zhang 
1026f4747e26SJunchao Zhang   PetscFunctionBegin;
1027f4747e26SJunchao Zhang   if (A->assembled && aijseq->diagonaldense) { // no missing diagonals
1028f4747e26SJunchao Zhang     PetscInt n = PetscMin(A->rmap->n, A->cmap->n);
1029f4747e26SJunchao Zhang 
1030f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
1031f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosSyncDevice(A));
1032f4747e26SJunchao Zhang     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1033f4747e26SJunchao Zhang     const auto &Aa     = aijkok->a_dual.view_device();
1034f4747e26SJunchao Zhang     const auto &Adiag  = aijkok->diag_dual.view_device();
1035d326c3f1SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) { Aa(Adiag(i)) += a; }));
1036f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosModifyDevice(A));
1037f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(n));
1038f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
1039f4747e26SJunchao Zhang   } else { // need reassembly, very slow!
1040f4747e26SJunchao Zhang     PetscCall(MatShift_Basic(A, a));
1041f4747e26SJunchao Zhang   }
1042f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1043f4747e26SJunchao Zhang }
1044f4747e26SJunchao Zhang 
1045f4747e26SJunchao Zhang static PetscErrorCode MatDiagonalSet_SeqAIJKokkos(Mat Y, Vec D, InsertMode is)
1046f4747e26SJunchao Zhang {
1047f4747e26SJunchao Zhang   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(Y->data);
1048f4747e26SJunchao Zhang 
1049f4747e26SJunchao Zhang   PetscFunctionBegin;
1050f4747e26SJunchao Zhang   if (Y->assembled && aijseq->diagonaldense) { // no missing diagonals
1051f4747e26SJunchao Zhang     ConstPetscScalarKokkosView dv;
1052f4747e26SJunchao Zhang     PetscInt                   n, nv;
1053f4747e26SJunchao Zhang 
1054f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
1055f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosSyncDevice(Y));
1056f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(D, &dv));
1057f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(D, &nv));
1058f4747e26SJunchao Zhang     n = PetscMin(Y->rmap->n, Y->cmap->n);
1059f4747e26SJunchao Zhang     PetscCheck(n == nv, PetscObjectComm((PetscObject)Y), PETSC_ERR_ARG_SIZ, "Matrix size and vector size do not match");
1060f4747e26SJunchao Zhang 
1061f4747e26SJunchao Zhang     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1062f4747e26SJunchao Zhang     const auto &Aa     = aijkok->a_dual.view_device();
1063f4747e26SJunchao Zhang     const auto &Adiag  = aijkok->diag_dual.view_device();
1064f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(
1065d326c3f1SJunchao Zhang       Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) {
1066f4747e26SJunchao Zhang         if (is == INSERT_VALUES) Aa(Adiag(i)) = dv(i);
1067f4747e26SJunchao Zhang         else Aa(Adiag(i)) += dv(i);
1068f4747e26SJunchao Zhang       }));
1069f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(D, &dv));
1070f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1071f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(n));
1072f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
1073f4747e26SJunchao Zhang   } else { // need reassembly, very slow!
1074f4747e26SJunchao Zhang     PetscCall(MatDiagonalSet_Default(Y, D, is));
1075f4747e26SJunchao Zhang   }
1076f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1077f4747e26SJunchao Zhang }
1078f4747e26SJunchao Zhang 
1079f4747e26SJunchao Zhang static PetscErrorCode MatDiagonalScale_SeqAIJKokkos(Mat A, Vec ll, Vec rr)
1080f4747e26SJunchao Zhang {
1081f4747e26SJunchao Zhang   Mat_SeqAIJ                *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1082f4747e26SJunchao Zhang   PetscInt                   m = A->rmap->n, n = A->cmap->n, nz = aijseq->nz;
1083f4747e26SJunchao Zhang   ConstPetscScalarKokkosView lv, rv;
1084f4747e26SJunchao Zhang 
1085f4747e26SJunchao Zhang   PetscFunctionBegin;
1086f4747e26SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1087f4747e26SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1088f4747e26SJunchao Zhang   const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1089f4747e26SJunchao Zhang   const auto &Aa     = aijkok->a_dual.view_device();
1090f4747e26SJunchao Zhang   const auto &Ai     = aijkok->i_dual.view_device();
1091f4747e26SJunchao Zhang   const auto &Aj     = aijkok->j_dual.view_device();
1092f4747e26SJunchao Zhang   if (ll) {
1093f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(ll, &m));
1094f4747e26SJunchao Zhang     PetscCheck(m == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Left scaling vector wrong length");
1095f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(ll, &lv));
1096f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for( // for each row
1097d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1098f4747e26SJunchao Zhang         PetscInt i   = t.league_rank(); // row i
1099f4747e26SJunchao Zhang         PetscInt len = Ai(i + 1) - Ai(i);
1100f4747e26SJunchao Zhang         // scale entries on the row
1101f4747e26SJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt j) { Aa(Ai(i) + j) *= lv(i); });
1102f4747e26SJunchao Zhang       }));
1103f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(ll, &lv));
1104f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(nz));
1105f4747e26SJunchao Zhang   }
1106f4747e26SJunchao Zhang   if (rr) {
1107f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(rr, &n));
1108f4747e26SJunchao Zhang     PetscCheck(n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Right scaling vector wrong length");
1109f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(rr, &rv));
1110f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for( // for each nonzero
1111d326c3f1SJunchao Zhang       Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt k) { Aa(k) *= rv(Aj(k)); }));
1112f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(rr, &lv));
1113f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(nz));
1114f4747e26SJunchao Zhang   }
1115f4747e26SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(A));
1116f4747e26SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1117f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1118f4747e26SJunchao Zhang }
1119f4747e26SJunchao Zhang 
1120d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatZeroEntries_SeqAIJKokkos(Mat A)
1121d71ae5a4SJacob Faibussowitsch {
1122076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1123a587d139SMark 
1124a587d139SMark   PetscFunctionBegin;
1125076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
11262328674fSJunchao Zhang   if (aijkok) { /* Only zero the device if data is already there */
1127d326c3f1SJunchao Zhang     KokkosBlas::fill(PetscGetKokkosExecutionSpace(), aijkok->a_dual.view_device(), 0.0);
11289566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A));
11292328674fSJunchao Zhang   } else { /* Might be preallocated but not assembled */
11309566063dSJacob Faibussowitsch     PetscCall(MatZeroEntries_SeqAIJ(A));
11312328674fSJunchao Zhang   }
11323ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1133a587d139SMark }
1134a587d139SMark 
1135d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatGetDiagonal_SeqAIJKokkos(Mat A, Vec x)
1136d71ae5a4SJacob Faibussowitsch {
1137f78ce678SMark Adams   Mat_SeqAIJKokkos     *aijkok;
1138f78ce678SMark Adams   PetscInt              n;
1139f78ce678SMark Adams   PetscScalarKokkosView xv;
1140f78ce678SMark Adams 
1141f78ce678SMark Adams   PetscFunctionBegin;
1142f78ce678SMark Adams   PetscCall(VecGetLocalSize(x, &n));
1143f78ce678SMark Adams   PetscCheck(n == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Nonconforming matrix and vector");
1144f78ce678SMark Adams   PetscCheck(A->factortype == MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_SUP, "MatGetDiagonal_SeqAIJKokkos not supported on factored matrices");
1145f78ce678SMark Adams 
1146f78ce678SMark Adams   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1147f78ce678SMark Adams   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1148f78ce678SMark Adams 
1149f78ce678SMark Adams   const auto &Aa    = aijkok->a_dual.view_device();
1150f78ce678SMark Adams   const auto &Ai    = aijkok->i_dual.view_device();
1151f78ce678SMark Adams   const auto &Adiag = aijkok->diag_dual.view_device();
1152f78ce678SMark Adams 
1153f78ce678SMark Adams   PetscCall(VecGetKokkosViewWrite(x, &xv));
11549371c9d4SSatish Balay   Kokkos::parallel_for(
1155d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) {
1156f78ce678SMark Adams       if (Adiag(i) < Ai(i + 1)) xv(i) = Aa(Adiag(i));
1157f78ce678SMark Adams       else xv(i) = 0;
1158f78ce678SMark Adams     });
1159f78ce678SMark Adams   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
11603ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1161f78ce678SMark Adams }
1162f78ce678SMark Adams 
1163db78de30SJunchao Zhang /* Get a Kokkos View from a mat of type MatSeqAIJKokkos */
1164d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1165d71ae5a4SJacob Faibussowitsch {
1166db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1167db78de30SJunchao Zhang 
1168db78de30SJunchao Zhang   PetscFunctionBegin;
1169db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11704f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1171db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11729566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1173db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1174076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
11753ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1176db78de30SJunchao Zhang }
1177db78de30SJunchao Zhang 
1178d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1179d71ae5a4SJacob Faibussowitsch {
1180db78de30SJunchao Zhang   PetscFunctionBegin;
1181db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11824f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1183db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11843ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1185db78de30SJunchao Zhang }
1186db78de30SJunchao Zhang 
1187d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, MatScalarKokkosView *kv)
1188d71ae5a4SJacob Faibussowitsch {
1189db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1190db78de30SJunchao Zhang 
1191db78de30SJunchao Zhang   PetscFunctionBegin;
1192db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11934f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1194db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11959566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1196db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1197076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
11983ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1199db78de30SJunchao Zhang }
1200db78de30SJunchao Zhang 
1201d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, MatScalarKokkosView *kv)
1202d71ae5a4SJacob Faibussowitsch {
1203db78de30SJunchao Zhang   PetscFunctionBegin;
1204db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12054f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1206db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12079566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
12083ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1209db78de30SJunchao Zhang }
1210db78de30SJunchao Zhang 
1211d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1212d71ae5a4SJacob Faibussowitsch {
1213db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1214db78de30SJunchao Zhang 
1215db78de30SJunchao Zhang   PetscFunctionBegin;
1216db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12174f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1218db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1219db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1220076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
12213ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1222db78de30SJunchao Zhang }
1223db78de30SJunchao Zhang 
1224d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1225d71ae5a4SJacob Faibussowitsch {
1226db78de30SJunchao Zhang   PetscFunctionBegin;
1227db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12284f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1229db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12309566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
12313ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1232db78de30SJunchao Zhang }
1233db78de30SJunchao Zhang 
1234c0c276a7Ssdargavi PetscErrorCode MatCreateSeqAIJKokkosWithKokkosViews(MPI_Comm comm, PetscInt m, PetscInt n, Kokkos::View<PetscInt *> &i_d, Kokkos::View<PetscInt *> &j_d, Kokkos::View<PetscScalar *> &a_d, Mat *A)
1235c0c276a7Ssdargavi {
1236c0c276a7Ssdargavi   Mat_SeqAIJKokkos *akok;
1237c0c276a7Ssdargavi 
1238c0c276a7Ssdargavi   PetscFunctionBegin;
1239c0c276a7Ssdargavi   auto exec = PetscGetKokkosExecutionSpace();
1240c0c276a7Ssdargavi   // Create host copies of the input aij
1241c0c276a7Ssdargavi   auto i_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), i_d);
1242c0c276a7Ssdargavi   auto j_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), j_d);
1243c0c276a7Ssdargavi   // Don't copy the vals to the host now
1244c0c276a7Ssdargavi   auto a_h = Kokkos::create_mirror_view(HostMirrorMemorySpace(), a_d);
1245c0c276a7Ssdargavi 
1246c0c276a7Ssdargavi   MatScalarKokkosDualView a_dual = MatScalarKokkosDualView(a_d, a_h);
1247c0c276a7Ssdargavi   // Note we have modified device data so it will copy lazily
1248c0c276a7Ssdargavi   a_dual.modify_device();
1249c0c276a7Ssdargavi   MatRowMapKokkosDualView i_dual = MatRowMapKokkosDualView(i_d, i_h);
1250c0c276a7Ssdargavi   MatColIdxKokkosDualView j_dual = MatColIdxKokkosDualView(j_d, j_h);
1251c0c276a7Ssdargavi 
1252c0c276a7Ssdargavi   PetscCallCXX(akok = new Mat_SeqAIJKokkos(m, n, j_dual.extent(0), i_dual, j_dual, a_dual));
1253c0c276a7Ssdargavi   PetscCall(MatCreate(comm, A));
1254c0c276a7Ssdargavi   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
1255c0c276a7Ssdargavi   PetscFunctionReturn(PETSC_SUCCESS);
1256c0c276a7Ssdargavi }
1257c0c276a7Ssdargavi 
1258c17cf699SJunchao Zhang /* Computes Y += alpha X */
1259d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAXPY_SeqAIJKokkos(Mat Y, PetscScalar alpha, Mat X, MatStructure pattern)
1260d71ae5a4SJacob Faibussowitsch {
1261a587d139SMark   Mat_SeqAIJ              *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
1262c17cf699SJunchao Zhang   Mat_SeqAIJKokkos        *xkok, *ykok, *zkok;
1263c17cf699SJunchao Zhang   ConstMatScalarKokkosView Xa;
1264c17cf699SJunchao Zhang   MatScalarKokkosView      Ya;
12654df4a32cSJunchao Zhang   auto                     exec = PetscGetKokkosExecutionSpace();
1266a587d139SMark 
1267a587d139SMark   PetscFunctionBegin;
1268c17cf699SJunchao Zhang   PetscCheckTypeName(Y, MATSEQAIJKOKKOS);
1269c17cf699SJunchao Zhang   PetscCheckTypeName(X, MATSEQAIJKOKKOS);
12709566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(Y));
12719566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(X));
12729566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1273db78de30SJunchao Zhang 
1274c17cf699SJunchao Zhang   if (pattern != SAME_NONZERO_PATTERN && x->nz == y->nz) {
1275a587d139SMark     PetscBool e;
12769566063dSJacob Faibussowitsch     PetscCall(PetscArraycmp(x->i, y->i, Y->rmap->n + 1, &e));
1277a587d139SMark     if (e) {
12789566063dSJacob Faibussowitsch       PetscCall(PetscArraycmp(x->j, y->j, y->nz, &e));
1279c17cf699SJunchao Zhang       if (e) pattern = SAME_NONZERO_PATTERN;
1280a587d139SMark     }
1281a587d139SMark   }
1282db78de30SJunchao Zhang 
1283c17cf699SJunchao Zhang   /* cusparseDcsrgeam2() computes C = alpha A + beta B. If one knew sparsity pattern of C, one can skip
1284c17cf699SJunchao Zhang     cusparseScsrgeam2_bufferSizeExt() / cusparseXcsrgeam2Nnz(), and directly call cusparseScsrgeam2().
1285c17cf699SJunchao Zhang     If X is SUBSET_NONZERO_PATTERN of Y, we could take advantage of this cusparse feature. However,
1286c17cf699SJunchao Zhang     KokkosSparse::spadd(alpha,A,beta,B,C) has symbolic and numeric phases, MatAXPY does not.
1287c17cf699SJunchao Zhang   */
1288c17cf699SJunchao Zhang   ykok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1289c17cf699SJunchao Zhang   xkok = static_cast<Mat_SeqAIJKokkos *>(X->spptr);
1290c17cf699SJunchao Zhang   Xa   = xkok->a_dual.view_device();
1291c17cf699SJunchao Zhang   Ya   = ykok->a_dual.view_device();
1292c17cf699SJunchao Zhang 
1293c17cf699SJunchao Zhang   if (pattern == SAME_NONZERO_PATTERN) {
1294d326c3f1SJunchao Zhang     KokkosBlas::axpy(exec, alpha, Xa, Ya);
12959566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1296c17cf699SJunchao Zhang   } else if (pattern == SUBSET_NONZERO_PATTERN) {
1297c17cf699SJunchao Zhang     MatRowMapKokkosView Xi = xkok->i_dual.view_device(), Yi = ykok->i_dual.view_device();
1298c17cf699SJunchao Zhang     MatColIdxKokkosView Xj = xkok->j_dual.view_device(), Yj = ykok->j_dual.view_device();
1299c17cf699SJunchao Zhang 
13009371c9d4SSatish Balay     Kokkos::parallel_for(
1301d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(exec, Y->rmap->n, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
13020e3ece09SJunchao Zhang         PetscInt i = t.league_rank(); // row i
13030e3ece09SJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() {
13040e3ece09SJunchao Zhang           // Only one thread works in a team
1305c17cf699SJunchao Zhang           PetscInt p, q = Yi(i);
13060e3ece09SJunchao Zhang           for (p = Xi(i); p < Xi(i + 1); p++) {          // For each nonzero on row i of X,
13070e3ece09SJunchao Zhang             while (Xj(p) != Yj(q) && q < Yi(i + 1)) q++; // find the matching nonzero on row i of Y.
13080e3ece09SJunchao Zhang             if (Xj(p) == Yj(q)) {                        // Found it
1309c17cf699SJunchao Zhang               Ya(q) += alpha * Xa(p);
1310c17cf699SJunchao Zhang               q++;
1311a587d139SMark             } else {
13120e3ece09SJunchao Zhang             // If not found, it indicates the input is wrong (X is not a SUBSET_NONZERO_PATTERN of Y).
13130e3ece09SJunchao Zhang             // Just insert a NaN at the beginning of row i if it is not empty, to make the result wrong.
13140e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_VERSION_GE(3, 7, 0)
13150e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::ArithTraits<PetscScalar>::nan();
13168b8b16f9SJunchao Zhang #else
13170e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::Experimental::nan("1");
13188b8b16f9SJunchao Zhang #endif
1319a587d139SMark             }
1320c17cf699SJunchao Zhang           }
1321c17cf699SJunchao Zhang         });
1322c17cf699SJunchao Zhang       });
13239566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
13240e3ece09SJunchao Zhang   } else { // different nonzero patterns
1325c17cf699SJunchao Zhang     Mat             Z;
1326c17cf699SJunchao Zhang     KokkosCsrMatrix zcsr;
1327c17cf699SJunchao Zhang     KernelHandle    kh;
13280e3ece09SJunchao Zhang     kh.create_spadd_handle(true); // X, Y are sorted
1329c17cf699SJunchao Zhang     KokkosSparse::spadd_symbolic(&kh, xkok->csrmat, ykok->csrmat, zcsr);
1330c17cf699SJunchao Zhang     KokkosSparse::spadd_numeric(&kh, alpha, xkok->csrmat, (PetscScalar)1.0, ykok->csrmat, zcsr);
1331c17cf699SJunchao Zhang     zkok = new Mat_SeqAIJKokkos(zcsr);
13329566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, zkok, &Z));
13339566063dSJacob Faibussowitsch     PetscCall(MatHeaderReplace(Y, &Z));
1334c17cf699SJunchao Zhang     kh.destroy_spadd_handle();
1335c17cf699SJunchao Zhang   }
13369566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
13370e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(xkok->a_dual.extent(0) * 2)); // Because we scaled X and then added it to Y
13383ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1339a587d139SMark }
1340a587d139SMark 
13412c4ab24aSJunchao Zhang struct MatCOOStruct_SeqAIJKokkos {
13422c4ab24aSJunchao Zhang   PetscCount           n;
13432c4ab24aSJunchao Zhang   PetscCount           Atot;
13442c4ab24aSJunchao Zhang   PetscInt             nz;
13452c4ab24aSJunchao Zhang   PetscCountKokkosView jmap;
13462c4ab24aSJunchao Zhang   PetscCountKokkosView perm;
13472c4ab24aSJunchao Zhang 
13482c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos(const MatCOOStruct_SeqAIJ *coo_h)
13492c4ab24aSJunchao Zhang   {
13502c4ab24aSJunchao Zhang     nz   = coo_h->nz;
13512c4ab24aSJunchao Zhang     n    = coo_h->n;
13522c4ab24aSJunchao Zhang     Atot = coo_h->Atot;
13532c4ab24aSJunchao Zhang     jmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->jmap, nz + 1));
13542c4ab24aSJunchao Zhang     perm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->perm, Atot));
13552c4ab24aSJunchao Zhang   }
13562c4ab24aSJunchao Zhang };
13572c4ab24aSJunchao Zhang 
135849abdd8aSBarry Smith static PetscErrorCode MatCOOStructDestroy_SeqAIJKokkos(void **data)
13592c4ab24aSJunchao Zhang {
13602c4ab24aSJunchao Zhang   PetscFunctionBegin;
136149abdd8aSBarry Smith   PetscCallCXX(delete static_cast<MatCOOStruct_SeqAIJKokkos *>(*data));
13622c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
13632c4ab24aSJunchao Zhang }
13642c4ab24aSJunchao Zhang 
1365d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_SeqAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1366d71ae5a4SJacob Faibussowitsch {
136742550becSJunchao Zhang   Mat_SeqAIJKokkos          *akok;
136842550becSJunchao Zhang   Mat_SeqAIJ                *aseq;
136903e76207SPierre Jolivet   PetscContainer             container_h;
13702c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJ       *coo_h;
13712c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo_d;
137242550becSJunchao Zhang 
137342550becSJunchao Zhang   PetscFunctionBegin;
13749566063dSJacob Faibussowitsch   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, coo_i, coo_j));
1375394ed5ebSJunchao Zhang   aseq = static_cast<Mat_SeqAIJ *>(mat->data);
137642550becSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(mat->spptr);
1377cbc6b225SStefano Zampini   delete akok;
1378f4747e26SJunchao Zhang   mat->spptr = akok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, aseq, mat->nonzerostate + 1, PETSC_FALSE);
13799566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries_SeqAIJKokkos(mat));
13802c4ab24aSJunchao Zhang 
13812c4ab24aSJunchao Zhang   // Copy the COO struct to device
13822c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
13832c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
13842c4ab24aSJunchao Zhang   PetscCallCXX(coo_d = new MatCOOStruct_SeqAIJKokkos(coo_h));
13852c4ab24aSJunchao Zhang 
13862c4ab24aSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
138703e76207SPierre Jolivet   PetscCall(PetscObjectContainerCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", coo_d, MatCOOStructDestroy_SeqAIJKokkos));
13883ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
138942550becSJunchao Zhang }
139042550becSJunchao Zhang 
1391d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_SeqAIJKokkos(Mat A, const PetscScalar v[], InsertMode imode)
1392d71ae5a4SJacob Faibussowitsch {
139342550becSJunchao Zhang   MatScalarKokkosView        Aa;
139442550becSJunchao Zhang   ConstMatScalarKokkosView   kv;
139542550becSJunchao Zhang   PetscMemType               memtype;
13962c4ab24aSJunchao Zhang   PetscContainer             container;
13972c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo;
139842550becSJunchao Zhang 
139942550becSJunchao Zhang   PetscFunctionBegin;
14002c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)A, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
14012c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
14022c4ab24aSJunchao Zhang 
14032c4ab24aSJunchao Zhang   const auto &n    = coo->n;
14042c4ab24aSJunchao Zhang   const auto &Annz = coo->nz;
14052c4ab24aSJunchao Zhang   const auto &jmap = coo->jmap;
14062c4ab24aSJunchao Zhang   const auto &perm = coo->perm;
14072c4ab24aSJunchao Zhang 
14089566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype));
140942550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
14102c4ab24aSJunchao Zhang     kv = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatScalarKokkosViewHost(v, n));
141142550becSJunchao Zhang   } else {
14122c4ab24aSJunchao Zhang     kv = ConstMatScalarKokkosView(v, n); /* Directly use v[]'s memory */
141342550becSJunchao Zhang   }
141442550becSJunchao Zhang 
1415c7b718f4SJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1416c7b718f4SJunchao Zhang   else PetscCall(MatSeqAIJGetKokkosView(A, &Aa));                             /* read & write matrix values */
141742550becSJunchao Zhang 
141808bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
14199371c9d4SSatish Balay   Kokkos::parallel_for(
1420d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz), KOKKOS_LAMBDA(const PetscCount i) {
1421c7b718f4SJunchao Zhang       PetscScalar sum = 0.0;
1422c7b718f4SJunchao Zhang       for (PetscCount k = jmap(i); k < jmap(i + 1); k++) sum += kv(perm(k));
1423c7b718f4SJunchao Zhang       Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1424c7b718f4SJunchao Zhang     });
142508bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1426394ed5ebSJunchao Zhang 
14279566063dSJacob Faibussowitsch   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
14289566063dSJacob Faibussowitsch   else PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
14293ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
143042550becSJunchao Zhang }
143142550becSJunchao Zhang 
1432d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat A)
1433d71ae5a4SJacob Faibussowitsch {
1434076ba34aSJunchao Zhang   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
1435076ba34aSJunchao Zhang 
14368c3ff71bSJunchao Zhang   PetscFunctionBegin;
1437076ba34aSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_KOKKOS; /* We do not really use this flag */
14386f3d89d0SStefano Zampini   A->boundtocpu  = PETSC_FALSE;
14396f3d89d0SStefano Zampini 
14408c3ff71bSJunchao Zhang   A->ops->assemblyend               = MatAssemblyEnd_SeqAIJKokkos;
14418c3ff71bSJunchao Zhang   A->ops->destroy                   = MatDestroy_SeqAIJKokkos;
14428c3ff71bSJunchao Zhang   A->ops->duplicate                 = MatDuplicate_SeqAIJKokkos;
1443a587d139SMark   A->ops->axpy                      = MatAXPY_SeqAIJKokkos;
1444f0cf5187SStefano Zampini   A->ops->scale                     = MatScale_SeqAIJKokkos;
1445a587d139SMark   A->ops->zeroentries               = MatZeroEntries_SeqAIJKokkos;
1446076ba34aSJunchao Zhang   A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJKokkos;
14478c3ff71bSJunchao Zhang   A->ops->mult                      = MatMult_SeqAIJKokkos;
14488c3ff71bSJunchao Zhang   A->ops->multadd                   = MatMultAdd_SeqAIJKokkos;
14498c3ff71bSJunchao Zhang   A->ops->multtranspose             = MatMultTranspose_SeqAIJKokkos;
14508c3ff71bSJunchao Zhang   A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJKokkos;
14518c3ff71bSJunchao Zhang   A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJKokkos;
14528c3ff71bSJunchao Zhang   A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJKokkos;
1453076ba34aSJunchao Zhang   A->ops->productnumeric            = MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos;
14540ecb592aSJunchao Zhang   A->ops->transpose                 = MatTranspose_SeqAIJKokkos;
1455152b3e56SJunchao Zhang   A->ops->setoption                 = MatSetOption_SeqAIJKokkos;
1456f78ce678SMark Adams   A->ops->getdiagonal               = MatGetDiagonal_SeqAIJKokkos;
1457f4747e26SJunchao Zhang   A->ops->shift                     = MatShift_SeqAIJKokkos;
1458f4747e26SJunchao Zhang   A->ops->diagonalset               = MatDiagonalSet_SeqAIJKokkos;
1459f4747e26SJunchao Zhang   A->ops->diagonalscale             = MatDiagonalScale_SeqAIJKokkos;
1460076ba34aSJunchao Zhang   a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJKokkos;
1461076ba34aSJunchao Zhang   a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJKokkos;
1462076ba34aSJunchao Zhang   a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJKokkos;
1463076ba34aSJunchao Zhang   a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJKokkos;
1464076ba34aSJunchao Zhang   a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJKokkos;
1465076ba34aSJunchao Zhang   a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJKokkos;
14667ee59b9bSJunchao Zhang   a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJKokkos;
146742550becSJunchao Zhang 
14689566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJKokkos));
14699566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJKokkos));
147057761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
147157761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijkokkos_hypre_C", MatConvert_AIJ_HYPRE));
147257761e9aSJunchao Zhang #endif
14733ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1474076ba34aSJunchao Zhang }
1475076ba34aSJunchao Zhang 
14769d13fa56SJunchao Zhang /*
14779d13fa56SJunchao Zhang    Extract the (prescribled) diagonal blocks of the matrix and then invert them
14789d13fa56SJunchao Zhang 
14799d13fa56SJunchao Zhang   Input Parameters:
14809d13fa56SJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
14819d13fa56SJunchao Zhang .  bs      - block sizes in 'csr' format, i.e., the i-th block has size bs(i+1) - bs(i)
14829d13fa56SJunchao Zhang .  bs2     - square of block sizes in 'csr' format, i.e., the i-th block should be stored at offset bs2(i) in diagVal[]
14839d13fa56SJunchao Zhang .  blkMap  - map row ids to block ids, i.e., row i belongs to the block blkMap(i)
14849d13fa56SJunchao Zhang -  work    - a pre-allocated work buffer (as big as diagVal) for use by this routine
14859d13fa56SJunchao Zhang 
14869d13fa56SJunchao Zhang   Output Parameter:
14879d13fa56SJunchao Zhang .  diagVal - the (pre-allocated) buffer to store the inverted blocks (each block is stored in column-major order)
14889d13fa56SJunchao Zhang */
14899d13fa56SJunchao Zhang PETSC_INTERN PetscErrorCode MatInvertVariableBlockDiagonal_SeqAIJKokkos(Mat A, const PetscIntKokkosView &bs, const PetscIntKokkosView &bs2, const PetscIntKokkosView &blkMap, PetscScalarKokkosView &work, PetscScalarKokkosView &diagVal)
14909d13fa56SJunchao Zhang {
14919d13fa56SJunchao Zhang   Mat_SeqAIJKokkos *akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
14929d13fa56SJunchao Zhang   PetscInt          nblocks = bs.extent(0) - 1;
14939d13fa56SJunchao Zhang 
14949d13fa56SJunchao Zhang   PetscFunctionBegin;
14959d13fa56SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A)); // Since we'll access A's value on device
14969d13fa56SJunchao Zhang 
14979d13fa56SJunchao Zhang   // Pull out the diagonal blocks of the matrix and then invert the blocks
14989d13fa56SJunchao Zhang   auto Aa    = akok->a_dual.view_device();
14999d13fa56SJunchao Zhang   auto Ai    = akok->i_dual.view_device();
15009d13fa56SJunchao Zhang   auto Aj    = akok->j_dual.view_device();
15019d13fa56SJunchao Zhang   auto Adiag = akok->diag_dual.view_device();
15029d13fa56SJunchao Zhang   // TODO: how to tune the team size?
150345402d8aSJunchao Zhang #if defined(KOKKOS_ENABLE_UNIFIED_MEMORY)
15049d13fa56SJunchao Zhang   auto ts = Kokkos::AUTO();
15059d13fa56SJunchao Zhang #else
15069d13fa56SJunchao Zhang   auto ts = 16; // improved performance 30% over Kokkos::AUTO() with CUDA, but failed with "Kokkos::abort: Requested Team Size is too large!" on CPUs
15079d13fa56SJunchao Zhang #endif
15089d13fa56SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1509d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), nblocks, ts), KOKKOS_LAMBDA(const KokkosTeamMemberType &teamMember) {
15109d13fa56SJunchao Zhang       const PetscInt bid    = teamMember.league_rank();                                                   // block id
15119d13fa56SJunchao Zhang       const PetscInt rstart = bs(bid);                                                                    // this block starts from this row
15129d13fa56SJunchao Zhang       const PetscInt m      = bs(bid + 1) - bs(bid);                                                      // size of this block
15139d13fa56SJunchao Zhang       const auto    &B      = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft>(&diagVal(bs2(bid)), m, m); // column-major order
15149d13fa56SJunchao Zhang       const auto    &W      = PetscScalarKokkosView(&work(bs2(bid)), m * m);
15159d13fa56SJunchao Zhang 
15169d13fa56SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(teamMember, m), [=](const PetscInt &r) { // r-th row in B
15179d13fa56SJunchao Zhang         PetscInt i = rstart + r;                                                            // i-th row in A
15189d13fa56SJunchao Zhang 
15199d13fa56SJunchao Zhang         if (Ai(i) <= Adiag(i) && Adiag(i) < Ai(i + 1)) { // if the diagonal exists (common case)
15209d13fa56SJunchao Zhang           PetscInt first = Adiag(i) - r;                 // we start to check nonzeros from here along this row
15219d13fa56SJunchao Zhang 
15229d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) {                   // walk n steps to see what column indices we will meet
15239d13fa56SJunchao Zhang             if (first + c < Ai(i) || first + c >= Ai(i + 1)) { // this entry (first+c) is out of range of this row, in other words, its value is zero
15249d13fa56SJunchao Zhang               B(r, c) = 0.0;
15259d13fa56SJunchao Zhang             } else if (Aj(first + c) == rstart + c) { // this entry is right on the (rstart+c) column
15269d13fa56SJunchao Zhang               B(r, c) = Aa(first + c);
15279d13fa56SJunchao Zhang             } else { // this entry does not show up in the CSR
15289d13fa56SJunchao Zhang               B(r, c) = 0.0;
15299d13fa56SJunchao Zhang             }
15309d13fa56SJunchao Zhang           }
15319d13fa56SJunchao Zhang         } else { // rare case that the diagonal does not exist
15329d13fa56SJunchao Zhang           const PetscInt begin = Ai(i);
15339d13fa56SJunchao Zhang           const PetscInt end   = Ai(i + 1);
15349d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) B(r, c) = 0.0;
15359d13fa56SJunchao Zhang           for (PetscInt j = begin; j < end; j++) { // scan the whole row; could use binary search but this is a rare case so we did not.
15369d13fa56SJunchao Zhang             if (rstart <= Aj(j) && Aj(j) < rstart + m) B(r, Aj(j) - rstart) = Aa(j);
15379d13fa56SJunchao Zhang             else if (Aj(j) >= rstart + m) break;
15389d13fa56SJunchao Zhang           }
15399d13fa56SJunchao Zhang         }
15409d13fa56SJunchao Zhang       });
15419d13fa56SJunchao Zhang 
15429d13fa56SJunchao Zhang       // LU-decompose B (w/o pivoting) and then invert B
15439d13fa56SJunchao Zhang       KokkosBatched::TeamLU<KokkosTeamMemberType, KokkosBatched::Algo::LU::Unblocked>::invoke(teamMember, B, 0.0);
15449d13fa56SJunchao Zhang       KokkosBatched::TeamInverseLU<KokkosTeamMemberType, KokkosBatched::Algo::InverseLU::Unblocked>::invoke(teamMember, B, W);
15459d13fa56SJunchao Zhang     }));
15469d13fa56SJunchao Zhang   // PetscLogGpuFlops() is done in the caller PCSetUp_VPBJacobi_Kokkos as we don't want to compute the flops in kernels
15479d13fa56SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15489d13fa56SJunchao Zhang }
15499d13fa56SJunchao Zhang 
1550d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSetSeqAIJKokkosWithCSRMatrix(Mat A, Mat_SeqAIJKokkos *akok)
1551d71ae5a4SJacob Faibussowitsch {
1552076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
1553076ba34aSJunchao Zhang   PetscInt    i, m, n;
15544df4a32cSJunchao Zhang   auto        exec = PetscGetKokkosExecutionSpace();
1555076ba34aSJunchao Zhang 
1556076ba34aSJunchao Zhang   PetscFunctionBegin;
15575f80ce2aSJacob Faibussowitsch   PetscCheck(!A->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A->spptr is supposed to be empty");
1558076ba34aSJunchao Zhang 
1559076ba34aSJunchao Zhang   m = akok->nrows();
1560076ba34aSJunchao Zhang   n = akok->ncols();
15619566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(A, m, n, m, n));
15629566063dSJacob Faibussowitsch   PetscCall(MatSetType(A, MATSEQAIJKOKKOS));
1563076ba34aSJunchao Zhang 
1564076ba34aSJunchao Zhang   /* Set up data structures of A as a MATSEQAIJ */
15659566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(A, MAT_SKIP_ALLOCATION, NULL));
156657508eceSPierre Jolivet   aseq = (Mat_SeqAIJ *)A->data;
1567076ba34aSJunchao Zhang 
1568e36ced11SJunchao Zhang   PetscCallCXX(akok->i_dual.sync_host(exec)); /* We always need sync'ed i, j on host */
1569e36ced11SJunchao Zhang   PetscCallCXX(akok->j_dual.sync_host(exec));
1570e36ced11SJunchao Zhang   PetscCallCXX(exec.fence());
1571076ba34aSJunchao Zhang 
1572076ba34aSJunchao Zhang   aseq->i       = akok->i_host_data();
1573076ba34aSJunchao Zhang   aseq->j       = akok->j_host_data();
1574076ba34aSJunchao Zhang   aseq->a       = akok->a_host_data();
1575076ba34aSJunchao Zhang   aseq->nonew   = -1; /*this indicates that inserting a new value in the matrix that generates a new nonzero is an error*/
1576076ba34aSJunchao Zhang   aseq->free_a  = PETSC_FALSE;
1577076ba34aSJunchao Zhang   aseq->free_ij = PETSC_FALSE;
1578076ba34aSJunchao Zhang   aseq->nz      = akok->nnz();
1579076ba34aSJunchao Zhang   aseq->maxnz   = aseq->nz;
1580076ba34aSJunchao Zhang 
15819566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->imax));
15829566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->ilen));
1583ad540459SPierre Jolivet   for (i = 0; i < m; i++) aseq->ilen[i] = aseq->imax[i] = aseq->i[i + 1] - aseq->i[i];
1584076ba34aSJunchao Zhang 
1585076ba34aSJunchao Zhang   /* It is critical to set the nonzerostate, as we use it to check if sparsity pattern (hence data) has changed on host in MatAssemblyEnd */
1586076ba34aSJunchao Zhang   akok->nonzerostate = A->nonzerostate;
1587ff751488SJunchao Zhang   A->spptr           = akok; /* Set A->spptr before MatAssembly so that A->spptr won't be allocated again there */
15889566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
15899566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
15903ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1591076ba34aSJunchao Zhang }
1592076ba34aSJunchao Zhang 
15930e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetKokkosCsrMatrix(Mat A, KokkosCsrMatrix *csr)
15940e3ece09SJunchao Zhang {
15950e3ece09SJunchao Zhang   PetscFunctionBegin;
15960e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
15970e3ece09SJunchao Zhang   *csr = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
15980e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15990e3ece09SJunchao Zhang }
16000e3ece09SJunchao Zhang 
16010e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithKokkosCsrMatrix(MPI_Comm comm, KokkosCsrMatrix csr, Mat *A)
16020e3ece09SJunchao Zhang {
16030e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok;
16044d86920dSPierre Jolivet 
16050e3ece09SJunchao Zhang   PetscFunctionBegin;
16060e3ece09SJunchao Zhang   PetscCallCXX(akok = new Mat_SeqAIJKokkos(csr));
16070e3ece09SJunchao Zhang   PetscCall(MatCreate(comm, A));
16080e3ece09SJunchao Zhang   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
16090e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
16100e3ece09SJunchao Zhang }
16110e3ece09SJunchao Zhang 
1612076ba34aSJunchao Zhang /* Crete a SEQAIJKOKKOS matrix with a Mat_SeqAIJKokkos data structure
1613076ba34aSJunchao Zhang 
1614076ba34aSJunchao Zhang    Note we have names like MatSeqAIJSetPreallocationCSR, so I use capitalized CSR
1615076ba34aSJunchao Zhang  */
1616d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithCSRMatrix(MPI_Comm comm, Mat_SeqAIJKokkos *akok, Mat *A)
1617d71ae5a4SJacob Faibussowitsch {
1618076ba34aSJunchao Zhang   PetscFunctionBegin;
16199566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
16209566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
16213ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16228c3ff71bSJunchao Zhang }
16238c3ff71bSJunchao Zhang 
1624152b3e56SJunchao Zhang /*@C
162511a5261eSBarry Smith   MatCreateSeqAIJKokkos - Creates a sparse matrix in `MATSEQAIJKOKKOS` (compressed row) format
16268c3ff71bSJunchao Zhang   (the default parallel PETSc format). This matrix will ultimately be handled by
162720f4b53cSBarry Smith   Kokkos for calculations.
16288c3ff71bSJunchao Zhang 
16298c3ff71bSJunchao Zhang   Collective
16308c3ff71bSJunchao Zhang 
16318c3ff71bSJunchao Zhang   Input Parameters:
163211a5261eSBarry Smith + comm - MPI communicator, set to `PETSC_COMM_SELF`
16338c3ff71bSJunchao Zhang . m    - number of rows
16348c3ff71bSJunchao Zhang . n    - number of columns
163520f4b53cSBarry Smith . nz   - number of nonzeros per row (same for all rows), ignored if `nnz` is provided
163620f4b53cSBarry Smith - nnz  - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL`
16378c3ff71bSJunchao Zhang 
16388c3ff71bSJunchao Zhang   Output Parameter:
16398c3ff71bSJunchao Zhang . A - the matrix
16408c3ff71bSJunchao Zhang 
16412ef1f0ffSBarry Smith   Level: intermediate
16422ef1f0ffSBarry Smith 
16432ef1f0ffSBarry Smith   Notes:
164411a5261eSBarry Smith   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
16458c3ff71bSJunchao Zhang   MatXXXXSetPreallocation() paradgm instead of this routine directly.
164611a5261eSBarry Smith   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
16478c3ff71bSJunchao Zhang 
164811a5261eSBarry Smith   The AIJ format, also called
16492ef1f0ffSBarry Smith   compressed row storage, is fully compatible with standard Fortran
16508c3ff71bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
165120f4b53cSBarry Smith   either one (as in Fortran) or zero.
16528c3ff71bSJunchao Zhang 
16532ef1f0ffSBarry Smith   Specify the preallocated storage with either `nz` or `nnz` (not both).
16542ef1f0ffSBarry Smith   Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
16552ef1f0ffSBarry Smith   allocation.
16568c3ff71bSJunchao Zhang 
1657fe59aa6dSJacob Faibussowitsch .seealso: [](ch_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
16588c3ff71bSJunchao Zhang @*/
1659d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateSeqAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
1660d71ae5a4SJacob Faibussowitsch {
16618c3ff71bSJunchao Zhang   PetscFunctionBegin;
16629566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
16639566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
16649566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, m, n));
16659566063dSJacob Faibussowitsch   PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
16669566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
16673ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16688c3ff71bSJunchao Zhang }
1669930e68a5SMark Adams 
1670aac854edSJunchao Zhang // After matrix numeric factorization, there are still steps to do before triangular solve can be called.
1671aac854edSJunchao Zhang // For example, for transpose solve, we might need to compute the transpose matrices if the solver does not support it (such as KK, while cusparse does).
1672aac854edSJunchao Zhang // In cusparse, one has to call cusparseSpSV_analysis() with updated triangular matrix values before calling cusparseSpSV_solve().
1673aac854edSJunchao Zhang // Simiarily, in KK sptrsv_symbolic() has to be called before sptrsv_solve(). We put these steps in MatSeqAIJKokkos{Transpose}SolveCheck.
1674aac854edSJunchao Zhang static PetscErrorCode MatSeqAIJKokkosSolveCheck(Mat A)
1675d71ae5a4SJacob Faibussowitsch {
167686a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors   = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1677aac854edSJunchao Zhang   const PetscBool             has_lower = factors->iL_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // false with Choleksy
1678aac854edSJunchao Zhang   const PetscBool             has_upper = factors->iU_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // true with LU and Choleksy
167986a27549SJunchao Zhang 
168086a27549SJunchao Zhang   PetscFunctionBegin;
1681aac854edSJunchao Zhang   if (!factors->sptrsv_symbolic_completed) { // If sptrsv_symbolic was not called yet
1682aac854edSJunchao Zhang     if (has_upper) PetscCallCXX(sptrsv_symbolic(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d));
1683aac854edSJunchao Zhang     if (has_lower) PetscCallCXX(sptrsv_symbolic(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d));
168486a27549SJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_TRUE;
168586a27549SJunchao Zhang   }
16863ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
168786a27549SJunchao Zhang }
168886a27549SJunchao Zhang 
1689d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosTransposeSolveCheck(Mat A)
1690d71ae5a4SJacob Faibussowitsch {
1691aac854edSJunchao Zhang   const PetscInt              n         = A->rmap->n;
169286a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors   = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1693aac854edSJunchao Zhang   const PetscBool             has_lower = factors->iL_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // false with Choleksy
1694aac854edSJunchao Zhang   const PetscBool             has_upper = factors->iU_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // true with LU or Choleksy
169586a27549SJunchao Zhang 
169686a27549SJunchao Zhang   PetscFunctionBegin;
1697aac854edSJunchao Zhang   if (!factors->transpose_updated) {
1698aac854edSJunchao Zhang     if (has_upper) {
1699aac854edSJunchao Zhang       if (!factors->iUt_d.extent(0)) {                                 // Allocate Ut on device if not yet
1700aac854edSJunchao Zhang         factors->iUt_d = MatRowMapKokkosView("factors->iUt_d", n + 1); // KK requires this view to be initialized to 0 to call transpose_matrix
17017b8d4ba6SJunchao Zhang         factors->jUt_d = MatColIdxKokkosView(NoInit("factors->jUt_d"), factors->jU_d.extent(0));
17027b8d4ba6SJunchao Zhang         factors->aUt_d = MatScalarKokkosView(NoInit("factors->aUt_d"), factors->aU_d.extent(0));
1703aac854edSJunchao Zhang       }
170486a27549SJunchao Zhang 
1705aac854edSJunchao Zhang       if (factors->iU_h.extent(0)) { // If U is on host (factorization was done on host), we also compute the transpose on host
1706aac854edSJunchao Zhang         if (!factors->U) {
1707aac854edSJunchao Zhang           Mat_SeqAIJ *seq;
170886a27549SJunchao Zhang 
1709aac854edSJunchao Zhang           PetscCall(MatCreateSeqAIJWithArrays(PETSC_COMM_SELF, n, n, factors->iU_h.data(), factors->jU_h.data(), factors->aU_h.data(), &factors->U));
1710aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->U, MAT_INITIAL_MATRIX, &factors->Ut));
171186a27549SJunchao Zhang 
1712aac854edSJunchao Zhang           seq            = static_cast<Mat_SeqAIJ *>(factors->Ut->data);
1713aac854edSJunchao Zhang           factors->iUt_h = MatRowMapKokkosViewHost(seq->i, n + 1);
1714aac854edSJunchao Zhang           factors->jUt_h = MatColIdxKokkosViewHost(seq->j, seq->nz);
1715aac854edSJunchao Zhang           factors->aUt_h = MatScalarKokkosViewHost(seq->a, seq->nz);
1716aac854edSJunchao Zhang         } else {
1717aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->U, MAT_REUSE_MATRIX, &factors->Ut)); // Matrix Ut' data is aliased with {i, j, a}Ut_h
1718aac854edSJunchao Zhang         }
1719aac854edSJunchao Zhang         // Copy Ut from host to device
1720aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->iUt_d, factors->iUt_h));
1721aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->jUt_d, factors->jUt_h));
1722aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->aUt_d, factors->aUt_h));
1723aac854edSJunchao Zhang       } else { // If U was computed on device, we also compute the transpose there
1724aac854edSJunchao Zhang         // TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices. We have to sort the indices, until KK provides finer control options.
1725aac854edSJunchao Zhang         PetscCallCXX(transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iU_d,
1726aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->jU_d, factors->aU_d,
1727aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->iUt_d, factors->jUt_d,
1728aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->aUt_d));
1729aac854edSJunchao Zhang         PetscCallCXX(sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iUt_d, factors->jUt_d, factors->aUt_d));
1730aac854edSJunchao Zhang       }
1731aac854edSJunchao Zhang       PetscCallCXX(sptrsv_symbolic(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d));
1732aac854edSJunchao Zhang     }
1733aac854edSJunchao Zhang 
1734aac854edSJunchao Zhang     // do the same for L with LU
1735aac854edSJunchao Zhang     if (has_lower) {
1736aac854edSJunchao Zhang       if (!factors->iLt_d.extent(0)) {                                 // Allocate Lt on device if not yet
1737aac854edSJunchao Zhang         factors->iLt_d = MatRowMapKokkosView("factors->iLt_d", n + 1); // KK requires this view to be initialized to 0 to call transpose_matrix
1738aac854edSJunchao Zhang         factors->jLt_d = MatColIdxKokkosView(NoInit("factors->jLt_d"), factors->jL_d.extent(0));
1739aac854edSJunchao Zhang         factors->aLt_d = MatScalarKokkosView(NoInit("factors->aLt_d"), factors->aL_d.extent(0));
1740aac854edSJunchao Zhang       }
1741aac854edSJunchao Zhang 
1742aac854edSJunchao Zhang       if (factors->iL_h.extent(0)) { // If L is on host, we also compute the transpose on host
1743aac854edSJunchao Zhang         if (!factors->L) {
1744aac854edSJunchao Zhang           Mat_SeqAIJ *seq;
1745aac854edSJunchao Zhang 
1746aac854edSJunchao Zhang           PetscCall(MatCreateSeqAIJWithArrays(PETSC_COMM_SELF, n, n, factors->iL_h.data(), factors->jL_h.data(), factors->aL_h.data(), &factors->L));
1747aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->L, MAT_INITIAL_MATRIX, &factors->Lt));
1748aac854edSJunchao Zhang 
1749aac854edSJunchao Zhang           seq            = static_cast<Mat_SeqAIJ *>(factors->Lt->data);
1750aac854edSJunchao Zhang           factors->iLt_h = MatRowMapKokkosViewHost(seq->i, n + 1);
1751aac854edSJunchao Zhang           factors->jLt_h = MatColIdxKokkosViewHost(seq->j, seq->nz);
1752aac854edSJunchao Zhang           factors->aLt_h = MatScalarKokkosViewHost(seq->a, seq->nz);
1753aac854edSJunchao Zhang         } else {
1754aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->L, MAT_REUSE_MATRIX, &factors->Lt)); // Matrix Lt' data is aliased with {i, j, a}Lt_h
1755aac854edSJunchao Zhang         }
1756aac854edSJunchao Zhang         // Copy Lt from host to device
1757aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->iLt_d, factors->iLt_h));
1758aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->jLt_d, factors->jLt_h));
1759aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->aLt_d, factors->aLt_h));
1760aac854edSJunchao Zhang       } else { // If L was computed on device, we also compute the transpose there
1761aac854edSJunchao Zhang         // TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices. We have to sort the indices, until KK provides finer control options.
1762aac854edSJunchao Zhang         PetscCallCXX(transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iL_d,
1763aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->jL_d, factors->aL_d,
1764aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->iLt_d, factors->jLt_d,
1765aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->aLt_d));
1766aac854edSJunchao Zhang         PetscCallCXX(sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iLt_d, factors->jLt_d, factors->aLt_d));
1767aac854edSJunchao Zhang       }
1768aac854edSJunchao Zhang       PetscCallCXX(sptrsv_symbolic(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d));
1769aac854edSJunchao Zhang     }
1770aac854edSJunchao Zhang 
177186a27549SJunchao Zhang     factors->transpose_updated = PETSC_TRUE;
177286a27549SJunchao Zhang   }
17733ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
177486a27549SJunchao Zhang }
177586a27549SJunchao Zhang 
1776aac854edSJunchao Zhang // Solve Ax = b, with RAR = U^T D U, where R is the row (and col) permutation matrix on A.
1777aac854edSJunchao Zhang // R is represented by rowperm in factors. If R is identity (i.e, no reordering), then rowperm is empty.
1778aac854edSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJKokkos_Cholesky(Mat A, Vec bb, Vec xx)
1779d71ae5a4SJacob Faibussowitsch {
1780aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
178186a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1782aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1783aac854edSJunchao Zhang   PetscScalarKokkosView       D       = factors->D_d;
1784aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1785aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1786aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1787aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm  = factors->rowperm;
1788aac854edSJunchao Zhang   PetscBool                   identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
178986a27549SJunchao Zhang 
179086a27549SJunchao Zhang   PetscFunctionBegin;
17919566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1792aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSolveCheck(A));          // for UX = T
1793aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A)); // for U^T Y = B
1794aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1795aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
1796aac854edSJunchao Zhang 
1797aac854edSJunchao Zhang   // Solve U^T Y = B
1798aac854edSJunchao Zhang   if (identity) { // Reorder b with the row permutation
1799aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1800aac854edSJunchao Zhang     Y = factors->workVector;
1801aac854edSJunchao Zhang   } else {
1802aac854edSJunchao Zhang     B = factors->workVector;
1803aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(rowperm(i)); }));
1804aac854edSJunchao Zhang     Y = x;
1805aac854edSJunchao Zhang   }
1806aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, B, Y));
1807aac854edSJunchao Zhang 
1808aac854edSJunchao Zhang   // Solve diag(D) Y' = Y.
1809aac854edSJunchao Zhang   // Actually just do Y' = Y*D since D is already inverted in MatCholeskyFactorNumeric_SeqAIJ(). It is basically a vector element-wise multiplication.
1810aac854edSJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { Y(i) = Y(i) * D(i); }));
1811aac854edSJunchao Zhang 
1812aac854edSJunchao Zhang   // Solve UX = Y
1813aac854edSJunchao Zhang   if (identity) {
1814aac854edSJunchao Zhang     X = x;
1815aac854edSJunchao Zhang   } else {
1816aac854edSJunchao Zhang     X = factors->workVector; // B is not needed anymore
1817aac854edSJunchao Zhang   }
1818aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, Y, X));
1819aac854edSJunchao Zhang 
1820aac854edSJunchao Zhang   // Reorder X with the inverse column (row) permutation
1821aac854edSJunchao Zhang   if (!identity) {
1822aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(rowperm(i)) = X(i); }));
1823aac854edSJunchao Zhang   }
1824aac854edSJunchao Zhang 
1825aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1826aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
18279566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
18283ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
182986a27549SJunchao Zhang }
183086a27549SJunchao Zhang 
1831aac854edSJunchao Zhang // Solve Ax = b, with RAC = LU, where R and C are row and col permutation matrices on A respectively.
1832aac854edSJunchao Zhang // R and C are represented by rowperm and colperm in factors.
1833aac854edSJunchao Zhang // If R or C is identity (i.e, no reordering), then rowperm or colperm is empty.
1834aac854edSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJKokkos_LU(Mat A, Vec bb, Vec xx)
1835d71ae5a4SJacob Faibussowitsch {
1836aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
183786a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1838aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1839aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1840aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1841aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1842aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm      = factors->rowperm;
1843aac854edSJunchao Zhang   PetscIntKokkosView         &colperm      = factors->colperm;
1844aac854edSJunchao Zhang   PetscBool                   row_identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1845aac854edSJunchao Zhang   PetscBool                   col_identity = colperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
184686a27549SJunchao Zhang 
184786a27549SJunchao Zhang   PetscFunctionBegin;
18489566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1849aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSolveCheck(A));
1850aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1851aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
185286a27549SJunchao Zhang 
1853aac854edSJunchao Zhang   // Solve L Y = B (i.e., L (U C^- x) = R b).  R b indicates applying the row permutation on b.
1854aac854edSJunchao Zhang   if (row_identity) {
1855aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1856aac854edSJunchao Zhang     Y = factors->workVector;
1857aac854edSJunchao Zhang   } else {
1858aac854edSJunchao Zhang     B = factors->workVector;
1859aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(rowperm(i)); }));
1860aac854edSJunchao Zhang     Y = x;
1861aac854edSJunchao Zhang   }
1862aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khL, factors->iL_d, factors->jL_d, factors->aL_d, B, Y));
1863aac854edSJunchao Zhang 
1864aac854edSJunchao Zhang   // Solve U C^- x = Y
1865aac854edSJunchao Zhang   if (col_identity) {
1866aac854edSJunchao Zhang     X = x;
1867aac854edSJunchao Zhang   } else {
1868aac854edSJunchao Zhang     X = factors->workVector;
1869aac854edSJunchao Zhang   }
1870aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, Y, X));
1871aac854edSJunchao Zhang 
1872aac854edSJunchao Zhang   // x = C X; Reorder X with the inverse col permutation
1873aac854edSJunchao Zhang   if (!col_identity) {
1874aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(colperm(i)) = X(i); }));
1875aac854edSJunchao Zhang   }
1876aac854edSJunchao Zhang 
1877aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1878aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
18799566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
18803ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
188186a27549SJunchao Zhang }
188286a27549SJunchao Zhang 
1883aac854edSJunchao Zhang // Solve A^T x = b, with RAC = LU, where R and C are row and col permutation matrices on A respectively.
1884aac854edSJunchao Zhang // R and C are represented by rowperm and colperm in factors.
1885aac854edSJunchao Zhang // If R or C is identity (i.e, no reordering), then rowperm or colperm is empty.
1886aac854edSJunchao Zhang // A = R^-1 L U C^-1, so A^T = C^-T U^T L^T R^-T. But since C^- = C^T, R^- = R^T, we have A^T = C U^T L^T R.
1887aac854edSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJKokkos_LU(Mat A, Vec bb, Vec xx)
1888aac854edSJunchao Zhang {
1889aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
1890aac854edSJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1891aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1892aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1893aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1894aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1895aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm      = factors->rowperm;
1896aac854edSJunchao Zhang   PetscIntKokkosView         &colperm      = factors->colperm;
1897aac854edSJunchao Zhang   PetscBool                   row_identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1898aac854edSJunchao Zhang   PetscBool                   col_identity = colperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1899aac854edSJunchao Zhang 
1900aac854edSJunchao Zhang   PetscFunctionBegin;
1901aac854edSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1902aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A)); // Update L^T, U^T if needed, and do sptrsv symbolic for L^T, U^T
1903aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1904aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
1905aac854edSJunchao Zhang 
1906aac854edSJunchao Zhang   // Solve U^T Y = B (i.e., U^T (L^T R x) = C^- b).  Note C^- b = C^T b, which means applying the column permutation on b.
1907aac854edSJunchao Zhang   if (col_identity) { // Reorder b with the col permutation
1908aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1909aac854edSJunchao Zhang     Y = factors->workVector;
1910aac854edSJunchao Zhang   } else {
1911aac854edSJunchao Zhang     B = factors->workVector;
1912aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(colperm(i)); }));
1913aac854edSJunchao Zhang     Y = x;
1914aac854edSJunchao Zhang   }
1915aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, B, Y));
1916aac854edSJunchao Zhang 
1917aac854edSJunchao Zhang   // Solve L^T X = Y
1918aac854edSJunchao Zhang   if (row_identity) {
1919aac854edSJunchao Zhang     X = x;
1920aac854edSJunchao Zhang   } else {
1921aac854edSJunchao Zhang     X = factors->workVector;
1922aac854edSJunchao Zhang   }
1923aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d, Y, X));
1924aac854edSJunchao Zhang 
1925aac854edSJunchao Zhang   // x = R^- X = R^T X; Reorder X with the inverse row permutation
1926aac854edSJunchao Zhang   if (!row_identity) {
1927aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(rowperm(i)) = X(i); }));
1928aac854edSJunchao Zhang   }
1929aac854edSJunchao Zhang 
1930aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1931aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
1932aac854edSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1933aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1934aac854edSJunchao Zhang }
1935aac854edSJunchao Zhang 
1936aac854edSJunchao Zhang static PetscErrorCode MatLUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1937aac854edSJunchao Zhang {
1938aac854edSJunchao Zhang   PetscFunctionBegin;
1939aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncHost(A));
1940aac854edSJunchao Zhang   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
1941aac854edSJunchao Zhang 
1942aac854edSJunchao Zhang   if (!info->solveonhost) { // if solve on host, then we don't need to copy L, U to device
1943aac854edSJunchao Zhang     Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
1944aac854edSJunchao Zhang     Mat_SeqAIJ                 *b       = static_cast<Mat_SeqAIJ *>(B->data);
1945aac854edSJunchao Zhang     const PetscInt             *Bi = b->i, *Bj = b->j, *Bdiag = b->diag;
1946aac854edSJunchao Zhang     const MatScalar            *Ba = b->a;
1947aac854edSJunchao Zhang     PetscInt                    m = B->rmap->n, n = B->cmap->n;
1948aac854edSJunchao Zhang 
1949aac854edSJunchao Zhang     if (factors->iL_h.extent(0) == 0) { // Allocate memory and copy the L, U structure for the first time
1950aac854edSJunchao Zhang       // Allocate memory and copy the structure
1951aac854edSJunchao Zhang       factors->iL_h = MatRowMapKokkosViewHost(NoInit("iL_h"), m + 1);
1952aac854edSJunchao Zhang       factors->jL_h = MatColIdxKokkosViewHost(NoInit("jL_h"), (Bi[m] - Bi[0]) + m); // + the diagonal entries
1953aac854edSJunchao Zhang       factors->aL_h = MatScalarKokkosViewHost(NoInit("aL_h"), (Bi[m] - Bi[0]) + m);
1954aac854edSJunchao Zhang       factors->iU_h = MatRowMapKokkosViewHost(NoInit("iU_h"), m + 1);
1955aac854edSJunchao Zhang       factors->jU_h = MatColIdxKokkosViewHost(NoInit("jU_h"), (Bdiag[0] - Bdiag[m]));
1956aac854edSJunchao Zhang       factors->aU_h = MatScalarKokkosViewHost(NoInit("aU_h"), (Bdiag[0] - Bdiag[m]));
1957aac854edSJunchao Zhang 
1958aac854edSJunchao Zhang       PetscInt *Li = factors->iL_h.data();
1959aac854edSJunchao Zhang       PetscInt *Lj = factors->jL_h.data();
1960aac854edSJunchao Zhang       PetscInt *Ui = factors->iU_h.data();
1961aac854edSJunchao Zhang       PetscInt *Uj = factors->jU_h.data();
1962aac854edSJunchao Zhang 
1963aac854edSJunchao Zhang       Li[0] = Ui[0] = 0;
1964aac854edSJunchao Zhang       for (PetscInt i = 0; i < m; i++) {
1965aac854edSJunchao Zhang         PetscInt llen = Bi[i + 1] - Bi[i];       // exclusive of the diagonal entry
1966aac854edSJunchao Zhang         PetscInt ulen = Bdiag[i] - Bdiag[i + 1]; // inclusive of the diagonal entry
1967aac854edSJunchao Zhang 
1968aac854edSJunchao Zhang         PetscArraycpy(Lj + Li[i], Bj + Bi[i], llen); // entries of L on the left of the diagonal
1969aac854edSJunchao Zhang         Lj[Li[i] + llen] = i;                        // diagonal entry of L
1970aac854edSJunchao Zhang 
1971aac854edSJunchao Zhang         Uj[Ui[i]] = i;                                                  // diagonal entry of U
1972aac854edSJunchao Zhang         PetscArraycpy(Uj + Ui[i] + 1, Bj + Bdiag[i + 1] + 1, ulen - 1); // entries of U on  the right of the diagonal
1973aac854edSJunchao Zhang 
1974aac854edSJunchao Zhang         Li[i + 1] = Li[i] + llen + 1;
1975aac854edSJunchao Zhang         Ui[i + 1] = Ui[i] + ulen;
1976aac854edSJunchao Zhang       }
1977aac854edSJunchao Zhang 
1978aac854edSJunchao Zhang       factors->iL_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iL_h);
1979aac854edSJunchao Zhang       factors->jL_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jL_h);
1980aac854edSJunchao Zhang       factors->iU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iU_h);
1981aac854edSJunchao Zhang       factors->jU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jU_h);
1982aac854edSJunchao Zhang       factors->aL_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aL_h);
1983aac854edSJunchao Zhang       factors->aU_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aU_h);
1984aac854edSJunchao Zhang 
1985aac854edSJunchao Zhang       // Copy row/col permutation to device
1986aac854edSJunchao Zhang       IS        rowperm = ((Mat_SeqAIJ *)B->data)->row;
1987aac854edSJunchao Zhang       PetscBool row_identity;
1988aac854edSJunchao Zhang       PetscCall(ISIdentity(rowperm, &row_identity));
1989aac854edSJunchao Zhang       if (!row_identity) {
1990aac854edSJunchao Zhang         const PetscInt *ip;
1991aac854edSJunchao Zhang 
1992aac854edSJunchao Zhang         PetscCall(ISGetIndices(rowperm, &ip));
1993aac854edSJunchao Zhang         factors->rowperm = PetscIntKokkosView(NoInit("rowperm"), m);
1994aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->rowperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), m)));
1995aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(rowperm, &ip));
1996aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(m * sizeof(PetscInt)));
1997aac854edSJunchao Zhang       }
1998aac854edSJunchao Zhang 
1999aac854edSJunchao Zhang       IS        colperm = ((Mat_SeqAIJ *)B->data)->col;
2000aac854edSJunchao Zhang       PetscBool col_identity;
2001aac854edSJunchao Zhang       PetscCall(ISIdentity(colperm, &col_identity));
2002aac854edSJunchao Zhang       if (!col_identity) {
2003aac854edSJunchao Zhang         const PetscInt *ip;
2004aac854edSJunchao Zhang 
2005aac854edSJunchao Zhang         PetscCall(ISGetIndices(colperm, &ip));
2006aac854edSJunchao Zhang         factors->colperm = PetscIntKokkosView(NoInit("colperm"), n);
2007aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->colperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), n)));
2008aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(colperm, &ip));
2009aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));
2010aac854edSJunchao Zhang       }
2011aac854edSJunchao Zhang 
2012aac854edSJunchao Zhang       /* Create sptrsv handles for L, U and their transpose */
2013aac854edSJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2014aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
2015aac854edSJunchao Zhang #else
2016aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
2017aac854edSJunchao Zhang #endif
2018aac854edSJunchao Zhang       factors->khL.create_sptrsv_handle(sptrsv_alg, m, true /* L is lower tri */);
2019aac854edSJunchao Zhang       factors->khU.create_sptrsv_handle(sptrsv_alg, m, false /* U is not lower tri */);
2020aac854edSJunchao Zhang       factors->khLt.create_sptrsv_handle(sptrsv_alg, m, false /* L^T is not lower tri */);
2021aac854edSJunchao Zhang       factors->khUt.create_sptrsv_handle(sptrsv_alg, m, true /* U^T is lower tri */);
2022aac854edSJunchao Zhang     }
2023aac854edSJunchao Zhang 
2024aac854edSJunchao Zhang     // Copy the value
2025aac854edSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
2026aac854edSJunchao Zhang       PetscInt        llen = Bi[i + 1] - Bi[i];
2027aac854edSJunchao Zhang       PetscInt        ulen = Bdiag[i] - Bdiag[i + 1];
2028aac854edSJunchao Zhang       const PetscInt *Li   = factors->iL_h.data();
2029aac854edSJunchao Zhang       const PetscInt *Ui   = factors->iU_h.data();
2030aac854edSJunchao Zhang 
2031aac854edSJunchao Zhang       PetscScalar *La = factors->aL_h.data();
2032aac854edSJunchao Zhang       PetscScalar *Ua = factors->aU_h.data();
2033aac854edSJunchao Zhang 
2034aac854edSJunchao Zhang       PetscArraycpy(La + Li[i], Ba + Bi[i], llen); // entries of L
2035aac854edSJunchao Zhang       La[Li[i] + llen] = 1.0;                      // diagonal entry
2036aac854edSJunchao Zhang 
2037aac854edSJunchao Zhang       Ua[Ui[i]] = 1.0 / Ba[Bdiag[i]];                                 // diagonal entry
2038aac854edSJunchao Zhang       PetscArraycpy(Ua + Ui[i] + 1, Ba + Bdiag[i + 1] + 1, ulen - 1); // entries of U
2039aac854edSJunchao Zhang     }
2040aac854edSJunchao Zhang 
2041aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aL_d, factors->aL_h));
2042aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aU_d, factors->aU_h));
2043aac854edSJunchao Zhang     // Once the factors' values have changed, we need to update their transpose and redo sptrsv symbolic
2044aac854edSJunchao Zhang     factors->transpose_updated         = PETSC_FALSE;
2045aac854edSJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_FALSE;
2046aac854edSJunchao Zhang 
2047aac854edSJunchao Zhang     B->ops->solve          = MatSolve_SeqAIJKokkos_LU;
2048aac854edSJunchao Zhang     B->ops->solvetranspose = MatSolveTranspose_SeqAIJKokkos_LU;
2049aac854edSJunchao Zhang   }
2050aac854edSJunchao Zhang 
2051aac854edSJunchao Zhang   B->ops->matsolve          = NULL;
2052aac854edSJunchao Zhang   B->ops->matsolvetranspose = NULL;
2053aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2054aac854edSJunchao Zhang }
2055aac854edSJunchao Zhang 
2056aac854edSJunchao Zhang static PetscErrorCode MatILUFactorNumeric_SeqAIJKokkos_ILU0(Mat B, Mat A, const MatFactorInfo *info)
2057d71ae5a4SJacob Faibussowitsch {
205886a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
205986a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
206086a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
206186a27549SJunchao Zhang 
206286a27549SJunchao Zhang   PetscFunctionBegin;
20639566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
2064aac854edSJunchao Zhang   PetscCheck(!info->factoronhost, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "MatFactorInfo.factoronhost should be false");
20659566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
2066076ba34aSJunchao Zhang 
2067076ba34aSJunchao Zhang   auto a_d = aijkok->a_dual.view_device();
2068076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
2069076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
2070076ba34aSJunchao Zhang 
2071aac854edSJunchao Zhang   PetscCallCXX(spiluk_numeric(&factors->kh, fill_lev, i_d, j_d, a_d, factors->iL_d, factors->jL_d, factors->aL_d, factors->iU_d, factors->jU_d, factors->aU_d));
207286a27549SJunchao Zhang 
207386a27549SJunchao Zhang   B->assembled              = PETSC_TRUE;
207486a27549SJunchao Zhang   B->preallocated           = PETSC_TRUE;
2075aac854edSJunchao Zhang   B->ops->solve             = MatSolve_SeqAIJKokkos_LU;
2076aac854edSJunchao Zhang   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJKokkos_LU;
207786a27549SJunchao Zhang   B->ops->matsolve          = NULL;
207886a27549SJunchao Zhang   B->ops->matsolvetranspose = NULL;
207986a27549SJunchao Zhang 
208086a27549SJunchao Zhang   /* Once the factors' value changed, we need to update their transpose and sptrsv handle */
208186a27549SJunchao Zhang   factors->transpose_updated         = PETSC_FALSE;
208286a27549SJunchao Zhang   factors->sptrsv_symbolic_completed = PETSC_FALSE;
2083eeadb341SJunchao Zhang   /* TODO: log flops, but how to know that? */
20849566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
20853ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
208686a27549SJunchao Zhang }
208786a27549SJunchao Zhang 
2088aac854edSJunchao Zhang // Use KK's spiluk_symbolic() to do ILU0 symbolic factorization, with no row/col reordering
2089aac854edSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos_ILU0(Mat B, Mat A, IS, IS, const MatFactorInfo *info)
2090d71ae5a4SJacob Faibussowitsch {
209186a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok;
209286a27549SJunchao Zhang   Mat_SeqAIJ                 *b;
209386a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
209486a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
209586a27549SJunchao Zhang   PetscInt                    nnzA     = ((Mat_SeqAIJ *)A->data)->nz, nnzL, nnzU;
209686a27549SJunchao Zhang   PetscInt                    n        = A->rmap->n;
209786a27549SJunchao Zhang 
209886a27549SJunchao Zhang   PetscFunctionBegin;
2099aac854edSJunchao Zhang   PetscCheck(!info->factoronhost, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "MatFactorInfo's factoronhost should be false as we are doing it on device right now");
21009566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
210186a27549SJunchao Zhang 
210286a27549SJunchao Zhang   /* Create a spiluk handle and then do symbolic factorization */
210386a27549SJunchao Zhang   nnzL = nnzU = PetscRealIntMultTruncate(info->fill, nnzA);
2104aac854edSJunchao Zhang   factors->kh.create_spiluk_handle(SPILUKAlgorithm::SEQLVLSCHD_TP1, n, nnzL, nnzU);
210586a27549SJunchao Zhang 
210686a27549SJunchao Zhang   auto spiluk_handle = factors->kh.get_spiluk_handle();
210786a27549SJunchao Zhang 
210886a27549SJunchao Zhang   Kokkos::realloc(factors->iL_d, n + 1); /* Free old arrays and realloc */
210986a27549SJunchao Zhang   Kokkos::realloc(factors->jL_d, spiluk_handle->get_nnzL());
211086a27549SJunchao Zhang   Kokkos::realloc(factors->iU_d, n + 1);
211186a27549SJunchao Zhang   Kokkos::realloc(factors->jU_d, spiluk_handle->get_nnzU());
211286a27549SJunchao Zhang 
211386a27549SJunchao Zhang   aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
2114076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
2115076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
2116aac854edSJunchao Zhang   PetscCallCXX(spiluk_symbolic(&factors->kh, fill_lev, i_d, j_d, factors->iL_d, factors->jL_d, factors->iU_d, factors->jU_d));
211786a27549SJunchao Zhang   /* TODO: if spiluk_symbolic is asynchronous, do we need to sync before calling get_nnzL()? */
211886a27549SJunchao Zhang 
211986a27549SJunchao Zhang   Kokkos::resize(factors->jL_d, spiluk_handle->get_nnzL()); /* Shrink or expand, and retain old value */
212086a27549SJunchao Zhang   Kokkos::resize(factors->jU_d, spiluk_handle->get_nnzU());
212186a27549SJunchao Zhang   Kokkos::realloc(factors->aL_d, spiluk_handle->get_nnzL()); /* No need to retain old value */
212286a27549SJunchao Zhang   Kokkos::realloc(factors->aU_d, spiluk_handle->get_nnzU());
212386a27549SJunchao Zhang 
212486a27549SJunchao Zhang   /* TODO: add options to select sptrsv algorithms */
212586a27549SJunchao Zhang   /* Create sptrsv handles for L, U and their transpose */
212686a27549SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2127aac854edSJunchao Zhang   auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
212886a27549SJunchao Zhang #else
2129aac854edSJunchao Zhang   auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
213086a27549SJunchao Zhang #endif
213186a27549SJunchao Zhang 
213286a27549SJunchao Zhang   factors->khL.create_sptrsv_handle(sptrsv_alg, n, true /* L is lower tri */);
213386a27549SJunchao Zhang   factors->khU.create_sptrsv_handle(sptrsv_alg, n, false /* U is not lower tri */);
213486a27549SJunchao Zhang   factors->khLt.create_sptrsv_handle(sptrsv_alg, n, false /* L^T is not lower tri */);
213586a27549SJunchao Zhang   factors->khUt.create_sptrsv_handle(sptrsv_alg, n, true /* U^T is lower tri */);
213686a27549SJunchao Zhang 
213786a27549SJunchao Zhang   /* Fill fields of the factor matrix B */
21389566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(B, MAT_SKIP_ALLOCATION, NULL));
213986a27549SJunchao Zhang   b     = (Mat_SeqAIJ *)B->data;
214086a27549SJunchao Zhang   b->nz = b->maxnz          = spiluk_handle->get_nnzL() + spiluk_handle->get_nnzU();
214186a27549SJunchao Zhang   B->info.fill_ratio_given  = info->fill;
2142a1e4e190SStefano Zampini   B->info.fill_ratio_needed = nnzA > 0 ? ((PetscReal)b->nz) / ((PetscReal)nnzA) : 1.0;
214386a27549SJunchao Zhang 
2144aac854edSJunchao Zhang   B->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJKokkos_ILU0;
21453ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2146930e68a5SMark Adams }
2147930e68a5SMark Adams 
2148aac854edSJunchao Zhang static PetscErrorCode MatLUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
2149aac854edSJunchao Zhang {
2150aac854edSJunchao Zhang   PetscFunctionBegin;
2151aac854edSJunchao Zhang   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
2152aac854edSJunchao Zhang   PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2153aac854edSJunchao Zhang   PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2154aac854edSJunchao Zhang   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
2155aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2156aac854edSJunchao Zhang }
2157aac854edSJunchao Zhang 
2158aac854edSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
2159aac854edSJunchao Zhang {
2160aac854edSJunchao Zhang   PetscBool row_identity = PETSC_FALSE, col_identity = PETSC_FALSE;
2161aac854edSJunchao Zhang 
2162aac854edSJunchao Zhang   PetscFunctionBegin;
2163aac854edSJunchao Zhang   if (!info->factoronhost) {
2164aac854edSJunchao Zhang     PetscCall(ISIdentity(isrow, &row_identity));
2165aac854edSJunchao Zhang     PetscCall(ISIdentity(iscol, &col_identity));
2166aac854edSJunchao Zhang   }
2167aac854edSJunchao Zhang 
2168aac854edSJunchao Zhang   PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2169aac854edSJunchao Zhang   PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2170aac854edSJunchao Zhang 
2171aac854edSJunchao Zhang   if (!info->factoronhost && !info->levels && row_identity && col_identity) { // if level 0 and no reordering
2172aac854edSJunchao Zhang     PetscCall(MatILUFactorSymbolic_SeqAIJKokkos_ILU0(B, A, isrow, iscol, info));
2173aac854edSJunchao Zhang   } else {
2174aac854edSJunchao Zhang     PetscCall(MatILUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info)); // otherwise, use PETSc's ILU on host
2175aac854edSJunchao Zhang     B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
2176aac854edSJunchao Zhang   }
2177aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2178aac854edSJunchao Zhang }
2179aac854edSJunchao Zhang 
2180aac854edSJunchao Zhang static PetscErrorCode MatCholeskyFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
2181aac854edSJunchao Zhang {
2182aac854edSJunchao Zhang   PetscFunctionBegin;
2183aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncHost(A));
2184aac854edSJunchao Zhang   PetscCall(MatCholeskyFactorNumeric_SeqAIJ(B, A, info));
2185aac854edSJunchao Zhang 
2186aac854edSJunchao Zhang   if (!info->solveonhost) { // if solve on host, then we don't need to copy L, U to device
2187aac854edSJunchao Zhang     Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
2188aac854edSJunchao Zhang     Mat_SeqAIJ                 *b       = static_cast<Mat_SeqAIJ *>(B->data);
2189aac854edSJunchao Zhang     const PetscInt             *Bi = b->i, *Bj = b->j, *Bdiag = b->diag;
2190aac854edSJunchao Zhang     const MatScalar            *Ba = b->a;
2191aac854edSJunchao Zhang     PetscInt                    m  = B->rmap->n;
2192aac854edSJunchao Zhang 
2193aac854edSJunchao Zhang     if (factors->iU_h.extent(0) == 0) { // First time of numeric factorization
2194aac854edSJunchao Zhang       // Allocate memory and copy the structure
2195aac854edSJunchao Zhang       factors->iU_h = PetscIntKokkosViewHost(const_cast<PetscInt *>(Bi), m + 1); // wrap Bi as iU_h
2196aac854edSJunchao Zhang       factors->jU_h = MatColIdxKokkosViewHost(NoInit("jU_h"), Bi[m]);
2197aac854edSJunchao Zhang       factors->aU_h = MatScalarKokkosViewHost(NoInit("aU_h"), Bi[m]);
2198aac854edSJunchao Zhang       factors->D_h  = MatScalarKokkosViewHost(NoInit("D_h"), m);
2199aac854edSJunchao Zhang       factors->aU_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aU_h);
2200aac854edSJunchao Zhang       factors->D_d  = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->D_h);
2201aac854edSJunchao Zhang 
2202aac854edSJunchao Zhang       // Build jU_h from the skewed Aj
2203aac854edSJunchao Zhang       PetscInt *Uj = factors->jU_h.data();
2204aac854edSJunchao Zhang       for (PetscInt i = 0; i < m; i++) {
2205aac854edSJunchao Zhang         PetscInt ulen = Bi[i + 1] - Bi[i];
2206aac854edSJunchao Zhang         Uj[Bi[i]]     = i;                                              // diagonal entry
2207aac854edSJunchao Zhang         PetscCall(PetscArraycpy(Uj + Bi[i] + 1, Bj + Bi[i], ulen - 1)); // entries of U on the right of the diagonal
2208aac854edSJunchao Zhang       }
2209aac854edSJunchao Zhang 
2210aac854edSJunchao Zhang       // Copy iU, jU to device
2211aac854edSJunchao Zhang       PetscCallCXX(factors->iU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iU_h));
2212aac854edSJunchao Zhang       PetscCallCXX(factors->jU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jU_h));
2213aac854edSJunchao Zhang 
2214aac854edSJunchao Zhang       // Copy row/col permutation to device
2215aac854edSJunchao Zhang       IS        rowperm = ((Mat_SeqAIJ *)B->data)->row;
2216aac854edSJunchao Zhang       PetscBool row_identity;
2217aac854edSJunchao Zhang       PetscCall(ISIdentity(rowperm, &row_identity));
2218aac854edSJunchao Zhang       if (!row_identity) {
2219aac854edSJunchao Zhang         const PetscInt *ip;
2220aac854edSJunchao Zhang 
2221aac854edSJunchao Zhang         PetscCall(ISGetIndices(rowperm, &ip));
2222aac854edSJunchao Zhang         PetscCallCXX(factors->rowperm = PetscIntKokkosView(NoInit("rowperm"), m));
2223aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->rowperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), m)));
2224aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(rowperm, &ip));
2225aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(m * sizeof(PetscInt)));
2226aac854edSJunchao Zhang       }
2227aac854edSJunchao Zhang 
2228aac854edSJunchao Zhang       // Create sptrsv handles for U and U^T
2229aac854edSJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2230aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
2231aac854edSJunchao Zhang #else
2232aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
2233aac854edSJunchao Zhang #endif
2234aac854edSJunchao Zhang       factors->khU.create_sptrsv_handle(sptrsv_alg, m, false /* U is not lower tri */);
2235aac854edSJunchao Zhang       factors->khUt.create_sptrsv_handle(sptrsv_alg, m, true /* U^T is lower tri */);
2236aac854edSJunchao Zhang     }
2237aac854edSJunchao Zhang     // These pointers were set MatCholeskyFactorNumeric_SeqAIJ(), so we always need to update them
2238aac854edSJunchao Zhang     B->ops->solve          = MatSolve_SeqAIJKokkos_Cholesky;
2239aac854edSJunchao Zhang     B->ops->solvetranspose = MatSolve_SeqAIJKokkos_Cholesky;
2240aac854edSJunchao Zhang 
2241aac854edSJunchao Zhang     // Copy the value
2242aac854edSJunchao Zhang     PetscScalar *Ua = factors->aU_h.data();
2243aac854edSJunchao Zhang     PetscScalar *D  = factors->D_h.data();
2244aac854edSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
2245aac854edSJunchao Zhang       D[i]      = Ba[Bdiag[i]];     // actually Aa[Adiag[i]] is the inverse of the diagonal
2246aac854edSJunchao Zhang       Ua[Bi[i]] = (PetscScalar)1.0; // set the unit diagonal for U
2247aac854edSJunchao Zhang       for (PetscInt k = 0; k < Bi[i + 1] - Bi[i] - 1; k++) Ua[Bi[i] + 1 + k] = -Ba[Bi[i] + k];
2248aac854edSJunchao Zhang     }
2249aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aU_d, factors->aU_h));
2250aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->D_d, factors->D_h));
2251aac854edSJunchao Zhang 
2252aac854edSJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_FALSE; // When numeric value changed, we must do these again
2253aac854edSJunchao Zhang     factors->transpose_updated         = PETSC_FALSE;
2254aac854edSJunchao Zhang   }
2255aac854edSJunchao Zhang 
2256aac854edSJunchao Zhang   B->ops->matsolve          = NULL;
2257aac854edSJunchao Zhang   B->ops->matsolvetranspose = NULL;
2258aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2259aac854edSJunchao Zhang }
2260aac854edSJunchao Zhang 
2261aac854edSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS perm, const MatFactorInfo *info)
2262aac854edSJunchao Zhang {
2263aac854edSJunchao Zhang   PetscFunctionBegin;
2264aac854edSJunchao Zhang   if (info->solveonhost) {
2265aac854edSJunchao Zhang     // If solve on host, we have to change the type, as eventually we need to call MatSolve_SeqSBAIJ_1_NaturalOrdering() etc.
2266aac854edSJunchao Zhang     PetscCall(MatSetType(B, MATSEQSBAIJ));
2267aac854edSJunchao Zhang     PetscCall(MatSeqSBAIJSetPreallocation(B, 1, MAT_SKIP_ALLOCATION, NULL));
2268aac854edSJunchao Zhang   }
2269aac854edSJunchao Zhang 
2270aac854edSJunchao Zhang   PetscCall(MatICCFactorSymbolic_SeqAIJ(B, A, perm, info));
2271aac854edSJunchao Zhang 
2272aac854edSJunchao Zhang   if (!info->solveonhost) {
2273*bfe80ac4SPierre Jolivet     // If solve on device, B is still a MATSEQAIJKOKKOS, so we are good to allocate B->spptr
2274aac854edSJunchao Zhang     PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2275aac854edSJunchao Zhang     PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2276aac854edSJunchao Zhang     B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJKokkos;
2277aac854edSJunchao Zhang   }
2278aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2279aac854edSJunchao Zhang }
2280aac854edSJunchao Zhang 
2281aac854edSJunchao Zhang static PetscErrorCode MatCholeskyFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS perm, const MatFactorInfo *info)
2282aac854edSJunchao Zhang {
2283aac854edSJunchao Zhang   PetscFunctionBegin;
2284aac854edSJunchao Zhang   if (info->solveonhost) {
2285aac854edSJunchao Zhang     // If solve on host, we have to change the type, as eventually we need to call MatSolve_SeqSBAIJ_1_NaturalOrdering() etc.
2286aac854edSJunchao Zhang     PetscCall(MatSetType(B, MATSEQSBAIJ));
2287aac854edSJunchao Zhang     PetscCall(MatSeqSBAIJSetPreallocation(B, 1, MAT_SKIP_ALLOCATION, NULL));
2288aac854edSJunchao Zhang   }
2289aac854edSJunchao Zhang 
2290aac854edSJunchao Zhang   PetscCall(MatCholeskyFactorSymbolic_SeqAIJ(B, A, perm, info)); // it sets B's two ISes ((Mat_SeqAIJ*)B->data)->{row, col} to perm
2291aac854edSJunchao Zhang 
2292aac854edSJunchao Zhang   if (!info->solveonhost) {
2293*bfe80ac4SPierre Jolivet     // If solve on device, B is still a MATSEQAIJKOKKOS, so we are good to allocate B->spptr
2294aac854edSJunchao Zhang     PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2295aac854edSJunchao Zhang     PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2296aac854edSJunchao Zhang     B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJKokkos;
2297aac854edSJunchao Zhang   }
2298aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2299aac854edSJunchao Zhang }
2300aac854edSJunchao Zhang 
2301aac854edSJunchao Zhang // The _Kokkos suffix means we will use Kokkos as a solver for the SeqAIJKokkos matrix
2302aac854edSJunchao Zhang static PetscErrorCode MatFactorGetSolverType_SeqAIJKokkos_Kokkos(Mat A, MatSolverType *type)
2303d71ae5a4SJacob Faibussowitsch {
2304930e68a5SMark Adams   PetscFunctionBegin;
2305930e68a5SMark Adams   *type = MATSOLVERKOKKOS;
23063ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2307930e68a5SMark Adams }
2308930e68a5SMark Adams 
2309930e68a5SMark Adams /*MC
231086a27549SJunchao Zhang   MATSOLVERKOKKOS = "Kokkos" - A matrix solver type providing triangular solvers for sequential matrices
231111a5261eSBarry Smith   on a single GPU of type, `MATSEQAIJKOKKOS`, `MATAIJKOKKOS`.
2312930e68a5SMark Adams 
2313930e68a5SMark Adams   Level: beginner
2314930e68a5SMark Adams 
23151cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJKokkos()`, `MATAIJKOKKOS`, `MatKokkosSetFormat()`, `MatKokkosStorageFormat`, `MatKokkosFormatOperation`
2316930e68a5SMark Adams M*/
231786a27549SJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_SeqAIJKokkos_Kokkos(Mat A, MatFactorType ftype, Mat *B) /* MatGetFactor_<MatType>_<MatSolverType> */
2318930e68a5SMark Adams {
2319930e68a5SMark Adams   PetscInt n = A->rmap->n;
2320aac854edSJunchao Zhang   MPI_Comm comm;
2321930e68a5SMark Adams 
2322930e68a5SMark Adams   PetscFunctionBegin;
2323aac854edSJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
2324aac854edSJunchao Zhang   PetscCall(MatCreate(comm, B));
23259566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*B, n, n, n, n));
2326aac854edSJunchao Zhang   PetscCall(MatSetBlockSizesFromMats(*B, A, A));
2327930e68a5SMark Adams   (*B)->factortype = ftype;
23289566063dSJacob Faibussowitsch   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
23299566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
2330aac854edSJunchao Zhang   PetscCheck(!(*B)->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2331aac854edSJunchao Zhang 
2332aac854edSJunchao Zhang   if (ftype == MAT_FACTOR_LU || ftype == MAT_FACTOR_ILU || ftype == MAT_FACTOR_ILUDT) {
2333aac854edSJunchao Zhang     (*B)->ops->lufactorsymbolic  = MatLUFactorSymbolic_SeqAIJKokkos;
2334aac854edSJunchao Zhang     (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJKokkos;
2335aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
2336aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILU]));
2337aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILUDT]));
2338aac854edSJunchao Zhang   } else if (ftype == MAT_FACTOR_CHOLESKY || ftype == MAT_FACTOR_ICC) {
2339aac854edSJunchao Zhang     (*B)->ops->iccfactorsymbolic      = MatICCFactorSymbolic_SeqAIJKokkos;
2340aac854edSJunchao Zhang     (*B)->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqAIJKokkos;
2341aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_CHOLESKY]));
2342aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ICC]));
2343aac854edSJunchao Zhang   } else SETERRQ(comm, PETSC_ERR_SUP, "MatFactorType %s is not supported by MatType SeqAIJKokkos", MatFactorTypes[ftype]);
2344aac854edSJunchao Zhang 
2345aac854edSJunchao Zhang   // The factorization can use the ordering provided in MatLUFactorSymbolic(), MatCholeskyFactorSymbolic() etc, though we do it on host
2346aac854edSJunchao Zhang   (*B)->canuseordering = PETSC_TRUE;
2347aac854edSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)*B, "MatFactorGetSolverType_C", MatFactorGetSolverType_SeqAIJKokkos_Kokkos));
23483ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2349930e68a5SMark Adams }
23508f7e8f9dSMark Adams 
2351aac854edSJunchao Zhang PETSC_INTERN PetscErrorCode MatSolverTypeRegister_Kokkos(void)
2352d71ae5a4SJacob Faibussowitsch {
235386a27549SJunchao Zhang   PetscFunctionBegin;
23549566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_SeqAIJKokkos_Kokkos));
2355aac854edSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_CHOLESKY, MatGetFactor_SeqAIJKokkos_Kokkos));
23569566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ILU, MatGetFactor_SeqAIJKokkos_Kokkos));
2357aac854edSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ICC, MatGetFactor_SeqAIJKokkos_Kokkos));
23583ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
235986a27549SJunchao Zhang }
236086a27549SJunchao Zhang 
2361076ba34aSJunchao Zhang /* Utility to print out a KokkosCsrMatrix for debugging */
2362d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PrintCsrMatrix(const KokkosCsrMatrix &csrmat)
2363d71ae5a4SJacob Faibussowitsch {
236445402d8aSJunchao Zhang   const auto        &iv = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.graph.row_map);
236545402d8aSJunchao Zhang   const auto        &jv = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.graph.entries);
236645402d8aSJunchao Zhang   const auto        &av = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.values);
2367076ba34aSJunchao Zhang   const PetscInt    *i  = iv.data();
2368076ba34aSJunchao Zhang   const PetscInt    *j  = jv.data();
2369076ba34aSJunchao Zhang   const PetscScalar *a  = av.data();
2370076ba34aSJunchao Zhang   PetscInt           m = csrmat.numRows(), n = csrmat.numCols(), nnz = csrmat.nnz();
2371076ba34aSJunchao Zhang 
2372076ba34aSJunchao Zhang   PetscFunctionBegin;
23739566063dSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT " x %" PetscInt_FMT " SeqAIJKokkos, with %" PetscInt_FMT " nonzeros\n", m, n, nnz));
2374076ba34aSJunchao Zhang   for (PetscInt k = 0; k < m; k++) {
23759566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT ": ", k));
237648a46eb9SPierre Jolivet     for (PetscInt p = i[k]; p < i[k + 1]; p++) PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT "(%.1f), ", j[p], (double)PetscRealPart(a[p])));
23779566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "\n"));
2378076ba34aSJunchao Zhang   }
23793ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2380076ba34aSJunchao Zhang }
2381