xref: /petsc/src/mat/impls/aij/seq/kokkos/aijkok.kokkos.cxx (revision 4e8208cbcbc709572b8abe32f33c78b69c819375)
1e36ced11SJunchao Zhang #include <petsc_kokkos.hpp>
211d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
3c0c276a7Ssdargavi #include <petscmat_kokkos.hpp>
4076ba34aSJunchao Zhang #include <petscpkg_version.h>
5152b3e56SJunchao Zhang #include <petsc/private/petscimpl.h>
642550becSJunchao Zhang #include <petsc/private/sfimpl.h>
7aac854edSJunchao Zhang #include <petsc/private/kokkosimpl.hpp>
87233ce55SJed Brown #include <petscsys.h>
98c3ff71bSJunchao Zhang 
108c3ff71bSJunchao Zhang #include <Kokkos_Core.hpp>
11f0cf5187SStefano Zampini #include <KokkosBlas.hpp>
128c3ff71bSJunchao Zhang #include <KokkosSparse_CrsMatrix.hpp>
13cc6e31f1SJunchao Zhang 
14cc6e31f1SJunchao Zhang // To suppress compiler warnings:
15cc6e31f1SJunchao Zhang // /path/include/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp:434:63:
16cc6e31f1SJunchao Zhang // warning: 'cusparseStatus_t cusparseDbsrmm(cusparseHandle_t, cusparseDirection_t, cusparseOperation_t,
17cc6e31f1SJunchao Zhang // cusparseOperation_t, int, int, int, int, const double*, cusparseMatDescr_t, const double*, const int*, const int*,
18cc6e31f1SJunchao Zhang // int, const double*, int, const double*, double*, int)' is deprecated: please use cusparseSpMM instead [-Wdeprecated-declarations]
192695cf96SNuno Nobre #define DISABLE_CUSPARSE_DEPRECATED
208c3ff71bSJunchao Zhang #include <KokkosSparse_spmv.hpp>
21cc6e31f1SJunchao Zhang 
2286a27549SJunchao Zhang #include <KokkosSparse_spiluk.hpp>
2386a27549SJunchao Zhang #include <KokkosSparse_sptrsv.hpp>
24076ba34aSJunchao Zhang #include <KokkosSparse_spgemm.hpp>
25076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
269d13fa56SJunchao Zhang #include <KokkosBatched_LU_Decl.hpp>
279d13fa56SJunchao Zhang #include <KokkosBatched_InverseLU_Decl.hpp>
2886a27549SJunchao Zhang 
2942550becSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
308c3ff71bSJunchao Zhang 
310e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 7, 0)
32f98996d3SJunchao Zhang   #include <KokkosSparse_Utils.hpp>
33f98996d3SJunchao Zhang using KokkosSparse::sort_crs_matrix;
349371c9d4SSatish Balay using KokkosSparse::Impl::transpose_matrix;
35f98996d3SJunchao Zhang #else
36f98996d3SJunchao Zhang   #include <KokkosKernels_Sorting.hpp>
37f98996d3SJunchao Zhang using KokkosKernels::sort_crs_matrix;
389371c9d4SSatish Balay using KokkosKernels::Impl::transpose_matrix;
39f98996d3SJunchao Zhang #endif
40f98996d3SJunchao Zhang 
41aac854edSJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(4, 6, 0)
42aac854edSJunchao Zhang using KokkosSparse::spiluk_symbolic;
43aac854edSJunchao Zhang using KokkosSparse::spiluk_numeric;
44aac854edSJunchao Zhang using KokkosSparse::sptrsv_symbolic;
45aac854edSJunchao Zhang using KokkosSparse::sptrsv_solve;
46aac854edSJunchao Zhang using KokkosSparse::Experimental::SPTRSVAlgorithm;
47aac854edSJunchao Zhang using KokkosSparse::Experimental::SPILUKAlgorithm;
48aac854edSJunchao Zhang #else
49aac854edSJunchao Zhang using KokkosSparse::Experimental::spiluk_symbolic;
50aac854edSJunchao Zhang using KokkosSparse::Experimental::spiluk_numeric;
51aac854edSJunchao Zhang using KokkosSparse::Experimental::sptrsv_symbolic;
52aac854edSJunchao Zhang using KokkosSparse::Experimental::sptrsv_solve;
53aac854edSJunchao Zhang using KokkosSparse::Experimental::SPTRSVAlgorithm;
54aac854edSJunchao Zhang using KokkosSparse::Experimental::SPILUKAlgorithm;
55aac854edSJunchao Zhang #endif
56aac854edSJunchao Zhang 
578c3ff71bSJunchao Zhang static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat); /* Forward declaration */
588c3ff71bSJunchao Zhang 
59076ba34aSJunchao Zhang /* MatAssemblyEnd_SeqAIJKokkos() happens when we finalized nonzeros of the matrix, either after
60076ba34aSJunchao Zhang    we assembled the matrix on host, or after we directly produced the matrix data on device (ex., through MatMatMult).
61076ba34aSJunchao Zhang    In the latter case, it is important to set a_dual's sync state correctly.
62076ba34aSJunchao Zhang  */
63d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_SeqAIJKokkos(Mat A, MatAssemblyType mode)
64d71ae5a4SJacob Faibussowitsch {
65076ba34aSJunchao Zhang   Mat_SeqAIJ       *aijseq;
66076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
678c3ff71bSJunchao Zhang 
688c3ff71bSJunchao Zhang   PetscFunctionBegin;
693ba16761SJacob Faibussowitsch   if (mode == MAT_FLUSH_ASSEMBLY) PetscFunctionReturn(PETSC_SUCCESS);
709566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
71076ba34aSJunchao Zhang 
72076ba34aSJunchao Zhang   aijseq = static_cast<Mat_SeqAIJ *>(A->data);
73076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
74076ba34aSJunchao Zhang 
75076ba34aSJunchao Zhang   /* If aijkok does not exist, we just copy i, j to device.
76076ba34aSJunchao Zhang      If aijkok already exists, but the device's nonzero pattern does not match with the host's, we assume the latest data is on host.
77076ba34aSJunchao Zhang      In both cases, we build a new aijkok structure.
78076ba34aSJunchao Zhang   */
79076ba34aSJunchao Zhang   if (!aijkok || aijkok->nonzerostate != A->nonzerostate) { /* aijkok might not exist yet or nonzero pattern has changed */
8093a54799SPierre Jolivet     if (aijkok && aijkok->host_aij_allocated_by_kokkos) {   /* Avoid accidentally freeing much needed a,i,j on host when deleting aijkok */
81d1c799ffSJunchao Zhang       PetscCall(PetscShmgetAllocateArray(aijkok->nrows() + 1, sizeof(PetscInt), (void **)&aijseq->i));
82d1c799ffSJunchao Zhang       PetscCall(PetscShmgetAllocateArray(aijkok->nnz(), sizeof(PetscInt), (void **)&aijseq->j));
83d1c799ffSJunchao Zhang       PetscCall(PetscShmgetAllocateArray(aijkok->nnz(), sizeof(PetscInt), (void **)&aijseq->a));
84d1c799ffSJunchao Zhang       PetscCall(PetscArraycpy(aijseq->i, aijkok->i_host_data(), aijkok->nrows() + 1));
85d1c799ffSJunchao Zhang       PetscCall(PetscArraycpy(aijseq->j, aijkok->j_host_data(), aijkok->nnz()));
86d1c799ffSJunchao Zhang       PetscCall(PetscArraycpy(aijseq->a, aijkok->a_host_data(), aijkok->nnz()));
87d1c799ffSJunchao Zhang       aijseq->free_a  = PETSC_TRUE;
88d1c799ffSJunchao Zhang       aijseq->free_ij = PETSC_TRUE;
89d1c799ffSJunchao Zhang       /* This arises from MatCreateSeqAIJKokkosWithKokkosCsrMatrix() used in MatMatMult, where
90d1c799ffSJunchao Zhang          we have the CsrMatrix on device first and then copy to host, followed by
91d1c799ffSJunchao Zhang          MatSetMPIAIJWithSplitSeqAIJ() with garray = NULL.
92d1c799ffSJunchao Zhang          One could improve it by not using NULL garray.
93d1c799ffSJunchao Zhang       */
94d1c799ffSJunchao Zhang     }
95076ba34aSJunchao Zhang     delete aijkok;
96421480d9SBarry Smith     aijkok   = new Mat_SeqAIJKokkos(A, A->rmap->n, A->cmap->n, aijseq, A->nonzerostate, PETSC_FALSE /* don't copy mat values to device */);
97076ba34aSJunchao Zhang     A->spptr = aijkok;
98f4747e26SJunchao Zhang   } else if (A->rmap->n && aijkok->diag_dual.extent(0) == 0) { // MatProduct might directly produce AIJ on device, but not the diag.
99421480d9SBarry Smith     const PetscInt *adiag;
100421480d9SBarry Smith     /* the a->diag is created at assmebly here because the rest of the Kokkos AIJ code assumes it always exists. This needs to be fixed since it is now only created when needed! */
101421480d9SBarry Smith     PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, &adiag, NULL));
102f4747e26SJunchao Zhang     MatRowMapKokkosViewHost diag_h(aijseq->diag, A->rmap->n);
103f4747e26SJunchao Zhang     auto                    diag_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), diag_h);
104f4747e26SJunchao Zhang     aijkok->diag_dual              = MatRowMapKokkosDualView(diag_d, diag_h);
105076ba34aSJunchao Zhang   }
1063ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1078c3ff71bSJunchao Zhang }
1088c3ff71bSJunchao Zhang 
10986a27549SJunchao Zhang /* Sync CSR data to device if not yet */
110d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosSyncDevice(Mat A)
111d71ae5a4SJacob Faibussowitsch {
1128c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1138c3ff71bSJunchao Zhang 
1148c3ff71bSJunchao Zhang   PetscFunctionBegin;
115aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from host to device");
1165f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
117076ba34aSJunchao Zhang   if (aijkok->a_dual.need_sync_device()) {
118f3d3cd90SZach Atkins     PetscCall(KokkosDualViewSyncDevice(aijkok->a_dual, PetscGetKokkosExecutionSpace()));
119580c7c76SPierre Jolivet     aijkok->transpose_updated = PETSC_FALSE; /* values of the transpose is out-of-date */
12086a27549SJunchao Zhang     aijkok->hermitian_updated = PETSC_FALSE;
1218c3ff71bSJunchao Zhang   }
1223ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1238c3ff71bSJunchao Zhang }
1248c3ff71bSJunchao Zhang 
125076ba34aSJunchao Zhang /* Mark the CSR data on device as modified */
126d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosModifyDevice(Mat A)
127d71ae5a4SJacob Faibussowitsch {
12886a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
12986a27549SJunchao Zhang 
13086a27549SJunchao Zhang   PetscFunctionBegin;
1315f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Not supported for factorized matries");
13286a27549SJunchao Zhang   aijkok->a_dual.clear_sync_state();
13386a27549SJunchao Zhang   aijkok->a_dual.modify_device();
13486a27549SJunchao Zhang   aijkok->transpose_updated = PETSC_FALSE;
13586a27549SJunchao Zhang   aijkok->hermitian_updated = PETSC_FALSE;
1369566063dSJacob Faibussowitsch   PetscCall(PetscObjectStateIncrease((PetscObject)A));
1373ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
13886a27549SJunchao Zhang }
13986a27549SJunchao Zhang 
140d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSyncHost(Mat A)
141d71ae5a4SJacob Faibussowitsch {
142f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1434df4a32cSJunchao Zhang   auto              exec   = PetscGetKokkosExecutionSpace();
144f0cf5187SStefano Zampini 
145f0cf5187SStefano Zampini   PetscFunctionBegin;
146f0cf5187SStefano Zampini   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
14786a27549SJunchao Zhang   /* We do not expect one needs factors on host  */
148aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from device to host");
1495f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing AIJKOK");
150f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncHost(aijkok->a_dual, exec));
1513ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
152f0cf5187SStefano Zampini }
153f0cf5187SStefano Zampini 
154d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
155d71ae5a4SJacob Faibussowitsch {
156076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
157f0cf5187SStefano Zampini 
158f0cf5187SStefano Zampini   PetscFunctionBegin;
1595519a089SJose E. Roman   /* aijkok contains valid pointers only if the host's nonzerostate matches with the device's.
1605519a089SJose E. Roman     Calling MatSeqAIJSetPreallocation() or MatSetValues() on host, where aijseq->{i,j,a} might be
1615519a089SJose E. Roman     reallocated, will lead to stale {i,j,a}_dual in aijkok. In both operations, the hosts's nonzerostate
1625519a089SJose E. Roman     must have been updated. The stale aijkok will be rebuilt during MatAssemblyEnd.
1635519a089SJose E. Roman   */
1645519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
165f3d3cd90SZach Atkins     PetscCall(KokkosDualViewSyncHost(aijkok->a_dual, PetscGetKokkosExecutionSpace()));
166076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
167076ba34aSJunchao Zhang   } else { /* Happens when calling MatSetValues on a newly created matrix */
168076ba34aSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
169076ba34aSJunchao Zhang   }
1703ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
171076ba34aSJunchao Zhang }
172076ba34aSJunchao Zhang 
173d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
174d71ae5a4SJacob Faibussowitsch {
175076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
176076ba34aSJunchao Zhang 
177076ba34aSJunchao Zhang   PetscFunctionBegin;
1785519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) aijkok->a_dual.modify_host();
1793ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
180076ba34aSJunchao Zhang }
181076ba34aSJunchao Zhang 
182d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
183d71ae5a4SJacob Faibussowitsch {
184076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
185076ba34aSJunchao Zhang 
186076ba34aSJunchao Zhang   PetscFunctionBegin;
1875519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
188f3d3cd90SZach Atkins     PetscCall(KokkosDualViewSyncHost(aijkok->a_dual, PetscGetKokkosExecutionSpace()));
189076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1902328674fSJunchao Zhang   } else {
1912328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1922328674fSJunchao Zhang   }
1933ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
194076ba34aSJunchao Zhang }
195076ba34aSJunchao Zhang 
196d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
197d71ae5a4SJacob Faibussowitsch {
198076ba34aSJunchao Zhang   PetscFunctionBegin;
199076ba34aSJunchao Zhang   *array = NULL;
2003ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
201076ba34aSJunchao Zhang }
202076ba34aSJunchao Zhang 
203d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
204d71ae5a4SJacob Faibussowitsch {
205076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
206076ba34aSJunchao Zhang 
207076ba34aSJunchao Zhang   PetscFunctionBegin;
2085519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
209076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
2102328674fSJunchao Zhang   } else { /* Ex. happens with MatZeroEntries on a preallocated but not assembled matrix */
2112328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
2122328674fSJunchao Zhang   }
2133ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
214076ba34aSJunchao Zhang }
215076ba34aSJunchao Zhang 
216d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
217d71ae5a4SJacob Faibussowitsch {
218076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
219076ba34aSJunchao Zhang 
220076ba34aSJunchao Zhang   PetscFunctionBegin;
2215519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
222076ba34aSJunchao Zhang     aijkok->a_dual.clear_sync_state();
223076ba34aSJunchao Zhang     aijkok->a_dual.modify_host();
2242328674fSJunchao Zhang   }
2253ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
226f0cf5187SStefano Zampini }
227f0cf5187SStefano Zampini 
228d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJKokkos(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
229d71ae5a4SJacob Faibussowitsch {
2307ee59b9bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
2317ee59b9bSJunchao Zhang 
2327ee59b9bSJunchao Zhang   PetscFunctionBegin;
2337ee59b9bSJunchao Zhang   PetscCheck(aijkok != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "aijkok is NULL");
2347ee59b9bSJunchao Zhang 
2357ee59b9bSJunchao Zhang   if (i) *i = aijkok->i_device_data();
2367ee59b9bSJunchao Zhang   if (j) *j = aijkok->j_device_data();
2377ee59b9bSJunchao Zhang   if (a) {
238f3d3cd90SZach Atkins     PetscCall(KokkosDualViewSyncDevice(aijkok->a_dual, PetscGetKokkosExecutionSpace()));
2397ee59b9bSJunchao Zhang     *a = aijkok->a_device_data();
2407ee59b9bSJunchao Zhang   }
2417ee59b9bSJunchao Zhang   if (mtype) *mtype = PETSC_MEMTYPE_KOKKOS;
2423ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2437ee59b9bSJunchao Zhang }
2447ee59b9bSJunchao Zhang 
24503db1824SAlex Lindsay static PetscErrorCode MatGetCurrentMemType_SeqAIJKokkos(PETSC_UNUSED Mat A, PetscMemType *mtype)
24603db1824SAlex Lindsay {
24703db1824SAlex Lindsay   PetscFunctionBegin;
24803db1824SAlex Lindsay   *mtype = PETSC_MEMTYPE_KOKKOS;
24903db1824SAlex Lindsay   PetscFunctionReturn(PETSC_SUCCESS);
25003db1824SAlex Lindsay }
25103db1824SAlex Lindsay 
2520e3ece09SJunchao Zhang /*
2530e3ece09SJunchao Zhang   Generate the sparsity pattern of a MatSeqAIJKokkos matrix's transpose on device.
2540e3ece09SJunchao Zhang 
2550e3ece09SJunchao Zhang   Input Parameter:
2560e3ece09SJunchao Zhang .  A       - the MATSEQAIJKOKKOS matrix
2570e3ece09SJunchao Zhang 
2580e3ece09SJunchao Zhang   Output Parameters:
2590e3ece09SJunchao Zhang +  perm_d - the permutation array on device, which connects Ta(i) = Aa(perm(i))
260aaa8cc7dSPierre Jolivet -  T_d    - the transpose on device, whose value array is allocated but not initialized
2610e3ece09SJunchao Zhang */
2620e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateTransposeStructure(Mat A, MatRowMapKokkosView &perm_d, KokkosCsrMatrix &T_d)
263d71ae5a4SJacob Faibussowitsch {
2640e3ece09SJunchao Zhang   Mat_SeqAIJ             *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2650e3ece09SJunchao Zhang   PetscInt                nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
2660e3ece09SJunchao Zhang   const PetscInt         *Ai = aseq->i, *Aj = aseq->j;
2677b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost Ti_h(NoInit("Ti"), n + 1);
2680e3ece09SJunchao Zhang   MatRowMapType          *Ti = Ti_h.data();
2697b8d4ba6SJunchao Zhang   MatColIdxKokkosViewHost Tj_h(NoInit("Tj"), nz);
2707b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost perm_h(NoInit("permutation"), nz);
2710e3ece09SJunchao Zhang   PetscInt               *Tj   = Tj_h.data();
2720e3ece09SJunchao Zhang   PetscInt               *perm = perm_h.data();
2730e3ece09SJunchao Zhang   PetscInt               *offset;
274152b3e56SJunchao Zhang 
275152b3e56SJunchao Zhang   PetscFunctionBegin;
2760e3ece09SJunchao Zhang   // Populate Ti
2770e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::deep_copy(Ti_h, 0));
2780e3ece09SJunchao Zhang   Ti++;
2790e3ece09SJunchao Zhang   for (PetscInt i = 0; i < nz; i++) Ti[Aj[i]]++;
2800e3ece09SJunchao Zhang   Ti--;
2810e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) Ti[i + 1] += Ti[i];
2820e3ece09SJunchao Zhang 
2830e3ece09SJunchao Zhang   // Populate Tj and the permutation array
2840e3ece09SJunchao Zhang   PetscCall(PetscCalloc1(n, &offset)); // offset in each T row to fill in its column indices
2850e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
2860e3ece09SJunchao Zhang     for (PetscInt j = Ai[i]; j < Ai[i + 1]; j++) { // A's (i,j) is T's (j,i)
2870e3ece09SJunchao Zhang       PetscInt r    = Aj[j];                       // row r of T
2880e3ece09SJunchao Zhang       PetscInt disp = Ti[r] + offset[r];
2890e3ece09SJunchao Zhang 
2900e3ece09SJunchao Zhang       Tj[disp]   = i; // col i of T
2910e3ece09SJunchao Zhang       perm[disp] = j;
2920e3ece09SJunchao Zhang       offset[r]++;
293076ba34aSJunchao Zhang     }
2940e3ece09SJunchao Zhang   }
2950e3ece09SJunchao Zhang   PetscCall(PetscFree(offset));
2960e3ece09SJunchao Zhang 
2970e3ece09SJunchao Zhang   // Sort each row of T, along with the permutation array
2980e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) PetscCall(PetscSortIntWithArray(Ti[i + 1] - Ti[i], Tj + Ti[i], perm + Ti[i]));
2990e3ece09SJunchao Zhang 
3000e3ece09SJunchao Zhang   // Output perm and T on device
3010e3ece09SJunchao Zhang   auto Ti_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ti_h);
3020e3ece09SJunchao Zhang   auto Tj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Tj_h);
3030e3ece09SJunchao Zhang   PetscCallCXX(T_d = KokkosCsrMatrix("csrmatT", n, m, nz, MatScalarKokkosView("Ta", nz), Ti_d, Tj_d));
3040e3ece09SJunchao Zhang   PetscCallCXX(perm_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), perm_h));
3053ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
306152b3e56SJunchao Zhang }
307152b3e56SJunchao Zhang 
3080e3ece09SJunchao Zhang // Generate the transpose on device and cache it internally
3090e3ece09SJunchao Zhang // Note: KK transpose_matrix() does not have support symbolic/numeric transpose, so we do it on our own
3100e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGenerateTranspose_Private(Mat A, KokkosCsrMatrix *csrmatT)
311d71ae5a4SJacob Faibussowitsch {
3120e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
3130e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3140e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
3150e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatT;
316152b3e56SJunchao Zhang 
317152b3e56SJunchao Zhang   PetscFunctionBegin;
3180e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
319f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncDevice(akok->a_dual, PetscGetKokkosExecutionSpace())); // Sync A's values since we are going to access them on device
3200e3ece09SJunchao Zhang 
3210e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
3220e3ece09SJunchao Zhang 
3230e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE) {
3240e3ece09SJunchao Zhang     *csrmatT = akok->csrmat;
3250e3ece09SJunchao Zhang   } else {
3260e3ece09SJunchao Zhang     // See if we already have a cached transpose and its value is up to date
3270e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
3280e3ece09SJunchao Zhang       if (!akok->transpose_updated) {            // if the value is out of date, update the cached version
3290e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
3300e3ece09SJunchao Zhang         auto       &Ta   = T.values;
3310e3ece09SJunchao Zhang 
332d326c3f1SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = Aa(perm(i)); }));
333076ba34aSJunchao Zhang       }
3340e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
3350e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
3360e3ece09SJunchao Zhang 
3370e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
3380e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
339d326c3f1SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = Aa(perm(i)); }));
3400e3ece09SJunchao Zhang     }
3410e3ece09SJunchao Zhang     akok->transpose_updated = PETSC_TRUE;
3420e3ece09SJunchao Zhang     *csrmatT                = akok->csrmatT;
3430e3ece09SJunchao Zhang   }
3440e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3450e3ece09SJunchao Zhang }
3460e3ece09SJunchao Zhang 
3470e3ece09SJunchao Zhang // Generate the Hermitian on device and cache it internally
3480e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateHermitian_Private(Mat A, KokkosCsrMatrix *csrmatH)
3490e3ece09SJunchao Zhang {
3500e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
3510e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3520e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
3530e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatH;
3540e3ece09SJunchao Zhang 
3550e3ece09SJunchao Zhang   PetscFunctionBegin;
3560e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
357f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncDevice(akok->a_dual, PetscGetKokkosExecutionSpace())); // Sync A's values since we are going to access them on device
3580e3ece09SJunchao Zhang 
3590e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
3600e3ece09SJunchao Zhang 
3610e3ece09SJunchao Zhang   if (A->hermitian == PETSC_BOOL3_TRUE) {
3620e3ece09SJunchao Zhang     *csrmatH = akok->csrmat;
3630e3ece09SJunchao Zhang   } else {
364b0c98d1dSPierre Jolivet     // See if we already have a cached Hermitian and its value is up to date
3650e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
3660e3ece09SJunchao Zhang       if (!akok->hermitian_updated) {            // if the value is out of date, update the cached version
3670e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
3680e3ece09SJunchao Zhang         auto       &Ta   = T.values;
3690e3ece09SJunchao Zhang 
370d326c3f1SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = PetscConj(Aa(perm(i))); }));
3710e3ece09SJunchao Zhang       }
3720e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
3730e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
3740e3ece09SJunchao Zhang 
3750e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
3760e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
377d326c3f1SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = PetscConj(Aa(perm(i))); }));
3780e3ece09SJunchao Zhang     }
3790e3ece09SJunchao Zhang     akok->hermitian_updated = PETSC_TRUE;
3800e3ece09SJunchao Zhang     *csrmatH                = akok->csrmatH;
3810e3ece09SJunchao Zhang   }
3823ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
383152b3e56SJunchao Zhang }
384a587d139SMark 
3858c3ff71bSJunchao Zhang /* y = A x */
386d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMult_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
387d71ae5a4SJacob Faibussowitsch {
3888c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
389152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
390152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3918c3ff71bSJunchao Zhang 
3928c3ff71bSJunchao Zhang   PetscFunctionBegin;
3939566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3949566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3959566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3969566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
3978c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
398d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), "N", 1.0 /*alpha*/, aijkok->csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A x + beta y */
3999566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4009566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
401076ba34aSJunchao Zhang   /* 2.0*nnz - numRows seems more accurate here but assumes there are no zero-rows. So a little sloppy here. */
4029566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
4039566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4043ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4058c3ff71bSJunchao Zhang }
4068c3ff71bSJunchao Zhang 
4078c3ff71bSJunchao Zhang /* y = A^T x */
408d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
409d71ae5a4SJacob Faibussowitsch {
4108c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
411152b3e56SJunchao Zhang   const char                *mode;
412152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
413152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
4140e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4158c3ff71bSJunchao Zhang 
4168c3ff71bSJunchao Zhang   PetscFunctionBegin;
4179566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4189566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4199566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4209566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
421152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4229566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
423152b3e56SJunchao Zhang     mode = "N";
424152b3e56SJunchao Zhang   } else {
425076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4260e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
427152b3e56SJunchao Zhang     mode   = "T";
428152b3e56SJunchao Zhang   }
429d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^T x + beta y */
4309566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4319566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
4320e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4339566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4343ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4358c3ff71bSJunchao Zhang }
4368c3ff71bSJunchao Zhang 
4378c3ff71bSJunchao Zhang /* y = A^H x */
438d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
439d71ae5a4SJacob Faibussowitsch {
4408c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
441152b3e56SJunchao Zhang   const char                *mode;
442152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
443152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
4440e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4458c3ff71bSJunchao Zhang 
4468c3ff71bSJunchao Zhang   PetscFunctionBegin;
4479566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4489566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4499566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4509566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
451152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4529566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
453152b3e56SJunchao Zhang     mode = "N";
454152b3e56SJunchao Zhang   } else {
455076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4560e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
457152b3e56SJunchao Zhang     mode   = "C";
458152b3e56SJunchao Zhang   }
459d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^H x + beta y */
4609566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4619566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
4620e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4639566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4643ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4658c3ff71bSJunchao Zhang }
4668c3ff71bSJunchao Zhang 
4678c3ff71bSJunchao Zhang /* z = A x + y */
468d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
469d71ae5a4SJacob Faibussowitsch {
4708c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
47192896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
472152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4738c3ff71bSJunchao Zhang 
4748c3ff71bSJunchao Zhang   PetscFunctionBegin;
4759566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4769566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
47792896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz)); // depending on yy's sync flags, zz might get its latest data on host
4789566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
47992896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv)); // do after VecCopy(yy, zz) to get the latest data on device
4808c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
481d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), "N", 1.0 /*alpha*/, aijkok->csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A x + beta z */
4829566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
48392896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
4849566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
4859566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4863ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4878c3ff71bSJunchao Zhang }
4888c3ff71bSJunchao Zhang 
4898c3ff71bSJunchao Zhang /* z = A^T x + y */
490d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
491d71ae5a4SJacob Faibussowitsch {
4928c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
493152b3e56SJunchao Zhang   const char                *mode;
49492896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
495152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4960e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4978c3ff71bSJunchao Zhang 
4988c3ff71bSJunchao Zhang   PetscFunctionBegin;
4999566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
5009566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
50192896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz));
5029566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
50392896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv));
504152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
5059566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
506152b3e56SJunchao Zhang     mode = "N";
507152b3e56SJunchao Zhang   } else {
508076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
5090e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
510152b3e56SJunchao Zhang     mode   = "T";
511152b3e56SJunchao Zhang   }
512d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^T x + beta z */
5139566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
51492896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
5150e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
5169566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
5173ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5188c3ff71bSJunchao Zhang }
5198c3ff71bSJunchao Zhang 
5208c3ff71bSJunchao Zhang /* z = A^H x + y */
521d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
522d71ae5a4SJacob Faibussowitsch {
5238c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
524152b3e56SJunchao Zhang   const char                *mode;
52592896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
526152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
5270e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
5288c3ff71bSJunchao Zhang 
5298c3ff71bSJunchao Zhang   PetscFunctionBegin;
5309566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
5319566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
53292896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz));
5339566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
53492896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv));
535152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
5369566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
537152b3e56SJunchao Zhang     mode = "N";
538152b3e56SJunchao Zhang   } else {
539076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
5400e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
541152b3e56SJunchao Zhang     mode   = "C";
542152b3e56SJunchao Zhang   }
543d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^H x + beta z */
5449566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
54592896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
5460e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
5479566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
5483ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
549152b3e56SJunchao Zhang }
550152b3e56SJunchao Zhang 
55166976f2fSJacob Faibussowitsch static PetscErrorCode MatSetOption_SeqAIJKokkos(Mat A, MatOption op, PetscBool flg)
552d71ae5a4SJacob Faibussowitsch {
553152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
554152b3e56SJunchao Zhang 
555152b3e56SJunchao Zhang   PetscFunctionBegin;
556152b3e56SJunchao Zhang   switch (op) {
557152b3e56SJunchao Zhang   case MAT_FORM_EXPLICIT_TRANSPOSE:
558152b3e56SJunchao Zhang     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
5599566063dSJacob Faibussowitsch     if (A->form_explicit_transpose && !flg && aijkok) PetscCall(aijkok->DestroyMatTranspose());
560152b3e56SJunchao Zhang     A->form_explicit_transpose = flg;
561152b3e56SJunchao Zhang     break;
562d71ae5a4SJacob Faibussowitsch   default:
563d71ae5a4SJacob Faibussowitsch     PetscCall(MatSetOption_SeqAIJ(A, op, flg));
564d71ae5a4SJacob Faibussowitsch     break;
565152b3e56SJunchao Zhang   }
5663ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5678c3ff71bSJunchao Zhang }
5688c3ff71bSJunchao Zhang 
569076ba34aSJunchao Zhang /* Depending on reuse, either build a new mat, or use the existing mat */
570d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
571d71ae5a4SJacob Faibussowitsch {
572076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
5738c3ff71bSJunchao Zhang 
5748c3ff71bSJunchao Zhang   PetscFunctionBegin;
5759566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
576076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) { /* Build a brand new mat */
57751ece73cSJunchao Zhang     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
57851ece73cSJunchao Zhang     PetscCall(MatSetType(*newmat, mtype));
5798c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {                 /* Reuse the mat created before */
5809566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); /* newmat is already a SeqAIJKokkos */
581076ba34aSJunchao Zhang   } else if (reuse == MAT_INPLACE_MATRIX) {               /* newmat is A */
5825f80ce2aSJacob Faibussowitsch     PetscCheck(A == *newmat, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "A != *newmat with MAT_INPLACE_MATRIX");
5839566063dSJacob Faibussowitsch     PetscCall(PetscFree(A->defaultvectype));
5849566063dSJacob Faibussowitsch     PetscCall(PetscStrallocpy(VECKOKKOS, &A->defaultvectype)); /* Allocate and copy the string */
5859566063dSJacob Faibussowitsch     PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSEQAIJKOKKOS));
5869566063dSJacob Faibussowitsch     PetscCall(MatSetOps_SeqAIJKokkos(A));
587076ba34aSJunchao Zhang     aseq = static_cast<Mat_SeqAIJ *>(A->data);
588394ed5ebSJunchao Zhang     if (A->assembled) { /* Copy i, j (but not values) to device for an assembled matrix if not yet */
5895f80ce2aSJacob Faibussowitsch       PetscCheck(!A->spptr, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Expect NULL (Mat_SeqAIJKokkos*)A->spptr");
590421480d9SBarry Smith       A->spptr = new Mat_SeqAIJKokkos(A, A->rmap->n, A->cmap->n, aseq, A->nonzerostate, PETSC_FALSE);
5918c3ff71bSJunchao Zhang     }
592076ba34aSJunchao Zhang   }
5933ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5948c3ff71bSJunchao Zhang }
5958c3ff71bSJunchao Zhang 
596076ba34aSJunchao Zhang /* MatDuplicate always creates a new matrix. MatDuplicate can be called either on an assembled matrix or
597076ba34aSJunchao Zhang    an unassembled matrix, even though MAT_COPY_VALUES is not allowed for unassembled matrix.
598076ba34aSJunchao Zhang  */
599d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDuplicate_SeqAIJKokkos(Mat A, MatDuplicateOption dupOption, Mat *B)
600d71ae5a4SJacob Faibussowitsch {
601076ba34aSJunchao Zhang   Mat_SeqAIJ       *bseq;
602076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok;
603076ba34aSJunchao Zhang   Mat               mat;
6048c3ff71bSJunchao Zhang 
6058c3ff71bSJunchao Zhang   PetscFunctionBegin;
606076ba34aSJunchao Zhang   /* Do not copy values on host as A's latest values might be on device. We don't want to do sync blindly */
6079566063dSJacob Faibussowitsch   PetscCall(MatDuplicate_SeqAIJ(A, MAT_DO_NOT_COPY_VALUES, B));
608076ba34aSJunchao Zhang   mat = *B;
609f4747e26SJunchao Zhang   if (A->assembled) {
610076ba34aSJunchao Zhang     bseq = static_cast<Mat_SeqAIJ *>(mat->data);
611421480d9SBarry Smith     bkok = new Mat_SeqAIJKokkos(mat, mat->rmap->n, mat->cmap->n, bseq, mat->nonzerostate, PETSC_FALSE);
612076ba34aSJunchao Zhang     bkok->a_dual.clear_sync_state(); /* Clear B's sync state as it will be decided below */
613076ba34aSJunchao Zhang     /* Now copy values to B if needed */
614076ba34aSJunchao Zhang     if (dupOption == MAT_COPY_VALUES) {
615076ba34aSJunchao Zhang       if (akok->a_dual.need_sync_device()) {
616076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_host(), akok->a_dual.view_host());
617076ba34aSJunchao Zhang         bkok->a_dual.modify_host();
618076ba34aSJunchao Zhang       } else { /* If device has the latest data, we only copy data on device */
619076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_device(), akok->a_dual.view_device());
620076ba34aSJunchao Zhang         bkok->a_dual.modify_device();
621076ba34aSJunchao Zhang       }
622076ba34aSJunchao Zhang     } else { /* MAT_DO_NOT_COPY_VALUES or MAT_SHARE_NONZERO_PATTERN. B's values should be zeroed */
623076ba34aSJunchao Zhang       /* B's values on host should be already zeroed by MatDuplicate_SeqAIJ() */
624076ba34aSJunchao Zhang       bkok->a_dual.modify_host();
625076ba34aSJunchao Zhang     }
626076ba34aSJunchao Zhang     mat->spptr = bkok;
627076ba34aSJunchao Zhang   }
628076ba34aSJunchao Zhang 
6299566063dSJacob Faibussowitsch   PetscCall(PetscFree(mat->defaultvectype));
6309566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &mat->defaultvectype)); /* Allocate and copy the string */
6319566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)mat, MATSEQAIJKOKKOS));
6329566063dSJacob Faibussowitsch   PetscCall(MatSetOps_SeqAIJKokkos(mat));
6333ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6348c3ff71bSJunchao Zhang }
6358c3ff71bSJunchao Zhang 
636d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatTranspose_SeqAIJKokkos(Mat A, MatReuse reuse, Mat *B)
637d71ae5a4SJacob Faibussowitsch {
6380ecb592aSJunchao Zhang   Mat               At;
6390e3ece09SJunchao Zhang   KokkosCsrMatrix   internT;
6400ecb592aSJunchao Zhang   Mat_SeqAIJKokkos *atkok, *bkok;
6410ecb592aSJunchao Zhang 
6420ecb592aSJunchao Zhang   PetscFunctionBegin;
6437fb60732SBarry Smith   if (reuse == MAT_REUSE_MATRIX) PetscCall(MatTransposeCheckNonzeroState_Private(A, *B));
6449566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &internT)); /* Generate a transpose internally */
6450ecb592aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX || reuse == MAT_INPLACE_MATRIX) {
646ff751488SJunchao Zhang     /* Deep copy internT, as we want to isolate the internal transpose */
6470e3ece09SJunchao Zhang     PetscCallCXX(atkok = new Mat_SeqAIJKokkos(KokkosCsrMatrix("csrmat", internT)));
6489566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PetscObjectComm((PetscObject)A), atkok, &At));
6490ecb592aSJunchao Zhang     if (reuse == MAT_INITIAL_MATRIX) *B = At;
6509566063dSJacob Faibussowitsch     else PetscCall(MatHeaderReplace(A, &At)); /* Replace A with At inplace */
6510ecb592aSJunchao Zhang   } else {                                    /* MAT_REUSE_MATRIX, just need to copy values to B on device */
6520ecb592aSJunchao Zhang     if ((*B)->assembled) {
6530ecb592aSJunchao Zhang       bkok = static_cast<Mat_SeqAIJKokkos *>((*B)->spptr);
6540e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(bkok->a_dual.view_device(), internT.values));
6559566063dSJacob Faibussowitsch       PetscCall(MatSeqAIJKokkosModifyDevice(*B));
6560ecb592aSJunchao Zhang     } else if ((*B)->preallocated) { /* It is ok for B to be only preallocated, as needed in MatTranspose_MPIAIJ */
6570ecb592aSJunchao Zhang       Mat_SeqAIJ             *bseq = static_cast<Mat_SeqAIJ *>((*B)->data);
6580e3ece09SJunchao Zhang       MatScalarKokkosViewHost a_h(bseq->a, internT.nnz()); /* bseq->nz = 0 if unassembled */
6590e3ece09SJunchao Zhang       MatColIdxKokkosViewHost j_h(bseq->j, internT.nnz());
6600e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(a_h, internT.values));
6610e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(j_h, internT.graph.entries));
6620ecb592aSJunchao Zhang     } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "B must be assembled or preallocated");
6630ecb592aSJunchao Zhang   }
6643ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6650ecb592aSJunchao Zhang }
6660ecb592aSJunchao Zhang 
667d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDestroy_SeqAIJKokkos(Mat A)
668d71ae5a4SJacob Faibussowitsch {
66986a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
6708c3ff71bSJunchao Zhang 
6718c3ff71bSJunchao Zhang   PetscFunctionBegin;
67286a27549SJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
67386a27549SJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
6748c3ff71bSJunchao Zhang     delete aijkok;
67586a27549SJunchao Zhang   } else {
67686a27549SJunchao Zhang     delete static_cast<Mat_SeqAIJKokkosTriFactors *>(A->spptr);
67786a27549SJunchao Zhang   }
678cbc6b225SStefano Zampini   A->spptr = NULL;
6799566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
6809566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
6819566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
68257761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
68357761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijkokkos_hypre_C", NULL));
68457761e9aSJunchao Zhang #endif
6859566063dSJacob Faibussowitsch   PetscCall(MatDestroy_SeqAIJ(A));
6863ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6878c3ff71bSJunchao Zhang }
6888c3ff71bSJunchao Zhang 
6893f3ba80aSJunchao Zhang /*MC
6903f3ba80aSJunchao Zhang    MATSEQAIJKOKKOS - MATAIJKOKKOS = "(seq)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
6913f3ba80aSJunchao Zhang 
69215229ffcSPierre Jolivet    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
6933f3ba80aSJunchao Zhang 
6942ef1f0ffSBarry Smith    Options Database Key:
69511a5261eSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATSEQAIJKOKKOS` during a call to `MatSetFromOptions()`
6963f3ba80aSJunchao Zhang 
6973f3ba80aSJunchao Zhang   Level: beginner
6983f3ba80aSJunchao Zhang 
6991cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateSeqAIJKokkos()`, `MATMPIAIJKOKKOS`
7003f3ba80aSJunchao Zhang M*/
701d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJKokkos(Mat A)
702d71ae5a4SJacob Faibussowitsch {
70386a27549SJunchao Zhang   PetscFunctionBegin;
7049566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
7059566063dSJacob Faibussowitsch   PetscCall(MatCreate_SeqAIJ(A));
7069566063dSJacob Faibussowitsch   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
7073ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
70886a27549SJunchao Zhang }
70986a27549SJunchao Zhang 
710076ba34aSJunchao Zhang /* Merge A, B into a matrix C. A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n) */
711d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
712d71ae5a4SJacob Faibussowitsch {
713076ba34aSJunchao Zhang   Mat_SeqAIJ         *a, *b;
714076ba34aSJunchao Zhang   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
715076ba34aSJunchao Zhang   MatScalarKokkosView aa, ba, ca;
716076ba34aSJunchao Zhang   MatRowMapKokkosView ai, bi, ci;
717076ba34aSJunchao Zhang   MatColIdxKokkosView aj, bj, cj;
718076ba34aSJunchao Zhang   PetscInt            m, n, nnz, aN;
719a3f881fbSStefano Zampini 
720a3f881fbSStefano Zampini   PetscFunctionBegin;
721076ba34aSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
722076ba34aSJunchao Zhang   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
7234f572ea9SToby Isaac   PetscAssertPointer(C, 4);
724076ba34aSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
725076ba34aSJunchao Zhang   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
7265f80ce2aSJacob Faibussowitsch   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
7275f80ce2aSJacob Faibussowitsch   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
728076ba34aSJunchao Zhang 
7299566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
7309566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
731076ba34aSJunchao Zhang   a    = static_cast<Mat_SeqAIJ *>(A->data);
732076ba34aSJunchao Zhang   b    = static_cast<Mat_SeqAIJ *>(B->data);
733076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
734076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
735076ba34aSJunchao Zhang   aa   = akok->a_dual.view_device();
736076ba34aSJunchao Zhang   ai   = akok->i_dual.view_device();
737076ba34aSJunchao Zhang   ba   = bkok->a_dual.view_device();
738076ba34aSJunchao Zhang   bi   = bkok->i_dual.view_device();
739076ba34aSJunchao Zhang   m    = A->rmap->n; /* M, N and nnz of C */
740076ba34aSJunchao Zhang   n    = A->cmap->n + B->cmap->n;
741076ba34aSJunchao Zhang   nnz  = a->nz + b->nz;
742076ba34aSJunchao Zhang   aN   = A->cmap->n; /* N of A */
743076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
744076ba34aSJunchao Zhang     aj      = akok->j_dual.view_device();
745076ba34aSJunchao Zhang     bj      = bkok->j_dual.view_device();
746ecd797f4SJunchao Zhang     auto ca = MatScalarKokkosView("a", aa.extent(0) + ba.extent(0));
747ecd797f4SJunchao Zhang     auto ci = MatRowMapKokkosView("i", ai.extent(0));
748ecd797f4SJunchao Zhang     auto cj = MatColIdxKokkosView("j", aj.extent(0) + bj.extent(0));
749076ba34aSJunchao Zhang 
750076ba34aSJunchao Zhang     /* Concatenate A and B in parallel using Kokkos hierarchical parallelism */
7519371c9d4SSatish Balay     Kokkos::parallel_for(
752d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
753076ba34aSJunchao Zhang         PetscInt i       = t.league_rank(); /* row i */
754076ba34aSJunchao Zhang         PetscInt coffset = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
755076ba34aSJunchao Zhang 
756076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* this side effect only happens once per whole team */
757076ba34aSJunchao Zhang                                                    ci(i) = coffset;
758076ba34aSJunchao Zhang                                                    if (i == m - 1) ci(m) = ai(m) + bi(m);
759076ba34aSJunchao Zhang         });
760076ba34aSJunchao Zhang 
761076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
762076ba34aSJunchao Zhang           if (k < alen) {
763076ba34aSJunchao Zhang             ca(coffset + k) = aa(ai(i) + k);
764076ba34aSJunchao Zhang             cj(coffset + k) = aj(ai(i) + k);
765076ba34aSJunchao Zhang           } else {
766076ba34aSJunchao Zhang             ca(coffset + k) = ba(bi(i) + k - alen);
767076ba34aSJunchao Zhang             cj(coffset + k) = bj(bi(i) + k - alen) + aN; /* Entries in B get new column indices in C */
768076ba34aSJunchao Zhang           }
769076ba34aSJunchao Zhang         });
770076ba34aSJunchao Zhang       });
771ecd797f4SJunchao Zhang     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, n, nnz, ci, cj, ca));
7729566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
773076ba34aSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
774076ba34aSJunchao Zhang     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
775076ba34aSJunchao Zhang     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
776076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
777076ba34aSJunchao Zhang     ca   = ckok->a_dual.view_device();
778076ba34aSJunchao Zhang     ci   = ckok->i_dual.view_device();
779076ba34aSJunchao Zhang 
7809371c9d4SSatish Balay     Kokkos::parallel_for(
781d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
782076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* row i */
783076ba34aSJunchao Zhang         PetscInt alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
784076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
785076ba34aSJunchao Zhang           if (k < alen) ca(ci(i) + k) = aa(ai(i) + k);
786076ba34aSJunchao Zhang           else ca(ci(i) + k) = ba(bi(i) + k - alen);
787076ba34aSJunchao Zhang         });
788076ba34aSJunchao Zhang       });
7899566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
790076ba34aSJunchao Zhang   }
7913ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
792076ba34aSJunchao Zhang }
793076ba34aSJunchao Zhang 
794*2a8381b2SBarry Smith static PetscErrorCode MatProductCtxDestroy_SeqAIJKokkos(PetscCtxRt pdata)
795d71ae5a4SJacob Faibussowitsch {
796076ba34aSJunchao Zhang   PetscFunctionBegin;
797cc1eb50dSBarry Smith   delete *reinterpret_cast<MatProductCtx_SeqAIJKokkos **>(pdata);
7983ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
799a3f881fbSStefano Zampini }
800a3f881fbSStefano Zampini 
801d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos(Mat C)
802d71ae5a4SJacob Faibussowitsch {
803a3f881fbSStefano Zampini   Mat_Product                *product = C->product;
804a3f881fbSStefano Zampini   Mat                         A, B;
805076ba34aSJunchao Zhang   bool                        transA, transB; /* use bool, since KK needs this type */
806a3f881fbSStefano Zampini   Mat_SeqAIJKokkos           *akok, *bkok, *ckok;
807a3f881fbSStefano Zampini   Mat_SeqAIJ                 *c;
808cc1eb50dSBarry Smith   MatProductCtx_SeqAIJKokkos *pdata;
8090e3ece09SJunchao Zhang   KokkosCsrMatrix             csrmatA, csrmatB;
810a3f881fbSStefano Zampini 
811a3f881fbSStefano Zampini   PetscFunctionBegin;
812a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
8135f80ce2aSJacob Faibussowitsch   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data empty");
814cc1eb50dSBarry Smith   pdata = static_cast<MatProductCtx_SeqAIJKokkos *>(C->product->data);
815076ba34aSJunchao Zhang 
8160e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
8170e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
8180e3ece09SJunchao Zhang   // we still do numeric.
8190e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
8200e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
8213ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
822076ba34aSJunchao Zhang   }
823076ba34aSJunchao Zhang 
824076ba34aSJunchao Zhang   switch (product->type) {
8259371c9d4SSatish Balay   case MATPRODUCT_AB:
8269371c9d4SSatish Balay     transA = false;
8279371c9d4SSatish Balay     transB = false;
8289371c9d4SSatish Balay     break;
8299371c9d4SSatish Balay   case MATPRODUCT_AtB:
8309371c9d4SSatish Balay     transA = true;
8319371c9d4SSatish Balay     transB = false;
8329371c9d4SSatish Balay     break;
8339371c9d4SSatish Balay   case MATPRODUCT_ABt:
8349371c9d4SSatish Balay     transA = false;
8359371c9d4SSatish Balay     transB = true;
8369371c9d4SSatish Balay     break;
837d71ae5a4SJacob Faibussowitsch   default:
838d71ae5a4SJacob Faibussowitsch     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
839076ba34aSJunchao Zhang   }
840076ba34aSJunchao Zhang 
841a3f881fbSStefano Zampini   A = product->A;
842a3f881fbSStefano Zampini   B = product->B;
8439566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8449566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
845a3f881fbSStefano Zampini   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
846a3f881fbSStefano Zampini   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
847a3f881fbSStefano Zampini   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
848076ba34aSJunchao Zhang 
8495f80ce2aSJacob Faibussowitsch   PetscCheck(ckok, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Device data structure spptr is empty");
850076ba34aSJunchao Zhang 
8510e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
8520e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
853076ba34aSJunchao Zhang 
854076ba34aSJunchao Zhang   /* TODO: Once KK spgemm implements transpose, we can get rid of the explicit transpose here */
855076ba34aSJunchao Zhang   if (transA) {
8569566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
857076ba34aSJunchao Zhang     transA = false;
858a3f881fbSStefano Zampini   }
859a3f881fbSStefano Zampini 
860076ba34aSJunchao Zhang   if (transB) {
8619566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
862076ba34aSJunchao Zhang     transB = false;
863076ba34aSJunchao Zhang   }
8649566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
8650e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, ckok->csrmat));
8660e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
867866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
868866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(ckok->csrmat)); /* without sort, mat_tests-ex62_14_seqaijkokkos fails */
869e944a159SJunchao Zhang #endif
870866eb059SJunchao Zhang 
8719566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
8729566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(C));
873a3f881fbSStefano Zampini   /* shorter version of MatAssemblyEnd_SeqAIJ */
874a3f881fbSStefano Zampini   c = (Mat_SeqAIJ *)C->data;
8759566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded, %" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
8769566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
8779566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
878a3f881fbSStefano Zampini   c->reallocs         = 0;
879076ba34aSJunchao Zhang   C->info.mallocs     = 0;
880a3f881fbSStefano Zampini   C->info.nz_unneeded = 0;
881a3f881fbSStefano Zampini   C->assembled = C->was_assembled = PETSC_TRUE;
882a3f881fbSStefano Zampini   C->num_ass++;
8833ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
884a3f881fbSStefano Zampini }
885a3f881fbSStefano Zampini 
886d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos(Mat C)
887d71ae5a4SJacob Faibussowitsch {
888076ba34aSJunchao Zhang   Mat_Product                *product = C->product;
889076ba34aSJunchao Zhang   MatProductType              ptype;
890076ba34aSJunchao Zhang   Mat                         A, B;
891076ba34aSJunchao Zhang   bool                        transA, transB;
892076ba34aSJunchao Zhang   Mat_SeqAIJKokkos           *akok, *bkok, *ckok;
893cc1eb50dSBarry Smith   MatProductCtx_SeqAIJKokkos *pdata;
894076ba34aSJunchao Zhang   MPI_Comm                    comm;
8950e3ece09SJunchao Zhang   KokkosCsrMatrix             csrmatA, csrmatB, csrmatC;
896a3f881fbSStefano Zampini 
897a3f881fbSStefano Zampini   PetscFunctionBegin;
898a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
8999566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
9005f80ce2aSJacob Faibussowitsch   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
901a3f881fbSStefano Zampini   A = product->A;
902a3f881fbSStefano Zampini   B = product->B;
9039566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
9049566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
905a3f881fbSStefano Zampini   akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
906a3f881fbSStefano Zampini   bkok    = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
9070e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
9080e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
909076ba34aSJunchao Zhang 
910a3f881fbSStefano Zampini   ptype = product->type;
9110e3ece09SJunchao Zhang   // Take advantage of the symmetry if true
9120e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
9130e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
9140e3ece09SJunchao Zhang     product->symbolic_used_the_fact_A_is_symmetric = PETSC_TRUE;
9150e3ece09SJunchao Zhang   }
9160e3ece09SJunchao Zhang   if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
9170e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
9180e3ece09SJunchao Zhang     product->symbolic_used_the_fact_B_is_symmetric = PETSC_TRUE;
9190e3ece09SJunchao Zhang   }
9200e3ece09SJunchao Zhang 
921a3f881fbSStefano Zampini   switch (ptype) {
9229371c9d4SSatish Balay   case MATPRODUCT_AB:
9239371c9d4SSatish Balay     transA = false;
9249371c9d4SSatish Balay     transB = false;
9250e6a1e94SMark Adams     PetscCall(MatSetBlockSizesFromMats(C, A, B));
9269371c9d4SSatish Balay     break;
9279371c9d4SSatish Balay   case MATPRODUCT_AtB:
9289371c9d4SSatish Balay     transA = true;
9299371c9d4SSatish Balay     transB = false;
9300e6a1e94SMark Adams     if (A->cmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->rmap, A->cmap->bs));
9310e6a1e94SMark Adams     if (B->cmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->cmap, B->cmap->bs));
9329371c9d4SSatish Balay     break;
9339371c9d4SSatish Balay   case MATPRODUCT_ABt:
9349371c9d4SSatish Balay     transA = false;
9359371c9d4SSatish Balay     transB = true;
9360e6a1e94SMark Adams     if (A->rmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->rmap, A->rmap->bs));
9370e6a1e94SMark Adams     if (B->rmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->cmap, B->rmap->bs));
9389371c9d4SSatish Balay     break;
939d71ae5a4SJacob Faibussowitsch   default:
940d71ae5a4SJacob Faibussowitsch     SETERRQ(comm, PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
941a3f881fbSStefano Zampini   }
942cc1eb50dSBarry Smith   PetscCallCXX(product->data = pdata = new MatProductCtx_SeqAIJKokkos());
943076ba34aSJunchao Zhang   pdata->reusesym = product->api_user;
944a3f881fbSStefano Zampini 
945076ba34aSJunchao Zhang   /* TODO: add command line options to select spgemm algorithms */
946866eb059SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; /* default alg is TPL if enabled, otherwise KK */
947866eb059SJunchao Zhang 
948866eb059SJunchao Zhang   /* CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 */
949866eb059SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
950866eb059SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
951866eb059SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
952866eb059SJunchao Zhang   #endif
953866eb059SJunchao Zhang #endif
9540e3ece09SJunchao Zhang   PetscCallCXX(pdata->kh.create_spgemm_handle(spgemm_alg));
955076ba34aSJunchao Zhang 
9569566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
957076ba34aSJunchao Zhang   /* TODO: Get rid of the explicit transpose once KK-spgemm implements the transpose option */
958076ba34aSJunchao Zhang   if (transA) {
9599566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
960076ba34aSJunchao Zhang     transA = false;
961076ba34aSJunchao Zhang   }
962076ba34aSJunchao Zhang 
963076ba34aSJunchao Zhang   if (transB) {
9649566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
965076ba34aSJunchao Zhang     transB = false;
966076ba34aSJunchao Zhang   }
967076ba34aSJunchao Zhang 
9680e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
969076ba34aSJunchao Zhang   /* spgemm_symbolic() only populates C's rowmap, but not C's column indices.
970076ba34aSJunchao Zhang     So we have to do a fake spgemm_numeric() here to get csrmatC.j_d setup, before
971076ba34aSJunchao Zhang     calling new Mat_SeqAIJKokkos().
972076ba34aSJunchao Zhang     TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
973076ba34aSJunchao Zhang   */
9740e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
9750e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
976866eb059SJunchao Zhang   /* Query if KK outputs a sorted matrix. If not, we need to sort it */
977866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
978866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(csrmatC)); /* sort_option defaults to -1 in KK!*/
979e944a159SJunchao Zhang #endif
9809566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
981076ba34aSJunchao Zhang 
9829566063dSJacob Faibussowitsch   PetscCallCXX(ckok = new Mat_SeqAIJKokkos(csrmatC));
9839566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(C, ckok));
984cc1eb50dSBarry Smith   C->product->destroy = MatProductCtxDestroy_SeqAIJKokkos;
9853ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
986a3f881fbSStefano Zampini }
987a3f881fbSStefano Zampini 
988a3f881fbSStefano Zampini /* handles sparse matrix matrix ops */
989d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSetFromOptions_SeqAIJKokkos(Mat mat)
990d71ae5a4SJacob Faibussowitsch {
991076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
992a3f881fbSStefano Zampini   PetscBool    Biskok = PETSC_FALSE, Ciskok = PETSC_TRUE;
993a3f881fbSStefano Zampini 
994a3f881fbSStefano Zampini   PetscFunctionBegin;
995a3f881fbSStefano Zampini   MatCheckProduct(mat, 1);
9969566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJKOKKOS, &Biskok));
99748a46eb9SPierre Jolivet   if (product->type == MATPRODUCT_ABC) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJKOKKOS, &Ciskok));
998a3f881fbSStefano Zampini   if (Biskok && Ciskok) {
999a3f881fbSStefano Zampini     switch (product->type) {
1000a3f881fbSStefano Zampini     case MATPRODUCT_AB:
1001a3f881fbSStefano Zampini     case MATPRODUCT_AtB:
1002d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABt:
1003d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos;
1004d71ae5a4SJacob Faibussowitsch       break;
1005a3f881fbSStefano Zampini     case MATPRODUCT_PtAP:
1006a3f881fbSStefano Zampini     case MATPRODUCT_RARt:
1007d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABC:
1008d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
1009d71ae5a4SJacob Faibussowitsch       break;
1010d71ae5a4SJacob Faibussowitsch     default:
1011d71ae5a4SJacob Faibussowitsch       break;
1012a3f881fbSStefano Zampini     }
1013a3f881fbSStefano Zampini   } else { /* fallback for AIJ */
10149566063dSJacob Faibussowitsch     PetscCall(MatProductSetFromOptions_SeqAIJ(mat));
1015a3f881fbSStefano Zampini   }
10163ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1017a3f881fbSStefano Zampini }
1018a587d139SMark 
1019d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatScale_SeqAIJKokkos(Mat A, PetscScalar a)
1020d71ae5a4SJacob Faibussowitsch {
1021f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok;
1022f0cf5187SStefano Zampini 
1023f0cf5187SStefano Zampini   PetscFunctionBegin;
10249566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
10259566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1026f0cf5187SStefano Zampini   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1027d326c3f1SJunchao Zhang   KokkosBlas::scal(PetscGetKokkosExecutionSpace(), aijkok->a_dual.view_device(), a, aijkok->a_dual.view_device());
10289566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
10299566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(aijkok->a_dual.extent(0)));
10309566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
10313ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1032f0cf5187SStefano Zampini }
1033f0cf5187SStefano Zampini 
1034f4747e26SJunchao Zhang // add a to A's diagonal (if A is square) or main diagonal (if A is rectangular)
1035f4747e26SJunchao Zhang static PetscErrorCode MatShift_SeqAIJKokkos(Mat A, PetscScalar a)
1036f4747e26SJunchao Zhang {
1037f4747e26SJunchao Zhang   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1038f4747e26SJunchao Zhang 
1039f4747e26SJunchao Zhang   PetscFunctionBegin;
104007425a8dSBarry Smith   if (A->assembled && aijseq->diagDense) { // no missing diagonals
1041f4747e26SJunchao Zhang     PetscInt n = PetscMin(A->rmap->n, A->cmap->n);
1042f4747e26SJunchao Zhang 
1043f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
1044f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosSyncDevice(A));
1045f4747e26SJunchao Zhang     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1046f4747e26SJunchao Zhang     const auto &Aa     = aijkok->a_dual.view_device();
1047f4747e26SJunchao Zhang     const auto &Adiag  = aijkok->diag_dual.view_device();
1048d326c3f1SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) { Aa(Adiag(i)) += a; }));
1049f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosModifyDevice(A));
1050f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(n));
1051f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
1052f4747e26SJunchao Zhang   } else { // need reassembly, very slow!
1053f4747e26SJunchao Zhang     PetscCall(MatShift_Basic(A, a));
1054f4747e26SJunchao Zhang   }
1055f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1056f4747e26SJunchao Zhang }
1057f4747e26SJunchao Zhang 
1058f4747e26SJunchao Zhang static PetscErrorCode MatDiagonalSet_SeqAIJKokkos(Mat Y, Vec D, InsertMode is)
1059f4747e26SJunchao Zhang {
1060f4747e26SJunchao Zhang   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(Y->data);
1061f4747e26SJunchao Zhang 
1062f4747e26SJunchao Zhang   PetscFunctionBegin;
106307425a8dSBarry Smith   if (Y->assembled && aijseq->diagDense) { // no missing diagonals
1064f4747e26SJunchao Zhang     ConstPetscScalarKokkosView dv;
1065f4747e26SJunchao Zhang     PetscInt                   n, nv;
1066f4747e26SJunchao Zhang 
1067f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
1068f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosSyncDevice(Y));
1069f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(D, &dv));
1070f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(D, &nv));
1071f4747e26SJunchao Zhang     n = PetscMin(Y->rmap->n, Y->cmap->n);
1072f4747e26SJunchao Zhang     PetscCheck(n == nv, PetscObjectComm((PetscObject)Y), PETSC_ERR_ARG_SIZ, "Matrix size and vector size do not match");
1073f4747e26SJunchao Zhang 
1074f4747e26SJunchao Zhang     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1075f4747e26SJunchao Zhang     const auto &Aa     = aijkok->a_dual.view_device();
1076f4747e26SJunchao Zhang     const auto &Adiag  = aijkok->diag_dual.view_device();
1077f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(
1078d326c3f1SJunchao Zhang       Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) {
1079f4747e26SJunchao Zhang         if (is == INSERT_VALUES) Aa(Adiag(i)) = dv(i);
1080f4747e26SJunchao Zhang         else Aa(Adiag(i)) += dv(i);
1081f4747e26SJunchao Zhang       }));
1082f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(D, &dv));
1083f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1084f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(n));
1085f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
1086f4747e26SJunchao Zhang   } else { // need reassembly, very slow!
1087f4747e26SJunchao Zhang     PetscCall(MatDiagonalSet_Default(Y, D, is));
1088f4747e26SJunchao Zhang   }
1089f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1090f4747e26SJunchao Zhang }
1091f4747e26SJunchao Zhang 
1092f4747e26SJunchao Zhang static PetscErrorCode MatDiagonalScale_SeqAIJKokkos(Mat A, Vec ll, Vec rr)
1093f4747e26SJunchao Zhang {
1094f4747e26SJunchao Zhang   Mat_SeqAIJ                *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1095f4747e26SJunchao Zhang   PetscInt                   m = A->rmap->n, n = A->cmap->n, nz = aijseq->nz;
1096f4747e26SJunchao Zhang   ConstPetscScalarKokkosView lv, rv;
1097f4747e26SJunchao Zhang 
1098f4747e26SJunchao Zhang   PetscFunctionBegin;
1099f4747e26SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1100f4747e26SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1101f4747e26SJunchao Zhang   const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1102f4747e26SJunchao Zhang   const auto &Aa     = aijkok->a_dual.view_device();
1103f4747e26SJunchao Zhang   const auto &Ai     = aijkok->i_dual.view_device();
1104f4747e26SJunchao Zhang   const auto &Aj     = aijkok->j_dual.view_device();
1105f4747e26SJunchao Zhang   if (ll) {
1106f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(ll, &m));
1107f4747e26SJunchao Zhang     PetscCheck(m == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Left scaling vector wrong length");
1108f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(ll, &lv));
1109f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for( // for each row
1110d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1111f4747e26SJunchao Zhang         PetscInt i   = t.league_rank(); // row i
1112f4747e26SJunchao Zhang         PetscInt len = Ai(i + 1) - Ai(i);
1113f4747e26SJunchao Zhang         // scale entries on the row
1114f4747e26SJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt j) { Aa(Ai(i) + j) *= lv(i); });
1115f4747e26SJunchao Zhang       }));
1116f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(ll, &lv));
1117f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(nz));
1118f4747e26SJunchao Zhang   }
1119f4747e26SJunchao Zhang   if (rr) {
1120f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(rr, &n));
1121f4747e26SJunchao Zhang     PetscCheck(n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Right scaling vector wrong length");
1122f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(rr, &rv));
1123f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for( // for each nonzero
1124d326c3f1SJunchao Zhang       Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt k) { Aa(k) *= rv(Aj(k)); }));
1125f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(rr, &lv));
1126f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(nz));
1127f4747e26SJunchao Zhang   }
1128f4747e26SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(A));
1129f4747e26SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1130f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1131f4747e26SJunchao Zhang }
1132f4747e26SJunchao Zhang 
1133d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatZeroEntries_SeqAIJKokkos(Mat A)
1134d71ae5a4SJacob Faibussowitsch {
1135076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1136a587d139SMark 
1137a587d139SMark   PetscFunctionBegin;
1138076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
11392328674fSJunchao Zhang   if (aijkok) { /* Only zero the device if data is already there */
1140d326c3f1SJunchao Zhang     KokkosBlas::fill(PetscGetKokkosExecutionSpace(), aijkok->a_dual.view_device(), 0.0);
11419566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A));
11422328674fSJunchao Zhang   } else { /* Might be preallocated but not assembled */
11439566063dSJacob Faibussowitsch     PetscCall(MatZeroEntries_SeqAIJ(A));
11442328674fSJunchao Zhang   }
11453ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1146a587d139SMark }
1147a587d139SMark 
1148d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatGetDiagonal_SeqAIJKokkos(Mat A, Vec x)
1149d71ae5a4SJacob Faibussowitsch {
1150f78ce678SMark Adams   Mat_SeqAIJKokkos     *aijkok;
1151f78ce678SMark Adams   PetscInt              n;
1152f78ce678SMark Adams   PetscScalarKokkosView xv;
1153f78ce678SMark Adams 
1154f78ce678SMark Adams   PetscFunctionBegin;
1155f78ce678SMark Adams   PetscCall(VecGetLocalSize(x, &n));
1156f78ce678SMark Adams   PetscCheck(n == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Nonconforming matrix and vector");
1157f78ce678SMark Adams   PetscCheck(A->factortype == MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_SUP, "MatGetDiagonal_SeqAIJKokkos not supported on factored matrices");
1158f78ce678SMark Adams 
1159f78ce678SMark Adams   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1160f78ce678SMark Adams   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1161f78ce678SMark Adams 
1162f78ce678SMark Adams   const auto &Aa    = aijkok->a_dual.view_device();
1163f78ce678SMark Adams   const auto &Ai    = aijkok->i_dual.view_device();
1164f78ce678SMark Adams   const auto &Adiag = aijkok->diag_dual.view_device();
1165f78ce678SMark Adams 
1166f78ce678SMark Adams   PetscCall(VecGetKokkosViewWrite(x, &xv));
11679371c9d4SSatish Balay   Kokkos::parallel_for(
1168d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) {
1169f78ce678SMark Adams       if (Adiag(i) < Ai(i + 1)) xv(i) = Aa(Adiag(i));
1170f78ce678SMark Adams       else xv(i) = 0;
1171f78ce678SMark Adams     });
1172f78ce678SMark Adams   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
11733ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1174f78ce678SMark Adams }
1175f78ce678SMark Adams 
1176db78de30SJunchao Zhang /* Get a Kokkos View from a mat of type MatSeqAIJKokkos */
1177d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1178d71ae5a4SJacob Faibussowitsch {
1179db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1180db78de30SJunchao Zhang 
1181db78de30SJunchao Zhang   PetscFunctionBegin;
1182db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11834f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1184db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11859566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1186db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1187076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
11883ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1189db78de30SJunchao Zhang }
1190db78de30SJunchao Zhang 
1191d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1192d71ae5a4SJacob Faibussowitsch {
1193db78de30SJunchao Zhang   PetscFunctionBegin;
1194db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11954f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1196db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11973ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1198db78de30SJunchao Zhang }
1199db78de30SJunchao Zhang 
1200d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, MatScalarKokkosView *kv)
1201d71ae5a4SJacob Faibussowitsch {
1202db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1203db78de30SJunchao Zhang 
1204db78de30SJunchao Zhang   PetscFunctionBegin;
1205db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12064f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1207db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12089566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1209db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1210076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
12113ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1212db78de30SJunchao Zhang }
1213db78de30SJunchao Zhang 
1214d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, MatScalarKokkosView *kv)
1215d71ae5a4SJacob Faibussowitsch {
1216db78de30SJunchao Zhang   PetscFunctionBegin;
1217db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12184f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1219db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12209566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
12213ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1222db78de30SJunchao Zhang }
1223db78de30SJunchao Zhang 
1224d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1225d71ae5a4SJacob Faibussowitsch {
1226db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1227db78de30SJunchao Zhang 
1228db78de30SJunchao Zhang   PetscFunctionBegin;
1229db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12304f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1231db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1232db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1233076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
12343ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1235db78de30SJunchao Zhang }
1236db78de30SJunchao Zhang 
1237d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1238d71ae5a4SJacob Faibussowitsch {
1239db78de30SJunchao Zhang   PetscFunctionBegin;
1240db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12414f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1242db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12439566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
12443ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1245db78de30SJunchao Zhang }
1246db78de30SJunchao Zhang 
1247c0c276a7Ssdargavi PetscErrorCode MatCreateSeqAIJKokkosWithKokkosViews(MPI_Comm comm, PetscInt m, PetscInt n, Kokkos::View<PetscInt *> &i_d, Kokkos::View<PetscInt *> &j_d, Kokkos::View<PetscScalar *> &a_d, Mat *A)
1248c0c276a7Ssdargavi {
1249c0c276a7Ssdargavi   Mat_SeqAIJKokkos *akok;
1250c0c276a7Ssdargavi 
1251c0c276a7Ssdargavi   PetscFunctionBegin;
1252ecd797f4SJunchao Zhang   PetscCallCXX(akok = new Mat_SeqAIJKokkos(m, n, j_d.extent(0), i_d, j_d, a_d));
1253c0c276a7Ssdargavi   PetscCall(MatCreate(comm, A));
1254c0c276a7Ssdargavi   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
1255c0c276a7Ssdargavi   PetscFunctionReturn(PETSC_SUCCESS);
1256c0c276a7Ssdargavi }
1257c0c276a7Ssdargavi 
1258c17cf699SJunchao Zhang /* Computes Y += alpha X */
1259d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAXPY_SeqAIJKokkos(Mat Y, PetscScalar alpha, Mat X, MatStructure pattern)
1260d71ae5a4SJacob Faibussowitsch {
1261a587d139SMark   Mat_SeqAIJ              *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
1262c17cf699SJunchao Zhang   Mat_SeqAIJKokkos        *xkok, *ykok, *zkok;
1263c17cf699SJunchao Zhang   ConstMatScalarKokkosView Xa;
1264c17cf699SJunchao Zhang   MatScalarKokkosView      Ya;
12654df4a32cSJunchao Zhang   auto                     exec = PetscGetKokkosExecutionSpace();
1266a587d139SMark 
1267a587d139SMark   PetscFunctionBegin;
1268c17cf699SJunchao Zhang   PetscCheckTypeName(Y, MATSEQAIJKOKKOS);
1269c17cf699SJunchao Zhang   PetscCheckTypeName(X, MATSEQAIJKOKKOS);
12709566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(Y));
12719566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(X));
12729566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1273db78de30SJunchao Zhang 
1274c17cf699SJunchao Zhang   if (pattern != SAME_NONZERO_PATTERN && x->nz == y->nz) {
1275a587d139SMark     PetscBool e;
12769566063dSJacob Faibussowitsch     PetscCall(PetscArraycmp(x->i, y->i, Y->rmap->n + 1, &e));
1277a587d139SMark     if (e) {
12789566063dSJacob Faibussowitsch       PetscCall(PetscArraycmp(x->j, y->j, y->nz, &e));
1279c17cf699SJunchao Zhang       if (e) pattern = SAME_NONZERO_PATTERN;
1280a587d139SMark     }
1281a587d139SMark   }
1282db78de30SJunchao Zhang 
1283c17cf699SJunchao Zhang   /* cusparseDcsrgeam2() computes C = alpha A + beta B. If one knew sparsity pattern of C, one can skip
1284c17cf699SJunchao Zhang     cusparseScsrgeam2_bufferSizeExt() / cusparseXcsrgeam2Nnz(), and directly call cusparseScsrgeam2().
1285c17cf699SJunchao Zhang     If X is SUBSET_NONZERO_PATTERN of Y, we could take advantage of this cusparse feature. However,
1286c17cf699SJunchao Zhang     KokkosSparse::spadd(alpha,A,beta,B,C) has symbolic and numeric phases, MatAXPY does not.
1287c17cf699SJunchao Zhang   */
1288c17cf699SJunchao Zhang   ykok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1289c17cf699SJunchao Zhang   xkok = static_cast<Mat_SeqAIJKokkos *>(X->spptr);
1290c17cf699SJunchao Zhang   Xa   = xkok->a_dual.view_device();
1291c17cf699SJunchao Zhang   Ya   = ykok->a_dual.view_device();
1292c17cf699SJunchao Zhang 
1293c17cf699SJunchao Zhang   if (pattern == SAME_NONZERO_PATTERN) {
1294d326c3f1SJunchao Zhang     KokkosBlas::axpy(exec, alpha, Xa, Ya);
12959566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1296c17cf699SJunchao Zhang   } else if (pattern == SUBSET_NONZERO_PATTERN) {
1297c17cf699SJunchao Zhang     MatRowMapKokkosView Xi = xkok->i_dual.view_device(), Yi = ykok->i_dual.view_device();
1298c17cf699SJunchao Zhang     MatColIdxKokkosView Xj = xkok->j_dual.view_device(), Yj = ykok->j_dual.view_device();
1299c17cf699SJunchao Zhang 
13009371c9d4SSatish Balay     Kokkos::parallel_for(
1301d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(exec, Y->rmap->n, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
13020e3ece09SJunchao Zhang         PetscInt i = t.league_rank(); // row i
13030e3ece09SJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() {
13040e3ece09SJunchao Zhang           // Only one thread works in a team
1305c17cf699SJunchao Zhang           PetscInt p, q = Yi(i);
13060e3ece09SJunchao Zhang           for (p = Xi(i); p < Xi(i + 1); p++) {          // For each nonzero on row i of X,
13070e3ece09SJunchao Zhang             while (Xj(p) != Yj(q) && q < Yi(i + 1)) q++; // find the matching nonzero on row i of Y.
13080e3ece09SJunchao Zhang             if (Xj(p) == Yj(q)) {                        // Found it
1309c17cf699SJunchao Zhang               Ya(q) += alpha * Xa(p);
1310c17cf699SJunchao Zhang               q++;
1311a587d139SMark             } else {
13120e3ece09SJunchao Zhang             // If not found, it indicates the input is wrong (X is not a SUBSET_NONZERO_PATTERN of Y).
13130e3ece09SJunchao Zhang             // Just insert a NaN at the beginning of row i if it is not empty, to make the result wrong.
1314257f855aSJunchao Zhang #if PETSC_PKG_KOKKOS_VERSION_GE(5, 0, 0)
1315257f855aSJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = KokkosKernels::ArithTraits<PetscScalar>::nan();
1316257f855aSJunchao Zhang #elif PETSC_PKG_KOKKOS_VERSION_GE(3, 7, 0)
13170e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::ArithTraits<PetscScalar>::nan();
13188b8b16f9SJunchao Zhang #else
13190e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::Experimental::nan("1");
13208b8b16f9SJunchao Zhang #endif
1321a587d139SMark             }
1322c17cf699SJunchao Zhang           }
1323c17cf699SJunchao Zhang         });
1324c17cf699SJunchao Zhang       });
13259566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
13260e3ece09SJunchao Zhang   } else { // different nonzero patterns
1327c17cf699SJunchao Zhang     Mat             Z;
1328c17cf699SJunchao Zhang     KokkosCsrMatrix zcsr;
1329c17cf699SJunchao Zhang     KernelHandle    kh;
13300e3ece09SJunchao Zhang     kh.create_spadd_handle(true); // X, Y are sorted
1331c17cf699SJunchao Zhang     KokkosSparse::spadd_symbolic(&kh, xkok->csrmat, ykok->csrmat, zcsr);
1332c17cf699SJunchao Zhang     KokkosSparse::spadd_numeric(&kh, alpha, xkok->csrmat, (PetscScalar)1.0, ykok->csrmat, zcsr);
1333c17cf699SJunchao Zhang     zkok = new Mat_SeqAIJKokkos(zcsr);
13349566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, zkok, &Z));
13359566063dSJacob Faibussowitsch     PetscCall(MatHeaderReplace(Y, &Z));
1336c17cf699SJunchao Zhang     kh.destroy_spadd_handle();
1337c17cf699SJunchao Zhang   }
13389566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
13390e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(xkok->a_dual.extent(0) * 2)); // Because we scaled X and then added it to Y
13403ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1341a587d139SMark }
1342a587d139SMark 
13432c4ab24aSJunchao Zhang struct MatCOOStruct_SeqAIJKokkos {
13442c4ab24aSJunchao Zhang   PetscCount           n;
13452c4ab24aSJunchao Zhang   PetscCount           Atot;
13462c4ab24aSJunchao Zhang   PetscInt             nz;
13472c4ab24aSJunchao Zhang   PetscCountKokkosView jmap;
13482c4ab24aSJunchao Zhang   PetscCountKokkosView perm;
13492c4ab24aSJunchao Zhang 
13502c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos(const MatCOOStruct_SeqAIJ *coo_h)
13512c4ab24aSJunchao Zhang   {
13522c4ab24aSJunchao Zhang     nz   = coo_h->nz;
13532c4ab24aSJunchao Zhang     n    = coo_h->n;
13542c4ab24aSJunchao Zhang     Atot = coo_h->Atot;
13552c4ab24aSJunchao Zhang     jmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->jmap, nz + 1));
13562c4ab24aSJunchao Zhang     perm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->perm, Atot));
13572c4ab24aSJunchao Zhang   }
13582c4ab24aSJunchao Zhang };
13592c4ab24aSJunchao Zhang 
1360*2a8381b2SBarry Smith static PetscErrorCode MatCOOStructDestroy_SeqAIJKokkos(PetscCtxRt data)
13612c4ab24aSJunchao Zhang {
13622c4ab24aSJunchao Zhang   PetscFunctionBegin;
1363*2a8381b2SBarry Smith   PetscCallCXX(delete *static_cast<MatCOOStruct_SeqAIJKokkos **>(data));
13642c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
13652c4ab24aSJunchao Zhang }
13662c4ab24aSJunchao Zhang 
1367d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_SeqAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1368d71ae5a4SJacob Faibussowitsch {
136942550becSJunchao Zhang   Mat_SeqAIJKokkos          *akok;
137042550becSJunchao Zhang   Mat_SeqAIJ                *aseq;
137103e76207SPierre Jolivet   PetscContainer             container_h;
13722c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJ       *coo_h;
13732c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo_d;
137442550becSJunchao Zhang 
137542550becSJunchao Zhang   PetscFunctionBegin;
13769566063dSJacob Faibussowitsch   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, coo_i, coo_j));
1377394ed5ebSJunchao Zhang   aseq = static_cast<Mat_SeqAIJ *>(mat->data);
137842550becSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(mat->spptr);
1379cbc6b225SStefano Zampini   delete akok;
1380421480d9SBarry Smith   mat->spptr = akok = new Mat_SeqAIJKokkos(mat, mat->rmap->n, mat->cmap->n, aseq, mat->nonzerostate + 1, PETSC_FALSE);
13819566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries_SeqAIJKokkos(mat));
13822c4ab24aSJunchao Zhang 
13832c4ab24aSJunchao Zhang   // Copy the COO struct to device
13842c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
1385*2a8381b2SBarry Smith   PetscCall(PetscContainerGetPointer(container_h, &coo_h));
13862c4ab24aSJunchao Zhang   PetscCallCXX(coo_d = new MatCOOStruct_SeqAIJKokkos(coo_h));
13872c4ab24aSJunchao Zhang 
13882c4ab24aSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
138903e76207SPierre Jolivet   PetscCall(PetscObjectContainerCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", coo_d, MatCOOStructDestroy_SeqAIJKokkos));
13903ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
139142550becSJunchao Zhang }
139242550becSJunchao Zhang 
1393d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_SeqAIJKokkos(Mat A, const PetscScalar v[], InsertMode imode)
1394d71ae5a4SJacob Faibussowitsch {
139542550becSJunchao Zhang   MatScalarKokkosView        Aa;
139642550becSJunchao Zhang   ConstMatScalarKokkosView   kv;
139742550becSJunchao Zhang   PetscMemType               memtype;
13982c4ab24aSJunchao Zhang   PetscContainer             container;
13992c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo;
140042550becSJunchao Zhang 
140142550becSJunchao Zhang   PetscFunctionBegin;
14022c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)A, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
1403*2a8381b2SBarry Smith   PetscCall(PetscContainerGetPointer(container, &coo));
14042c4ab24aSJunchao Zhang 
14052c4ab24aSJunchao Zhang   const auto &n    = coo->n;
14062c4ab24aSJunchao Zhang   const auto &Annz = coo->nz;
14072c4ab24aSJunchao Zhang   const auto &jmap = coo->jmap;
14082c4ab24aSJunchao Zhang   const auto &perm = coo->perm;
14092c4ab24aSJunchao Zhang 
14109566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype));
141142550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
14122c4ab24aSJunchao Zhang     kv = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatScalarKokkosViewHost(v, n));
141342550becSJunchao Zhang   } else {
14142c4ab24aSJunchao Zhang     kv = ConstMatScalarKokkosView(v, n); /* Directly use v[]'s memory */
141542550becSJunchao Zhang   }
141642550becSJunchao Zhang 
1417c7b718f4SJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1418c7b718f4SJunchao Zhang   else PetscCall(MatSeqAIJGetKokkosView(A, &Aa));                             /* read & write matrix values */
141942550becSJunchao Zhang 
142008bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
14219371c9d4SSatish Balay   Kokkos::parallel_for(
1422d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz), KOKKOS_LAMBDA(const PetscCount i) {
1423c7b718f4SJunchao Zhang       PetscScalar sum = 0.0;
1424c7b718f4SJunchao Zhang       for (PetscCount k = jmap(i); k < jmap(i + 1); k++) sum += kv(perm(k));
1425c7b718f4SJunchao Zhang       Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1426c7b718f4SJunchao Zhang     });
142708bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1428394ed5ebSJunchao Zhang 
14299566063dSJacob Faibussowitsch   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
14309566063dSJacob Faibussowitsch   else PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
14313ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
143242550becSJunchao Zhang }
143342550becSJunchao Zhang 
1434d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat A)
1435d71ae5a4SJacob Faibussowitsch {
1436076ba34aSJunchao Zhang   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
1437076ba34aSJunchao Zhang 
14388c3ff71bSJunchao Zhang   PetscFunctionBegin;
1439076ba34aSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_KOKKOS; /* We do not really use this flag */
14406f3d89d0SStefano Zampini   A->boundtocpu  = PETSC_FALSE;
14416f3d89d0SStefano Zampini 
14428c3ff71bSJunchao Zhang   A->ops->assemblyend               = MatAssemblyEnd_SeqAIJKokkos;
14438c3ff71bSJunchao Zhang   A->ops->destroy                   = MatDestroy_SeqAIJKokkos;
14448c3ff71bSJunchao Zhang   A->ops->duplicate                 = MatDuplicate_SeqAIJKokkos;
1445a587d139SMark   A->ops->axpy                      = MatAXPY_SeqAIJKokkos;
1446f0cf5187SStefano Zampini   A->ops->scale                     = MatScale_SeqAIJKokkos;
1447a587d139SMark   A->ops->zeroentries               = MatZeroEntries_SeqAIJKokkos;
1448076ba34aSJunchao Zhang   A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJKokkos;
14498c3ff71bSJunchao Zhang   A->ops->mult                      = MatMult_SeqAIJKokkos;
14508c3ff71bSJunchao Zhang   A->ops->multadd                   = MatMultAdd_SeqAIJKokkos;
14518c3ff71bSJunchao Zhang   A->ops->multtranspose             = MatMultTranspose_SeqAIJKokkos;
14528c3ff71bSJunchao Zhang   A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJKokkos;
14538c3ff71bSJunchao Zhang   A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJKokkos;
14548c3ff71bSJunchao Zhang   A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJKokkos;
1455076ba34aSJunchao Zhang   A->ops->productnumeric            = MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos;
14560ecb592aSJunchao Zhang   A->ops->transpose                 = MatTranspose_SeqAIJKokkos;
1457152b3e56SJunchao Zhang   A->ops->setoption                 = MatSetOption_SeqAIJKokkos;
1458f78ce678SMark Adams   A->ops->getdiagonal               = MatGetDiagonal_SeqAIJKokkos;
1459f4747e26SJunchao Zhang   A->ops->shift                     = MatShift_SeqAIJKokkos;
1460f4747e26SJunchao Zhang   A->ops->diagonalset               = MatDiagonalSet_SeqAIJKokkos;
1461f4747e26SJunchao Zhang   A->ops->diagonalscale             = MatDiagonalScale_SeqAIJKokkos;
146203db1824SAlex Lindsay   A->ops->getcurrentmemtype         = MatGetCurrentMemType_SeqAIJKokkos;
1463076ba34aSJunchao Zhang   a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJKokkos;
1464076ba34aSJunchao Zhang   a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJKokkos;
1465076ba34aSJunchao Zhang   a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJKokkos;
1466076ba34aSJunchao Zhang   a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJKokkos;
1467076ba34aSJunchao Zhang   a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJKokkos;
1468076ba34aSJunchao Zhang   a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJKokkos;
14697ee59b9bSJunchao Zhang   a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJKokkos;
147042550becSJunchao Zhang 
14719566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJKokkos));
14729566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJKokkos));
147357761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
147457761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijkokkos_hypre_C", MatConvert_AIJ_HYPRE));
147557761e9aSJunchao Zhang #endif
14763ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1477076ba34aSJunchao Zhang }
1478076ba34aSJunchao Zhang 
14799d13fa56SJunchao Zhang /*
14809d13fa56SJunchao Zhang    Extract the (prescribled) diagonal blocks of the matrix and then invert them
14819d13fa56SJunchao Zhang 
14829d13fa56SJunchao Zhang   Input Parameters:
14839d13fa56SJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
14849d13fa56SJunchao Zhang .  bs      - block sizes in 'csr' format, i.e., the i-th block has size bs(i+1) - bs(i)
14859d13fa56SJunchao Zhang .  bs2     - square of block sizes in 'csr' format, i.e., the i-th block should be stored at offset bs2(i) in diagVal[]
14869d13fa56SJunchao Zhang .  blkMap  - map row ids to block ids, i.e., row i belongs to the block blkMap(i)
14879d13fa56SJunchao Zhang -  work    - a pre-allocated work buffer (as big as diagVal) for use by this routine
14889d13fa56SJunchao Zhang 
14899d13fa56SJunchao Zhang   Output Parameter:
14909d13fa56SJunchao Zhang .  diagVal - the (pre-allocated) buffer to store the inverted blocks (each block is stored in column-major order)
14919d13fa56SJunchao Zhang */
14929d13fa56SJunchao Zhang PETSC_INTERN PetscErrorCode MatInvertVariableBlockDiagonal_SeqAIJKokkos(Mat A, const PetscIntKokkosView &bs, const PetscIntKokkosView &bs2, const PetscIntKokkosView &blkMap, PetscScalarKokkosView &work, PetscScalarKokkosView &diagVal)
14939d13fa56SJunchao Zhang {
14949d13fa56SJunchao Zhang   Mat_SeqAIJKokkos *akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
14959d13fa56SJunchao Zhang   PetscInt          nblocks = bs.extent(0) - 1;
14969d13fa56SJunchao Zhang 
14979d13fa56SJunchao Zhang   PetscFunctionBegin;
14989d13fa56SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A)); // Since we'll access A's value on device
14999d13fa56SJunchao Zhang 
15009d13fa56SJunchao Zhang   // Pull out the diagonal blocks of the matrix and then invert the blocks
15019d13fa56SJunchao Zhang   auto Aa    = akok->a_dual.view_device();
15029d13fa56SJunchao Zhang   auto Ai    = akok->i_dual.view_device();
15039d13fa56SJunchao Zhang   auto Aj    = akok->j_dual.view_device();
15049d13fa56SJunchao Zhang   auto Adiag = akok->diag_dual.view_device();
15059d13fa56SJunchao Zhang   // TODO: how to tune the team size?
150645402d8aSJunchao Zhang #if defined(KOKKOS_ENABLE_UNIFIED_MEMORY)
15079d13fa56SJunchao Zhang   auto ts = Kokkos::AUTO();
15089d13fa56SJunchao Zhang #else
15099d13fa56SJunchao Zhang   auto ts = 16; // improved performance 30% over Kokkos::AUTO() with CUDA, but failed with "Kokkos::abort: Requested Team Size is too large!" on CPUs
15109d13fa56SJunchao Zhang #endif
15119d13fa56SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1512d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), nblocks, ts), KOKKOS_LAMBDA(const KokkosTeamMemberType &teamMember) {
15139d13fa56SJunchao Zhang       const PetscInt bid    = teamMember.league_rank();                                                   // block id
15149d13fa56SJunchao Zhang       const PetscInt rstart = bs(bid);                                                                    // this block starts from this row
15159d13fa56SJunchao Zhang       const PetscInt m      = bs(bid + 1) - bs(bid);                                                      // size of this block
15169d13fa56SJunchao Zhang       const auto    &B      = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft>(&diagVal(bs2(bid)), m, m); // column-major order
15179d13fa56SJunchao Zhang       const auto    &W      = PetscScalarKokkosView(&work(bs2(bid)), m * m);
15189d13fa56SJunchao Zhang 
15199d13fa56SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(teamMember, m), [=](const PetscInt &r) { // r-th row in B
15209d13fa56SJunchao Zhang         PetscInt i = rstart + r;                                                            // i-th row in A
15219d13fa56SJunchao Zhang 
15229d13fa56SJunchao Zhang         if (Ai(i) <= Adiag(i) && Adiag(i) < Ai(i + 1)) { // if the diagonal exists (common case)
15239d13fa56SJunchao Zhang           PetscInt first = Adiag(i) - r;                 // we start to check nonzeros from here along this row
15249d13fa56SJunchao Zhang 
15259d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) {                   // walk n steps to see what column indices we will meet
15269d13fa56SJunchao Zhang             if (first + c < Ai(i) || first + c >= Ai(i + 1)) { // this entry (first+c) is out of range of this row, in other words, its value is zero
15279d13fa56SJunchao Zhang               B(r, c) = 0.0;
15289d13fa56SJunchao Zhang             } else if (Aj(first + c) == rstart + c) { // this entry is right on the (rstart+c) column
15299d13fa56SJunchao Zhang               B(r, c) = Aa(first + c);
15309d13fa56SJunchao Zhang             } else { // this entry does not show up in the CSR
15319d13fa56SJunchao Zhang               B(r, c) = 0.0;
15329d13fa56SJunchao Zhang             }
15339d13fa56SJunchao Zhang           }
15349d13fa56SJunchao Zhang         } else { // rare case that the diagonal does not exist
15359d13fa56SJunchao Zhang           const PetscInt begin = Ai(i);
15369d13fa56SJunchao Zhang           const PetscInt end   = Ai(i + 1);
15379d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) B(r, c) = 0.0;
15389d13fa56SJunchao Zhang           for (PetscInt j = begin; j < end; j++) { // scan the whole row; could use binary search but this is a rare case so we did not.
15399d13fa56SJunchao Zhang             if (rstart <= Aj(j) && Aj(j) < rstart + m) B(r, Aj(j) - rstart) = Aa(j);
15409d13fa56SJunchao Zhang             else if (Aj(j) >= rstart + m) break;
15419d13fa56SJunchao Zhang           }
15429d13fa56SJunchao Zhang         }
15439d13fa56SJunchao Zhang       });
15449d13fa56SJunchao Zhang 
15459d13fa56SJunchao Zhang       // LU-decompose B (w/o pivoting) and then invert B
15469d13fa56SJunchao Zhang       KokkosBatched::TeamLU<KokkosTeamMemberType, KokkosBatched::Algo::LU::Unblocked>::invoke(teamMember, B, 0.0);
15479d13fa56SJunchao Zhang       KokkosBatched::TeamInverseLU<KokkosTeamMemberType, KokkosBatched::Algo::InverseLU::Unblocked>::invoke(teamMember, B, W);
15489d13fa56SJunchao Zhang     }));
15499d13fa56SJunchao Zhang   // PetscLogGpuFlops() is done in the caller PCSetUp_VPBJacobi_Kokkos as we don't want to compute the flops in kernels
15509d13fa56SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15519d13fa56SJunchao Zhang }
15529d13fa56SJunchao Zhang 
1553d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSetSeqAIJKokkosWithCSRMatrix(Mat A, Mat_SeqAIJKokkos *akok)
1554d71ae5a4SJacob Faibussowitsch {
1555076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
1556076ba34aSJunchao Zhang   PetscInt    i, m, n;
15574df4a32cSJunchao Zhang   auto        exec = PetscGetKokkosExecutionSpace();
1558076ba34aSJunchao Zhang 
1559076ba34aSJunchao Zhang   PetscFunctionBegin;
15605f80ce2aSJacob Faibussowitsch   PetscCheck(!A->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A->spptr is supposed to be empty");
1561076ba34aSJunchao Zhang 
1562076ba34aSJunchao Zhang   m = akok->nrows();
1563076ba34aSJunchao Zhang   n = akok->ncols();
15649566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(A, m, n, m, n));
15659566063dSJacob Faibussowitsch   PetscCall(MatSetType(A, MATSEQAIJKOKKOS));
1566076ba34aSJunchao Zhang 
1567076ba34aSJunchao Zhang   /* Set up data structures of A as a MATSEQAIJ */
15689566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(A, MAT_SKIP_ALLOCATION, NULL));
156957508eceSPierre Jolivet   aseq = (Mat_SeqAIJ *)A->data;
1570076ba34aSJunchao Zhang 
1571f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncHost(akok->i_dual, exec)); /* We always need sync'ed i, j on host */
1572f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncHost(akok->j_dual, exec));
1573076ba34aSJunchao Zhang 
1574076ba34aSJunchao Zhang   aseq->i       = akok->i_host_data();
1575076ba34aSJunchao Zhang   aseq->j       = akok->j_host_data();
1576076ba34aSJunchao Zhang   aseq->a       = akok->a_host_data();
1577076ba34aSJunchao Zhang   aseq->nonew   = -1; /*this indicates that inserting a new value in the matrix that generates a new nonzero is an error*/
1578076ba34aSJunchao Zhang   aseq->free_a  = PETSC_FALSE;
1579076ba34aSJunchao Zhang   aseq->free_ij = PETSC_FALSE;
1580076ba34aSJunchao Zhang   aseq->nz      = akok->nnz();
1581076ba34aSJunchao Zhang   aseq->maxnz   = aseq->nz;
1582076ba34aSJunchao Zhang 
15839566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->imax));
15849566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->ilen));
1585ad540459SPierre Jolivet   for (i = 0; i < m; i++) aseq->ilen[i] = aseq->imax[i] = aseq->i[i + 1] - aseq->i[i];
1586076ba34aSJunchao Zhang 
1587076ba34aSJunchao Zhang   /* It is critical to set the nonzerostate, as we use it to check if sparsity pattern (hence data) has changed on host in MatAssemblyEnd */
1588076ba34aSJunchao Zhang   akok->nonzerostate = A->nonzerostate;
1589ff751488SJunchao Zhang   A->spptr           = akok; /* Set A->spptr before MatAssembly so that A->spptr won't be allocated again there */
15909566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
15919566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
15923ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1593076ba34aSJunchao Zhang }
1594076ba34aSJunchao Zhang 
15950e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetKokkosCsrMatrix(Mat A, KokkosCsrMatrix *csr)
15960e3ece09SJunchao Zhang {
15970e3ece09SJunchao Zhang   PetscFunctionBegin;
15980e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
15990e3ece09SJunchao Zhang   *csr = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
16000e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
16010e3ece09SJunchao Zhang }
16020e3ece09SJunchao Zhang 
16030e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithKokkosCsrMatrix(MPI_Comm comm, KokkosCsrMatrix csr, Mat *A)
16040e3ece09SJunchao Zhang {
16050e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok;
16064d86920dSPierre Jolivet 
16070e3ece09SJunchao Zhang   PetscFunctionBegin;
16080e3ece09SJunchao Zhang   PetscCallCXX(akok = new Mat_SeqAIJKokkos(csr));
16090e3ece09SJunchao Zhang   PetscCall(MatCreate(comm, A));
16100e3ece09SJunchao Zhang   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
16110e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
16120e3ece09SJunchao Zhang }
16130e3ece09SJunchao Zhang 
1614076ba34aSJunchao Zhang /* Crete a SEQAIJKOKKOS matrix with a Mat_SeqAIJKokkos data structure
1615076ba34aSJunchao Zhang 
1616076ba34aSJunchao Zhang    Note we have names like MatSeqAIJSetPreallocationCSR, so I use capitalized CSR
1617076ba34aSJunchao Zhang  */
1618d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithCSRMatrix(MPI_Comm comm, Mat_SeqAIJKokkos *akok, Mat *A)
1619d71ae5a4SJacob Faibussowitsch {
1620076ba34aSJunchao Zhang   PetscFunctionBegin;
16219566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
16229566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
16233ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16248c3ff71bSJunchao Zhang }
16258c3ff71bSJunchao Zhang 
1626152b3e56SJunchao Zhang /*@C
162711a5261eSBarry Smith   MatCreateSeqAIJKokkos - Creates a sparse matrix in `MATSEQAIJKOKKOS` (compressed row) format
16288c3ff71bSJunchao Zhang   (the default parallel PETSc format). This matrix will ultimately be handled by
162920f4b53cSBarry Smith   Kokkos for calculations.
16308c3ff71bSJunchao Zhang 
16318c3ff71bSJunchao Zhang   Collective
16328c3ff71bSJunchao Zhang 
16338c3ff71bSJunchao Zhang   Input Parameters:
163411a5261eSBarry Smith + comm - MPI communicator, set to `PETSC_COMM_SELF`
16358c3ff71bSJunchao Zhang . m    - number of rows
16368c3ff71bSJunchao Zhang . n    - number of columns
163720f4b53cSBarry Smith . nz   - number of nonzeros per row (same for all rows), ignored if `nnz` is provided
163820f4b53cSBarry Smith - nnz  - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL`
16398c3ff71bSJunchao Zhang 
16408c3ff71bSJunchao Zhang   Output Parameter:
16418c3ff71bSJunchao Zhang . A - the matrix
16428c3ff71bSJunchao Zhang 
16432ef1f0ffSBarry Smith   Level: intermediate
16442ef1f0ffSBarry Smith 
16452ef1f0ffSBarry Smith   Notes:
164611a5261eSBarry Smith   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
16478c3ff71bSJunchao Zhang   MatXXXXSetPreallocation() paradgm instead of this routine directly.
164811a5261eSBarry Smith   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
16498c3ff71bSJunchao Zhang 
165011a5261eSBarry Smith   The AIJ format, also called
16512ef1f0ffSBarry Smith   compressed row storage, is fully compatible with standard Fortran
16528c3ff71bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
165320f4b53cSBarry Smith   either one (as in Fortran) or zero.
16548c3ff71bSJunchao Zhang 
16552ef1f0ffSBarry Smith   Specify the preallocated storage with either `nz` or `nnz` (not both).
16562ef1f0ffSBarry Smith   Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
16572ef1f0ffSBarry Smith   allocation.
16588c3ff71bSJunchao Zhang 
1659fe59aa6dSJacob Faibussowitsch .seealso: [](ch_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
16608c3ff71bSJunchao Zhang @*/
1661d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateSeqAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
1662d71ae5a4SJacob Faibussowitsch {
16638c3ff71bSJunchao Zhang   PetscFunctionBegin;
16649566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
16659566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
16669566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, m, n));
16679566063dSJacob Faibussowitsch   PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
16689566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
16693ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16708c3ff71bSJunchao Zhang }
1671930e68a5SMark Adams 
1672aac854edSJunchao Zhang // After matrix numeric factorization, there are still steps to do before triangular solve can be called.
1673aac854edSJunchao Zhang // For example, for transpose solve, we might need to compute the transpose matrices if the solver does not support it (such as KK, while cusparse does).
1674aac854edSJunchao Zhang // In cusparse, one has to call cusparseSpSV_analysis() with updated triangular matrix values before calling cusparseSpSV_solve().
1675aac854edSJunchao Zhang // Simiarily, in KK sptrsv_symbolic() has to be called before sptrsv_solve(). We put these steps in MatSeqAIJKokkos{Transpose}SolveCheck.
1676aac854edSJunchao Zhang static PetscErrorCode MatSeqAIJKokkosSolveCheck(Mat A)
1677d71ae5a4SJacob Faibussowitsch {
167886a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors   = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1679aac854edSJunchao Zhang   const PetscBool             has_lower = factors->iL_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // false with Choleksy
1680aac854edSJunchao Zhang   const PetscBool             has_upper = factors->iU_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // true with LU and Choleksy
168186a27549SJunchao Zhang 
168286a27549SJunchao Zhang   PetscFunctionBegin;
1683aac854edSJunchao Zhang   if (!factors->sptrsv_symbolic_completed) { // If sptrsv_symbolic was not called yet
1684aac854edSJunchao Zhang     if (has_upper) PetscCallCXX(sptrsv_symbolic(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d));
1685aac854edSJunchao Zhang     if (has_lower) PetscCallCXX(sptrsv_symbolic(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d));
168686a27549SJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_TRUE;
168786a27549SJunchao Zhang   }
16883ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
168986a27549SJunchao Zhang }
169086a27549SJunchao Zhang 
1691d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosTransposeSolveCheck(Mat A)
1692d71ae5a4SJacob Faibussowitsch {
1693aac854edSJunchao Zhang   const PetscInt              n         = A->rmap->n;
169486a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors   = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1695aac854edSJunchao Zhang   const PetscBool             has_lower = factors->iL_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // false with Choleksy
1696aac854edSJunchao Zhang   const PetscBool             has_upper = factors->iU_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // true with LU or Choleksy
169786a27549SJunchao Zhang 
169886a27549SJunchao Zhang   PetscFunctionBegin;
1699aac854edSJunchao Zhang   if (!factors->transpose_updated) {
1700aac854edSJunchao Zhang     if (has_upper) {
1701aac854edSJunchao Zhang       if (!factors->iUt_d.extent(0)) {                                 // Allocate Ut on device if not yet
1702aac854edSJunchao Zhang         factors->iUt_d = MatRowMapKokkosView("factors->iUt_d", n + 1); // KK requires this view to be initialized to 0 to call transpose_matrix
17037b8d4ba6SJunchao Zhang         factors->jUt_d = MatColIdxKokkosView(NoInit("factors->jUt_d"), factors->jU_d.extent(0));
17047b8d4ba6SJunchao Zhang         factors->aUt_d = MatScalarKokkosView(NoInit("factors->aUt_d"), factors->aU_d.extent(0));
1705aac854edSJunchao Zhang       }
170686a27549SJunchao Zhang 
1707aac854edSJunchao Zhang       if (factors->iU_h.extent(0)) { // If U is on host (factorization was done on host), we also compute the transpose on host
1708aac854edSJunchao Zhang         if (!factors->U) {
1709aac854edSJunchao Zhang           Mat_SeqAIJ *seq;
171086a27549SJunchao Zhang 
1711aac854edSJunchao Zhang           PetscCall(MatCreateSeqAIJWithArrays(PETSC_COMM_SELF, n, n, factors->iU_h.data(), factors->jU_h.data(), factors->aU_h.data(), &factors->U));
1712aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->U, MAT_INITIAL_MATRIX, &factors->Ut));
171386a27549SJunchao Zhang 
1714aac854edSJunchao Zhang           seq            = static_cast<Mat_SeqAIJ *>(factors->Ut->data);
1715aac854edSJunchao Zhang           factors->iUt_h = MatRowMapKokkosViewHost(seq->i, n + 1);
1716aac854edSJunchao Zhang           factors->jUt_h = MatColIdxKokkosViewHost(seq->j, seq->nz);
1717aac854edSJunchao Zhang           factors->aUt_h = MatScalarKokkosViewHost(seq->a, seq->nz);
1718aac854edSJunchao Zhang         } else {
1719aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->U, MAT_REUSE_MATRIX, &factors->Ut)); // Matrix Ut' data is aliased with {i, j, a}Ut_h
1720aac854edSJunchao Zhang         }
1721aac854edSJunchao Zhang         // Copy Ut from host to device
1722aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->iUt_d, factors->iUt_h));
1723aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->jUt_d, factors->jUt_h));
1724aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->aUt_d, factors->aUt_h));
1725aac854edSJunchao Zhang       } else { // If U was computed on device, we also compute the transpose there
1726aac854edSJunchao Zhang         // TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices. We have to sort the indices, until KK provides finer control options.
1727aac854edSJunchao Zhang         PetscCallCXX(transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iU_d,
1728aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->jU_d, factors->aU_d,
1729aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->iUt_d, factors->jUt_d,
1730aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->aUt_d));
1731aac854edSJunchao Zhang         PetscCallCXX(sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iUt_d, factors->jUt_d, factors->aUt_d));
1732aac854edSJunchao Zhang       }
1733aac854edSJunchao Zhang       PetscCallCXX(sptrsv_symbolic(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d));
1734aac854edSJunchao Zhang     }
1735aac854edSJunchao Zhang 
1736aac854edSJunchao Zhang     // do the same for L with LU
1737aac854edSJunchao Zhang     if (has_lower) {
1738aac854edSJunchao Zhang       if (!factors->iLt_d.extent(0)) {                                 // Allocate Lt on device if not yet
1739aac854edSJunchao Zhang         factors->iLt_d = MatRowMapKokkosView("factors->iLt_d", n + 1); // KK requires this view to be initialized to 0 to call transpose_matrix
1740aac854edSJunchao Zhang         factors->jLt_d = MatColIdxKokkosView(NoInit("factors->jLt_d"), factors->jL_d.extent(0));
1741aac854edSJunchao Zhang         factors->aLt_d = MatScalarKokkosView(NoInit("factors->aLt_d"), factors->aL_d.extent(0));
1742aac854edSJunchao Zhang       }
1743aac854edSJunchao Zhang 
1744aac854edSJunchao Zhang       if (factors->iL_h.extent(0)) { // If L is on host, we also compute the transpose on host
1745aac854edSJunchao Zhang         if (!factors->L) {
1746aac854edSJunchao Zhang           Mat_SeqAIJ *seq;
1747aac854edSJunchao Zhang 
1748aac854edSJunchao Zhang           PetscCall(MatCreateSeqAIJWithArrays(PETSC_COMM_SELF, n, n, factors->iL_h.data(), factors->jL_h.data(), factors->aL_h.data(), &factors->L));
1749aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->L, MAT_INITIAL_MATRIX, &factors->Lt));
1750aac854edSJunchao Zhang 
1751aac854edSJunchao Zhang           seq            = static_cast<Mat_SeqAIJ *>(factors->Lt->data);
1752aac854edSJunchao Zhang           factors->iLt_h = MatRowMapKokkosViewHost(seq->i, n + 1);
1753aac854edSJunchao Zhang           factors->jLt_h = MatColIdxKokkosViewHost(seq->j, seq->nz);
1754aac854edSJunchao Zhang           factors->aLt_h = MatScalarKokkosViewHost(seq->a, seq->nz);
1755aac854edSJunchao Zhang         } else {
1756aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->L, MAT_REUSE_MATRIX, &factors->Lt)); // Matrix Lt' data is aliased with {i, j, a}Lt_h
1757aac854edSJunchao Zhang         }
1758aac854edSJunchao Zhang         // Copy Lt from host to device
1759aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->iLt_d, factors->iLt_h));
1760aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->jLt_d, factors->jLt_h));
1761aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->aLt_d, factors->aLt_h));
1762aac854edSJunchao Zhang       } else { // If L was computed on device, we also compute the transpose there
1763aac854edSJunchao Zhang         // TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices. We have to sort the indices, until KK provides finer control options.
1764aac854edSJunchao Zhang         PetscCallCXX(transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iL_d,
1765aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->jL_d, factors->aL_d,
1766aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->iLt_d, factors->jLt_d,
1767aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->aLt_d));
1768aac854edSJunchao Zhang         PetscCallCXX(sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iLt_d, factors->jLt_d, factors->aLt_d));
1769aac854edSJunchao Zhang       }
1770aac854edSJunchao Zhang       PetscCallCXX(sptrsv_symbolic(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d));
1771aac854edSJunchao Zhang     }
1772aac854edSJunchao Zhang 
177386a27549SJunchao Zhang     factors->transpose_updated = PETSC_TRUE;
177486a27549SJunchao Zhang   }
17753ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
177686a27549SJunchao Zhang }
177786a27549SJunchao Zhang 
1778aac854edSJunchao Zhang // Solve Ax = b, with RAR = U^T D U, where R is the row (and col) permutation matrix on A.
1779aac854edSJunchao Zhang // R is represented by rowperm in factors. If R is identity (i.e, no reordering), then rowperm is empty.
1780aac854edSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJKokkos_Cholesky(Mat A, Vec bb, Vec xx)
1781d71ae5a4SJacob Faibussowitsch {
1782aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
178386a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1784aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1785aac854edSJunchao Zhang   PetscScalarKokkosView       D       = factors->D_d;
1786aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1787aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1788aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1789aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm  = factors->rowperm;
1790aac854edSJunchao Zhang   PetscBool                   identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
179186a27549SJunchao Zhang 
179286a27549SJunchao Zhang   PetscFunctionBegin;
17939566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1794aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSolveCheck(A));          // for UX = T
1795aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A)); // for U^T Y = B
1796aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1797aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
1798aac854edSJunchao Zhang 
1799aac854edSJunchao Zhang   // Solve U^T Y = B
1800aac854edSJunchao Zhang   if (identity) { // Reorder b with the row permutation
1801aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1802aac854edSJunchao Zhang     Y = factors->workVector;
1803aac854edSJunchao Zhang   } else {
1804aac854edSJunchao Zhang     B = factors->workVector;
1805aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(rowperm(i)); }));
1806aac854edSJunchao Zhang     Y = x;
1807aac854edSJunchao Zhang   }
1808aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, B, Y));
1809aac854edSJunchao Zhang 
1810aac854edSJunchao Zhang   // Solve diag(D) Y' = Y.
1811aac854edSJunchao Zhang   // Actually just do Y' = Y*D since D is already inverted in MatCholeskyFactorNumeric_SeqAIJ(). It is basically a vector element-wise multiplication.
1812aac854edSJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { Y(i) = Y(i) * D(i); }));
1813aac854edSJunchao Zhang 
1814aac854edSJunchao Zhang   // Solve UX = Y
1815aac854edSJunchao Zhang   if (identity) {
1816aac854edSJunchao Zhang     X = x;
1817aac854edSJunchao Zhang   } else {
1818aac854edSJunchao Zhang     X = factors->workVector; // B is not needed anymore
1819aac854edSJunchao Zhang   }
1820aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, Y, X));
1821aac854edSJunchao Zhang 
1822aac854edSJunchao Zhang   // Reorder X with the inverse column (row) permutation
18233a7d0413SPierre Jolivet   if (!identity) PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(rowperm(i)) = X(i); }));
1824aac854edSJunchao Zhang 
1825aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1826aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
18279566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
18283ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
182986a27549SJunchao Zhang }
183086a27549SJunchao Zhang 
1831aac854edSJunchao Zhang // Solve Ax = b, with RAC = LU, where R and C are row and col permutation matrices on A respectively.
1832aac854edSJunchao Zhang // R and C are represented by rowperm and colperm in factors.
1833aac854edSJunchao Zhang // If R or C is identity (i.e, no reordering), then rowperm or colperm is empty.
1834aac854edSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJKokkos_LU(Mat A, Vec bb, Vec xx)
1835d71ae5a4SJacob Faibussowitsch {
1836aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
183786a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1838aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1839aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1840aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1841aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1842aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm      = factors->rowperm;
1843aac854edSJunchao Zhang   PetscIntKokkosView         &colperm      = factors->colperm;
1844aac854edSJunchao Zhang   PetscBool                   row_identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1845aac854edSJunchao Zhang   PetscBool                   col_identity = colperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
184686a27549SJunchao Zhang 
184786a27549SJunchao Zhang   PetscFunctionBegin;
18489566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1849aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSolveCheck(A));
1850aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1851aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
185286a27549SJunchao Zhang 
1853aac854edSJunchao Zhang   // Solve L Y = B (i.e., L (U C^- x) = R b).  R b indicates applying the row permutation on b.
1854aac854edSJunchao Zhang   if (row_identity) {
1855aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1856aac854edSJunchao Zhang     Y = factors->workVector;
1857aac854edSJunchao Zhang   } else {
1858aac854edSJunchao Zhang     B = factors->workVector;
1859aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(rowperm(i)); }));
1860aac854edSJunchao Zhang     Y = x;
1861aac854edSJunchao Zhang   }
1862aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khL, factors->iL_d, factors->jL_d, factors->aL_d, B, Y));
1863aac854edSJunchao Zhang 
1864aac854edSJunchao Zhang   // Solve U C^- x = Y
1865aac854edSJunchao Zhang   if (col_identity) {
1866aac854edSJunchao Zhang     X = x;
1867aac854edSJunchao Zhang   } else {
1868aac854edSJunchao Zhang     X = factors->workVector;
1869aac854edSJunchao Zhang   }
1870aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, Y, X));
1871aac854edSJunchao Zhang 
1872aac854edSJunchao Zhang   // x = C X; Reorder X with the inverse col permutation
18733a7d0413SPierre Jolivet   if (!col_identity) PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(colperm(i)) = X(i); }));
1874aac854edSJunchao Zhang 
1875aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1876aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
18779566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
18783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
187986a27549SJunchao Zhang }
188086a27549SJunchao Zhang 
1881aac854edSJunchao Zhang // Solve A^T x = b, with RAC = LU, where R and C are row and col permutation matrices on A respectively.
1882aac854edSJunchao Zhang // R and C are represented by rowperm and colperm in factors.
1883aac854edSJunchao Zhang // If R or C is identity (i.e, no reordering), then rowperm or colperm is empty.
1884aac854edSJunchao Zhang // A = R^-1 L U C^-1, so A^T = C^-T U^T L^T R^-T. But since C^- = C^T, R^- = R^T, we have A^T = C U^T L^T R.
1885aac854edSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJKokkos_LU(Mat A, Vec bb, Vec xx)
1886aac854edSJunchao Zhang {
1887aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
1888aac854edSJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1889aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1890aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1891aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1892aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1893aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm      = factors->rowperm;
1894aac854edSJunchao Zhang   PetscIntKokkosView         &colperm      = factors->colperm;
1895aac854edSJunchao Zhang   PetscBool                   row_identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1896aac854edSJunchao Zhang   PetscBool                   col_identity = colperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1897aac854edSJunchao Zhang 
1898aac854edSJunchao Zhang   PetscFunctionBegin;
1899aac854edSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1900aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A)); // Update L^T, U^T if needed, and do sptrsv symbolic for L^T, U^T
1901aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1902aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
1903aac854edSJunchao Zhang 
1904aac854edSJunchao Zhang   // Solve U^T Y = B (i.e., U^T (L^T R x) = C^- b).  Note C^- b = C^T b, which means applying the column permutation on b.
1905aac854edSJunchao Zhang   if (col_identity) { // Reorder b with the col permutation
1906aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1907aac854edSJunchao Zhang     Y = factors->workVector;
1908aac854edSJunchao Zhang   } else {
1909aac854edSJunchao Zhang     B = factors->workVector;
1910aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(colperm(i)); }));
1911aac854edSJunchao Zhang     Y = x;
1912aac854edSJunchao Zhang   }
1913aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, B, Y));
1914aac854edSJunchao Zhang 
1915aac854edSJunchao Zhang   // Solve L^T X = Y
1916aac854edSJunchao Zhang   if (row_identity) {
1917aac854edSJunchao Zhang     X = x;
1918aac854edSJunchao Zhang   } else {
1919aac854edSJunchao Zhang     X = factors->workVector;
1920aac854edSJunchao Zhang   }
1921aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d, Y, X));
1922aac854edSJunchao Zhang 
1923aac854edSJunchao Zhang   // x = R^- X = R^T X; Reorder X with the inverse row permutation
19243a7d0413SPierre Jolivet   if (!row_identity) PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(rowperm(i)) = X(i); }));
1925aac854edSJunchao Zhang 
1926aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1927aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
1928aac854edSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1929aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1930aac854edSJunchao Zhang }
1931aac854edSJunchao Zhang 
1932aac854edSJunchao Zhang static PetscErrorCode MatLUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1933aac854edSJunchao Zhang {
1934aac854edSJunchao Zhang   PetscFunctionBegin;
1935aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncHost(A));
1936aac854edSJunchao Zhang   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
1937aac854edSJunchao Zhang 
1938aac854edSJunchao Zhang   if (!info->solveonhost) { // if solve on host, then we don't need to copy L, U to device
1939aac854edSJunchao Zhang     Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
1940aac854edSJunchao Zhang     Mat_SeqAIJ                 *b       = static_cast<Mat_SeqAIJ *>(B->data);
1941aac854edSJunchao Zhang     const PetscInt             *Bi = b->i, *Bj = b->j, *Bdiag = b->diag;
1942aac854edSJunchao Zhang     const MatScalar            *Ba = b->a;
1943aac854edSJunchao Zhang     PetscInt                    m = B->rmap->n, n = B->cmap->n;
1944aac854edSJunchao Zhang 
1945aac854edSJunchao Zhang     if (factors->iL_h.extent(0) == 0) { // Allocate memory and copy the L, U structure for the first time
1946aac854edSJunchao Zhang       // Allocate memory and copy the structure
1947aac854edSJunchao Zhang       factors->iL_h = MatRowMapKokkosViewHost(NoInit("iL_h"), m + 1);
1948aac854edSJunchao Zhang       factors->jL_h = MatColIdxKokkosViewHost(NoInit("jL_h"), (Bi[m] - Bi[0]) + m); // + the diagonal entries
1949aac854edSJunchao Zhang       factors->aL_h = MatScalarKokkosViewHost(NoInit("aL_h"), (Bi[m] - Bi[0]) + m);
1950aac854edSJunchao Zhang       factors->iU_h = MatRowMapKokkosViewHost(NoInit("iU_h"), m + 1);
1951aac854edSJunchao Zhang       factors->jU_h = MatColIdxKokkosViewHost(NoInit("jU_h"), (Bdiag[0] - Bdiag[m]));
1952aac854edSJunchao Zhang       factors->aU_h = MatScalarKokkosViewHost(NoInit("aU_h"), (Bdiag[0] - Bdiag[m]));
1953aac854edSJunchao Zhang 
1954aac854edSJunchao Zhang       PetscInt *Li = factors->iL_h.data();
1955aac854edSJunchao Zhang       PetscInt *Lj = factors->jL_h.data();
1956aac854edSJunchao Zhang       PetscInt *Ui = factors->iU_h.data();
1957aac854edSJunchao Zhang       PetscInt *Uj = factors->jU_h.data();
1958aac854edSJunchao Zhang 
1959aac854edSJunchao Zhang       Li[0] = Ui[0] = 0;
1960aac854edSJunchao Zhang       for (PetscInt i = 0; i < m; i++) {
1961aac854edSJunchao Zhang         PetscInt llen = Bi[i + 1] - Bi[i];       // exclusive of the diagonal entry
1962aac854edSJunchao Zhang         PetscInt ulen = Bdiag[i] - Bdiag[i + 1]; // inclusive of the diagonal entry
1963aac854edSJunchao Zhang 
196464dc1d19SNuno Nobre         PetscCall(PetscArraycpy(Lj + Li[i], Bj + Bi[i], llen)); // entries of L on the left of the diagonal
1965aac854edSJunchao Zhang         Lj[Li[i] + llen] = i;                                   // diagonal entry of L
1966aac854edSJunchao Zhang 
1967aac854edSJunchao Zhang         Uj[Ui[i]] = i;                                                             // diagonal entry of U
196864dc1d19SNuno Nobre         PetscCall(PetscArraycpy(Uj + Ui[i] + 1, Bj + Bdiag[i + 1] + 1, ulen - 1)); // entries of U on  the right of the diagonal
1969aac854edSJunchao Zhang 
1970aac854edSJunchao Zhang         Li[i + 1] = Li[i] + llen + 1;
1971aac854edSJunchao Zhang         Ui[i + 1] = Ui[i] + ulen;
1972aac854edSJunchao Zhang       }
1973aac854edSJunchao Zhang 
1974aac854edSJunchao Zhang       factors->iL_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iL_h);
1975aac854edSJunchao Zhang       factors->jL_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jL_h);
1976aac854edSJunchao Zhang       factors->iU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iU_h);
1977aac854edSJunchao Zhang       factors->jU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jU_h);
1978aac854edSJunchao Zhang       factors->aL_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aL_h);
1979aac854edSJunchao Zhang       factors->aU_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aU_h);
1980aac854edSJunchao Zhang 
1981aac854edSJunchao Zhang       // Copy row/col permutation to device
1982aac854edSJunchao Zhang       IS        rowperm = ((Mat_SeqAIJ *)B->data)->row;
1983aac854edSJunchao Zhang       PetscBool row_identity;
1984aac854edSJunchao Zhang       PetscCall(ISIdentity(rowperm, &row_identity));
1985aac854edSJunchao Zhang       if (!row_identity) {
1986aac854edSJunchao Zhang         const PetscInt *ip;
1987aac854edSJunchao Zhang 
1988aac854edSJunchao Zhang         PetscCall(ISGetIndices(rowperm, &ip));
1989aac854edSJunchao Zhang         factors->rowperm = PetscIntKokkosView(NoInit("rowperm"), m);
1990aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->rowperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), m)));
1991aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(rowperm, &ip));
1992aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(m * sizeof(PetscInt)));
1993aac854edSJunchao Zhang       }
1994aac854edSJunchao Zhang 
1995aac854edSJunchao Zhang       IS        colperm = ((Mat_SeqAIJ *)B->data)->col;
1996aac854edSJunchao Zhang       PetscBool col_identity;
1997aac854edSJunchao Zhang       PetscCall(ISIdentity(colperm, &col_identity));
1998aac854edSJunchao Zhang       if (!col_identity) {
1999aac854edSJunchao Zhang         const PetscInt *ip;
2000aac854edSJunchao Zhang 
2001aac854edSJunchao Zhang         PetscCall(ISGetIndices(colperm, &ip));
2002aac854edSJunchao Zhang         factors->colperm = PetscIntKokkosView(NoInit("colperm"), n);
2003aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->colperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), n)));
2004aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(colperm, &ip));
2005aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));
2006aac854edSJunchao Zhang       }
2007aac854edSJunchao Zhang 
2008aac854edSJunchao Zhang       /* Create sptrsv handles for L, U and their transpose */
2009aac854edSJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2010aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
2011aac854edSJunchao Zhang #else
2012aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
2013aac854edSJunchao Zhang #endif
2014aac854edSJunchao Zhang       factors->khL.create_sptrsv_handle(sptrsv_alg, m, true /* L is lower tri */);
2015aac854edSJunchao Zhang       factors->khU.create_sptrsv_handle(sptrsv_alg, m, false /* U is not lower tri */);
2016aac854edSJunchao Zhang       factors->khLt.create_sptrsv_handle(sptrsv_alg, m, false /* L^T is not lower tri */);
2017aac854edSJunchao Zhang       factors->khUt.create_sptrsv_handle(sptrsv_alg, m, true /* U^T is lower tri */);
2018aac854edSJunchao Zhang     }
2019aac854edSJunchao Zhang 
2020aac854edSJunchao Zhang     // Copy the value
2021aac854edSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
2022aac854edSJunchao Zhang       PetscInt        llen = Bi[i + 1] - Bi[i];
2023aac854edSJunchao Zhang       PetscInt        ulen = Bdiag[i] - Bdiag[i + 1];
2024aac854edSJunchao Zhang       const PetscInt *Li   = factors->iL_h.data();
2025aac854edSJunchao Zhang       const PetscInt *Ui   = factors->iU_h.data();
2026aac854edSJunchao Zhang 
2027aac854edSJunchao Zhang       PetscScalar *La = factors->aL_h.data();
2028aac854edSJunchao Zhang       PetscScalar *Ua = factors->aU_h.data();
2029aac854edSJunchao Zhang 
203064dc1d19SNuno Nobre       PetscCall(PetscArraycpy(La + Li[i], Ba + Bi[i], llen)); // entries of L
2031aac854edSJunchao Zhang       La[Li[i] + llen] = 1.0;                                 // diagonal entry
2032aac854edSJunchao Zhang 
2033aac854edSJunchao Zhang       Ua[Ui[i]] = 1.0 / Ba[Bdiag[i]];                                            // diagonal entry
203464dc1d19SNuno Nobre       PetscCall(PetscArraycpy(Ua + Ui[i] + 1, Ba + Bdiag[i + 1] + 1, ulen - 1)); // entries of U
2035aac854edSJunchao Zhang     }
2036aac854edSJunchao Zhang 
2037aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aL_d, factors->aL_h));
2038aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aU_d, factors->aU_h));
2039aac854edSJunchao Zhang     // Once the factors' values have changed, we need to update their transpose and redo sptrsv symbolic
2040aac854edSJunchao Zhang     factors->transpose_updated         = PETSC_FALSE;
2041aac854edSJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_FALSE;
2042aac854edSJunchao Zhang 
2043aac854edSJunchao Zhang     B->ops->solve          = MatSolve_SeqAIJKokkos_LU;
2044aac854edSJunchao Zhang     B->ops->solvetranspose = MatSolveTranspose_SeqAIJKokkos_LU;
2045aac854edSJunchao Zhang   }
2046aac854edSJunchao Zhang 
2047aac854edSJunchao Zhang   B->ops->matsolve          = NULL;
2048aac854edSJunchao Zhang   B->ops->matsolvetranspose = NULL;
2049aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2050aac854edSJunchao Zhang }
2051aac854edSJunchao Zhang 
2052aac854edSJunchao Zhang static PetscErrorCode MatILUFactorNumeric_SeqAIJKokkos_ILU0(Mat B, Mat A, const MatFactorInfo *info)
2053d71ae5a4SJacob Faibussowitsch {
205486a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
205586a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
205686a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
205786a27549SJunchao Zhang 
205886a27549SJunchao Zhang   PetscFunctionBegin;
20599566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
2060aac854edSJunchao Zhang   PetscCheck(!info->factoronhost, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "MatFactorInfo.factoronhost should be false");
20619566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
2062076ba34aSJunchao Zhang 
2063076ba34aSJunchao Zhang   auto a_d = aijkok->a_dual.view_device();
2064076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
2065076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
2066076ba34aSJunchao Zhang 
2067aac854edSJunchao Zhang   PetscCallCXX(spiluk_numeric(&factors->kh, fill_lev, i_d, j_d, a_d, factors->iL_d, factors->jL_d, factors->aL_d, factors->iU_d, factors->jU_d, factors->aU_d));
206886a27549SJunchao Zhang 
206986a27549SJunchao Zhang   B->assembled              = PETSC_TRUE;
207086a27549SJunchao Zhang   B->preallocated           = PETSC_TRUE;
2071aac854edSJunchao Zhang   B->ops->solve             = MatSolve_SeqAIJKokkos_LU;
2072aac854edSJunchao Zhang   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJKokkos_LU;
207386a27549SJunchao Zhang   B->ops->matsolve          = NULL;
207486a27549SJunchao Zhang   B->ops->matsolvetranspose = NULL;
207586a27549SJunchao Zhang 
207686a27549SJunchao Zhang   /* Once the factors' value changed, we need to update their transpose and sptrsv handle */
207786a27549SJunchao Zhang   factors->transpose_updated         = PETSC_FALSE;
207886a27549SJunchao Zhang   factors->sptrsv_symbolic_completed = PETSC_FALSE;
2079eeadb341SJunchao Zhang   /* TODO: log flops, but how to know that? */
20809566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
20813ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
208286a27549SJunchao Zhang }
208386a27549SJunchao Zhang 
2084aac854edSJunchao Zhang // Use KK's spiluk_symbolic() to do ILU0 symbolic factorization, with no row/col reordering
2085aac854edSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos_ILU0(Mat B, Mat A, IS, IS, const MatFactorInfo *info)
2086d71ae5a4SJacob Faibussowitsch {
208786a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok;
208886a27549SJunchao Zhang   Mat_SeqAIJ                 *b;
208986a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
209086a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
209186a27549SJunchao Zhang   PetscInt                    nnzA     = ((Mat_SeqAIJ *)A->data)->nz, nnzL, nnzU;
209286a27549SJunchao Zhang   PetscInt                    n        = A->rmap->n;
209386a27549SJunchao Zhang 
209486a27549SJunchao Zhang   PetscFunctionBegin;
2095aac854edSJunchao Zhang   PetscCheck(!info->factoronhost, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "MatFactorInfo's factoronhost should be false as we are doing it on device right now");
20969566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
209786a27549SJunchao Zhang 
209886a27549SJunchao Zhang   /* Create a spiluk handle and then do symbolic factorization */
209986a27549SJunchao Zhang   nnzL = nnzU = PetscRealIntMultTruncate(info->fill, nnzA);
2100aac854edSJunchao Zhang   factors->kh.create_spiluk_handle(SPILUKAlgorithm::SEQLVLSCHD_TP1, n, nnzL, nnzU);
210186a27549SJunchao Zhang 
210286a27549SJunchao Zhang   auto spiluk_handle = factors->kh.get_spiluk_handle();
210386a27549SJunchao Zhang 
210486a27549SJunchao Zhang   Kokkos::realloc(factors->iL_d, n + 1); /* Free old arrays and realloc */
210586a27549SJunchao Zhang   Kokkos::realloc(factors->jL_d, spiluk_handle->get_nnzL());
210686a27549SJunchao Zhang   Kokkos::realloc(factors->iU_d, n + 1);
210786a27549SJunchao Zhang   Kokkos::realloc(factors->jU_d, spiluk_handle->get_nnzU());
210886a27549SJunchao Zhang 
210986a27549SJunchao Zhang   aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
2110076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
2111076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
2112aac854edSJunchao Zhang   PetscCallCXX(spiluk_symbolic(&factors->kh, fill_lev, i_d, j_d, factors->iL_d, factors->jL_d, factors->iU_d, factors->jU_d));
211386a27549SJunchao Zhang   /* TODO: if spiluk_symbolic is asynchronous, do we need to sync before calling get_nnzL()? */
211486a27549SJunchao Zhang 
211586a27549SJunchao Zhang   Kokkos::resize(factors->jL_d, spiluk_handle->get_nnzL()); /* Shrink or expand, and retain old value */
211686a27549SJunchao Zhang   Kokkos::resize(factors->jU_d, spiluk_handle->get_nnzU());
211786a27549SJunchao Zhang   Kokkos::realloc(factors->aL_d, spiluk_handle->get_nnzL()); /* No need to retain old value */
211886a27549SJunchao Zhang   Kokkos::realloc(factors->aU_d, spiluk_handle->get_nnzU());
211986a27549SJunchao Zhang 
212086a27549SJunchao Zhang   /* TODO: add options to select sptrsv algorithms */
212186a27549SJunchao Zhang   /* Create sptrsv handles for L, U and their transpose */
212286a27549SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2123aac854edSJunchao Zhang   auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
212486a27549SJunchao Zhang #else
2125aac854edSJunchao Zhang   auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
212686a27549SJunchao Zhang #endif
212786a27549SJunchao Zhang 
212886a27549SJunchao Zhang   factors->khL.create_sptrsv_handle(sptrsv_alg, n, true /* L is lower tri */);
212986a27549SJunchao Zhang   factors->khU.create_sptrsv_handle(sptrsv_alg, n, false /* U is not lower tri */);
213086a27549SJunchao Zhang   factors->khLt.create_sptrsv_handle(sptrsv_alg, n, false /* L^T is not lower tri */);
213186a27549SJunchao Zhang   factors->khUt.create_sptrsv_handle(sptrsv_alg, n, true /* U^T is lower tri */);
213286a27549SJunchao Zhang 
213386a27549SJunchao Zhang   /* Fill fields of the factor matrix B */
21349566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(B, MAT_SKIP_ALLOCATION, NULL));
213586a27549SJunchao Zhang   b     = (Mat_SeqAIJ *)B->data;
213686a27549SJunchao Zhang   b->nz = b->maxnz          = spiluk_handle->get_nnzL() + spiluk_handle->get_nnzU();
213786a27549SJunchao Zhang   B->info.fill_ratio_given  = info->fill;
2138a1e4e190SStefano Zampini   B->info.fill_ratio_needed = nnzA > 0 ? ((PetscReal)b->nz) / ((PetscReal)nnzA) : 1.0;
213986a27549SJunchao Zhang 
2140aac854edSJunchao Zhang   B->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJKokkos_ILU0;
21413ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2142930e68a5SMark Adams }
2143930e68a5SMark Adams 
2144aac854edSJunchao Zhang static PetscErrorCode MatLUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
2145aac854edSJunchao Zhang {
2146aac854edSJunchao Zhang   PetscFunctionBegin;
2147aac854edSJunchao Zhang   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
2148aac854edSJunchao Zhang   PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2149aac854edSJunchao Zhang   PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2150aac854edSJunchao Zhang   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
2151aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2152aac854edSJunchao Zhang }
2153aac854edSJunchao Zhang 
2154aac854edSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
2155aac854edSJunchao Zhang {
2156aac854edSJunchao Zhang   PetscBool row_identity = PETSC_FALSE, col_identity = PETSC_FALSE;
2157aac854edSJunchao Zhang 
2158aac854edSJunchao Zhang   PetscFunctionBegin;
2159aac854edSJunchao Zhang   if (!info->factoronhost) {
2160aac854edSJunchao Zhang     PetscCall(ISIdentity(isrow, &row_identity));
2161aac854edSJunchao Zhang     PetscCall(ISIdentity(iscol, &col_identity));
2162aac854edSJunchao Zhang   }
2163aac854edSJunchao Zhang 
2164aac854edSJunchao Zhang   PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2165aac854edSJunchao Zhang   PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2166aac854edSJunchao Zhang 
2167aac854edSJunchao Zhang   if (!info->factoronhost && !info->levels && row_identity && col_identity) { // if level 0 and no reordering
2168aac854edSJunchao Zhang     PetscCall(MatILUFactorSymbolic_SeqAIJKokkos_ILU0(B, A, isrow, iscol, info));
2169aac854edSJunchao Zhang   } else {
2170aac854edSJunchao Zhang     PetscCall(MatILUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info)); // otherwise, use PETSc's ILU on host
2171aac854edSJunchao Zhang     B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
2172aac854edSJunchao Zhang   }
2173aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2174aac854edSJunchao Zhang }
2175aac854edSJunchao Zhang 
2176aac854edSJunchao Zhang static PetscErrorCode MatCholeskyFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
2177aac854edSJunchao Zhang {
2178aac854edSJunchao Zhang   PetscFunctionBegin;
2179aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncHost(A));
2180aac854edSJunchao Zhang   PetscCall(MatCholeskyFactorNumeric_SeqAIJ(B, A, info));
2181aac854edSJunchao Zhang 
2182aac854edSJunchao Zhang   if (!info->solveonhost) { // if solve on host, then we don't need to copy L, U to device
2183aac854edSJunchao Zhang     Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
2184aac854edSJunchao Zhang     Mat_SeqAIJ                 *b       = static_cast<Mat_SeqAIJ *>(B->data);
2185aac854edSJunchao Zhang     const PetscInt             *Bi = b->i, *Bj = b->j, *Bdiag = b->diag;
2186aac854edSJunchao Zhang     const MatScalar            *Ba = b->a;
2187aac854edSJunchao Zhang     PetscInt                    m  = B->rmap->n;
2188aac854edSJunchao Zhang 
2189aac854edSJunchao Zhang     if (factors->iU_h.extent(0) == 0) { // First time of numeric factorization
2190aac854edSJunchao Zhang       // Allocate memory and copy the structure
2191aac854edSJunchao Zhang       factors->iU_h = PetscIntKokkosViewHost(const_cast<PetscInt *>(Bi), m + 1); // wrap Bi as iU_h
2192aac854edSJunchao Zhang       factors->jU_h = MatColIdxKokkosViewHost(NoInit("jU_h"), Bi[m]);
2193aac854edSJunchao Zhang       factors->aU_h = MatScalarKokkosViewHost(NoInit("aU_h"), Bi[m]);
2194aac854edSJunchao Zhang       factors->D_h  = MatScalarKokkosViewHost(NoInit("D_h"), m);
2195aac854edSJunchao Zhang       factors->aU_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aU_h);
2196aac854edSJunchao Zhang       factors->D_d  = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->D_h);
2197aac854edSJunchao Zhang 
2198aac854edSJunchao Zhang       // Build jU_h from the skewed Aj
2199aac854edSJunchao Zhang       PetscInt *Uj = factors->jU_h.data();
2200aac854edSJunchao Zhang       for (PetscInt i = 0; i < m; i++) {
2201aac854edSJunchao Zhang         PetscInt ulen = Bi[i + 1] - Bi[i];
2202aac854edSJunchao Zhang         Uj[Bi[i]]     = i;                                              // diagonal entry
2203aac854edSJunchao Zhang         PetscCall(PetscArraycpy(Uj + Bi[i] + 1, Bj + Bi[i], ulen - 1)); // entries of U on the right of the diagonal
2204aac854edSJunchao Zhang       }
2205aac854edSJunchao Zhang 
2206aac854edSJunchao Zhang       // Copy iU, jU to device
2207aac854edSJunchao Zhang       PetscCallCXX(factors->iU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iU_h));
2208aac854edSJunchao Zhang       PetscCallCXX(factors->jU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jU_h));
2209aac854edSJunchao Zhang 
2210aac854edSJunchao Zhang       // Copy row/col permutation to device
2211aac854edSJunchao Zhang       IS        rowperm = ((Mat_SeqAIJ *)B->data)->row;
2212aac854edSJunchao Zhang       PetscBool row_identity;
2213aac854edSJunchao Zhang       PetscCall(ISIdentity(rowperm, &row_identity));
2214aac854edSJunchao Zhang       if (!row_identity) {
2215aac854edSJunchao Zhang         const PetscInt *ip;
2216aac854edSJunchao Zhang 
2217aac854edSJunchao Zhang         PetscCall(ISGetIndices(rowperm, &ip));
2218aac854edSJunchao Zhang         PetscCallCXX(factors->rowperm = PetscIntKokkosView(NoInit("rowperm"), m));
2219aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->rowperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), m)));
2220aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(rowperm, &ip));
2221aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(m * sizeof(PetscInt)));
2222aac854edSJunchao Zhang       }
2223aac854edSJunchao Zhang 
2224aac854edSJunchao Zhang       // Create sptrsv handles for U and U^T
2225aac854edSJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2226aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
2227aac854edSJunchao Zhang #else
2228aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
2229aac854edSJunchao Zhang #endif
2230aac854edSJunchao Zhang       factors->khU.create_sptrsv_handle(sptrsv_alg, m, false /* U is not lower tri */);
2231aac854edSJunchao Zhang       factors->khUt.create_sptrsv_handle(sptrsv_alg, m, true /* U^T is lower tri */);
2232aac854edSJunchao Zhang     }
2233aac854edSJunchao Zhang     // These pointers were set MatCholeskyFactorNumeric_SeqAIJ(), so we always need to update them
2234aac854edSJunchao Zhang     B->ops->solve          = MatSolve_SeqAIJKokkos_Cholesky;
2235aac854edSJunchao Zhang     B->ops->solvetranspose = MatSolve_SeqAIJKokkos_Cholesky;
2236aac854edSJunchao Zhang 
2237aac854edSJunchao Zhang     // Copy the value
2238aac854edSJunchao Zhang     PetscScalar *Ua = factors->aU_h.data();
2239aac854edSJunchao Zhang     PetscScalar *D  = factors->D_h.data();
2240aac854edSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
2241aac854edSJunchao Zhang       D[i]      = Ba[Bdiag[i]];     // actually Aa[Adiag[i]] is the inverse of the diagonal
2242aac854edSJunchao Zhang       Ua[Bi[i]] = (PetscScalar)1.0; // set the unit diagonal for U
2243aac854edSJunchao Zhang       for (PetscInt k = 0; k < Bi[i + 1] - Bi[i] - 1; k++) Ua[Bi[i] + 1 + k] = -Ba[Bi[i] + k];
2244aac854edSJunchao Zhang     }
2245aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aU_d, factors->aU_h));
2246aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->D_d, factors->D_h));
2247aac854edSJunchao Zhang 
2248aac854edSJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_FALSE; // When numeric value changed, we must do these again
2249aac854edSJunchao Zhang     factors->transpose_updated         = PETSC_FALSE;
2250aac854edSJunchao Zhang   }
2251aac854edSJunchao Zhang 
2252aac854edSJunchao Zhang   B->ops->matsolve          = NULL;
2253aac854edSJunchao Zhang   B->ops->matsolvetranspose = NULL;
2254aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2255aac854edSJunchao Zhang }
2256aac854edSJunchao Zhang 
2257aac854edSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS perm, const MatFactorInfo *info)
2258aac854edSJunchao Zhang {
2259aac854edSJunchao Zhang   PetscFunctionBegin;
2260aac854edSJunchao Zhang   if (info->solveonhost) {
2261aac854edSJunchao Zhang     // If solve on host, we have to change the type, as eventually we need to call MatSolve_SeqSBAIJ_1_NaturalOrdering() etc.
2262aac854edSJunchao Zhang     PetscCall(MatSetType(B, MATSEQSBAIJ));
2263aac854edSJunchao Zhang     PetscCall(MatSeqSBAIJSetPreallocation(B, 1, MAT_SKIP_ALLOCATION, NULL));
2264aac854edSJunchao Zhang   }
2265aac854edSJunchao Zhang 
2266aac854edSJunchao Zhang   PetscCall(MatICCFactorSymbolic_SeqAIJ(B, A, perm, info));
2267aac854edSJunchao Zhang 
2268aac854edSJunchao Zhang   if (!info->solveonhost) {
2269bfe80ac4SPierre Jolivet     // If solve on device, B is still a MATSEQAIJKOKKOS, so we are good to allocate B->spptr
2270aac854edSJunchao Zhang     PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2271aac854edSJunchao Zhang     PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2272aac854edSJunchao Zhang     B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJKokkos;
2273aac854edSJunchao Zhang   }
2274aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2275aac854edSJunchao Zhang }
2276aac854edSJunchao Zhang 
2277aac854edSJunchao Zhang static PetscErrorCode MatCholeskyFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS perm, const MatFactorInfo *info)
2278aac854edSJunchao Zhang {
2279aac854edSJunchao Zhang   PetscFunctionBegin;
2280aac854edSJunchao Zhang   if (info->solveonhost) {
2281aac854edSJunchao Zhang     // If solve on host, we have to change the type, as eventually we need to call MatSolve_SeqSBAIJ_1_NaturalOrdering() etc.
2282aac854edSJunchao Zhang     PetscCall(MatSetType(B, MATSEQSBAIJ));
2283aac854edSJunchao Zhang     PetscCall(MatSeqSBAIJSetPreallocation(B, 1, MAT_SKIP_ALLOCATION, NULL));
2284aac854edSJunchao Zhang   }
2285aac854edSJunchao Zhang 
2286aac854edSJunchao Zhang   PetscCall(MatCholeskyFactorSymbolic_SeqAIJ(B, A, perm, info)); // it sets B's two ISes ((Mat_SeqAIJ*)B->data)->{row, col} to perm
2287aac854edSJunchao Zhang 
2288aac854edSJunchao Zhang   if (!info->solveonhost) {
2289bfe80ac4SPierre Jolivet     // If solve on device, B is still a MATSEQAIJKOKKOS, so we are good to allocate B->spptr
2290aac854edSJunchao Zhang     PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2291aac854edSJunchao Zhang     PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2292aac854edSJunchao Zhang     B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJKokkos;
2293aac854edSJunchao Zhang   }
2294aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2295aac854edSJunchao Zhang }
2296aac854edSJunchao Zhang 
2297aac854edSJunchao Zhang // The _Kokkos suffix means we will use Kokkos as a solver for the SeqAIJKokkos matrix
2298aac854edSJunchao Zhang static PetscErrorCode MatFactorGetSolverType_SeqAIJKokkos_Kokkos(Mat A, MatSolverType *type)
2299d71ae5a4SJacob Faibussowitsch {
2300930e68a5SMark Adams   PetscFunctionBegin;
2301930e68a5SMark Adams   *type = MATSOLVERKOKKOS;
23023ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2303930e68a5SMark Adams }
2304930e68a5SMark Adams 
2305930e68a5SMark Adams /*MC
230686a27549SJunchao Zhang   MATSOLVERKOKKOS = "Kokkos" - A matrix solver type providing triangular solvers for sequential matrices
230711a5261eSBarry Smith   on a single GPU of type, `MATSEQAIJKOKKOS`, `MATAIJKOKKOS`.
2308930e68a5SMark Adams 
2309930e68a5SMark Adams   Level: beginner
2310930e68a5SMark Adams 
2311e255e071SSatish Balay .seealso: [](ch_matrices), `Mat`, `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJKokkos()`, `MATAIJKOKKOS`
2312930e68a5SMark Adams M*/
231386a27549SJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_SeqAIJKokkos_Kokkos(Mat A, MatFactorType ftype, Mat *B) /* MatGetFactor_<MatType>_<MatSolverType> */
2314930e68a5SMark Adams {
2315930e68a5SMark Adams   PetscInt n = A->rmap->n;
2316aac854edSJunchao Zhang   MPI_Comm comm;
2317930e68a5SMark Adams 
2318930e68a5SMark Adams   PetscFunctionBegin;
2319aac854edSJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
2320aac854edSJunchao Zhang   PetscCall(MatCreate(comm, B));
23219566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*B, n, n, n, n));
2322aac854edSJunchao Zhang   PetscCall(MatSetBlockSizesFromMats(*B, A, A));
2323930e68a5SMark Adams   (*B)->factortype = ftype;
23249566063dSJacob Faibussowitsch   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
23259566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
2326aac854edSJunchao Zhang   PetscCheck(!(*B)->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2327aac854edSJunchao Zhang 
2328aac854edSJunchao Zhang   if (ftype == MAT_FACTOR_LU || ftype == MAT_FACTOR_ILU || ftype == MAT_FACTOR_ILUDT) {
2329aac854edSJunchao Zhang     (*B)->ops->lufactorsymbolic  = MatLUFactorSymbolic_SeqAIJKokkos;
2330aac854edSJunchao Zhang     (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJKokkos;
2331aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
2332aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILU]));
2333aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILUDT]));
2334aac854edSJunchao Zhang   } else if (ftype == MAT_FACTOR_CHOLESKY || ftype == MAT_FACTOR_ICC) {
2335aac854edSJunchao Zhang     (*B)->ops->iccfactorsymbolic      = MatICCFactorSymbolic_SeqAIJKokkos;
2336aac854edSJunchao Zhang     (*B)->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqAIJKokkos;
2337aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_CHOLESKY]));
2338aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ICC]));
2339aac854edSJunchao Zhang   } else SETERRQ(comm, PETSC_ERR_SUP, "MatFactorType %s is not supported by MatType SeqAIJKokkos", MatFactorTypes[ftype]);
2340aac854edSJunchao Zhang 
2341aac854edSJunchao Zhang   // The factorization can use the ordering provided in MatLUFactorSymbolic(), MatCholeskyFactorSymbolic() etc, though we do it on host
2342aac854edSJunchao Zhang   (*B)->canuseordering = PETSC_TRUE;
2343aac854edSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)*B, "MatFactorGetSolverType_C", MatFactorGetSolverType_SeqAIJKokkos_Kokkos));
23443ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2345930e68a5SMark Adams }
23468f7e8f9dSMark Adams 
2347aac854edSJunchao Zhang PETSC_INTERN PetscErrorCode MatSolverTypeRegister_Kokkos(void)
2348d71ae5a4SJacob Faibussowitsch {
234986a27549SJunchao Zhang   PetscFunctionBegin;
23509566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_SeqAIJKokkos_Kokkos));
2351aac854edSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_CHOLESKY, MatGetFactor_SeqAIJKokkos_Kokkos));
23529566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ILU, MatGetFactor_SeqAIJKokkos_Kokkos));
2353aac854edSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ICC, MatGetFactor_SeqAIJKokkos_Kokkos));
23543ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
235586a27549SJunchao Zhang }
235686a27549SJunchao Zhang 
2357076ba34aSJunchao Zhang /* Utility to print out a KokkosCsrMatrix for debugging */
2358d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PrintCsrMatrix(const KokkosCsrMatrix &csrmat)
2359d71ae5a4SJacob Faibussowitsch {
236045402d8aSJunchao Zhang   const auto        &iv = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.graph.row_map);
236145402d8aSJunchao Zhang   const auto        &jv = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.graph.entries);
236245402d8aSJunchao Zhang   const auto        &av = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.values);
2363076ba34aSJunchao Zhang   const PetscInt    *i  = iv.data();
2364076ba34aSJunchao Zhang   const PetscInt    *j  = jv.data();
2365076ba34aSJunchao Zhang   const PetscScalar *a  = av.data();
2366076ba34aSJunchao Zhang   PetscInt           m = csrmat.numRows(), n = csrmat.numCols(), nnz = csrmat.nnz();
2367076ba34aSJunchao Zhang 
2368076ba34aSJunchao Zhang   PetscFunctionBegin;
23699566063dSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT " x %" PetscInt_FMT " SeqAIJKokkos, with %" PetscInt_FMT " nonzeros\n", m, n, nnz));
2370076ba34aSJunchao Zhang   for (PetscInt k = 0; k < m; k++) {
23719566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT ": ", k));
237248a46eb9SPierre Jolivet     for (PetscInt p = i[k]; p < i[k + 1]; p++) PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT "(%.1f), ", j[p], (double)PetscRealPart(a[p])));
23739566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "\n"));
2374076ba34aSJunchao Zhang   }
23753ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2376076ba34aSJunchao Zhang }
2377