xref: /petsc/src/mat/impls/aij/seq/kokkos/aijkok.kokkos.cxx (revision f3d3cd90648576fafae91a24e2611daaef7bcd2e)
1e36ced11SJunchao Zhang #include <petsc_kokkos.hpp>
211d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
3c0c276a7Ssdargavi #include <petscmat_kokkos.hpp>
4076ba34aSJunchao Zhang #include <petscpkg_version.h>
5152b3e56SJunchao Zhang #include <petsc/private/petscimpl.h>
642550becSJunchao Zhang #include <petsc/private/sfimpl.h>
7aac854edSJunchao Zhang #include <petsc/private/kokkosimpl.hpp>
87233ce55SJed Brown #include <petscsys.h>
98c3ff71bSJunchao Zhang 
108c3ff71bSJunchao Zhang #include <Kokkos_Core.hpp>
11f0cf5187SStefano Zampini #include <KokkosBlas.hpp>
128c3ff71bSJunchao Zhang #include <KokkosSparse_CrsMatrix.hpp>
13cc6e31f1SJunchao Zhang 
14cc6e31f1SJunchao Zhang // To suppress compiler warnings:
15cc6e31f1SJunchao Zhang // /path/include/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp:434:63:
16cc6e31f1SJunchao Zhang // warning: 'cusparseStatus_t cusparseDbsrmm(cusparseHandle_t, cusparseDirection_t, cusparseOperation_t,
17cc6e31f1SJunchao Zhang // cusparseOperation_t, int, int, int, int, const double*, cusparseMatDescr_t, const double*, const int*, const int*,
18cc6e31f1SJunchao Zhang // int, const double*, int, const double*, double*, int)' is deprecated: please use cusparseSpMM instead [-Wdeprecated-declarations]
19cc6e31f1SJunchao Zhang PETSC_PRAGMA_DIAGNOSTIC_IGNORED_BEGIN("-Wdeprecated-declarations")
208c3ff71bSJunchao Zhang #include <KokkosSparse_spmv.hpp>
21cc6e31f1SJunchao Zhang PETSC_PRAGMA_DIAGNOSTIC_IGNORED_END()
22cc6e31f1SJunchao Zhang 
2386a27549SJunchao Zhang #include <KokkosSparse_spiluk.hpp>
2486a27549SJunchao Zhang #include <KokkosSparse_sptrsv.hpp>
25076ba34aSJunchao Zhang #include <KokkosSparse_spgemm.hpp>
26076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
279d13fa56SJunchao Zhang #include <KokkosBatched_LU_Decl.hpp>
289d13fa56SJunchao Zhang #include <KokkosBatched_InverseLU_Decl.hpp>
2986a27549SJunchao Zhang 
3042550becSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
318c3ff71bSJunchao Zhang 
320e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 7, 0)
33f98996d3SJunchao Zhang   #include <KokkosSparse_Utils.hpp>
34f98996d3SJunchao Zhang using KokkosSparse::sort_crs_matrix;
359371c9d4SSatish Balay using KokkosSparse::Impl::transpose_matrix;
36f98996d3SJunchao Zhang #else
37f98996d3SJunchao Zhang   #include <KokkosKernels_Sorting.hpp>
38f98996d3SJunchao Zhang using KokkosKernels::sort_crs_matrix;
399371c9d4SSatish Balay using KokkosKernels::Impl::transpose_matrix;
40f98996d3SJunchao Zhang #endif
41f98996d3SJunchao Zhang 
42aac854edSJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(4, 6, 0)
43aac854edSJunchao Zhang using KokkosSparse::spiluk_symbolic;
44aac854edSJunchao Zhang using KokkosSparse::spiluk_numeric;
45aac854edSJunchao Zhang using KokkosSparse::sptrsv_symbolic;
46aac854edSJunchao Zhang using KokkosSparse::sptrsv_solve;
47aac854edSJunchao Zhang using KokkosSparse::Experimental::SPTRSVAlgorithm;
48aac854edSJunchao Zhang using KokkosSparse::Experimental::SPILUKAlgorithm;
49aac854edSJunchao Zhang #else
50aac854edSJunchao Zhang using KokkosSparse::Experimental::spiluk_symbolic;
51aac854edSJunchao Zhang using KokkosSparse::Experimental::spiluk_numeric;
52aac854edSJunchao Zhang using KokkosSparse::Experimental::sptrsv_symbolic;
53aac854edSJunchao Zhang using KokkosSparse::Experimental::sptrsv_solve;
54aac854edSJunchao Zhang using KokkosSparse::Experimental::SPTRSVAlgorithm;
55aac854edSJunchao Zhang using KokkosSparse::Experimental::SPILUKAlgorithm;
56aac854edSJunchao Zhang #endif
57aac854edSJunchao Zhang 
588c3ff71bSJunchao Zhang static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat); /* Forward declaration */
598c3ff71bSJunchao Zhang 
60076ba34aSJunchao Zhang /* MatAssemblyEnd_SeqAIJKokkos() happens when we finalized nonzeros of the matrix, either after
61076ba34aSJunchao Zhang    we assembled the matrix on host, or after we directly produced the matrix data on device (ex., through MatMatMult).
62076ba34aSJunchao Zhang    In the latter case, it is important to set a_dual's sync state correctly.
63076ba34aSJunchao Zhang  */
64d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_SeqAIJKokkos(Mat A, MatAssemblyType mode)
65d71ae5a4SJacob Faibussowitsch {
66076ba34aSJunchao Zhang   Mat_SeqAIJ       *aijseq;
67076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
688c3ff71bSJunchao Zhang 
698c3ff71bSJunchao Zhang   PetscFunctionBegin;
703ba16761SJacob Faibussowitsch   if (mode == MAT_FLUSH_ASSEMBLY) PetscFunctionReturn(PETSC_SUCCESS);
719566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
72076ba34aSJunchao Zhang 
73076ba34aSJunchao Zhang   aijseq = static_cast<Mat_SeqAIJ *>(A->data);
74076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
75076ba34aSJunchao Zhang 
76076ba34aSJunchao Zhang   /* If aijkok does not exist, we just copy i, j to device.
77076ba34aSJunchao Zhang      If aijkok already exists, but the device's nonzero pattern does not match with the host's, we assume the latest data is on host.
78076ba34aSJunchao Zhang      In both cases, we build a new aijkok structure.
79076ba34aSJunchao Zhang   */
80076ba34aSJunchao Zhang   if (!aijkok || aijkok->nonzerostate != A->nonzerostate) { /* aijkok might not exist yet or nonzero pattern has changed */
81076ba34aSJunchao Zhang     delete aijkok;
82f4747e26SJunchao Zhang     aijkok   = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aijseq, A->nonzerostate, PETSC_FALSE /*don't copy mat values to device*/);
83076ba34aSJunchao Zhang     A->spptr = aijkok;
84f4747e26SJunchao Zhang   } else if (A->rmap->n && aijkok->diag_dual.extent(0) == 0) { // MatProduct might directly produce AIJ on device, but not the diag.
85f4747e26SJunchao Zhang     MatRowMapKokkosViewHost diag_h(aijseq->diag, A->rmap->n);
86f4747e26SJunchao Zhang     auto                    diag_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), diag_h);
87f4747e26SJunchao Zhang     aijkok->diag_dual              = MatRowMapKokkosDualView(diag_d, diag_h);
88076ba34aSJunchao Zhang   }
893ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
908c3ff71bSJunchao Zhang }
918c3ff71bSJunchao Zhang 
9286a27549SJunchao Zhang /* Sync CSR data to device if not yet */
93d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosSyncDevice(Mat A)
94d71ae5a4SJacob Faibussowitsch {
958c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
968c3ff71bSJunchao Zhang 
978c3ff71bSJunchao Zhang   PetscFunctionBegin;
98aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from host to device");
995f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
100076ba34aSJunchao Zhang   if (aijkok->a_dual.need_sync_device()) {
101*f3d3cd90SZach Atkins     PetscCall(KokkosDualViewSyncDevice(aijkok->a_dual, PetscGetKokkosExecutionSpace()));
102580c7c76SPierre Jolivet     aijkok->transpose_updated = PETSC_FALSE; /* values of the transpose is out-of-date */
10386a27549SJunchao Zhang     aijkok->hermitian_updated = PETSC_FALSE;
1048c3ff71bSJunchao Zhang   }
1053ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1068c3ff71bSJunchao Zhang }
1078c3ff71bSJunchao Zhang 
108076ba34aSJunchao Zhang /* Mark the CSR data on device as modified */
109d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosModifyDevice(Mat A)
110d71ae5a4SJacob Faibussowitsch {
11186a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
11286a27549SJunchao Zhang 
11386a27549SJunchao Zhang   PetscFunctionBegin;
1145f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Not supported for factorized matries");
11586a27549SJunchao Zhang   aijkok->a_dual.clear_sync_state();
11686a27549SJunchao Zhang   aijkok->a_dual.modify_device();
11786a27549SJunchao Zhang   aijkok->transpose_updated = PETSC_FALSE;
11886a27549SJunchao Zhang   aijkok->hermitian_updated = PETSC_FALSE;
1199566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJInvalidateDiagonal(A));
1209566063dSJacob Faibussowitsch   PetscCall(PetscObjectStateIncrease((PetscObject)A));
1213ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
12286a27549SJunchao Zhang }
12386a27549SJunchao Zhang 
124d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSyncHost(Mat A)
125d71ae5a4SJacob Faibussowitsch {
126f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1274df4a32cSJunchao Zhang   auto              exec   = PetscGetKokkosExecutionSpace();
128f0cf5187SStefano Zampini 
129f0cf5187SStefano Zampini   PetscFunctionBegin;
130f0cf5187SStefano Zampini   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
13186a27549SJunchao Zhang   /* We do not expect one needs factors on host  */
132aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from device to host");
1335f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing AIJKOK");
134*f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncHost(aijkok->a_dual, exec));
1353ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
136f0cf5187SStefano Zampini }
137f0cf5187SStefano Zampini 
138d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
139d71ae5a4SJacob Faibussowitsch {
140076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
141f0cf5187SStefano Zampini 
142f0cf5187SStefano Zampini   PetscFunctionBegin;
1435519a089SJose E. Roman   /* aijkok contains valid pointers only if the host's nonzerostate matches with the device's.
1445519a089SJose E. Roman     Calling MatSeqAIJSetPreallocation() or MatSetValues() on host, where aijseq->{i,j,a} might be
1455519a089SJose E. Roman     reallocated, will lead to stale {i,j,a}_dual in aijkok. In both operations, the hosts's nonzerostate
1465519a089SJose E. Roman     must have been updated. The stale aijkok will be rebuilt during MatAssemblyEnd.
1475519a089SJose E. Roman   */
1485519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
149*f3d3cd90SZach Atkins     PetscCall(KokkosDualViewSyncHost(aijkok->a_dual, PetscGetKokkosExecutionSpace()));
150076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
151076ba34aSJunchao Zhang   } else { /* Happens when calling MatSetValues on a newly created matrix */
152076ba34aSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
153076ba34aSJunchao Zhang   }
1543ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
155076ba34aSJunchao Zhang }
156076ba34aSJunchao Zhang 
157d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
158d71ae5a4SJacob Faibussowitsch {
159fabba767SZach Atkins #if !defined(KOKKOS_ENABLE_UNIFIED_MEMORY)
160076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
161fabba767SZach Atkins #endif
162076ba34aSJunchao Zhang 
163076ba34aSJunchao Zhang   PetscFunctionBegin;
164fabba767SZach Atkins #if !defined(KOKKOS_ENABLE_UNIFIED_MEMORY)
1655519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) aijkok->a_dual.modify_host();
166fabba767SZach Atkins #endif
1673ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
168076ba34aSJunchao Zhang }
169076ba34aSJunchao Zhang 
170d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
171d71ae5a4SJacob Faibussowitsch {
172076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
173076ba34aSJunchao Zhang 
174076ba34aSJunchao Zhang   PetscFunctionBegin;
1755519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
176*f3d3cd90SZach Atkins     PetscCall(KokkosDualViewSyncHost(aijkok->a_dual, PetscGetKokkosExecutionSpace()));
177076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1782328674fSJunchao Zhang   } else {
1792328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1802328674fSJunchao Zhang   }
1813ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
182076ba34aSJunchao Zhang }
183076ba34aSJunchao Zhang 
184d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
185d71ae5a4SJacob Faibussowitsch {
186076ba34aSJunchao Zhang   PetscFunctionBegin;
187076ba34aSJunchao Zhang   *array = NULL;
1883ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
189076ba34aSJunchao Zhang }
190076ba34aSJunchao Zhang 
191d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
192d71ae5a4SJacob Faibussowitsch {
193076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
194076ba34aSJunchao Zhang 
195076ba34aSJunchao Zhang   PetscFunctionBegin;
1965519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
197076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1982328674fSJunchao Zhang   } else { /* Ex. happens with MatZeroEntries on a preallocated but not assembled matrix */
1992328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
2002328674fSJunchao Zhang   }
2013ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
202076ba34aSJunchao Zhang }
203076ba34aSJunchao Zhang 
204d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
205d71ae5a4SJacob Faibussowitsch {
206076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
207076ba34aSJunchao Zhang 
208076ba34aSJunchao Zhang   PetscFunctionBegin;
209fabba767SZach Atkins #if !defined(KOKKOS_ENABLE_UNIFIED_MEMORY)
2105519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
211076ba34aSJunchao Zhang     aijkok->a_dual.clear_sync_state();
212076ba34aSJunchao Zhang     aijkok->a_dual.modify_host();
2132328674fSJunchao Zhang   }
214fabba767SZach Atkins #else
215fabba767SZach Atkins   (void)aijkok;
216fabba767SZach Atkins #endif
2173ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
218f0cf5187SStefano Zampini }
219f0cf5187SStefano Zampini 
220d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJKokkos(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
221d71ae5a4SJacob Faibussowitsch {
2227ee59b9bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
2237ee59b9bSJunchao Zhang 
2247ee59b9bSJunchao Zhang   PetscFunctionBegin;
2257ee59b9bSJunchao Zhang   PetscCheck(aijkok != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "aijkok is NULL");
2267ee59b9bSJunchao Zhang 
2277ee59b9bSJunchao Zhang   if (i) *i = aijkok->i_device_data();
2287ee59b9bSJunchao Zhang   if (j) *j = aijkok->j_device_data();
2297ee59b9bSJunchao Zhang   if (a) {
230*f3d3cd90SZach Atkins     PetscCall(KokkosDualViewSyncDevice(aijkok->a_dual, PetscGetKokkosExecutionSpace()));
2317ee59b9bSJunchao Zhang     *a = aijkok->a_device_data();
2327ee59b9bSJunchao Zhang   }
2337ee59b9bSJunchao Zhang   if (mtype) *mtype = PETSC_MEMTYPE_KOKKOS;
2343ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2357ee59b9bSJunchao Zhang }
2367ee59b9bSJunchao Zhang 
2370e3ece09SJunchao Zhang /*
2380e3ece09SJunchao Zhang   Generate the sparsity pattern of a MatSeqAIJKokkos matrix's transpose on device.
2390e3ece09SJunchao Zhang 
2400e3ece09SJunchao Zhang   Input Parameter:
2410e3ece09SJunchao Zhang .  A       - the MATSEQAIJKOKKOS matrix
2420e3ece09SJunchao Zhang 
2430e3ece09SJunchao Zhang   Output Parameters:
2440e3ece09SJunchao Zhang +  perm_d - the permutation array on device, which connects Ta(i) = Aa(perm(i))
245aaa8cc7dSPierre Jolivet -  T_d    - the transpose on device, whose value array is allocated but not initialized
2460e3ece09SJunchao Zhang */
2470e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateTransposeStructure(Mat A, MatRowMapKokkosView &perm_d, KokkosCsrMatrix &T_d)
248d71ae5a4SJacob Faibussowitsch {
2490e3ece09SJunchao Zhang   Mat_SeqAIJ             *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2500e3ece09SJunchao Zhang   PetscInt                nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
2510e3ece09SJunchao Zhang   const PetscInt         *Ai = aseq->i, *Aj = aseq->j;
2527b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost Ti_h(NoInit("Ti"), n + 1);
2530e3ece09SJunchao Zhang   MatRowMapType          *Ti = Ti_h.data();
2547b8d4ba6SJunchao Zhang   MatColIdxKokkosViewHost Tj_h(NoInit("Tj"), nz);
2557b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost perm_h(NoInit("permutation"), nz);
2560e3ece09SJunchao Zhang   PetscInt               *Tj   = Tj_h.data();
2570e3ece09SJunchao Zhang   PetscInt               *perm = perm_h.data();
2580e3ece09SJunchao Zhang   PetscInt               *offset;
259152b3e56SJunchao Zhang 
260152b3e56SJunchao Zhang   PetscFunctionBegin;
2610e3ece09SJunchao Zhang   // Populate Ti
2620e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::deep_copy(Ti_h, 0));
2630e3ece09SJunchao Zhang   Ti++;
2640e3ece09SJunchao Zhang   for (PetscInt i = 0; i < nz; i++) Ti[Aj[i]]++;
2650e3ece09SJunchao Zhang   Ti--;
2660e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) Ti[i + 1] += Ti[i];
2670e3ece09SJunchao Zhang 
2680e3ece09SJunchao Zhang   // Populate Tj and the permutation array
2690e3ece09SJunchao Zhang   PetscCall(PetscCalloc1(n, &offset)); // offset in each T row to fill in its column indices
2700e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
2710e3ece09SJunchao Zhang     for (PetscInt j = Ai[i]; j < Ai[i + 1]; j++) { // A's (i,j) is T's (j,i)
2720e3ece09SJunchao Zhang       PetscInt r    = Aj[j];                       // row r of T
2730e3ece09SJunchao Zhang       PetscInt disp = Ti[r] + offset[r];
2740e3ece09SJunchao Zhang 
2750e3ece09SJunchao Zhang       Tj[disp]   = i; // col i of T
2760e3ece09SJunchao Zhang       perm[disp] = j;
2770e3ece09SJunchao Zhang       offset[r]++;
278076ba34aSJunchao Zhang     }
2790e3ece09SJunchao Zhang   }
2800e3ece09SJunchao Zhang   PetscCall(PetscFree(offset));
2810e3ece09SJunchao Zhang 
2820e3ece09SJunchao Zhang   // Sort each row of T, along with the permutation array
2830e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) PetscCall(PetscSortIntWithArray(Ti[i + 1] - Ti[i], Tj + Ti[i], perm + Ti[i]));
2840e3ece09SJunchao Zhang 
2850e3ece09SJunchao Zhang   // Output perm and T on device
2860e3ece09SJunchao Zhang   auto Ti_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ti_h);
2870e3ece09SJunchao Zhang   auto Tj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Tj_h);
2880e3ece09SJunchao Zhang   PetscCallCXX(T_d = KokkosCsrMatrix("csrmatT", n, m, nz, MatScalarKokkosView("Ta", nz), Ti_d, Tj_d));
2890e3ece09SJunchao Zhang   PetscCallCXX(perm_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), perm_h));
2903ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
291152b3e56SJunchao Zhang }
292152b3e56SJunchao Zhang 
2930e3ece09SJunchao Zhang // Generate the transpose on device and cache it internally
2940e3ece09SJunchao Zhang // Note: KK transpose_matrix() does not have support symbolic/numeric transpose, so we do it on our own
2950e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGenerateTranspose_Private(Mat A, KokkosCsrMatrix *csrmatT)
296d71ae5a4SJacob Faibussowitsch {
2970e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2980e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
2990e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
3000e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatT;
301152b3e56SJunchao Zhang 
302152b3e56SJunchao Zhang   PetscFunctionBegin;
3030e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
304*f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncDevice(akok->a_dual, PetscGetKokkosExecutionSpace())); // Sync A's values since we are going to access them on device
3050e3ece09SJunchao Zhang 
3060e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
3070e3ece09SJunchao Zhang 
3080e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE) {
3090e3ece09SJunchao Zhang     *csrmatT = akok->csrmat;
3100e3ece09SJunchao Zhang   } else {
3110e3ece09SJunchao Zhang     // See if we already have a cached transpose and its value is up to date
3120e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
3130e3ece09SJunchao Zhang       if (!akok->transpose_updated) {            // if the value is out of date, update the cached version
3140e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
3150e3ece09SJunchao Zhang         auto       &Ta   = T.values;
3160e3ece09SJunchao Zhang 
317d326c3f1SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = Aa(perm(i)); }));
318076ba34aSJunchao Zhang       }
3190e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
3200e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
3210e3ece09SJunchao Zhang 
3220e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
3230e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
324d326c3f1SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = Aa(perm(i)); }));
3250e3ece09SJunchao Zhang     }
3260e3ece09SJunchao Zhang     akok->transpose_updated = PETSC_TRUE;
3270e3ece09SJunchao Zhang     *csrmatT                = akok->csrmatT;
3280e3ece09SJunchao Zhang   }
3290e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3300e3ece09SJunchao Zhang }
3310e3ece09SJunchao Zhang 
3320e3ece09SJunchao Zhang // Generate the Hermitian on device and cache it internally
3330e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateHermitian_Private(Mat A, KokkosCsrMatrix *csrmatH)
3340e3ece09SJunchao Zhang {
3350e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
3360e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3370e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
3380e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatH;
3390e3ece09SJunchao Zhang 
3400e3ece09SJunchao Zhang   PetscFunctionBegin;
3410e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
342*f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncDevice(akok->a_dual, PetscGetKokkosExecutionSpace())); // Sync A's values since we are going to access them on device
3430e3ece09SJunchao Zhang 
3440e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
3450e3ece09SJunchao Zhang 
3460e3ece09SJunchao Zhang   if (A->hermitian == PETSC_BOOL3_TRUE) {
3470e3ece09SJunchao Zhang     *csrmatH = akok->csrmat;
3480e3ece09SJunchao Zhang   } else {
3490e3ece09SJunchao Zhang     // See if we already have a cached hermitian and its value is up to date
3500e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
3510e3ece09SJunchao Zhang       if (!akok->hermitian_updated) {            // if the value is out of date, update the cached version
3520e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
3530e3ece09SJunchao Zhang         auto       &Ta   = T.values;
3540e3ece09SJunchao Zhang 
355d326c3f1SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = PetscConj(Aa(perm(i))); }));
3560e3ece09SJunchao Zhang       }
3570e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
3580e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
3590e3ece09SJunchao Zhang 
3600e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
3610e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
362d326c3f1SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = PetscConj(Aa(perm(i))); }));
3630e3ece09SJunchao Zhang     }
3640e3ece09SJunchao Zhang     akok->hermitian_updated = PETSC_TRUE;
3650e3ece09SJunchao Zhang     *csrmatH                = akok->csrmatH;
3660e3ece09SJunchao Zhang   }
3673ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
368152b3e56SJunchao Zhang }
369a587d139SMark 
3708c3ff71bSJunchao Zhang /* y = A x */
371d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMult_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
372d71ae5a4SJacob Faibussowitsch {
3738c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
374152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
375152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3768c3ff71bSJunchao Zhang 
3778c3ff71bSJunchao Zhang   PetscFunctionBegin;
3789566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3799566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3809566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3819566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
3828c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
383d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), "N", 1.0 /*alpha*/, aijkok->csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A x + beta y */
3849566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3859566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
386076ba34aSJunchao Zhang   /* 2.0*nnz - numRows seems more accurate here but assumes there are no zero-rows. So a little sloppy here. */
3879566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
3889566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
3893ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
3908c3ff71bSJunchao Zhang }
3918c3ff71bSJunchao Zhang 
3928c3ff71bSJunchao Zhang /* y = A^T x */
393d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
394d71ae5a4SJacob Faibussowitsch {
3958c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
396152b3e56SJunchao Zhang   const char                *mode;
397152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
398152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3990e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4008c3ff71bSJunchao Zhang 
4018c3ff71bSJunchao Zhang   PetscFunctionBegin;
4029566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4039566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4049566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4059566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
406152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4079566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
408152b3e56SJunchao Zhang     mode = "N";
409152b3e56SJunchao Zhang   } else {
410076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4110e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
412152b3e56SJunchao Zhang     mode   = "T";
413152b3e56SJunchao Zhang   }
414d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^T x + beta y */
4159566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4169566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
4170e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4189566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4193ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4208c3ff71bSJunchao Zhang }
4218c3ff71bSJunchao Zhang 
4228c3ff71bSJunchao Zhang /* y = A^H x */
423d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
424d71ae5a4SJacob Faibussowitsch {
4258c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
426152b3e56SJunchao Zhang   const char                *mode;
427152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
428152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
4290e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4308c3ff71bSJunchao Zhang 
4318c3ff71bSJunchao Zhang   PetscFunctionBegin;
4329566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4339566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4349566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4359566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
436152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4379566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
438152b3e56SJunchao Zhang     mode = "N";
439152b3e56SJunchao Zhang   } else {
440076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4410e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
442152b3e56SJunchao Zhang     mode   = "C";
443152b3e56SJunchao Zhang   }
444d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^H x + beta y */
4459566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4469566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
4470e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4489566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4493ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4508c3ff71bSJunchao Zhang }
4518c3ff71bSJunchao Zhang 
4528c3ff71bSJunchao Zhang /* z = A x + y */
453d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
454d71ae5a4SJacob Faibussowitsch {
4558c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
45692896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
457152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4588c3ff71bSJunchao Zhang 
4598c3ff71bSJunchao Zhang   PetscFunctionBegin;
4609566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4619566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
46292896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz)); // depending on yy's sync flags, zz might get its latest data on host
4639566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
46492896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv)); // do after VecCopy(yy, zz) to get the latest data on device
4658c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
466d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), "N", 1.0 /*alpha*/, aijkok->csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A x + beta z */
4679566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
46892896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
4699566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
4709566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4713ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4728c3ff71bSJunchao Zhang }
4738c3ff71bSJunchao Zhang 
4748c3ff71bSJunchao Zhang /* z = A^T x + y */
475d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
476d71ae5a4SJacob Faibussowitsch {
4778c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
478152b3e56SJunchao Zhang   const char                *mode;
47992896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
480152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4810e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4828c3ff71bSJunchao Zhang 
4838c3ff71bSJunchao Zhang   PetscFunctionBegin;
4849566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4859566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
48692896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz));
4879566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
48892896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv));
489152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4909566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
491152b3e56SJunchao Zhang     mode = "N";
492152b3e56SJunchao Zhang   } else {
493076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4940e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
495152b3e56SJunchao Zhang     mode   = "T";
496152b3e56SJunchao Zhang   }
497d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^T x + beta z */
4989566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
49992896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
5000e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
5019566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
5023ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5038c3ff71bSJunchao Zhang }
5048c3ff71bSJunchao Zhang 
5058c3ff71bSJunchao Zhang /* z = A^H x + y */
506d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
507d71ae5a4SJacob Faibussowitsch {
5088c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
509152b3e56SJunchao Zhang   const char                *mode;
51092896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
511152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
5120e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
5138c3ff71bSJunchao Zhang 
5148c3ff71bSJunchao Zhang   PetscFunctionBegin;
5159566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
5169566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
51792896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz));
5189566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
51992896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv));
520152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
5219566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
522152b3e56SJunchao Zhang     mode = "N";
523152b3e56SJunchao Zhang   } else {
524076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
5250e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
526152b3e56SJunchao Zhang     mode   = "C";
527152b3e56SJunchao Zhang   }
528d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^H x + beta z */
5299566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
53092896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
5310e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
5329566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
5333ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
534152b3e56SJunchao Zhang }
535152b3e56SJunchao Zhang 
53666976f2fSJacob Faibussowitsch static PetscErrorCode MatSetOption_SeqAIJKokkos(Mat A, MatOption op, PetscBool flg)
537d71ae5a4SJacob Faibussowitsch {
538152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
539152b3e56SJunchao Zhang 
540152b3e56SJunchao Zhang   PetscFunctionBegin;
541152b3e56SJunchao Zhang   switch (op) {
542152b3e56SJunchao Zhang   case MAT_FORM_EXPLICIT_TRANSPOSE:
543152b3e56SJunchao Zhang     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
5449566063dSJacob Faibussowitsch     if (A->form_explicit_transpose && !flg && aijkok) PetscCall(aijkok->DestroyMatTranspose());
545152b3e56SJunchao Zhang     A->form_explicit_transpose = flg;
546152b3e56SJunchao Zhang     break;
547d71ae5a4SJacob Faibussowitsch   default:
548d71ae5a4SJacob Faibussowitsch     PetscCall(MatSetOption_SeqAIJ(A, op, flg));
549d71ae5a4SJacob Faibussowitsch     break;
550152b3e56SJunchao Zhang   }
5513ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5528c3ff71bSJunchao Zhang }
5538c3ff71bSJunchao Zhang 
554076ba34aSJunchao Zhang /* Depending on reuse, either build a new mat, or use the existing mat */
555d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
556d71ae5a4SJacob Faibussowitsch {
557076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
5588c3ff71bSJunchao Zhang 
5598c3ff71bSJunchao Zhang   PetscFunctionBegin;
5609566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
561076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {                      /* Build a brand new mat */
5629566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));  /* the returned newmat is a SeqAIJKokkos */
5638c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {                 /* Reuse the mat created before */
5649566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); /* newmat is already a SeqAIJKokkos */
565076ba34aSJunchao Zhang   } else if (reuse == MAT_INPLACE_MATRIX) {               /* newmat is A */
5665f80ce2aSJacob Faibussowitsch     PetscCheck(A == *newmat, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "A != *newmat with MAT_INPLACE_MATRIX");
5679566063dSJacob Faibussowitsch     PetscCall(PetscFree(A->defaultvectype));
5689566063dSJacob Faibussowitsch     PetscCall(PetscStrallocpy(VECKOKKOS, &A->defaultvectype)); /* Allocate and copy the string */
5699566063dSJacob Faibussowitsch     PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSEQAIJKOKKOS));
5709566063dSJacob Faibussowitsch     PetscCall(MatSetOps_SeqAIJKokkos(A));
571076ba34aSJunchao Zhang     aseq = static_cast<Mat_SeqAIJ *>(A->data);
572394ed5ebSJunchao Zhang     if (A->assembled) { /* Copy i, j (but not values) to device for an assembled matrix if not yet */
5735f80ce2aSJacob Faibussowitsch       PetscCheck(!A->spptr, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Expect NULL (Mat_SeqAIJKokkos*)A->spptr");
574f4747e26SJunchao Zhang       A->spptr = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aseq, A->nonzerostate, PETSC_FALSE);
5758c3ff71bSJunchao Zhang     }
576076ba34aSJunchao Zhang   }
5773ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5788c3ff71bSJunchao Zhang }
5798c3ff71bSJunchao Zhang 
580076ba34aSJunchao Zhang /* MatDuplicate always creates a new matrix. MatDuplicate can be called either on an assembled matrix or
581076ba34aSJunchao Zhang    an unassembled matrix, even though MAT_COPY_VALUES is not allowed for unassembled matrix.
582076ba34aSJunchao Zhang  */
583d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDuplicate_SeqAIJKokkos(Mat A, MatDuplicateOption dupOption, Mat *B)
584d71ae5a4SJacob Faibussowitsch {
585076ba34aSJunchao Zhang   Mat_SeqAIJ       *bseq;
586076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok;
587076ba34aSJunchao Zhang   Mat               mat;
5888c3ff71bSJunchao Zhang 
5898c3ff71bSJunchao Zhang   PetscFunctionBegin;
590076ba34aSJunchao Zhang   /* Do not copy values on host as A's latest values might be on device. We don't want to do sync blindly */
5919566063dSJacob Faibussowitsch   PetscCall(MatDuplicate_SeqAIJ(A, MAT_DO_NOT_COPY_VALUES, B));
592076ba34aSJunchao Zhang   mat = *B;
593f4747e26SJunchao Zhang   if (A->assembled) {
594076ba34aSJunchao Zhang     bseq = static_cast<Mat_SeqAIJ *>(mat->data);
595f4747e26SJunchao Zhang     bkok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, bseq, mat->nonzerostate, PETSC_FALSE);
596076ba34aSJunchao Zhang     bkok->a_dual.clear_sync_state(); /* Clear B's sync state as it will be decided below */
597076ba34aSJunchao Zhang     /* Now copy values to B if needed */
598076ba34aSJunchao Zhang     if (dupOption == MAT_COPY_VALUES) {
599076ba34aSJunchao Zhang       if (akok->a_dual.need_sync_device()) {
600076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_host(), akok->a_dual.view_host());
601076ba34aSJunchao Zhang         bkok->a_dual.modify_host();
602076ba34aSJunchao Zhang       } else { /* If device has the latest data, we only copy data on device */
603076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_device(), akok->a_dual.view_device());
604076ba34aSJunchao Zhang         bkok->a_dual.modify_device();
605076ba34aSJunchao Zhang       }
606076ba34aSJunchao Zhang     } else { /* MAT_DO_NOT_COPY_VALUES or MAT_SHARE_NONZERO_PATTERN. B's values should be zeroed */
607076ba34aSJunchao Zhang       /* B's values on host should be already zeroed by MatDuplicate_SeqAIJ() */
608076ba34aSJunchao Zhang       bkok->a_dual.modify_host();
609076ba34aSJunchao Zhang     }
610076ba34aSJunchao Zhang     mat->spptr = bkok;
611076ba34aSJunchao Zhang   }
612076ba34aSJunchao Zhang 
6139566063dSJacob Faibussowitsch   PetscCall(PetscFree(mat->defaultvectype));
6149566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &mat->defaultvectype)); /* Allocate and copy the string */
6159566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)mat, MATSEQAIJKOKKOS));
6169566063dSJacob Faibussowitsch   PetscCall(MatSetOps_SeqAIJKokkos(mat));
6173ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6188c3ff71bSJunchao Zhang }
6198c3ff71bSJunchao Zhang 
620d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatTranspose_SeqAIJKokkos(Mat A, MatReuse reuse, Mat *B)
621d71ae5a4SJacob Faibussowitsch {
6220ecb592aSJunchao Zhang   Mat               At;
6230e3ece09SJunchao Zhang   KokkosCsrMatrix   internT;
6240ecb592aSJunchao Zhang   Mat_SeqAIJKokkos *atkok, *bkok;
6250ecb592aSJunchao Zhang 
6260ecb592aSJunchao Zhang   PetscFunctionBegin;
6277fb60732SBarry Smith   if (reuse == MAT_REUSE_MATRIX) PetscCall(MatTransposeCheckNonzeroState_Private(A, *B));
6289566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &internT)); /* Generate a transpose internally */
6290ecb592aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX || reuse == MAT_INPLACE_MATRIX) {
630ff751488SJunchao Zhang     /* Deep copy internT, as we want to isolate the internal transpose */
6310e3ece09SJunchao Zhang     PetscCallCXX(atkok = new Mat_SeqAIJKokkos(KokkosCsrMatrix("csrmat", internT)));
6329566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PetscObjectComm((PetscObject)A), atkok, &At));
6330ecb592aSJunchao Zhang     if (reuse == MAT_INITIAL_MATRIX) *B = At;
6349566063dSJacob Faibussowitsch     else PetscCall(MatHeaderReplace(A, &At)); /* Replace A with At inplace */
6350ecb592aSJunchao Zhang   } else {                                    /* MAT_REUSE_MATRIX, just need to copy values to B on device */
6360ecb592aSJunchao Zhang     if ((*B)->assembled) {
6370ecb592aSJunchao Zhang       bkok = static_cast<Mat_SeqAIJKokkos *>((*B)->spptr);
6380e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(bkok->a_dual.view_device(), internT.values));
6399566063dSJacob Faibussowitsch       PetscCall(MatSeqAIJKokkosModifyDevice(*B));
6400ecb592aSJunchao Zhang     } else if ((*B)->preallocated) { /* It is ok for B to be only preallocated, as needed in MatTranspose_MPIAIJ */
6410ecb592aSJunchao Zhang       Mat_SeqAIJ             *bseq = static_cast<Mat_SeqAIJ *>((*B)->data);
6420e3ece09SJunchao Zhang       MatScalarKokkosViewHost a_h(bseq->a, internT.nnz()); /* bseq->nz = 0 if unassembled */
6430e3ece09SJunchao Zhang       MatColIdxKokkosViewHost j_h(bseq->j, internT.nnz());
6440e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(a_h, internT.values));
6450e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(j_h, internT.graph.entries));
6460ecb592aSJunchao Zhang     } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "B must be assembled or preallocated");
6470ecb592aSJunchao Zhang   }
6483ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6490ecb592aSJunchao Zhang }
6500ecb592aSJunchao Zhang 
651d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDestroy_SeqAIJKokkos(Mat A)
652d71ae5a4SJacob Faibussowitsch {
65386a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
6548c3ff71bSJunchao Zhang 
6558c3ff71bSJunchao Zhang   PetscFunctionBegin;
65686a27549SJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
65786a27549SJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
6588c3ff71bSJunchao Zhang     delete aijkok;
65986a27549SJunchao Zhang   } else {
66086a27549SJunchao Zhang     delete static_cast<Mat_SeqAIJKokkosTriFactors *>(A->spptr);
66186a27549SJunchao Zhang   }
662cbc6b225SStefano Zampini   A->spptr = NULL;
6639566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
6649566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
6659566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
66657761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
66757761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijkokkos_hypre_C", NULL));
66857761e9aSJunchao Zhang #endif
6699566063dSJacob Faibussowitsch   PetscCall(MatDestroy_SeqAIJ(A));
6703ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6718c3ff71bSJunchao Zhang }
6728c3ff71bSJunchao Zhang 
6733f3ba80aSJunchao Zhang /*MC
6743f3ba80aSJunchao Zhang    MATSEQAIJKOKKOS - MATAIJKOKKOS = "(seq)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
6753f3ba80aSJunchao Zhang 
67615229ffcSPierre Jolivet    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
6773f3ba80aSJunchao Zhang 
6782ef1f0ffSBarry Smith    Options Database Key:
67911a5261eSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATSEQAIJKOKKOS` during a call to `MatSetFromOptions()`
6803f3ba80aSJunchao Zhang 
6813f3ba80aSJunchao Zhang   Level: beginner
6823f3ba80aSJunchao Zhang 
6831cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateSeqAIJKokkos()`, `MATMPIAIJKOKKOS`
6843f3ba80aSJunchao Zhang M*/
685d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJKokkos(Mat A)
686d71ae5a4SJacob Faibussowitsch {
68786a27549SJunchao Zhang   PetscFunctionBegin;
6889566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
6899566063dSJacob Faibussowitsch   PetscCall(MatCreate_SeqAIJ(A));
6909566063dSJacob Faibussowitsch   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
6913ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
69286a27549SJunchao Zhang }
69386a27549SJunchao Zhang 
694076ba34aSJunchao Zhang /* Merge A, B into a matrix C. A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n) */
695d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
696d71ae5a4SJacob Faibussowitsch {
697076ba34aSJunchao Zhang   Mat_SeqAIJ         *a, *b;
698076ba34aSJunchao Zhang   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
699076ba34aSJunchao Zhang   MatScalarKokkosView aa, ba, ca;
700076ba34aSJunchao Zhang   MatRowMapKokkosView ai, bi, ci;
701076ba34aSJunchao Zhang   MatColIdxKokkosView aj, bj, cj;
702076ba34aSJunchao Zhang   PetscInt            m, n, nnz, aN;
703a3f881fbSStefano Zampini 
704a3f881fbSStefano Zampini   PetscFunctionBegin;
705076ba34aSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
706076ba34aSJunchao Zhang   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
7074f572ea9SToby Isaac   PetscAssertPointer(C, 4);
708076ba34aSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
709076ba34aSJunchao Zhang   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
7105f80ce2aSJacob Faibussowitsch   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
7115f80ce2aSJacob Faibussowitsch   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
712076ba34aSJunchao Zhang 
7139566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
7149566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
715076ba34aSJunchao Zhang   a    = static_cast<Mat_SeqAIJ *>(A->data);
716076ba34aSJunchao Zhang   b    = static_cast<Mat_SeqAIJ *>(B->data);
717076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
718076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
719076ba34aSJunchao Zhang   aa   = akok->a_dual.view_device();
720076ba34aSJunchao Zhang   ai   = akok->i_dual.view_device();
721076ba34aSJunchao Zhang   ba   = bkok->a_dual.view_device();
722076ba34aSJunchao Zhang   bi   = bkok->i_dual.view_device();
723076ba34aSJunchao Zhang   m    = A->rmap->n; /* M, N and nnz of C */
724076ba34aSJunchao Zhang   n    = A->cmap->n + B->cmap->n;
725076ba34aSJunchao Zhang   nnz  = a->nz + b->nz;
726076ba34aSJunchao Zhang   aN   = A->cmap->n; /* N of A */
727076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
728076ba34aSJunchao Zhang     aj           = akok->j_dual.view_device();
729076ba34aSJunchao Zhang     bj           = bkok->j_dual.view_device();
730076ba34aSJunchao Zhang     auto ca_dual = MatScalarKokkosDualView("a", aa.extent(0) + ba.extent(0));
731076ba34aSJunchao Zhang     auto ci_dual = MatRowMapKokkosDualView("i", ai.extent(0));
732076ba34aSJunchao Zhang     auto cj_dual = MatColIdxKokkosDualView("j", aj.extent(0) + bj.extent(0));
733076ba34aSJunchao Zhang     ca           = ca_dual.view_device();
734076ba34aSJunchao Zhang     ci           = ci_dual.view_device();
735076ba34aSJunchao Zhang     cj           = cj_dual.view_device();
736076ba34aSJunchao Zhang 
737076ba34aSJunchao Zhang     /* Concatenate A and B in parallel using Kokkos hierarchical parallelism */
7389371c9d4SSatish Balay     Kokkos::parallel_for(
739d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
740076ba34aSJunchao Zhang         PetscInt i       = t.league_rank(); /* row i */
741076ba34aSJunchao Zhang         PetscInt coffset = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
742076ba34aSJunchao Zhang 
743076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* this side effect only happens once per whole team */
744076ba34aSJunchao Zhang                                                    ci(i) = coffset;
745076ba34aSJunchao Zhang                                                    if (i == m - 1) ci(m) = ai(m) + bi(m);
746076ba34aSJunchao Zhang         });
747076ba34aSJunchao Zhang 
748076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
749076ba34aSJunchao Zhang           if (k < alen) {
750076ba34aSJunchao Zhang             ca(coffset + k) = aa(ai(i) + k);
751076ba34aSJunchao Zhang             cj(coffset + k) = aj(ai(i) + k);
752076ba34aSJunchao Zhang           } else {
753076ba34aSJunchao Zhang             ca(coffset + k) = ba(bi(i) + k - alen);
754076ba34aSJunchao Zhang             cj(coffset + k) = bj(bi(i) + k - alen) + aN; /* Entries in B get new column indices in C */
755076ba34aSJunchao Zhang           }
756076ba34aSJunchao Zhang         });
757076ba34aSJunchao Zhang       });
758076ba34aSJunchao Zhang     ca_dual.modify_device();
759076ba34aSJunchao Zhang     ci_dual.modify_device();
760076ba34aSJunchao Zhang     cj_dual.modify_device();
7619566063dSJacob Faibussowitsch     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, n, nnz, ci_dual, cj_dual, ca_dual));
7629566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
763076ba34aSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
764076ba34aSJunchao Zhang     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
765076ba34aSJunchao Zhang     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
766076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
767076ba34aSJunchao Zhang     ca   = ckok->a_dual.view_device();
768076ba34aSJunchao Zhang     ci   = ckok->i_dual.view_device();
769076ba34aSJunchao Zhang 
7709371c9d4SSatish Balay     Kokkos::parallel_for(
771d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
772076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* row i */
773076ba34aSJunchao Zhang         PetscInt alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
774076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
775076ba34aSJunchao Zhang           if (k < alen) ca(ci(i) + k) = aa(ai(i) + k);
776076ba34aSJunchao Zhang           else ca(ci(i) + k) = ba(bi(i) + k - alen);
777076ba34aSJunchao Zhang         });
778076ba34aSJunchao Zhang       });
7799566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
780076ba34aSJunchao Zhang   }
7813ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
782076ba34aSJunchao Zhang }
783076ba34aSJunchao Zhang 
784d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_SeqAIJKokkos(void *pdata)
785d71ae5a4SJacob Faibussowitsch {
786076ba34aSJunchao Zhang   PetscFunctionBegin;
787076ba34aSJunchao Zhang   delete static_cast<MatProductData_SeqAIJKokkos *>(pdata);
7883ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
789a3f881fbSStefano Zampini }
790a3f881fbSStefano Zampini 
791d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos(Mat C)
792d71ae5a4SJacob Faibussowitsch {
793a3f881fbSStefano Zampini   Mat_Product                 *product = C->product;
794a3f881fbSStefano Zampini   Mat                          A, B;
795076ba34aSJunchao Zhang   bool                         transA, transB; /* use bool, since KK needs this type */
796a3f881fbSStefano Zampini   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
797a3f881fbSStefano Zampini   Mat_SeqAIJ                  *c;
798076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
7990e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB;
800a3f881fbSStefano Zampini 
801a3f881fbSStefano Zampini   PetscFunctionBegin;
802a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
8035f80ce2aSJacob Faibussowitsch   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data empty");
804076ba34aSJunchao Zhang   pdata = static_cast<MatProductData_SeqAIJKokkos *>(C->product->data);
805076ba34aSJunchao Zhang 
8060e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
8070e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
8080e3ece09SJunchao Zhang   // we still do numeric.
8090e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
8100e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
8113ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
812076ba34aSJunchao Zhang   }
813076ba34aSJunchao Zhang 
814076ba34aSJunchao Zhang   switch (product->type) {
8159371c9d4SSatish Balay   case MATPRODUCT_AB:
8169371c9d4SSatish Balay     transA = false;
8179371c9d4SSatish Balay     transB = false;
8189371c9d4SSatish Balay     break;
8199371c9d4SSatish Balay   case MATPRODUCT_AtB:
8209371c9d4SSatish Balay     transA = true;
8219371c9d4SSatish Balay     transB = false;
8229371c9d4SSatish Balay     break;
8239371c9d4SSatish Balay   case MATPRODUCT_ABt:
8249371c9d4SSatish Balay     transA = false;
8259371c9d4SSatish Balay     transB = true;
8269371c9d4SSatish Balay     break;
827d71ae5a4SJacob Faibussowitsch   default:
828d71ae5a4SJacob Faibussowitsch     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
829076ba34aSJunchao Zhang   }
830076ba34aSJunchao Zhang 
831a3f881fbSStefano Zampini   A = product->A;
832a3f881fbSStefano Zampini   B = product->B;
8339566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8349566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
835a3f881fbSStefano Zampini   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
836a3f881fbSStefano Zampini   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
837a3f881fbSStefano Zampini   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
838076ba34aSJunchao Zhang 
8395f80ce2aSJacob Faibussowitsch   PetscCheck(ckok, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Device data structure spptr is empty");
840076ba34aSJunchao Zhang 
8410e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
8420e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
843076ba34aSJunchao Zhang 
844076ba34aSJunchao Zhang   /* TODO: Once KK spgemm implements transpose, we can get rid of the explicit transpose here */
845076ba34aSJunchao Zhang   if (transA) {
8469566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
847076ba34aSJunchao Zhang     transA = false;
848a3f881fbSStefano Zampini   }
849a3f881fbSStefano Zampini 
850076ba34aSJunchao Zhang   if (transB) {
8519566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
852076ba34aSJunchao Zhang     transB = false;
853076ba34aSJunchao Zhang   }
8549566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
8550e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, ckok->csrmat));
8560e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
857866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
858866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(ckok->csrmat)); /* without sort, mat_tests-ex62_14_seqaijkokkos fails */
859e944a159SJunchao Zhang #endif
860866eb059SJunchao Zhang 
8619566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
8629566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(C));
863a3f881fbSStefano Zampini   /* shorter version of MatAssemblyEnd_SeqAIJ */
864a3f881fbSStefano Zampini   c = (Mat_SeqAIJ *)C->data;
8659566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded,%" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
8669566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
8679566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
868a3f881fbSStefano Zampini   c->reallocs         = 0;
869076ba34aSJunchao Zhang   C->info.mallocs     = 0;
870a3f881fbSStefano Zampini   C->info.nz_unneeded = 0;
871a3f881fbSStefano Zampini   C->assembled = C->was_assembled = PETSC_TRUE;
872a3f881fbSStefano Zampini   C->num_ass++;
8733ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
874a3f881fbSStefano Zampini }
875a3f881fbSStefano Zampini 
876d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos(Mat C)
877d71ae5a4SJacob Faibussowitsch {
878076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
879076ba34aSJunchao Zhang   MatProductType               ptype;
880076ba34aSJunchao Zhang   Mat                          A, B;
881076ba34aSJunchao Zhang   bool                         transA, transB;
882076ba34aSJunchao Zhang   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
883076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
884076ba34aSJunchao Zhang   MPI_Comm                     comm;
8850e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB, csrmatC;
886a3f881fbSStefano Zampini 
887a3f881fbSStefano Zampini   PetscFunctionBegin;
888a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
8899566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
8905f80ce2aSJacob Faibussowitsch   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
891a3f881fbSStefano Zampini   A = product->A;
892a3f881fbSStefano Zampini   B = product->B;
8939566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8949566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
895a3f881fbSStefano Zampini   akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
896a3f881fbSStefano Zampini   bkok    = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
8970e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
8980e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
899076ba34aSJunchao Zhang 
900a3f881fbSStefano Zampini   ptype = product->type;
9010e3ece09SJunchao Zhang   // Take advantage of the symmetry if true
9020e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
9030e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
9040e3ece09SJunchao Zhang     product->symbolic_used_the_fact_A_is_symmetric = PETSC_TRUE;
9050e3ece09SJunchao Zhang   }
9060e3ece09SJunchao Zhang   if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
9070e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
9080e3ece09SJunchao Zhang     product->symbolic_used_the_fact_B_is_symmetric = PETSC_TRUE;
9090e3ece09SJunchao Zhang   }
9100e3ece09SJunchao Zhang 
911a3f881fbSStefano Zampini   switch (ptype) {
9129371c9d4SSatish Balay   case MATPRODUCT_AB:
9139371c9d4SSatish Balay     transA = false;
9149371c9d4SSatish Balay     transB = false;
9150e6a1e94SMark Adams     PetscCall(MatSetBlockSizesFromMats(C, A, B));
9169371c9d4SSatish Balay     break;
9179371c9d4SSatish Balay   case MATPRODUCT_AtB:
9189371c9d4SSatish Balay     transA = true;
9199371c9d4SSatish Balay     transB = false;
9200e6a1e94SMark Adams     if (A->cmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->rmap, A->cmap->bs));
9210e6a1e94SMark Adams     if (B->cmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->cmap, B->cmap->bs));
9229371c9d4SSatish Balay     break;
9239371c9d4SSatish Balay   case MATPRODUCT_ABt:
9249371c9d4SSatish Balay     transA = false;
9259371c9d4SSatish Balay     transB = true;
9260e6a1e94SMark Adams     if (A->rmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->rmap, A->rmap->bs));
9270e6a1e94SMark Adams     if (B->rmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->cmap, B->rmap->bs));
9289371c9d4SSatish Balay     break;
929d71ae5a4SJacob Faibussowitsch   default:
930d71ae5a4SJacob Faibussowitsch     SETERRQ(comm, PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
931a3f881fbSStefano Zampini   }
9320e3ece09SJunchao Zhang   PetscCallCXX(product->data = pdata = new MatProductData_SeqAIJKokkos());
933076ba34aSJunchao Zhang   pdata->reusesym = product->api_user;
934a3f881fbSStefano Zampini 
935076ba34aSJunchao Zhang   /* TODO: add command line options to select spgemm algorithms */
936866eb059SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; /* default alg is TPL if enabled, otherwise KK */
937866eb059SJunchao Zhang 
938866eb059SJunchao Zhang   /* CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 */
939866eb059SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
940866eb059SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
941866eb059SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
942866eb059SJunchao Zhang   #endif
943866eb059SJunchao Zhang #endif
9440e3ece09SJunchao Zhang   PetscCallCXX(pdata->kh.create_spgemm_handle(spgemm_alg));
945076ba34aSJunchao Zhang 
9469566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
947076ba34aSJunchao Zhang   /* TODO: Get rid of the explicit transpose once KK-spgemm implements the transpose option */
948076ba34aSJunchao Zhang   if (transA) {
9499566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
950076ba34aSJunchao Zhang     transA = false;
951076ba34aSJunchao Zhang   }
952076ba34aSJunchao Zhang 
953076ba34aSJunchao Zhang   if (transB) {
9549566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
955076ba34aSJunchao Zhang     transB = false;
956076ba34aSJunchao Zhang   }
957076ba34aSJunchao Zhang 
9580e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
959076ba34aSJunchao Zhang   /* spgemm_symbolic() only populates C's rowmap, but not C's column indices.
960076ba34aSJunchao Zhang     So we have to do a fake spgemm_numeric() here to get csrmatC.j_d setup, before
961076ba34aSJunchao Zhang     calling new Mat_SeqAIJKokkos().
962076ba34aSJunchao Zhang     TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
963076ba34aSJunchao Zhang   */
9640e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
9650e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
966866eb059SJunchao Zhang   /* Query if KK outputs a sorted matrix. If not, we need to sort it */
967866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
968866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(csrmatC)); /* sort_option defaults to -1 in KK!*/
969e944a159SJunchao Zhang #endif
9709566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
971076ba34aSJunchao Zhang 
9729566063dSJacob Faibussowitsch   PetscCallCXX(ckok = new Mat_SeqAIJKokkos(csrmatC));
9739566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(C, ckok));
974076ba34aSJunchao Zhang   C->product->destroy = MatProductDataDestroy_SeqAIJKokkos;
9753ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
976a3f881fbSStefano Zampini }
977a3f881fbSStefano Zampini 
978a3f881fbSStefano Zampini /* handles sparse matrix matrix ops */
979d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSetFromOptions_SeqAIJKokkos(Mat mat)
980d71ae5a4SJacob Faibussowitsch {
981076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
982a3f881fbSStefano Zampini   PetscBool    Biskok = PETSC_FALSE, Ciskok = PETSC_TRUE;
983a3f881fbSStefano Zampini 
984a3f881fbSStefano Zampini   PetscFunctionBegin;
985a3f881fbSStefano Zampini   MatCheckProduct(mat, 1);
9869566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJKOKKOS, &Biskok));
98748a46eb9SPierre Jolivet   if (product->type == MATPRODUCT_ABC) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJKOKKOS, &Ciskok));
988a3f881fbSStefano Zampini   if (Biskok && Ciskok) {
989a3f881fbSStefano Zampini     switch (product->type) {
990a3f881fbSStefano Zampini     case MATPRODUCT_AB:
991a3f881fbSStefano Zampini     case MATPRODUCT_AtB:
992d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABt:
993d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos;
994d71ae5a4SJacob Faibussowitsch       break;
995a3f881fbSStefano Zampini     case MATPRODUCT_PtAP:
996a3f881fbSStefano Zampini     case MATPRODUCT_RARt:
997d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABC:
998d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
999d71ae5a4SJacob Faibussowitsch       break;
1000d71ae5a4SJacob Faibussowitsch     default:
1001d71ae5a4SJacob Faibussowitsch       break;
1002a3f881fbSStefano Zampini     }
1003a3f881fbSStefano Zampini   } else { /* fallback for AIJ */
10049566063dSJacob Faibussowitsch     PetscCall(MatProductSetFromOptions_SeqAIJ(mat));
1005a3f881fbSStefano Zampini   }
10063ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1007a3f881fbSStefano Zampini }
1008a587d139SMark 
1009d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatScale_SeqAIJKokkos(Mat A, PetscScalar a)
1010d71ae5a4SJacob Faibussowitsch {
1011f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok;
1012f0cf5187SStefano Zampini 
1013f0cf5187SStefano Zampini   PetscFunctionBegin;
10149566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
10159566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1016f0cf5187SStefano Zampini   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1017d326c3f1SJunchao Zhang   KokkosBlas::scal(PetscGetKokkosExecutionSpace(), aijkok->a_dual.view_device(), a, aijkok->a_dual.view_device());
10189566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
10199566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(aijkok->a_dual.extent(0)));
10209566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
10213ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1022f0cf5187SStefano Zampini }
1023f0cf5187SStefano Zampini 
1024f4747e26SJunchao Zhang // add a to A's diagonal (if A is square) or main diagonal (if A is rectangular)
1025f4747e26SJunchao Zhang static PetscErrorCode MatShift_SeqAIJKokkos(Mat A, PetscScalar a)
1026f4747e26SJunchao Zhang {
1027f4747e26SJunchao Zhang   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1028f4747e26SJunchao Zhang 
1029f4747e26SJunchao Zhang   PetscFunctionBegin;
1030f4747e26SJunchao Zhang   if (A->assembled && aijseq->diagonaldense) { // no missing diagonals
1031f4747e26SJunchao Zhang     PetscInt n = PetscMin(A->rmap->n, A->cmap->n);
1032f4747e26SJunchao Zhang 
1033f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
1034f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosSyncDevice(A));
1035f4747e26SJunchao Zhang     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1036f4747e26SJunchao Zhang     const auto &Aa     = aijkok->a_dual.view_device();
1037f4747e26SJunchao Zhang     const auto &Adiag  = aijkok->diag_dual.view_device();
1038d326c3f1SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) { Aa(Adiag(i)) += a; }));
1039f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosModifyDevice(A));
1040f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(n));
1041f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
1042f4747e26SJunchao Zhang   } else { // need reassembly, very slow!
1043f4747e26SJunchao Zhang     PetscCall(MatShift_Basic(A, a));
1044f4747e26SJunchao Zhang   }
1045f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1046f4747e26SJunchao Zhang }
1047f4747e26SJunchao Zhang 
1048f4747e26SJunchao Zhang static PetscErrorCode MatDiagonalSet_SeqAIJKokkos(Mat Y, Vec D, InsertMode is)
1049f4747e26SJunchao Zhang {
1050f4747e26SJunchao Zhang   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(Y->data);
1051f4747e26SJunchao Zhang 
1052f4747e26SJunchao Zhang   PetscFunctionBegin;
1053f4747e26SJunchao Zhang   if (Y->assembled && aijseq->diagonaldense) { // no missing diagonals
1054f4747e26SJunchao Zhang     ConstPetscScalarKokkosView dv;
1055f4747e26SJunchao Zhang     PetscInt                   n, nv;
1056f4747e26SJunchao Zhang 
1057f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
1058f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosSyncDevice(Y));
1059f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(D, &dv));
1060f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(D, &nv));
1061f4747e26SJunchao Zhang     n = PetscMin(Y->rmap->n, Y->cmap->n);
1062f4747e26SJunchao Zhang     PetscCheck(n == nv, PetscObjectComm((PetscObject)Y), PETSC_ERR_ARG_SIZ, "Matrix size and vector size do not match");
1063f4747e26SJunchao Zhang 
1064f4747e26SJunchao Zhang     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1065f4747e26SJunchao Zhang     const auto &Aa     = aijkok->a_dual.view_device();
1066f4747e26SJunchao Zhang     const auto &Adiag  = aijkok->diag_dual.view_device();
1067f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(
1068d326c3f1SJunchao Zhang       Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) {
1069f4747e26SJunchao Zhang         if (is == INSERT_VALUES) Aa(Adiag(i)) = dv(i);
1070f4747e26SJunchao Zhang         else Aa(Adiag(i)) += dv(i);
1071f4747e26SJunchao Zhang       }));
1072f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(D, &dv));
1073f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1074f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(n));
1075f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
1076f4747e26SJunchao Zhang   } else { // need reassembly, very slow!
1077f4747e26SJunchao Zhang     PetscCall(MatDiagonalSet_Default(Y, D, is));
1078f4747e26SJunchao Zhang   }
1079f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1080f4747e26SJunchao Zhang }
1081f4747e26SJunchao Zhang 
1082f4747e26SJunchao Zhang static PetscErrorCode MatDiagonalScale_SeqAIJKokkos(Mat A, Vec ll, Vec rr)
1083f4747e26SJunchao Zhang {
1084f4747e26SJunchao Zhang   Mat_SeqAIJ                *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1085f4747e26SJunchao Zhang   PetscInt                   m = A->rmap->n, n = A->cmap->n, nz = aijseq->nz;
1086f4747e26SJunchao Zhang   ConstPetscScalarKokkosView lv, rv;
1087f4747e26SJunchao Zhang 
1088f4747e26SJunchao Zhang   PetscFunctionBegin;
1089f4747e26SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1090f4747e26SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1091f4747e26SJunchao Zhang   const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1092f4747e26SJunchao Zhang   const auto &Aa     = aijkok->a_dual.view_device();
1093f4747e26SJunchao Zhang   const auto &Ai     = aijkok->i_dual.view_device();
1094f4747e26SJunchao Zhang   const auto &Aj     = aijkok->j_dual.view_device();
1095f4747e26SJunchao Zhang   if (ll) {
1096f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(ll, &m));
1097f4747e26SJunchao Zhang     PetscCheck(m == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Left scaling vector wrong length");
1098f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(ll, &lv));
1099f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for( // for each row
1100d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1101f4747e26SJunchao Zhang         PetscInt i   = t.league_rank(); // row i
1102f4747e26SJunchao Zhang         PetscInt len = Ai(i + 1) - Ai(i);
1103f4747e26SJunchao Zhang         // scale entries on the row
1104f4747e26SJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt j) { Aa(Ai(i) + j) *= lv(i); });
1105f4747e26SJunchao Zhang       }));
1106f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(ll, &lv));
1107f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(nz));
1108f4747e26SJunchao Zhang   }
1109f4747e26SJunchao Zhang   if (rr) {
1110f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(rr, &n));
1111f4747e26SJunchao Zhang     PetscCheck(n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Right scaling vector wrong length");
1112f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(rr, &rv));
1113f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for( // for each nonzero
1114d326c3f1SJunchao Zhang       Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt k) { Aa(k) *= rv(Aj(k)); }));
1115f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(rr, &lv));
1116f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(nz));
1117f4747e26SJunchao Zhang   }
1118f4747e26SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(A));
1119f4747e26SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1120f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1121f4747e26SJunchao Zhang }
1122f4747e26SJunchao Zhang 
1123d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatZeroEntries_SeqAIJKokkos(Mat A)
1124d71ae5a4SJacob Faibussowitsch {
1125076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1126a587d139SMark 
1127a587d139SMark   PetscFunctionBegin;
1128076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
11292328674fSJunchao Zhang   if (aijkok) { /* Only zero the device if data is already there */
1130d326c3f1SJunchao Zhang     KokkosBlas::fill(PetscGetKokkosExecutionSpace(), aijkok->a_dual.view_device(), 0.0);
11319566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A));
11322328674fSJunchao Zhang   } else { /* Might be preallocated but not assembled */
11339566063dSJacob Faibussowitsch     PetscCall(MatZeroEntries_SeqAIJ(A));
11342328674fSJunchao Zhang   }
11353ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1136a587d139SMark }
1137a587d139SMark 
1138d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatGetDiagonal_SeqAIJKokkos(Mat A, Vec x)
1139d71ae5a4SJacob Faibussowitsch {
1140f78ce678SMark Adams   Mat_SeqAIJKokkos     *aijkok;
1141f78ce678SMark Adams   PetscInt              n;
1142f78ce678SMark Adams   PetscScalarKokkosView xv;
1143f78ce678SMark Adams 
1144f78ce678SMark Adams   PetscFunctionBegin;
1145f78ce678SMark Adams   PetscCall(VecGetLocalSize(x, &n));
1146f78ce678SMark Adams   PetscCheck(n == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Nonconforming matrix and vector");
1147f78ce678SMark Adams   PetscCheck(A->factortype == MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_SUP, "MatGetDiagonal_SeqAIJKokkos not supported on factored matrices");
1148f78ce678SMark Adams 
1149f78ce678SMark Adams   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1150f78ce678SMark Adams   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1151f78ce678SMark Adams 
1152f78ce678SMark Adams   const auto &Aa    = aijkok->a_dual.view_device();
1153f78ce678SMark Adams   const auto &Ai    = aijkok->i_dual.view_device();
1154f78ce678SMark Adams   const auto &Adiag = aijkok->diag_dual.view_device();
1155f78ce678SMark Adams 
1156f78ce678SMark Adams   PetscCall(VecGetKokkosViewWrite(x, &xv));
11579371c9d4SSatish Balay   Kokkos::parallel_for(
1158d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) {
1159f78ce678SMark Adams       if (Adiag(i) < Ai(i + 1)) xv(i) = Aa(Adiag(i));
1160f78ce678SMark Adams       else xv(i) = 0;
1161f78ce678SMark Adams     });
1162f78ce678SMark Adams   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
11633ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1164f78ce678SMark Adams }
1165f78ce678SMark Adams 
1166db78de30SJunchao Zhang /* Get a Kokkos View from a mat of type MatSeqAIJKokkos */
1167d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1168d71ae5a4SJacob Faibussowitsch {
1169db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1170db78de30SJunchao Zhang 
1171db78de30SJunchao Zhang   PetscFunctionBegin;
1172db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11734f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1174db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11759566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1176db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1177076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
11783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1179db78de30SJunchao Zhang }
1180db78de30SJunchao Zhang 
1181d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1182d71ae5a4SJacob Faibussowitsch {
1183db78de30SJunchao Zhang   PetscFunctionBegin;
1184db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11854f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1186db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11873ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1188db78de30SJunchao Zhang }
1189db78de30SJunchao Zhang 
1190d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, MatScalarKokkosView *kv)
1191d71ae5a4SJacob Faibussowitsch {
1192db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1193db78de30SJunchao Zhang 
1194db78de30SJunchao Zhang   PetscFunctionBegin;
1195db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11964f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1197db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11989566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1199db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1200076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
12013ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1202db78de30SJunchao Zhang }
1203db78de30SJunchao Zhang 
1204d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, MatScalarKokkosView *kv)
1205d71ae5a4SJacob Faibussowitsch {
1206db78de30SJunchao Zhang   PetscFunctionBegin;
1207db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12084f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1209db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12109566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
12113ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1212db78de30SJunchao Zhang }
1213db78de30SJunchao Zhang 
1214d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1215d71ae5a4SJacob Faibussowitsch {
1216db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1217db78de30SJunchao Zhang 
1218db78de30SJunchao Zhang   PetscFunctionBegin;
1219db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12204f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1221db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1222db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1223076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
12243ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1225db78de30SJunchao Zhang }
1226db78de30SJunchao Zhang 
1227d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1228d71ae5a4SJacob Faibussowitsch {
1229db78de30SJunchao Zhang   PetscFunctionBegin;
1230db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12314f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1232db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12339566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
12343ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1235db78de30SJunchao Zhang }
1236db78de30SJunchao Zhang 
1237c0c276a7Ssdargavi PetscErrorCode MatCreateSeqAIJKokkosWithKokkosViews(MPI_Comm comm, PetscInt m, PetscInt n, Kokkos::View<PetscInt *> &i_d, Kokkos::View<PetscInt *> &j_d, Kokkos::View<PetscScalar *> &a_d, Mat *A)
1238c0c276a7Ssdargavi {
1239c0c276a7Ssdargavi   Mat_SeqAIJKokkos *akok;
1240c0c276a7Ssdargavi 
1241c0c276a7Ssdargavi   PetscFunctionBegin;
1242c0c276a7Ssdargavi   auto exec = PetscGetKokkosExecutionSpace();
1243c0c276a7Ssdargavi   // Create host copies of the input aij
1244c0c276a7Ssdargavi   auto i_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), i_d);
1245c0c276a7Ssdargavi   auto j_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), j_d);
1246c0c276a7Ssdargavi   // Don't copy the vals to the host now
1247c0c276a7Ssdargavi   auto a_h = Kokkos::create_mirror_view(HostMirrorMemorySpace(), a_d);
1248c0c276a7Ssdargavi 
1249c0c276a7Ssdargavi   MatScalarKokkosDualView a_dual = MatScalarKokkosDualView(a_d, a_h);
1250c0c276a7Ssdargavi   // Note we have modified device data so it will copy lazily
1251c0c276a7Ssdargavi   a_dual.modify_device();
1252c0c276a7Ssdargavi   MatRowMapKokkosDualView i_dual = MatRowMapKokkosDualView(i_d, i_h);
1253c0c276a7Ssdargavi   MatColIdxKokkosDualView j_dual = MatColIdxKokkosDualView(j_d, j_h);
1254c0c276a7Ssdargavi 
1255c0c276a7Ssdargavi   PetscCallCXX(akok = new Mat_SeqAIJKokkos(m, n, j_dual.extent(0), i_dual, j_dual, a_dual));
1256c0c276a7Ssdargavi   PetscCall(MatCreate(comm, A));
1257c0c276a7Ssdargavi   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
1258c0c276a7Ssdargavi   PetscFunctionReturn(PETSC_SUCCESS);
1259c0c276a7Ssdargavi }
1260c0c276a7Ssdargavi 
1261c17cf699SJunchao Zhang /* Computes Y += alpha X */
1262d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAXPY_SeqAIJKokkos(Mat Y, PetscScalar alpha, Mat X, MatStructure pattern)
1263d71ae5a4SJacob Faibussowitsch {
1264a587d139SMark   Mat_SeqAIJ              *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
1265c17cf699SJunchao Zhang   Mat_SeqAIJKokkos        *xkok, *ykok, *zkok;
1266c17cf699SJunchao Zhang   ConstMatScalarKokkosView Xa;
1267c17cf699SJunchao Zhang   MatScalarKokkosView      Ya;
12684df4a32cSJunchao Zhang   auto                     exec = PetscGetKokkosExecutionSpace();
1269a587d139SMark 
1270a587d139SMark   PetscFunctionBegin;
1271c17cf699SJunchao Zhang   PetscCheckTypeName(Y, MATSEQAIJKOKKOS);
1272c17cf699SJunchao Zhang   PetscCheckTypeName(X, MATSEQAIJKOKKOS);
12739566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(Y));
12749566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(X));
12759566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1276db78de30SJunchao Zhang 
1277c17cf699SJunchao Zhang   if (pattern != SAME_NONZERO_PATTERN && x->nz == y->nz) {
1278a587d139SMark     PetscBool e;
12799566063dSJacob Faibussowitsch     PetscCall(PetscArraycmp(x->i, y->i, Y->rmap->n + 1, &e));
1280a587d139SMark     if (e) {
12819566063dSJacob Faibussowitsch       PetscCall(PetscArraycmp(x->j, y->j, y->nz, &e));
1282c17cf699SJunchao Zhang       if (e) pattern = SAME_NONZERO_PATTERN;
1283a587d139SMark     }
1284a587d139SMark   }
1285db78de30SJunchao Zhang 
1286c17cf699SJunchao Zhang   /* cusparseDcsrgeam2() computes C = alpha A + beta B. If one knew sparsity pattern of C, one can skip
1287c17cf699SJunchao Zhang     cusparseScsrgeam2_bufferSizeExt() / cusparseXcsrgeam2Nnz(), and directly call cusparseScsrgeam2().
1288c17cf699SJunchao Zhang     If X is SUBSET_NONZERO_PATTERN of Y, we could take advantage of this cusparse feature. However,
1289c17cf699SJunchao Zhang     KokkosSparse::spadd(alpha,A,beta,B,C) has symbolic and numeric phases, MatAXPY does not.
1290c17cf699SJunchao Zhang   */
1291c17cf699SJunchao Zhang   ykok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1292c17cf699SJunchao Zhang   xkok = static_cast<Mat_SeqAIJKokkos *>(X->spptr);
1293c17cf699SJunchao Zhang   Xa   = xkok->a_dual.view_device();
1294c17cf699SJunchao Zhang   Ya   = ykok->a_dual.view_device();
1295c17cf699SJunchao Zhang 
1296c17cf699SJunchao Zhang   if (pattern == SAME_NONZERO_PATTERN) {
1297d326c3f1SJunchao Zhang     KokkosBlas::axpy(exec, alpha, Xa, Ya);
12989566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1299c17cf699SJunchao Zhang   } else if (pattern == SUBSET_NONZERO_PATTERN) {
1300c17cf699SJunchao Zhang     MatRowMapKokkosView Xi = xkok->i_dual.view_device(), Yi = ykok->i_dual.view_device();
1301c17cf699SJunchao Zhang     MatColIdxKokkosView Xj = xkok->j_dual.view_device(), Yj = ykok->j_dual.view_device();
1302c17cf699SJunchao Zhang 
13039371c9d4SSatish Balay     Kokkos::parallel_for(
1304d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(exec, Y->rmap->n, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
13050e3ece09SJunchao Zhang         PetscInt i = t.league_rank(); // row i
13060e3ece09SJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() {
13070e3ece09SJunchao Zhang           // Only one thread works in a team
1308c17cf699SJunchao Zhang           PetscInt p, q = Yi(i);
13090e3ece09SJunchao Zhang           for (p = Xi(i); p < Xi(i + 1); p++) {          // For each nonzero on row i of X,
13100e3ece09SJunchao Zhang             while (Xj(p) != Yj(q) && q < Yi(i + 1)) q++; // find the matching nonzero on row i of Y.
13110e3ece09SJunchao Zhang             if (Xj(p) == Yj(q)) {                        // Found it
1312c17cf699SJunchao Zhang               Ya(q) += alpha * Xa(p);
1313c17cf699SJunchao Zhang               q++;
1314a587d139SMark             } else {
13150e3ece09SJunchao Zhang             // If not found, it indicates the input is wrong (X is not a SUBSET_NONZERO_PATTERN of Y).
13160e3ece09SJunchao Zhang             // Just insert a NaN at the beginning of row i if it is not empty, to make the result wrong.
13170e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_VERSION_GE(3, 7, 0)
13180e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::ArithTraits<PetscScalar>::nan();
13198b8b16f9SJunchao Zhang #else
13200e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::Experimental::nan("1");
13218b8b16f9SJunchao Zhang #endif
1322a587d139SMark             }
1323c17cf699SJunchao Zhang           }
1324c17cf699SJunchao Zhang         });
1325c17cf699SJunchao Zhang       });
13269566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
13270e3ece09SJunchao Zhang   } else { // different nonzero patterns
1328c17cf699SJunchao Zhang     Mat             Z;
1329c17cf699SJunchao Zhang     KokkosCsrMatrix zcsr;
1330c17cf699SJunchao Zhang     KernelHandle    kh;
13310e3ece09SJunchao Zhang     kh.create_spadd_handle(true); // X, Y are sorted
1332c17cf699SJunchao Zhang     KokkosSparse::spadd_symbolic(&kh, xkok->csrmat, ykok->csrmat, zcsr);
1333c17cf699SJunchao Zhang     KokkosSparse::spadd_numeric(&kh, alpha, xkok->csrmat, (PetscScalar)1.0, ykok->csrmat, zcsr);
1334c17cf699SJunchao Zhang     zkok = new Mat_SeqAIJKokkos(zcsr);
13359566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, zkok, &Z));
13369566063dSJacob Faibussowitsch     PetscCall(MatHeaderReplace(Y, &Z));
1337c17cf699SJunchao Zhang     kh.destroy_spadd_handle();
1338c17cf699SJunchao Zhang   }
13399566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
13400e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(xkok->a_dual.extent(0) * 2)); // Because we scaled X and then added it to Y
13413ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1342a587d139SMark }
1343a587d139SMark 
13442c4ab24aSJunchao Zhang struct MatCOOStruct_SeqAIJKokkos {
13452c4ab24aSJunchao Zhang   PetscCount           n;
13462c4ab24aSJunchao Zhang   PetscCount           Atot;
13472c4ab24aSJunchao Zhang   PetscInt             nz;
13482c4ab24aSJunchao Zhang   PetscCountKokkosView jmap;
13492c4ab24aSJunchao Zhang   PetscCountKokkosView perm;
13502c4ab24aSJunchao Zhang 
13512c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos(const MatCOOStruct_SeqAIJ *coo_h)
13522c4ab24aSJunchao Zhang   {
13532c4ab24aSJunchao Zhang     nz   = coo_h->nz;
13542c4ab24aSJunchao Zhang     n    = coo_h->n;
13552c4ab24aSJunchao Zhang     Atot = coo_h->Atot;
13562c4ab24aSJunchao Zhang     jmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->jmap, nz + 1));
13572c4ab24aSJunchao Zhang     perm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->perm, Atot));
13582c4ab24aSJunchao Zhang   }
13592c4ab24aSJunchao Zhang };
13602c4ab24aSJunchao Zhang 
136149abdd8aSBarry Smith static PetscErrorCode MatCOOStructDestroy_SeqAIJKokkos(void **data)
13622c4ab24aSJunchao Zhang {
13632c4ab24aSJunchao Zhang   PetscFunctionBegin;
136449abdd8aSBarry Smith   PetscCallCXX(delete static_cast<MatCOOStruct_SeqAIJKokkos *>(*data));
13652c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
13662c4ab24aSJunchao Zhang }
13672c4ab24aSJunchao Zhang 
1368d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_SeqAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1369d71ae5a4SJacob Faibussowitsch {
137042550becSJunchao Zhang   Mat_SeqAIJKokkos          *akok;
137142550becSJunchao Zhang   Mat_SeqAIJ                *aseq;
137203e76207SPierre Jolivet   PetscContainer             container_h;
13732c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJ       *coo_h;
13742c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo_d;
137542550becSJunchao Zhang 
137642550becSJunchao Zhang   PetscFunctionBegin;
13779566063dSJacob Faibussowitsch   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, coo_i, coo_j));
1378394ed5ebSJunchao Zhang   aseq = static_cast<Mat_SeqAIJ *>(mat->data);
137942550becSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(mat->spptr);
1380cbc6b225SStefano Zampini   delete akok;
1381f4747e26SJunchao Zhang   mat->spptr = akok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, aseq, mat->nonzerostate + 1, PETSC_FALSE);
13829566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries_SeqAIJKokkos(mat));
13832c4ab24aSJunchao Zhang 
13842c4ab24aSJunchao Zhang   // Copy the COO struct to device
13852c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
13862c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
13872c4ab24aSJunchao Zhang   PetscCallCXX(coo_d = new MatCOOStruct_SeqAIJKokkos(coo_h));
13882c4ab24aSJunchao Zhang 
13892c4ab24aSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
139003e76207SPierre Jolivet   PetscCall(PetscObjectContainerCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", coo_d, MatCOOStructDestroy_SeqAIJKokkos));
13913ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
139242550becSJunchao Zhang }
139342550becSJunchao Zhang 
1394d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_SeqAIJKokkos(Mat A, const PetscScalar v[], InsertMode imode)
1395d71ae5a4SJacob Faibussowitsch {
139642550becSJunchao Zhang   MatScalarKokkosView        Aa;
139742550becSJunchao Zhang   ConstMatScalarKokkosView   kv;
139842550becSJunchao Zhang   PetscMemType               memtype;
13992c4ab24aSJunchao Zhang   PetscContainer             container;
14002c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo;
140142550becSJunchao Zhang 
140242550becSJunchao Zhang   PetscFunctionBegin;
14032c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)A, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
14042c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
14052c4ab24aSJunchao Zhang 
14062c4ab24aSJunchao Zhang   const auto &n    = coo->n;
14072c4ab24aSJunchao Zhang   const auto &Annz = coo->nz;
14082c4ab24aSJunchao Zhang   const auto &jmap = coo->jmap;
14092c4ab24aSJunchao Zhang   const auto &perm = coo->perm;
14102c4ab24aSJunchao Zhang 
14119566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype));
141242550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
14132c4ab24aSJunchao Zhang     kv = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatScalarKokkosViewHost(v, n));
141442550becSJunchao Zhang   } else {
14152c4ab24aSJunchao Zhang     kv = ConstMatScalarKokkosView(v, n); /* Directly use v[]'s memory */
141642550becSJunchao Zhang   }
141742550becSJunchao Zhang 
1418c7b718f4SJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1419c7b718f4SJunchao Zhang   else PetscCall(MatSeqAIJGetKokkosView(A, &Aa));                             /* read & write matrix values */
142042550becSJunchao Zhang 
142108bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
14229371c9d4SSatish Balay   Kokkos::parallel_for(
1423d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz), KOKKOS_LAMBDA(const PetscCount i) {
1424c7b718f4SJunchao Zhang       PetscScalar sum = 0.0;
1425c7b718f4SJunchao Zhang       for (PetscCount k = jmap(i); k < jmap(i + 1); k++) sum += kv(perm(k));
1426c7b718f4SJunchao Zhang       Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1427c7b718f4SJunchao Zhang     });
142808bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1429394ed5ebSJunchao Zhang 
14309566063dSJacob Faibussowitsch   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
14319566063dSJacob Faibussowitsch   else PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
14323ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
143342550becSJunchao Zhang }
143442550becSJunchao Zhang 
1435d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat A)
1436d71ae5a4SJacob Faibussowitsch {
1437076ba34aSJunchao Zhang   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
1438076ba34aSJunchao Zhang 
14398c3ff71bSJunchao Zhang   PetscFunctionBegin;
1440076ba34aSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_KOKKOS; /* We do not really use this flag */
14416f3d89d0SStefano Zampini   A->boundtocpu  = PETSC_FALSE;
14426f3d89d0SStefano Zampini 
14438c3ff71bSJunchao Zhang   A->ops->assemblyend               = MatAssemblyEnd_SeqAIJKokkos;
14448c3ff71bSJunchao Zhang   A->ops->destroy                   = MatDestroy_SeqAIJKokkos;
14458c3ff71bSJunchao Zhang   A->ops->duplicate                 = MatDuplicate_SeqAIJKokkos;
1446a587d139SMark   A->ops->axpy                      = MatAXPY_SeqAIJKokkos;
1447f0cf5187SStefano Zampini   A->ops->scale                     = MatScale_SeqAIJKokkos;
1448a587d139SMark   A->ops->zeroentries               = MatZeroEntries_SeqAIJKokkos;
1449076ba34aSJunchao Zhang   A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJKokkos;
14508c3ff71bSJunchao Zhang   A->ops->mult                      = MatMult_SeqAIJKokkos;
14518c3ff71bSJunchao Zhang   A->ops->multadd                   = MatMultAdd_SeqAIJKokkos;
14528c3ff71bSJunchao Zhang   A->ops->multtranspose             = MatMultTranspose_SeqAIJKokkos;
14538c3ff71bSJunchao Zhang   A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJKokkos;
14548c3ff71bSJunchao Zhang   A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJKokkos;
14558c3ff71bSJunchao Zhang   A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJKokkos;
1456076ba34aSJunchao Zhang   A->ops->productnumeric            = MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos;
14570ecb592aSJunchao Zhang   A->ops->transpose                 = MatTranspose_SeqAIJKokkos;
1458152b3e56SJunchao Zhang   A->ops->setoption                 = MatSetOption_SeqAIJKokkos;
1459f78ce678SMark Adams   A->ops->getdiagonal               = MatGetDiagonal_SeqAIJKokkos;
1460f4747e26SJunchao Zhang   A->ops->shift                     = MatShift_SeqAIJKokkos;
1461f4747e26SJunchao Zhang   A->ops->diagonalset               = MatDiagonalSet_SeqAIJKokkos;
1462f4747e26SJunchao Zhang   A->ops->diagonalscale             = MatDiagonalScale_SeqAIJKokkos;
1463076ba34aSJunchao Zhang   a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJKokkos;
1464076ba34aSJunchao Zhang   a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJKokkos;
1465076ba34aSJunchao Zhang   a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJKokkos;
1466076ba34aSJunchao Zhang   a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJKokkos;
1467076ba34aSJunchao Zhang   a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJKokkos;
1468076ba34aSJunchao Zhang   a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJKokkos;
14697ee59b9bSJunchao Zhang   a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJKokkos;
147042550becSJunchao Zhang 
14719566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJKokkos));
14729566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJKokkos));
147357761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
147457761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijkokkos_hypre_C", MatConvert_AIJ_HYPRE));
147557761e9aSJunchao Zhang #endif
14763ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1477076ba34aSJunchao Zhang }
1478076ba34aSJunchao Zhang 
14799d13fa56SJunchao Zhang /*
14809d13fa56SJunchao Zhang    Extract the (prescribled) diagonal blocks of the matrix and then invert them
14819d13fa56SJunchao Zhang 
14829d13fa56SJunchao Zhang   Input Parameters:
14839d13fa56SJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
14849d13fa56SJunchao Zhang .  bs      - block sizes in 'csr' format, i.e., the i-th block has size bs(i+1) - bs(i)
14859d13fa56SJunchao Zhang .  bs2     - square of block sizes in 'csr' format, i.e., the i-th block should be stored at offset bs2(i) in diagVal[]
14869d13fa56SJunchao Zhang .  blkMap  - map row ids to block ids, i.e., row i belongs to the block blkMap(i)
14879d13fa56SJunchao Zhang -  work    - a pre-allocated work buffer (as big as diagVal) for use by this routine
14889d13fa56SJunchao Zhang 
14899d13fa56SJunchao Zhang   Output Parameter:
14909d13fa56SJunchao Zhang .  diagVal - the (pre-allocated) buffer to store the inverted blocks (each block is stored in column-major order)
14919d13fa56SJunchao Zhang */
14929d13fa56SJunchao Zhang PETSC_INTERN PetscErrorCode MatInvertVariableBlockDiagonal_SeqAIJKokkos(Mat A, const PetscIntKokkosView &bs, const PetscIntKokkosView &bs2, const PetscIntKokkosView &blkMap, PetscScalarKokkosView &work, PetscScalarKokkosView &diagVal)
14939d13fa56SJunchao Zhang {
14949d13fa56SJunchao Zhang   Mat_SeqAIJKokkos *akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
14959d13fa56SJunchao Zhang   PetscInt          nblocks = bs.extent(0) - 1;
14969d13fa56SJunchao Zhang 
14979d13fa56SJunchao Zhang   PetscFunctionBegin;
14989d13fa56SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A)); // Since we'll access A's value on device
14999d13fa56SJunchao Zhang 
15009d13fa56SJunchao Zhang   // Pull out the diagonal blocks of the matrix and then invert the blocks
15019d13fa56SJunchao Zhang   auto Aa    = akok->a_dual.view_device();
15029d13fa56SJunchao Zhang   auto Ai    = akok->i_dual.view_device();
15039d13fa56SJunchao Zhang   auto Aj    = akok->j_dual.view_device();
15049d13fa56SJunchao Zhang   auto Adiag = akok->diag_dual.view_device();
15059d13fa56SJunchao Zhang   // TODO: how to tune the team size?
150645402d8aSJunchao Zhang #if defined(KOKKOS_ENABLE_UNIFIED_MEMORY)
15079d13fa56SJunchao Zhang   auto ts = Kokkos::AUTO();
15089d13fa56SJunchao Zhang #else
15099d13fa56SJunchao Zhang   auto ts = 16; // improved performance 30% over Kokkos::AUTO() with CUDA, but failed with "Kokkos::abort: Requested Team Size is too large!" on CPUs
15109d13fa56SJunchao Zhang #endif
15119d13fa56SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1512d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), nblocks, ts), KOKKOS_LAMBDA(const KokkosTeamMemberType &teamMember) {
15139d13fa56SJunchao Zhang       const PetscInt bid    = teamMember.league_rank();                                                   // block id
15149d13fa56SJunchao Zhang       const PetscInt rstart = bs(bid);                                                                    // this block starts from this row
15159d13fa56SJunchao Zhang       const PetscInt m      = bs(bid + 1) - bs(bid);                                                      // size of this block
15169d13fa56SJunchao Zhang       const auto    &B      = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft>(&diagVal(bs2(bid)), m, m); // column-major order
15179d13fa56SJunchao Zhang       const auto    &W      = PetscScalarKokkosView(&work(bs2(bid)), m * m);
15189d13fa56SJunchao Zhang 
15199d13fa56SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(teamMember, m), [=](const PetscInt &r) { // r-th row in B
15209d13fa56SJunchao Zhang         PetscInt i = rstart + r;                                                            // i-th row in A
15219d13fa56SJunchao Zhang 
15229d13fa56SJunchao Zhang         if (Ai(i) <= Adiag(i) && Adiag(i) < Ai(i + 1)) { // if the diagonal exists (common case)
15239d13fa56SJunchao Zhang           PetscInt first = Adiag(i) - r;                 // we start to check nonzeros from here along this row
15249d13fa56SJunchao Zhang 
15259d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) {                   // walk n steps to see what column indices we will meet
15269d13fa56SJunchao Zhang             if (first + c < Ai(i) || first + c >= Ai(i + 1)) { // this entry (first+c) is out of range of this row, in other words, its value is zero
15279d13fa56SJunchao Zhang               B(r, c) = 0.0;
15289d13fa56SJunchao Zhang             } else if (Aj(first + c) == rstart + c) { // this entry is right on the (rstart+c) column
15299d13fa56SJunchao Zhang               B(r, c) = Aa(first + c);
15309d13fa56SJunchao Zhang             } else { // this entry does not show up in the CSR
15319d13fa56SJunchao Zhang               B(r, c) = 0.0;
15329d13fa56SJunchao Zhang             }
15339d13fa56SJunchao Zhang           }
15349d13fa56SJunchao Zhang         } else { // rare case that the diagonal does not exist
15359d13fa56SJunchao Zhang           const PetscInt begin = Ai(i);
15369d13fa56SJunchao Zhang           const PetscInt end   = Ai(i + 1);
15379d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) B(r, c) = 0.0;
15389d13fa56SJunchao Zhang           for (PetscInt j = begin; j < end; j++) { // scan the whole row; could use binary search but this is a rare case so we did not.
15399d13fa56SJunchao Zhang             if (rstart <= Aj(j) && Aj(j) < rstart + m) B(r, Aj(j) - rstart) = Aa(j);
15409d13fa56SJunchao Zhang             else if (Aj(j) >= rstart + m) break;
15419d13fa56SJunchao Zhang           }
15429d13fa56SJunchao Zhang         }
15439d13fa56SJunchao Zhang       });
15449d13fa56SJunchao Zhang 
15459d13fa56SJunchao Zhang       // LU-decompose B (w/o pivoting) and then invert B
15469d13fa56SJunchao Zhang       KokkosBatched::TeamLU<KokkosTeamMemberType, KokkosBatched::Algo::LU::Unblocked>::invoke(teamMember, B, 0.0);
15479d13fa56SJunchao Zhang       KokkosBatched::TeamInverseLU<KokkosTeamMemberType, KokkosBatched::Algo::InverseLU::Unblocked>::invoke(teamMember, B, W);
15489d13fa56SJunchao Zhang     }));
15499d13fa56SJunchao Zhang   // PetscLogGpuFlops() is done in the caller PCSetUp_VPBJacobi_Kokkos as we don't want to compute the flops in kernels
15509d13fa56SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15519d13fa56SJunchao Zhang }
15529d13fa56SJunchao Zhang 
1553d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSetSeqAIJKokkosWithCSRMatrix(Mat A, Mat_SeqAIJKokkos *akok)
1554d71ae5a4SJacob Faibussowitsch {
1555076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
1556076ba34aSJunchao Zhang   PetscInt    i, m, n;
15574df4a32cSJunchao Zhang   auto        exec = PetscGetKokkosExecutionSpace();
1558076ba34aSJunchao Zhang 
1559076ba34aSJunchao Zhang   PetscFunctionBegin;
15605f80ce2aSJacob Faibussowitsch   PetscCheck(!A->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A->spptr is supposed to be empty");
1561076ba34aSJunchao Zhang 
1562076ba34aSJunchao Zhang   m = akok->nrows();
1563076ba34aSJunchao Zhang   n = akok->ncols();
15649566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(A, m, n, m, n));
15659566063dSJacob Faibussowitsch   PetscCall(MatSetType(A, MATSEQAIJKOKKOS));
1566076ba34aSJunchao Zhang 
1567076ba34aSJunchao Zhang   /* Set up data structures of A as a MATSEQAIJ */
15689566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(A, MAT_SKIP_ALLOCATION, NULL));
156957508eceSPierre Jolivet   aseq = (Mat_SeqAIJ *)A->data;
1570076ba34aSJunchao Zhang 
1571*f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncHost(akok->i_dual, exec)); /* We always need sync'ed i, j on host */
1572*f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncHost(akok->j_dual, exec));
1573076ba34aSJunchao Zhang 
1574076ba34aSJunchao Zhang   aseq->i       = akok->i_host_data();
1575076ba34aSJunchao Zhang   aseq->j       = akok->j_host_data();
1576076ba34aSJunchao Zhang   aseq->a       = akok->a_host_data();
1577076ba34aSJunchao Zhang   aseq->nonew   = -1; /*this indicates that inserting a new value in the matrix that generates a new nonzero is an error*/
1578076ba34aSJunchao Zhang   aseq->free_a  = PETSC_FALSE;
1579076ba34aSJunchao Zhang   aseq->free_ij = PETSC_FALSE;
1580076ba34aSJunchao Zhang   aseq->nz      = akok->nnz();
1581076ba34aSJunchao Zhang   aseq->maxnz   = aseq->nz;
1582076ba34aSJunchao Zhang 
15839566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->imax));
15849566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->ilen));
1585ad540459SPierre Jolivet   for (i = 0; i < m; i++) aseq->ilen[i] = aseq->imax[i] = aseq->i[i + 1] - aseq->i[i];
1586076ba34aSJunchao Zhang 
1587076ba34aSJunchao Zhang   /* It is critical to set the nonzerostate, as we use it to check if sparsity pattern (hence data) has changed on host in MatAssemblyEnd */
1588076ba34aSJunchao Zhang   akok->nonzerostate = A->nonzerostate;
1589ff751488SJunchao Zhang   A->spptr           = akok; /* Set A->spptr before MatAssembly so that A->spptr won't be allocated again there */
15909566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
15919566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
15923ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1593076ba34aSJunchao Zhang }
1594076ba34aSJunchao Zhang 
15950e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetKokkosCsrMatrix(Mat A, KokkosCsrMatrix *csr)
15960e3ece09SJunchao Zhang {
15970e3ece09SJunchao Zhang   PetscFunctionBegin;
15980e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
15990e3ece09SJunchao Zhang   *csr = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
16000e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
16010e3ece09SJunchao Zhang }
16020e3ece09SJunchao Zhang 
16030e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithKokkosCsrMatrix(MPI_Comm comm, KokkosCsrMatrix csr, Mat *A)
16040e3ece09SJunchao Zhang {
16050e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok;
16064d86920dSPierre Jolivet 
16070e3ece09SJunchao Zhang   PetscFunctionBegin;
16080e3ece09SJunchao Zhang   PetscCallCXX(akok = new Mat_SeqAIJKokkos(csr));
16090e3ece09SJunchao Zhang   PetscCall(MatCreate(comm, A));
16100e3ece09SJunchao Zhang   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
16110e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
16120e3ece09SJunchao Zhang }
16130e3ece09SJunchao Zhang 
1614076ba34aSJunchao Zhang /* Crete a SEQAIJKOKKOS matrix with a Mat_SeqAIJKokkos data structure
1615076ba34aSJunchao Zhang 
1616076ba34aSJunchao Zhang    Note we have names like MatSeqAIJSetPreallocationCSR, so I use capitalized CSR
1617076ba34aSJunchao Zhang  */
1618d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithCSRMatrix(MPI_Comm comm, Mat_SeqAIJKokkos *akok, Mat *A)
1619d71ae5a4SJacob Faibussowitsch {
1620076ba34aSJunchao Zhang   PetscFunctionBegin;
16219566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
16229566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
16233ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16248c3ff71bSJunchao Zhang }
16258c3ff71bSJunchao Zhang 
1626152b3e56SJunchao Zhang /*@C
162711a5261eSBarry Smith   MatCreateSeqAIJKokkos - Creates a sparse matrix in `MATSEQAIJKOKKOS` (compressed row) format
16288c3ff71bSJunchao Zhang   (the default parallel PETSc format). This matrix will ultimately be handled by
162920f4b53cSBarry Smith   Kokkos for calculations.
16308c3ff71bSJunchao Zhang 
16318c3ff71bSJunchao Zhang   Collective
16328c3ff71bSJunchao Zhang 
16338c3ff71bSJunchao Zhang   Input Parameters:
163411a5261eSBarry Smith + comm - MPI communicator, set to `PETSC_COMM_SELF`
16358c3ff71bSJunchao Zhang . m    - number of rows
16368c3ff71bSJunchao Zhang . n    - number of columns
163720f4b53cSBarry Smith . nz   - number of nonzeros per row (same for all rows), ignored if `nnz` is provided
163820f4b53cSBarry Smith - nnz  - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL`
16398c3ff71bSJunchao Zhang 
16408c3ff71bSJunchao Zhang   Output Parameter:
16418c3ff71bSJunchao Zhang . A - the matrix
16428c3ff71bSJunchao Zhang 
16432ef1f0ffSBarry Smith   Level: intermediate
16442ef1f0ffSBarry Smith 
16452ef1f0ffSBarry Smith   Notes:
164611a5261eSBarry Smith   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
16478c3ff71bSJunchao Zhang   MatXXXXSetPreallocation() paradgm instead of this routine directly.
164811a5261eSBarry Smith   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
16498c3ff71bSJunchao Zhang 
165011a5261eSBarry Smith   The AIJ format, also called
16512ef1f0ffSBarry Smith   compressed row storage, is fully compatible with standard Fortran
16528c3ff71bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
165320f4b53cSBarry Smith   either one (as in Fortran) or zero.
16548c3ff71bSJunchao Zhang 
16552ef1f0ffSBarry Smith   Specify the preallocated storage with either `nz` or `nnz` (not both).
16562ef1f0ffSBarry Smith   Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
16572ef1f0ffSBarry Smith   allocation.
16588c3ff71bSJunchao Zhang 
1659fe59aa6dSJacob Faibussowitsch .seealso: [](ch_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
16608c3ff71bSJunchao Zhang @*/
1661d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateSeqAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
1662d71ae5a4SJacob Faibussowitsch {
16638c3ff71bSJunchao Zhang   PetscFunctionBegin;
16649566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
16659566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
16669566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, m, n));
16679566063dSJacob Faibussowitsch   PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
16689566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
16693ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16708c3ff71bSJunchao Zhang }
1671930e68a5SMark Adams 
1672aac854edSJunchao Zhang // After matrix numeric factorization, there are still steps to do before triangular solve can be called.
1673aac854edSJunchao Zhang // For example, for transpose solve, we might need to compute the transpose matrices if the solver does not support it (such as KK, while cusparse does).
1674aac854edSJunchao Zhang // In cusparse, one has to call cusparseSpSV_analysis() with updated triangular matrix values before calling cusparseSpSV_solve().
1675aac854edSJunchao Zhang // Simiarily, in KK sptrsv_symbolic() has to be called before sptrsv_solve(). We put these steps in MatSeqAIJKokkos{Transpose}SolveCheck.
1676aac854edSJunchao Zhang static PetscErrorCode MatSeqAIJKokkosSolveCheck(Mat A)
1677d71ae5a4SJacob Faibussowitsch {
167886a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors   = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1679aac854edSJunchao Zhang   const PetscBool             has_lower = factors->iL_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // false with Choleksy
1680aac854edSJunchao Zhang   const PetscBool             has_upper = factors->iU_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // true with LU and Choleksy
168186a27549SJunchao Zhang 
168286a27549SJunchao Zhang   PetscFunctionBegin;
1683aac854edSJunchao Zhang   if (!factors->sptrsv_symbolic_completed) { // If sptrsv_symbolic was not called yet
1684aac854edSJunchao Zhang     if (has_upper) PetscCallCXX(sptrsv_symbolic(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d));
1685aac854edSJunchao Zhang     if (has_lower) PetscCallCXX(sptrsv_symbolic(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d));
168686a27549SJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_TRUE;
168786a27549SJunchao Zhang   }
16883ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
168986a27549SJunchao Zhang }
169086a27549SJunchao Zhang 
1691d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosTransposeSolveCheck(Mat A)
1692d71ae5a4SJacob Faibussowitsch {
1693aac854edSJunchao Zhang   const PetscInt              n         = A->rmap->n;
169486a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors   = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1695aac854edSJunchao Zhang   const PetscBool             has_lower = factors->iL_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // false with Choleksy
1696aac854edSJunchao Zhang   const PetscBool             has_upper = factors->iU_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // true with LU or Choleksy
169786a27549SJunchao Zhang 
169886a27549SJunchao Zhang   PetscFunctionBegin;
1699aac854edSJunchao Zhang   if (!factors->transpose_updated) {
1700aac854edSJunchao Zhang     if (has_upper) {
1701aac854edSJunchao Zhang       if (!factors->iUt_d.extent(0)) {                                 // Allocate Ut on device if not yet
1702aac854edSJunchao Zhang         factors->iUt_d = MatRowMapKokkosView("factors->iUt_d", n + 1); // KK requires this view to be initialized to 0 to call transpose_matrix
17037b8d4ba6SJunchao Zhang         factors->jUt_d = MatColIdxKokkosView(NoInit("factors->jUt_d"), factors->jU_d.extent(0));
17047b8d4ba6SJunchao Zhang         factors->aUt_d = MatScalarKokkosView(NoInit("factors->aUt_d"), factors->aU_d.extent(0));
1705aac854edSJunchao Zhang       }
170686a27549SJunchao Zhang 
1707aac854edSJunchao Zhang       if (factors->iU_h.extent(0)) { // If U is on host (factorization was done on host), we also compute the transpose on host
1708aac854edSJunchao Zhang         if (!factors->U) {
1709aac854edSJunchao Zhang           Mat_SeqAIJ *seq;
171086a27549SJunchao Zhang 
1711aac854edSJunchao Zhang           PetscCall(MatCreateSeqAIJWithArrays(PETSC_COMM_SELF, n, n, factors->iU_h.data(), factors->jU_h.data(), factors->aU_h.data(), &factors->U));
1712aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->U, MAT_INITIAL_MATRIX, &factors->Ut));
171386a27549SJunchao Zhang 
1714aac854edSJunchao Zhang           seq            = static_cast<Mat_SeqAIJ *>(factors->Ut->data);
1715aac854edSJunchao Zhang           factors->iUt_h = MatRowMapKokkosViewHost(seq->i, n + 1);
1716aac854edSJunchao Zhang           factors->jUt_h = MatColIdxKokkosViewHost(seq->j, seq->nz);
1717aac854edSJunchao Zhang           factors->aUt_h = MatScalarKokkosViewHost(seq->a, seq->nz);
1718aac854edSJunchao Zhang         } else {
1719aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->U, MAT_REUSE_MATRIX, &factors->Ut)); // Matrix Ut' data is aliased with {i, j, a}Ut_h
1720aac854edSJunchao Zhang         }
1721aac854edSJunchao Zhang         // Copy Ut from host to device
1722aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->iUt_d, factors->iUt_h));
1723aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->jUt_d, factors->jUt_h));
1724aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->aUt_d, factors->aUt_h));
1725aac854edSJunchao Zhang       } else { // If U was computed on device, we also compute the transpose there
1726aac854edSJunchao Zhang         // TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices. We have to sort the indices, until KK provides finer control options.
1727aac854edSJunchao Zhang         PetscCallCXX(transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iU_d,
1728aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->jU_d, factors->aU_d,
1729aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->iUt_d, factors->jUt_d,
1730aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->aUt_d));
1731aac854edSJunchao Zhang         PetscCallCXX(sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iUt_d, factors->jUt_d, factors->aUt_d));
1732aac854edSJunchao Zhang       }
1733aac854edSJunchao Zhang       PetscCallCXX(sptrsv_symbolic(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d));
1734aac854edSJunchao Zhang     }
1735aac854edSJunchao Zhang 
1736aac854edSJunchao Zhang     // do the same for L with LU
1737aac854edSJunchao Zhang     if (has_lower) {
1738aac854edSJunchao Zhang       if (!factors->iLt_d.extent(0)) {                                 // Allocate Lt on device if not yet
1739aac854edSJunchao Zhang         factors->iLt_d = MatRowMapKokkosView("factors->iLt_d", n + 1); // KK requires this view to be initialized to 0 to call transpose_matrix
1740aac854edSJunchao Zhang         factors->jLt_d = MatColIdxKokkosView(NoInit("factors->jLt_d"), factors->jL_d.extent(0));
1741aac854edSJunchao Zhang         factors->aLt_d = MatScalarKokkosView(NoInit("factors->aLt_d"), factors->aL_d.extent(0));
1742aac854edSJunchao Zhang       }
1743aac854edSJunchao Zhang 
1744aac854edSJunchao Zhang       if (factors->iL_h.extent(0)) { // If L is on host, we also compute the transpose on host
1745aac854edSJunchao Zhang         if (!factors->L) {
1746aac854edSJunchao Zhang           Mat_SeqAIJ *seq;
1747aac854edSJunchao Zhang 
1748aac854edSJunchao Zhang           PetscCall(MatCreateSeqAIJWithArrays(PETSC_COMM_SELF, n, n, factors->iL_h.data(), factors->jL_h.data(), factors->aL_h.data(), &factors->L));
1749aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->L, MAT_INITIAL_MATRIX, &factors->Lt));
1750aac854edSJunchao Zhang 
1751aac854edSJunchao Zhang           seq            = static_cast<Mat_SeqAIJ *>(factors->Lt->data);
1752aac854edSJunchao Zhang           factors->iLt_h = MatRowMapKokkosViewHost(seq->i, n + 1);
1753aac854edSJunchao Zhang           factors->jLt_h = MatColIdxKokkosViewHost(seq->j, seq->nz);
1754aac854edSJunchao Zhang           factors->aLt_h = MatScalarKokkosViewHost(seq->a, seq->nz);
1755aac854edSJunchao Zhang         } else {
1756aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->L, MAT_REUSE_MATRIX, &factors->Lt)); // Matrix Lt' data is aliased with {i, j, a}Lt_h
1757aac854edSJunchao Zhang         }
1758aac854edSJunchao Zhang         // Copy Lt from host to device
1759aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->iLt_d, factors->iLt_h));
1760aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->jLt_d, factors->jLt_h));
1761aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->aLt_d, factors->aLt_h));
1762aac854edSJunchao Zhang       } else { // If L was computed on device, we also compute the transpose there
1763aac854edSJunchao Zhang         // TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices. We have to sort the indices, until KK provides finer control options.
1764aac854edSJunchao Zhang         PetscCallCXX(transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iL_d,
1765aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->jL_d, factors->aL_d,
1766aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->iLt_d, factors->jLt_d,
1767aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->aLt_d));
1768aac854edSJunchao Zhang         PetscCallCXX(sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iLt_d, factors->jLt_d, factors->aLt_d));
1769aac854edSJunchao Zhang       }
1770aac854edSJunchao Zhang       PetscCallCXX(sptrsv_symbolic(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d));
1771aac854edSJunchao Zhang     }
1772aac854edSJunchao Zhang 
177386a27549SJunchao Zhang     factors->transpose_updated = PETSC_TRUE;
177486a27549SJunchao Zhang   }
17753ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
177686a27549SJunchao Zhang }
177786a27549SJunchao Zhang 
1778aac854edSJunchao Zhang // Solve Ax = b, with RAR = U^T D U, where R is the row (and col) permutation matrix on A.
1779aac854edSJunchao Zhang // R is represented by rowperm in factors. If R is identity (i.e, no reordering), then rowperm is empty.
1780aac854edSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJKokkos_Cholesky(Mat A, Vec bb, Vec xx)
1781d71ae5a4SJacob Faibussowitsch {
1782aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
178386a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1784aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1785aac854edSJunchao Zhang   PetscScalarKokkosView       D       = factors->D_d;
1786aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1787aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1788aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1789aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm  = factors->rowperm;
1790aac854edSJunchao Zhang   PetscBool                   identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
179186a27549SJunchao Zhang 
179286a27549SJunchao Zhang   PetscFunctionBegin;
17939566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1794aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSolveCheck(A));          // for UX = T
1795aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A)); // for U^T Y = B
1796aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1797aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
1798aac854edSJunchao Zhang 
1799aac854edSJunchao Zhang   // Solve U^T Y = B
1800aac854edSJunchao Zhang   if (identity) { // Reorder b with the row permutation
1801aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1802aac854edSJunchao Zhang     Y = factors->workVector;
1803aac854edSJunchao Zhang   } else {
1804aac854edSJunchao Zhang     B = factors->workVector;
1805aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(rowperm(i)); }));
1806aac854edSJunchao Zhang     Y = x;
1807aac854edSJunchao Zhang   }
1808aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, B, Y));
1809aac854edSJunchao Zhang 
1810aac854edSJunchao Zhang   // Solve diag(D) Y' = Y.
1811aac854edSJunchao Zhang   // Actually just do Y' = Y*D since D is already inverted in MatCholeskyFactorNumeric_SeqAIJ(). It is basically a vector element-wise multiplication.
1812aac854edSJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { Y(i) = Y(i) * D(i); }));
1813aac854edSJunchao Zhang 
1814aac854edSJunchao Zhang   // Solve UX = Y
1815aac854edSJunchao Zhang   if (identity) {
1816aac854edSJunchao Zhang     X = x;
1817aac854edSJunchao Zhang   } else {
1818aac854edSJunchao Zhang     X = factors->workVector; // B is not needed anymore
1819aac854edSJunchao Zhang   }
1820aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, Y, X));
1821aac854edSJunchao Zhang 
1822aac854edSJunchao Zhang   // Reorder X with the inverse column (row) permutation
1823aac854edSJunchao Zhang   if (!identity) {
1824aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(rowperm(i)) = X(i); }));
1825aac854edSJunchao Zhang   }
1826aac854edSJunchao Zhang 
1827aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1828aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
18299566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
18303ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
183186a27549SJunchao Zhang }
183286a27549SJunchao Zhang 
1833aac854edSJunchao Zhang // Solve Ax = b, with RAC = LU, where R and C are row and col permutation matrices on A respectively.
1834aac854edSJunchao Zhang // R and C are represented by rowperm and colperm in factors.
1835aac854edSJunchao Zhang // If R or C is identity (i.e, no reordering), then rowperm or colperm is empty.
1836aac854edSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJKokkos_LU(Mat A, Vec bb, Vec xx)
1837d71ae5a4SJacob Faibussowitsch {
1838aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
183986a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1840aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1841aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1842aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1843aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1844aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm      = factors->rowperm;
1845aac854edSJunchao Zhang   PetscIntKokkosView         &colperm      = factors->colperm;
1846aac854edSJunchao Zhang   PetscBool                   row_identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1847aac854edSJunchao Zhang   PetscBool                   col_identity = colperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
184886a27549SJunchao Zhang 
184986a27549SJunchao Zhang   PetscFunctionBegin;
18509566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1851aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSolveCheck(A));
1852aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1853aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
185486a27549SJunchao Zhang 
1855aac854edSJunchao Zhang   // Solve L Y = B (i.e., L (U C^- x) = R b).  R b indicates applying the row permutation on b.
1856aac854edSJunchao Zhang   if (row_identity) {
1857aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1858aac854edSJunchao Zhang     Y = factors->workVector;
1859aac854edSJunchao Zhang   } else {
1860aac854edSJunchao Zhang     B = factors->workVector;
1861aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(rowperm(i)); }));
1862aac854edSJunchao Zhang     Y = x;
1863aac854edSJunchao Zhang   }
1864aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khL, factors->iL_d, factors->jL_d, factors->aL_d, B, Y));
1865aac854edSJunchao Zhang 
1866aac854edSJunchao Zhang   // Solve U C^- x = Y
1867aac854edSJunchao Zhang   if (col_identity) {
1868aac854edSJunchao Zhang     X = x;
1869aac854edSJunchao Zhang   } else {
1870aac854edSJunchao Zhang     X = factors->workVector;
1871aac854edSJunchao Zhang   }
1872aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, Y, X));
1873aac854edSJunchao Zhang 
1874aac854edSJunchao Zhang   // x = C X; Reorder X with the inverse col permutation
1875aac854edSJunchao Zhang   if (!col_identity) {
1876aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(colperm(i)) = X(i); }));
1877aac854edSJunchao Zhang   }
1878aac854edSJunchao Zhang 
1879aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1880aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
18819566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
18823ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
188386a27549SJunchao Zhang }
188486a27549SJunchao Zhang 
1885aac854edSJunchao Zhang // Solve A^T x = b, with RAC = LU, where R and C are row and col permutation matrices on A respectively.
1886aac854edSJunchao Zhang // R and C are represented by rowperm and colperm in factors.
1887aac854edSJunchao Zhang // If R or C is identity (i.e, no reordering), then rowperm or colperm is empty.
1888aac854edSJunchao Zhang // A = R^-1 L U C^-1, so A^T = C^-T U^T L^T R^-T. But since C^- = C^T, R^- = R^T, we have A^T = C U^T L^T R.
1889aac854edSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJKokkos_LU(Mat A, Vec bb, Vec xx)
1890aac854edSJunchao Zhang {
1891aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
1892aac854edSJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1893aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1894aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1895aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1896aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1897aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm      = factors->rowperm;
1898aac854edSJunchao Zhang   PetscIntKokkosView         &colperm      = factors->colperm;
1899aac854edSJunchao Zhang   PetscBool                   row_identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1900aac854edSJunchao Zhang   PetscBool                   col_identity = colperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1901aac854edSJunchao Zhang 
1902aac854edSJunchao Zhang   PetscFunctionBegin;
1903aac854edSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1904aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A)); // Update L^T, U^T if needed, and do sptrsv symbolic for L^T, U^T
1905aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1906aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
1907aac854edSJunchao Zhang 
1908aac854edSJunchao Zhang   // Solve U^T Y = B (i.e., U^T (L^T R x) = C^- b).  Note C^- b = C^T b, which means applying the column permutation on b.
1909aac854edSJunchao Zhang   if (col_identity) { // Reorder b with the col permutation
1910aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1911aac854edSJunchao Zhang     Y = factors->workVector;
1912aac854edSJunchao Zhang   } else {
1913aac854edSJunchao Zhang     B = factors->workVector;
1914aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(colperm(i)); }));
1915aac854edSJunchao Zhang     Y = x;
1916aac854edSJunchao Zhang   }
1917aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, B, Y));
1918aac854edSJunchao Zhang 
1919aac854edSJunchao Zhang   // Solve L^T X = Y
1920aac854edSJunchao Zhang   if (row_identity) {
1921aac854edSJunchao Zhang     X = x;
1922aac854edSJunchao Zhang   } else {
1923aac854edSJunchao Zhang     X = factors->workVector;
1924aac854edSJunchao Zhang   }
1925aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d, Y, X));
1926aac854edSJunchao Zhang 
1927aac854edSJunchao Zhang   // x = R^- X = R^T X; Reorder X with the inverse row permutation
1928aac854edSJunchao Zhang   if (!row_identity) {
1929aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(rowperm(i)) = X(i); }));
1930aac854edSJunchao Zhang   }
1931aac854edSJunchao Zhang 
1932aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1933aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
1934aac854edSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1935aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1936aac854edSJunchao Zhang }
1937aac854edSJunchao Zhang 
1938aac854edSJunchao Zhang static PetscErrorCode MatLUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1939aac854edSJunchao Zhang {
1940aac854edSJunchao Zhang   PetscFunctionBegin;
1941aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncHost(A));
1942aac854edSJunchao Zhang   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
1943aac854edSJunchao Zhang 
1944aac854edSJunchao Zhang   if (!info->solveonhost) { // if solve on host, then we don't need to copy L, U to device
1945aac854edSJunchao Zhang     Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
1946aac854edSJunchao Zhang     Mat_SeqAIJ                 *b       = static_cast<Mat_SeqAIJ *>(B->data);
1947aac854edSJunchao Zhang     const PetscInt             *Bi = b->i, *Bj = b->j, *Bdiag = b->diag;
1948aac854edSJunchao Zhang     const MatScalar            *Ba = b->a;
1949aac854edSJunchao Zhang     PetscInt                    m = B->rmap->n, n = B->cmap->n;
1950aac854edSJunchao Zhang 
1951aac854edSJunchao Zhang     if (factors->iL_h.extent(0) == 0) { // Allocate memory and copy the L, U structure for the first time
1952aac854edSJunchao Zhang       // Allocate memory and copy the structure
1953aac854edSJunchao Zhang       factors->iL_h = MatRowMapKokkosViewHost(NoInit("iL_h"), m + 1);
1954aac854edSJunchao Zhang       factors->jL_h = MatColIdxKokkosViewHost(NoInit("jL_h"), (Bi[m] - Bi[0]) + m); // + the diagonal entries
1955aac854edSJunchao Zhang       factors->aL_h = MatScalarKokkosViewHost(NoInit("aL_h"), (Bi[m] - Bi[0]) + m);
1956aac854edSJunchao Zhang       factors->iU_h = MatRowMapKokkosViewHost(NoInit("iU_h"), m + 1);
1957aac854edSJunchao Zhang       factors->jU_h = MatColIdxKokkosViewHost(NoInit("jU_h"), (Bdiag[0] - Bdiag[m]));
1958aac854edSJunchao Zhang       factors->aU_h = MatScalarKokkosViewHost(NoInit("aU_h"), (Bdiag[0] - Bdiag[m]));
1959aac854edSJunchao Zhang 
1960aac854edSJunchao Zhang       PetscInt *Li = factors->iL_h.data();
1961aac854edSJunchao Zhang       PetscInt *Lj = factors->jL_h.data();
1962aac854edSJunchao Zhang       PetscInt *Ui = factors->iU_h.data();
1963aac854edSJunchao Zhang       PetscInt *Uj = factors->jU_h.data();
1964aac854edSJunchao Zhang 
1965aac854edSJunchao Zhang       Li[0] = Ui[0] = 0;
1966aac854edSJunchao Zhang       for (PetscInt i = 0; i < m; i++) {
1967aac854edSJunchao Zhang         PetscInt llen = Bi[i + 1] - Bi[i];       // exclusive of the diagonal entry
1968aac854edSJunchao Zhang         PetscInt ulen = Bdiag[i] - Bdiag[i + 1]; // inclusive of the diagonal entry
1969aac854edSJunchao Zhang 
1970aac854edSJunchao Zhang         PetscArraycpy(Lj + Li[i], Bj + Bi[i], llen); // entries of L on the left of the diagonal
1971aac854edSJunchao Zhang         Lj[Li[i] + llen] = i;                        // diagonal entry of L
1972aac854edSJunchao Zhang 
1973aac854edSJunchao Zhang         Uj[Ui[i]] = i;                                                  // diagonal entry of U
1974aac854edSJunchao Zhang         PetscArraycpy(Uj + Ui[i] + 1, Bj + Bdiag[i + 1] + 1, ulen - 1); // entries of U on  the right of the diagonal
1975aac854edSJunchao Zhang 
1976aac854edSJunchao Zhang         Li[i + 1] = Li[i] + llen + 1;
1977aac854edSJunchao Zhang         Ui[i + 1] = Ui[i] + ulen;
1978aac854edSJunchao Zhang       }
1979aac854edSJunchao Zhang 
1980aac854edSJunchao Zhang       factors->iL_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iL_h);
1981aac854edSJunchao Zhang       factors->jL_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jL_h);
1982aac854edSJunchao Zhang       factors->iU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iU_h);
1983aac854edSJunchao Zhang       factors->jU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jU_h);
1984aac854edSJunchao Zhang       factors->aL_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aL_h);
1985aac854edSJunchao Zhang       factors->aU_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aU_h);
1986aac854edSJunchao Zhang 
1987aac854edSJunchao Zhang       // Copy row/col permutation to device
1988aac854edSJunchao Zhang       IS        rowperm = ((Mat_SeqAIJ *)B->data)->row;
1989aac854edSJunchao Zhang       PetscBool row_identity;
1990aac854edSJunchao Zhang       PetscCall(ISIdentity(rowperm, &row_identity));
1991aac854edSJunchao Zhang       if (!row_identity) {
1992aac854edSJunchao Zhang         const PetscInt *ip;
1993aac854edSJunchao Zhang 
1994aac854edSJunchao Zhang         PetscCall(ISGetIndices(rowperm, &ip));
1995aac854edSJunchao Zhang         factors->rowperm = PetscIntKokkosView(NoInit("rowperm"), m);
1996aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->rowperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), m)));
1997aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(rowperm, &ip));
1998aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(m * sizeof(PetscInt)));
1999aac854edSJunchao Zhang       }
2000aac854edSJunchao Zhang 
2001aac854edSJunchao Zhang       IS        colperm = ((Mat_SeqAIJ *)B->data)->col;
2002aac854edSJunchao Zhang       PetscBool col_identity;
2003aac854edSJunchao Zhang       PetscCall(ISIdentity(colperm, &col_identity));
2004aac854edSJunchao Zhang       if (!col_identity) {
2005aac854edSJunchao Zhang         const PetscInt *ip;
2006aac854edSJunchao Zhang 
2007aac854edSJunchao Zhang         PetscCall(ISGetIndices(colperm, &ip));
2008aac854edSJunchao Zhang         factors->colperm = PetscIntKokkosView(NoInit("colperm"), n);
2009aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->colperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), n)));
2010aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(colperm, &ip));
2011aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));
2012aac854edSJunchao Zhang       }
2013aac854edSJunchao Zhang 
2014aac854edSJunchao Zhang       /* Create sptrsv handles for L, U and their transpose */
2015aac854edSJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2016aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
2017aac854edSJunchao Zhang #else
2018aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
2019aac854edSJunchao Zhang #endif
2020aac854edSJunchao Zhang       factors->khL.create_sptrsv_handle(sptrsv_alg, m, true /* L is lower tri */);
2021aac854edSJunchao Zhang       factors->khU.create_sptrsv_handle(sptrsv_alg, m, false /* U is not lower tri */);
2022aac854edSJunchao Zhang       factors->khLt.create_sptrsv_handle(sptrsv_alg, m, false /* L^T is not lower tri */);
2023aac854edSJunchao Zhang       factors->khUt.create_sptrsv_handle(sptrsv_alg, m, true /* U^T is lower tri */);
2024aac854edSJunchao Zhang     }
2025aac854edSJunchao Zhang 
2026aac854edSJunchao Zhang     // Copy the value
2027aac854edSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
2028aac854edSJunchao Zhang       PetscInt        llen = Bi[i + 1] - Bi[i];
2029aac854edSJunchao Zhang       PetscInt        ulen = Bdiag[i] - Bdiag[i + 1];
2030aac854edSJunchao Zhang       const PetscInt *Li   = factors->iL_h.data();
2031aac854edSJunchao Zhang       const PetscInt *Ui   = factors->iU_h.data();
2032aac854edSJunchao Zhang 
2033aac854edSJunchao Zhang       PetscScalar *La = factors->aL_h.data();
2034aac854edSJunchao Zhang       PetscScalar *Ua = factors->aU_h.data();
2035aac854edSJunchao Zhang 
2036aac854edSJunchao Zhang       PetscArraycpy(La + Li[i], Ba + Bi[i], llen); // entries of L
2037aac854edSJunchao Zhang       La[Li[i] + llen] = 1.0;                      // diagonal entry
2038aac854edSJunchao Zhang 
2039aac854edSJunchao Zhang       Ua[Ui[i]] = 1.0 / Ba[Bdiag[i]];                                 // diagonal entry
2040aac854edSJunchao Zhang       PetscArraycpy(Ua + Ui[i] + 1, Ba + Bdiag[i + 1] + 1, ulen - 1); // entries of U
2041aac854edSJunchao Zhang     }
2042aac854edSJunchao Zhang 
2043aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aL_d, factors->aL_h));
2044aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aU_d, factors->aU_h));
2045aac854edSJunchao Zhang     // Once the factors' values have changed, we need to update their transpose and redo sptrsv symbolic
2046aac854edSJunchao Zhang     factors->transpose_updated         = PETSC_FALSE;
2047aac854edSJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_FALSE;
2048aac854edSJunchao Zhang 
2049aac854edSJunchao Zhang     B->ops->solve          = MatSolve_SeqAIJKokkos_LU;
2050aac854edSJunchao Zhang     B->ops->solvetranspose = MatSolveTranspose_SeqAIJKokkos_LU;
2051aac854edSJunchao Zhang   }
2052aac854edSJunchao Zhang 
2053aac854edSJunchao Zhang   B->ops->matsolve          = NULL;
2054aac854edSJunchao Zhang   B->ops->matsolvetranspose = NULL;
2055aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2056aac854edSJunchao Zhang }
2057aac854edSJunchao Zhang 
2058aac854edSJunchao Zhang static PetscErrorCode MatILUFactorNumeric_SeqAIJKokkos_ILU0(Mat B, Mat A, const MatFactorInfo *info)
2059d71ae5a4SJacob Faibussowitsch {
206086a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
206186a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
206286a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
206386a27549SJunchao Zhang 
206486a27549SJunchao Zhang   PetscFunctionBegin;
20659566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
2066aac854edSJunchao Zhang   PetscCheck(!info->factoronhost, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "MatFactorInfo.factoronhost should be false");
20679566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
2068076ba34aSJunchao Zhang 
2069076ba34aSJunchao Zhang   auto a_d = aijkok->a_dual.view_device();
2070076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
2071076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
2072076ba34aSJunchao Zhang 
2073aac854edSJunchao Zhang   PetscCallCXX(spiluk_numeric(&factors->kh, fill_lev, i_d, j_d, a_d, factors->iL_d, factors->jL_d, factors->aL_d, factors->iU_d, factors->jU_d, factors->aU_d));
207486a27549SJunchao Zhang 
207586a27549SJunchao Zhang   B->assembled              = PETSC_TRUE;
207686a27549SJunchao Zhang   B->preallocated           = PETSC_TRUE;
2077aac854edSJunchao Zhang   B->ops->solve             = MatSolve_SeqAIJKokkos_LU;
2078aac854edSJunchao Zhang   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJKokkos_LU;
207986a27549SJunchao Zhang   B->ops->matsolve          = NULL;
208086a27549SJunchao Zhang   B->ops->matsolvetranspose = NULL;
208186a27549SJunchao Zhang 
208286a27549SJunchao Zhang   /* Once the factors' value changed, we need to update their transpose and sptrsv handle */
208386a27549SJunchao Zhang   factors->transpose_updated         = PETSC_FALSE;
208486a27549SJunchao Zhang   factors->sptrsv_symbolic_completed = PETSC_FALSE;
2085eeadb341SJunchao Zhang   /* TODO: log flops, but how to know that? */
20869566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
20873ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
208886a27549SJunchao Zhang }
208986a27549SJunchao Zhang 
2090aac854edSJunchao Zhang // Use KK's spiluk_symbolic() to do ILU0 symbolic factorization, with no row/col reordering
2091aac854edSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos_ILU0(Mat B, Mat A, IS, IS, const MatFactorInfo *info)
2092d71ae5a4SJacob Faibussowitsch {
209386a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok;
209486a27549SJunchao Zhang   Mat_SeqAIJ                 *b;
209586a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
209686a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
209786a27549SJunchao Zhang   PetscInt                    nnzA     = ((Mat_SeqAIJ *)A->data)->nz, nnzL, nnzU;
209886a27549SJunchao Zhang   PetscInt                    n        = A->rmap->n;
209986a27549SJunchao Zhang 
210086a27549SJunchao Zhang   PetscFunctionBegin;
2101aac854edSJunchao Zhang   PetscCheck(!info->factoronhost, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "MatFactorInfo's factoronhost should be false as we are doing it on device right now");
21029566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
210386a27549SJunchao Zhang 
210486a27549SJunchao Zhang   /* Create a spiluk handle and then do symbolic factorization */
210586a27549SJunchao Zhang   nnzL = nnzU = PetscRealIntMultTruncate(info->fill, nnzA);
2106aac854edSJunchao Zhang   factors->kh.create_spiluk_handle(SPILUKAlgorithm::SEQLVLSCHD_TP1, n, nnzL, nnzU);
210786a27549SJunchao Zhang 
210886a27549SJunchao Zhang   auto spiluk_handle = factors->kh.get_spiluk_handle();
210986a27549SJunchao Zhang 
211086a27549SJunchao Zhang   Kokkos::realloc(factors->iL_d, n + 1); /* Free old arrays and realloc */
211186a27549SJunchao Zhang   Kokkos::realloc(factors->jL_d, spiluk_handle->get_nnzL());
211286a27549SJunchao Zhang   Kokkos::realloc(factors->iU_d, n + 1);
211386a27549SJunchao Zhang   Kokkos::realloc(factors->jU_d, spiluk_handle->get_nnzU());
211486a27549SJunchao Zhang 
211586a27549SJunchao Zhang   aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
2116076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
2117076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
2118aac854edSJunchao Zhang   PetscCallCXX(spiluk_symbolic(&factors->kh, fill_lev, i_d, j_d, factors->iL_d, factors->jL_d, factors->iU_d, factors->jU_d));
211986a27549SJunchao Zhang   /* TODO: if spiluk_symbolic is asynchronous, do we need to sync before calling get_nnzL()? */
212086a27549SJunchao Zhang 
212186a27549SJunchao Zhang   Kokkos::resize(factors->jL_d, spiluk_handle->get_nnzL()); /* Shrink or expand, and retain old value */
212286a27549SJunchao Zhang   Kokkos::resize(factors->jU_d, spiluk_handle->get_nnzU());
212386a27549SJunchao Zhang   Kokkos::realloc(factors->aL_d, spiluk_handle->get_nnzL()); /* No need to retain old value */
212486a27549SJunchao Zhang   Kokkos::realloc(factors->aU_d, spiluk_handle->get_nnzU());
212586a27549SJunchao Zhang 
212686a27549SJunchao Zhang   /* TODO: add options to select sptrsv algorithms */
212786a27549SJunchao Zhang   /* Create sptrsv handles for L, U and their transpose */
212886a27549SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2129aac854edSJunchao Zhang   auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
213086a27549SJunchao Zhang #else
2131aac854edSJunchao Zhang   auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
213286a27549SJunchao Zhang #endif
213386a27549SJunchao Zhang 
213486a27549SJunchao Zhang   factors->khL.create_sptrsv_handle(sptrsv_alg, n, true /* L is lower tri */);
213586a27549SJunchao Zhang   factors->khU.create_sptrsv_handle(sptrsv_alg, n, false /* U is not lower tri */);
213686a27549SJunchao Zhang   factors->khLt.create_sptrsv_handle(sptrsv_alg, n, false /* L^T is not lower tri */);
213786a27549SJunchao Zhang   factors->khUt.create_sptrsv_handle(sptrsv_alg, n, true /* U^T is lower tri */);
213886a27549SJunchao Zhang 
213986a27549SJunchao Zhang   /* Fill fields of the factor matrix B */
21409566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(B, MAT_SKIP_ALLOCATION, NULL));
214186a27549SJunchao Zhang   b     = (Mat_SeqAIJ *)B->data;
214286a27549SJunchao Zhang   b->nz = b->maxnz          = spiluk_handle->get_nnzL() + spiluk_handle->get_nnzU();
214386a27549SJunchao Zhang   B->info.fill_ratio_given  = info->fill;
2144a1e4e190SStefano Zampini   B->info.fill_ratio_needed = nnzA > 0 ? ((PetscReal)b->nz) / ((PetscReal)nnzA) : 1.0;
214586a27549SJunchao Zhang 
2146aac854edSJunchao Zhang   B->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJKokkos_ILU0;
21473ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2148930e68a5SMark Adams }
2149930e68a5SMark Adams 
2150aac854edSJunchao Zhang static PetscErrorCode MatLUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
2151aac854edSJunchao Zhang {
2152aac854edSJunchao Zhang   PetscFunctionBegin;
2153aac854edSJunchao Zhang   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
2154aac854edSJunchao Zhang   PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2155aac854edSJunchao Zhang   PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2156aac854edSJunchao Zhang   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
2157aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2158aac854edSJunchao Zhang }
2159aac854edSJunchao Zhang 
2160aac854edSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
2161aac854edSJunchao Zhang {
2162aac854edSJunchao Zhang   PetscBool row_identity = PETSC_FALSE, col_identity = PETSC_FALSE;
2163aac854edSJunchao Zhang 
2164aac854edSJunchao Zhang   PetscFunctionBegin;
2165aac854edSJunchao Zhang   if (!info->factoronhost) {
2166aac854edSJunchao Zhang     PetscCall(ISIdentity(isrow, &row_identity));
2167aac854edSJunchao Zhang     PetscCall(ISIdentity(iscol, &col_identity));
2168aac854edSJunchao Zhang   }
2169aac854edSJunchao Zhang 
2170aac854edSJunchao Zhang   PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2171aac854edSJunchao Zhang   PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2172aac854edSJunchao Zhang 
2173aac854edSJunchao Zhang   if (!info->factoronhost && !info->levels && row_identity && col_identity) { // if level 0 and no reordering
2174aac854edSJunchao Zhang     PetscCall(MatILUFactorSymbolic_SeqAIJKokkos_ILU0(B, A, isrow, iscol, info));
2175aac854edSJunchao Zhang   } else {
2176aac854edSJunchao Zhang     PetscCall(MatILUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info)); // otherwise, use PETSc's ILU on host
2177aac854edSJunchao Zhang     B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
2178aac854edSJunchao Zhang   }
2179aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2180aac854edSJunchao Zhang }
2181aac854edSJunchao Zhang 
2182aac854edSJunchao Zhang static PetscErrorCode MatCholeskyFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
2183aac854edSJunchao Zhang {
2184aac854edSJunchao Zhang   PetscFunctionBegin;
2185aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncHost(A));
2186aac854edSJunchao Zhang   PetscCall(MatCholeskyFactorNumeric_SeqAIJ(B, A, info));
2187aac854edSJunchao Zhang 
2188aac854edSJunchao Zhang   if (!info->solveonhost) { // if solve on host, then we don't need to copy L, U to device
2189aac854edSJunchao Zhang     Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
2190aac854edSJunchao Zhang     Mat_SeqAIJ                 *b       = static_cast<Mat_SeqAIJ *>(B->data);
2191aac854edSJunchao Zhang     const PetscInt             *Bi = b->i, *Bj = b->j, *Bdiag = b->diag;
2192aac854edSJunchao Zhang     const MatScalar            *Ba = b->a;
2193aac854edSJunchao Zhang     PetscInt                    m  = B->rmap->n;
2194aac854edSJunchao Zhang 
2195aac854edSJunchao Zhang     if (factors->iU_h.extent(0) == 0) { // First time of numeric factorization
2196aac854edSJunchao Zhang       // Allocate memory and copy the structure
2197aac854edSJunchao Zhang       factors->iU_h = PetscIntKokkosViewHost(const_cast<PetscInt *>(Bi), m + 1); // wrap Bi as iU_h
2198aac854edSJunchao Zhang       factors->jU_h = MatColIdxKokkosViewHost(NoInit("jU_h"), Bi[m]);
2199aac854edSJunchao Zhang       factors->aU_h = MatScalarKokkosViewHost(NoInit("aU_h"), Bi[m]);
2200aac854edSJunchao Zhang       factors->D_h  = MatScalarKokkosViewHost(NoInit("D_h"), m);
2201aac854edSJunchao Zhang       factors->aU_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aU_h);
2202aac854edSJunchao Zhang       factors->D_d  = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->D_h);
2203aac854edSJunchao Zhang 
2204aac854edSJunchao Zhang       // Build jU_h from the skewed Aj
2205aac854edSJunchao Zhang       PetscInt *Uj = factors->jU_h.data();
2206aac854edSJunchao Zhang       for (PetscInt i = 0; i < m; i++) {
2207aac854edSJunchao Zhang         PetscInt ulen = Bi[i + 1] - Bi[i];
2208aac854edSJunchao Zhang         Uj[Bi[i]]     = i;                                              // diagonal entry
2209aac854edSJunchao Zhang         PetscCall(PetscArraycpy(Uj + Bi[i] + 1, Bj + Bi[i], ulen - 1)); // entries of U on the right of the diagonal
2210aac854edSJunchao Zhang       }
2211aac854edSJunchao Zhang 
2212aac854edSJunchao Zhang       // Copy iU, jU to device
2213aac854edSJunchao Zhang       PetscCallCXX(factors->iU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iU_h));
2214aac854edSJunchao Zhang       PetscCallCXX(factors->jU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jU_h));
2215aac854edSJunchao Zhang 
2216aac854edSJunchao Zhang       // Copy row/col permutation to device
2217aac854edSJunchao Zhang       IS        rowperm = ((Mat_SeqAIJ *)B->data)->row;
2218aac854edSJunchao Zhang       PetscBool row_identity;
2219aac854edSJunchao Zhang       PetscCall(ISIdentity(rowperm, &row_identity));
2220aac854edSJunchao Zhang       if (!row_identity) {
2221aac854edSJunchao Zhang         const PetscInt *ip;
2222aac854edSJunchao Zhang 
2223aac854edSJunchao Zhang         PetscCall(ISGetIndices(rowperm, &ip));
2224aac854edSJunchao Zhang         PetscCallCXX(factors->rowperm = PetscIntKokkosView(NoInit("rowperm"), m));
2225aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->rowperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), m)));
2226aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(rowperm, &ip));
2227aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(m * sizeof(PetscInt)));
2228aac854edSJunchao Zhang       }
2229aac854edSJunchao Zhang 
2230aac854edSJunchao Zhang       // Create sptrsv handles for U and U^T
2231aac854edSJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2232aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
2233aac854edSJunchao Zhang #else
2234aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
2235aac854edSJunchao Zhang #endif
2236aac854edSJunchao Zhang       factors->khU.create_sptrsv_handle(sptrsv_alg, m, false /* U is not lower tri */);
2237aac854edSJunchao Zhang       factors->khUt.create_sptrsv_handle(sptrsv_alg, m, true /* U^T is lower tri */);
2238aac854edSJunchao Zhang     }
2239aac854edSJunchao Zhang     // These pointers were set MatCholeskyFactorNumeric_SeqAIJ(), so we always need to update them
2240aac854edSJunchao Zhang     B->ops->solve          = MatSolve_SeqAIJKokkos_Cholesky;
2241aac854edSJunchao Zhang     B->ops->solvetranspose = MatSolve_SeqAIJKokkos_Cholesky;
2242aac854edSJunchao Zhang 
2243aac854edSJunchao Zhang     // Copy the value
2244aac854edSJunchao Zhang     PetscScalar *Ua = factors->aU_h.data();
2245aac854edSJunchao Zhang     PetscScalar *D  = factors->D_h.data();
2246aac854edSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
2247aac854edSJunchao Zhang       D[i]      = Ba[Bdiag[i]];     // actually Aa[Adiag[i]] is the inverse of the diagonal
2248aac854edSJunchao Zhang       Ua[Bi[i]] = (PetscScalar)1.0; // set the unit diagonal for U
2249aac854edSJunchao Zhang       for (PetscInt k = 0; k < Bi[i + 1] - Bi[i] - 1; k++) Ua[Bi[i] + 1 + k] = -Ba[Bi[i] + k];
2250aac854edSJunchao Zhang     }
2251aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aU_d, factors->aU_h));
2252aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->D_d, factors->D_h));
2253aac854edSJunchao Zhang 
2254aac854edSJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_FALSE; // When numeric value changed, we must do these again
2255aac854edSJunchao Zhang     factors->transpose_updated         = PETSC_FALSE;
2256aac854edSJunchao Zhang   }
2257aac854edSJunchao Zhang 
2258aac854edSJunchao Zhang   B->ops->matsolve          = NULL;
2259aac854edSJunchao Zhang   B->ops->matsolvetranspose = NULL;
2260aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2261aac854edSJunchao Zhang }
2262aac854edSJunchao Zhang 
2263aac854edSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS perm, const MatFactorInfo *info)
2264aac854edSJunchao Zhang {
2265aac854edSJunchao Zhang   PetscFunctionBegin;
2266aac854edSJunchao Zhang   if (info->solveonhost) {
2267aac854edSJunchao Zhang     // If solve on host, we have to change the type, as eventually we need to call MatSolve_SeqSBAIJ_1_NaturalOrdering() etc.
2268aac854edSJunchao Zhang     PetscCall(MatSetType(B, MATSEQSBAIJ));
2269aac854edSJunchao Zhang     PetscCall(MatSeqSBAIJSetPreallocation(B, 1, MAT_SKIP_ALLOCATION, NULL));
2270aac854edSJunchao Zhang   }
2271aac854edSJunchao Zhang 
2272aac854edSJunchao Zhang   PetscCall(MatICCFactorSymbolic_SeqAIJ(B, A, perm, info));
2273aac854edSJunchao Zhang 
2274aac854edSJunchao Zhang   if (!info->solveonhost) {
2275bfe80ac4SPierre Jolivet     // If solve on device, B is still a MATSEQAIJKOKKOS, so we are good to allocate B->spptr
2276aac854edSJunchao Zhang     PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2277aac854edSJunchao Zhang     PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2278aac854edSJunchao Zhang     B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJKokkos;
2279aac854edSJunchao Zhang   }
2280aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2281aac854edSJunchao Zhang }
2282aac854edSJunchao Zhang 
2283aac854edSJunchao Zhang static PetscErrorCode MatCholeskyFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS perm, const MatFactorInfo *info)
2284aac854edSJunchao Zhang {
2285aac854edSJunchao Zhang   PetscFunctionBegin;
2286aac854edSJunchao Zhang   if (info->solveonhost) {
2287aac854edSJunchao Zhang     // If solve on host, we have to change the type, as eventually we need to call MatSolve_SeqSBAIJ_1_NaturalOrdering() etc.
2288aac854edSJunchao Zhang     PetscCall(MatSetType(B, MATSEQSBAIJ));
2289aac854edSJunchao Zhang     PetscCall(MatSeqSBAIJSetPreallocation(B, 1, MAT_SKIP_ALLOCATION, NULL));
2290aac854edSJunchao Zhang   }
2291aac854edSJunchao Zhang 
2292aac854edSJunchao Zhang   PetscCall(MatCholeskyFactorSymbolic_SeqAIJ(B, A, perm, info)); // it sets B's two ISes ((Mat_SeqAIJ*)B->data)->{row, col} to perm
2293aac854edSJunchao Zhang 
2294aac854edSJunchao Zhang   if (!info->solveonhost) {
2295bfe80ac4SPierre Jolivet     // If solve on device, B is still a MATSEQAIJKOKKOS, so we are good to allocate B->spptr
2296aac854edSJunchao Zhang     PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2297aac854edSJunchao Zhang     PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2298aac854edSJunchao Zhang     B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJKokkos;
2299aac854edSJunchao Zhang   }
2300aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2301aac854edSJunchao Zhang }
2302aac854edSJunchao Zhang 
2303aac854edSJunchao Zhang // The _Kokkos suffix means we will use Kokkos as a solver for the SeqAIJKokkos matrix
2304aac854edSJunchao Zhang static PetscErrorCode MatFactorGetSolverType_SeqAIJKokkos_Kokkos(Mat A, MatSolverType *type)
2305d71ae5a4SJacob Faibussowitsch {
2306930e68a5SMark Adams   PetscFunctionBegin;
2307930e68a5SMark Adams   *type = MATSOLVERKOKKOS;
23083ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2309930e68a5SMark Adams }
2310930e68a5SMark Adams 
2311930e68a5SMark Adams /*MC
231286a27549SJunchao Zhang   MATSOLVERKOKKOS = "Kokkos" - A matrix solver type providing triangular solvers for sequential matrices
231311a5261eSBarry Smith   on a single GPU of type, `MATSEQAIJKOKKOS`, `MATAIJKOKKOS`.
2314930e68a5SMark Adams 
2315930e68a5SMark Adams   Level: beginner
2316930e68a5SMark Adams 
23171cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJKokkos()`, `MATAIJKOKKOS`, `MatKokkosSetFormat()`, `MatKokkosStorageFormat`, `MatKokkosFormatOperation`
2318930e68a5SMark Adams M*/
231986a27549SJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_SeqAIJKokkos_Kokkos(Mat A, MatFactorType ftype, Mat *B) /* MatGetFactor_<MatType>_<MatSolverType> */
2320930e68a5SMark Adams {
2321930e68a5SMark Adams   PetscInt n = A->rmap->n;
2322aac854edSJunchao Zhang   MPI_Comm comm;
2323930e68a5SMark Adams 
2324930e68a5SMark Adams   PetscFunctionBegin;
2325aac854edSJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
2326aac854edSJunchao Zhang   PetscCall(MatCreate(comm, B));
23279566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*B, n, n, n, n));
2328aac854edSJunchao Zhang   PetscCall(MatSetBlockSizesFromMats(*B, A, A));
2329930e68a5SMark Adams   (*B)->factortype = ftype;
23309566063dSJacob Faibussowitsch   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
23319566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
2332aac854edSJunchao Zhang   PetscCheck(!(*B)->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2333aac854edSJunchao Zhang 
2334aac854edSJunchao Zhang   if (ftype == MAT_FACTOR_LU || ftype == MAT_FACTOR_ILU || ftype == MAT_FACTOR_ILUDT) {
2335aac854edSJunchao Zhang     (*B)->ops->lufactorsymbolic  = MatLUFactorSymbolic_SeqAIJKokkos;
2336aac854edSJunchao Zhang     (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJKokkos;
2337aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
2338aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILU]));
2339aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILUDT]));
2340aac854edSJunchao Zhang   } else if (ftype == MAT_FACTOR_CHOLESKY || ftype == MAT_FACTOR_ICC) {
2341aac854edSJunchao Zhang     (*B)->ops->iccfactorsymbolic      = MatICCFactorSymbolic_SeqAIJKokkos;
2342aac854edSJunchao Zhang     (*B)->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqAIJKokkos;
2343aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_CHOLESKY]));
2344aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ICC]));
2345aac854edSJunchao Zhang   } else SETERRQ(comm, PETSC_ERR_SUP, "MatFactorType %s is not supported by MatType SeqAIJKokkos", MatFactorTypes[ftype]);
2346aac854edSJunchao Zhang 
2347aac854edSJunchao Zhang   // The factorization can use the ordering provided in MatLUFactorSymbolic(), MatCholeskyFactorSymbolic() etc, though we do it on host
2348aac854edSJunchao Zhang   (*B)->canuseordering = PETSC_TRUE;
2349aac854edSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)*B, "MatFactorGetSolverType_C", MatFactorGetSolverType_SeqAIJKokkos_Kokkos));
23503ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2351930e68a5SMark Adams }
23528f7e8f9dSMark Adams 
2353aac854edSJunchao Zhang PETSC_INTERN PetscErrorCode MatSolverTypeRegister_Kokkos(void)
2354d71ae5a4SJacob Faibussowitsch {
235586a27549SJunchao Zhang   PetscFunctionBegin;
23569566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_SeqAIJKokkos_Kokkos));
2357aac854edSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_CHOLESKY, MatGetFactor_SeqAIJKokkos_Kokkos));
23589566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ILU, MatGetFactor_SeqAIJKokkos_Kokkos));
2359aac854edSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ICC, MatGetFactor_SeqAIJKokkos_Kokkos));
23603ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
236186a27549SJunchao Zhang }
236286a27549SJunchao Zhang 
2363076ba34aSJunchao Zhang /* Utility to print out a KokkosCsrMatrix for debugging */
2364d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PrintCsrMatrix(const KokkosCsrMatrix &csrmat)
2365d71ae5a4SJacob Faibussowitsch {
236645402d8aSJunchao Zhang   const auto        &iv = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.graph.row_map);
236745402d8aSJunchao Zhang   const auto        &jv = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.graph.entries);
236845402d8aSJunchao Zhang   const auto        &av = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.values);
2369076ba34aSJunchao Zhang   const PetscInt    *i  = iv.data();
2370076ba34aSJunchao Zhang   const PetscInt    *j  = jv.data();
2371076ba34aSJunchao Zhang   const PetscScalar *a  = av.data();
2372076ba34aSJunchao Zhang   PetscInt           m = csrmat.numRows(), n = csrmat.numCols(), nnz = csrmat.nnz();
2373076ba34aSJunchao Zhang 
2374076ba34aSJunchao Zhang   PetscFunctionBegin;
23759566063dSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT " x %" PetscInt_FMT " SeqAIJKokkos, with %" PetscInt_FMT " nonzeros\n", m, n, nnz));
2376076ba34aSJunchao Zhang   for (PetscInt k = 0; k < m; k++) {
23779566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT ": ", k));
237848a46eb9SPierre Jolivet     for (PetscInt p = i[k]; p < i[k + 1]; p++) PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT "(%.1f), ", j[p], (double)PetscRealPart(a[p])));
23799566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "\n"));
2380076ba34aSJunchao Zhang   }
23813ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2382076ba34aSJunchao Zhang }
2383