xref: /petsc/src/mat/impls/aij/seq/kokkos/aijkok.kokkos.cxx (revision d1c799ffc2c2dd0945dfd53da7d3f7c32cb9db4c)
1e36ced11SJunchao Zhang #include <petsc_kokkos.hpp>
211d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
3c0c276a7Ssdargavi #include <petscmat_kokkos.hpp>
4076ba34aSJunchao Zhang #include <petscpkg_version.h>
5152b3e56SJunchao Zhang #include <petsc/private/petscimpl.h>
642550becSJunchao Zhang #include <petsc/private/sfimpl.h>
7aac854edSJunchao Zhang #include <petsc/private/kokkosimpl.hpp>
87233ce55SJed Brown #include <petscsys.h>
98c3ff71bSJunchao Zhang 
108c3ff71bSJunchao Zhang #include <Kokkos_Core.hpp>
11f0cf5187SStefano Zampini #include <KokkosBlas.hpp>
128c3ff71bSJunchao Zhang #include <KokkosSparse_CrsMatrix.hpp>
13cc6e31f1SJunchao Zhang 
14cc6e31f1SJunchao Zhang // To suppress compiler warnings:
15cc6e31f1SJunchao Zhang // /path/include/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp:434:63:
16cc6e31f1SJunchao Zhang // warning: 'cusparseStatus_t cusparseDbsrmm(cusparseHandle_t, cusparseDirection_t, cusparseOperation_t,
17cc6e31f1SJunchao Zhang // cusparseOperation_t, int, int, int, int, const double*, cusparseMatDescr_t, const double*, const int*, const int*,
18cc6e31f1SJunchao Zhang // int, const double*, int, const double*, double*, int)' is deprecated: please use cusparseSpMM instead [-Wdeprecated-declarations]
19cc6e31f1SJunchao Zhang PETSC_PRAGMA_DIAGNOSTIC_IGNORED_BEGIN("-Wdeprecated-declarations")
208c3ff71bSJunchao Zhang #include <KokkosSparse_spmv.hpp>
21cc6e31f1SJunchao Zhang PETSC_PRAGMA_DIAGNOSTIC_IGNORED_END()
22cc6e31f1SJunchao Zhang 
2386a27549SJunchao Zhang #include <KokkosSparse_spiluk.hpp>
2486a27549SJunchao Zhang #include <KokkosSparse_sptrsv.hpp>
25076ba34aSJunchao Zhang #include <KokkosSparse_spgemm.hpp>
26076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
279d13fa56SJunchao Zhang #include <KokkosBatched_LU_Decl.hpp>
289d13fa56SJunchao Zhang #include <KokkosBatched_InverseLU_Decl.hpp>
2986a27549SJunchao Zhang 
3042550becSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
318c3ff71bSJunchao Zhang 
320e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 7, 0)
33f98996d3SJunchao Zhang   #include <KokkosSparse_Utils.hpp>
34f98996d3SJunchao Zhang using KokkosSparse::sort_crs_matrix;
359371c9d4SSatish Balay using KokkosSparse::Impl::transpose_matrix;
36f98996d3SJunchao Zhang #else
37f98996d3SJunchao Zhang   #include <KokkosKernels_Sorting.hpp>
38f98996d3SJunchao Zhang using KokkosKernels::sort_crs_matrix;
399371c9d4SSatish Balay using KokkosKernels::Impl::transpose_matrix;
40f98996d3SJunchao Zhang #endif
41f98996d3SJunchao Zhang 
42aac854edSJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(4, 6, 0)
43aac854edSJunchao Zhang using KokkosSparse::spiluk_symbolic;
44aac854edSJunchao Zhang using KokkosSparse::spiluk_numeric;
45aac854edSJunchao Zhang using KokkosSparse::sptrsv_symbolic;
46aac854edSJunchao Zhang using KokkosSparse::sptrsv_solve;
47aac854edSJunchao Zhang using KokkosSparse::Experimental::SPTRSVAlgorithm;
48aac854edSJunchao Zhang using KokkosSparse::Experimental::SPILUKAlgorithm;
49aac854edSJunchao Zhang #else
50aac854edSJunchao Zhang using KokkosSparse::Experimental::spiluk_symbolic;
51aac854edSJunchao Zhang using KokkosSparse::Experimental::spiluk_numeric;
52aac854edSJunchao Zhang using KokkosSparse::Experimental::sptrsv_symbolic;
53aac854edSJunchao Zhang using KokkosSparse::Experimental::sptrsv_solve;
54aac854edSJunchao Zhang using KokkosSparse::Experimental::SPTRSVAlgorithm;
55aac854edSJunchao Zhang using KokkosSparse::Experimental::SPILUKAlgorithm;
56aac854edSJunchao Zhang #endif
57aac854edSJunchao Zhang 
588c3ff71bSJunchao Zhang static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat); /* Forward declaration */
598c3ff71bSJunchao Zhang 
60076ba34aSJunchao Zhang /* MatAssemblyEnd_SeqAIJKokkos() happens when we finalized nonzeros of the matrix, either after
61076ba34aSJunchao Zhang    we assembled the matrix on host, or after we directly produced the matrix data on device (ex., through MatMatMult).
62076ba34aSJunchao Zhang    In the latter case, it is important to set a_dual's sync state correctly.
63076ba34aSJunchao Zhang  */
64d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_SeqAIJKokkos(Mat A, MatAssemblyType mode)
65d71ae5a4SJacob Faibussowitsch {
66076ba34aSJunchao Zhang   Mat_SeqAIJ       *aijseq;
67076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
688c3ff71bSJunchao Zhang 
698c3ff71bSJunchao Zhang   PetscFunctionBegin;
703ba16761SJacob Faibussowitsch   if (mode == MAT_FLUSH_ASSEMBLY) PetscFunctionReturn(PETSC_SUCCESS);
719566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
72076ba34aSJunchao Zhang 
73076ba34aSJunchao Zhang   aijseq = static_cast<Mat_SeqAIJ *>(A->data);
74076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
75076ba34aSJunchao Zhang 
76076ba34aSJunchao Zhang   /* If aijkok does not exist, we just copy i, j to device.
77076ba34aSJunchao Zhang      If aijkok already exists, but the device's nonzero pattern does not match with the host's, we assume the latest data is on host.
78076ba34aSJunchao Zhang      In both cases, we build a new aijkok structure.
79076ba34aSJunchao Zhang   */
80076ba34aSJunchao Zhang   if (!aijkok || aijkok->nonzerostate != A->nonzerostate) { /* aijkok might not exist yet or nonzero pattern has changed */
81*d1c799ffSJunchao Zhang     if (aijkok && aijkok->host_aij_allocated_by_kokkos) {   /* Avoid accidently freeing much needed a,i,j on host when deleting aijkok */
82*d1c799ffSJunchao Zhang       PetscCall(PetscShmgetAllocateArray(aijkok->nrows() + 1, sizeof(PetscInt), (void **)&aijseq->i));
83*d1c799ffSJunchao Zhang       PetscCall(PetscShmgetAllocateArray(aijkok->nnz(), sizeof(PetscInt), (void **)&aijseq->j));
84*d1c799ffSJunchao Zhang       PetscCall(PetscShmgetAllocateArray(aijkok->nnz(), sizeof(PetscInt), (void **)&aijseq->a));
85*d1c799ffSJunchao Zhang       PetscCall(PetscArraycpy(aijseq->i, aijkok->i_host_data(), aijkok->nrows() + 1));
86*d1c799ffSJunchao Zhang       PetscCall(PetscArraycpy(aijseq->j, aijkok->j_host_data(), aijkok->nnz()));
87*d1c799ffSJunchao Zhang       PetscCall(PetscArraycpy(aijseq->a, aijkok->a_host_data(), aijkok->nnz()));
88*d1c799ffSJunchao Zhang       aijseq->free_a  = PETSC_TRUE;
89*d1c799ffSJunchao Zhang       aijseq->free_ij = PETSC_TRUE;
90*d1c799ffSJunchao Zhang       /* This arises from MatCreateSeqAIJKokkosWithKokkosCsrMatrix() used in MatMatMult, where
91*d1c799ffSJunchao Zhang          we have the CsrMatrix on device first and then copy to host, followed by
92*d1c799ffSJunchao Zhang          MatSetMPIAIJWithSplitSeqAIJ() with garray = NULL.
93*d1c799ffSJunchao Zhang          One could improve it by not using NULL garray.
94*d1c799ffSJunchao Zhang       */
95*d1c799ffSJunchao Zhang     }
96076ba34aSJunchao Zhang     delete aijkok;
97f4747e26SJunchao Zhang     aijkok   = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aijseq, A->nonzerostate, PETSC_FALSE /*don't copy mat values to device*/);
98076ba34aSJunchao Zhang     A->spptr = aijkok;
99f4747e26SJunchao Zhang   } else if (A->rmap->n && aijkok->diag_dual.extent(0) == 0) { // MatProduct might directly produce AIJ on device, but not the diag.
100f4747e26SJunchao Zhang     MatRowMapKokkosViewHost diag_h(aijseq->diag, A->rmap->n);
101f4747e26SJunchao Zhang     auto                    diag_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), diag_h);
102f4747e26SJunchao Zhang     aijkok->diag_dual              = MatRowMapKokkosDualView(diag_d, diag_h);
103076ba34aSJunchao Zhang   }
1043ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1058c3ff71bSJunchao Zhang }
1068c3ff71bSJunchao Zhang 
10786a27549SJunchao Zhang /* Sync CSR data to device if not yet */
108d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosSyncDevice(Mat A)
109d71ae5a4SJacob Faibussowitsch {
1108c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1118c3ff71bSJunchao Zhang 
1128c3ff71bSJunchao Zhang   PetscFunctionBegin;
113aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from host to device");
1145f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
115076ba34aSJunchao Zhang   if (aijkok->a_dual.need_sync_device()) {
116076ba34aSJunchao Zhang     aijkok->a_dual.sync_device();
117580c7c76SPierre Jolivet     aijkok->transpose_updated = PETSC_FALSE; /* values of the transpose is out-of-date */
11886a27549SJunchao Zhang     aijkok->hermitian_updated = PETSC_FALSE;
1198c3ff71bSJunchao Zhang   }
1203ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1218c3ff71bSJunchao Zhang }
1228c3ff71bSJunchao Zhang 
123076ba34aSJunchao Zhang /* Mark the CSR data on device as modified */
124d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosModifyDevice(Mat A)
125d71ae5a4SJacob Faibussowitsch {
12686a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
12786a27549SJunchao Zhang 
12886a27549SJunchao Zhang   PetscFunctionBegin;
1295f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Not supported for factorized matries");
13086a27549SJunchao Zhang   aijkok->a_dual.clear_sync_state();
13186a27549SJunchao Zhang   aijkok->a_dual.modify_device();
13286a27549SJunchao Zhang   aijkok->transpose_updated = PETSC_FALSE;
13386a27549SJunchao Zhang   aijkok->hermitian_updated = PETSC_FALSE;
1349566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJInvalidateDiagonal(A));
1359566063dSJacob Faibussowitsch   PetscCall(PetscObjectStateIncrease((PetscObject)A));
1363ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
13786a27549SJunchao Zhang }
13886a27549SJunchao Zhang 
139d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSyncHost(Mat A)
140d71ae5a4SJacob Faibussowitsch {
141f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1424df4a32cSJunchao Zhang   auto              exec   = PetscGetKokkosExecutionSpace();
143f0cf5187SStefano Zampini 
144f0cf5187SStefano Zampini   PetscFunctionBegin;
145f0cf5187SStefano Zampini   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
14686a27549SJunchao Zhang   /* We do not expect one needs factors on host  */
147aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from device to host");
1485f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing AIJKOK");
149aac854edSJunchao Zhang   PetscCall(KokkosDualViewSync<HostMirrorMemorySpace>(aijkok->a_dual, exec));
1503ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
151f0cf5187SStefano Zampini }
152f0cf5187SStefano Zampini 
153d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
154d71ae5a4SJacob Faibussowitsch {
155076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
156f0cf5187SStefano Zampini 
157f0cf5187SStefano Zampini   PetscFunctionBegin;
1585519a089SJose E. Roman   /* aijkok contains valid pointers only if the host's nonzerostate matches with the device's.
1595519a089SJose E. Roman     Calling MatSeqAIJSetPreallocation() or MatSetValues() on host, where aijseq->{i,j,a} might be
1605519a089SJose E. Roman     reallocated, will lead to stale {i,j,a}_dual in aijkok. In both operations, the hosts's nonzerostate
1615519a089SJose E. Roman     must have been updated. The stale aijkok will be rebuilt during MatAssemblyEnd.
1625519a089SJose E. Roman   */
1635519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
1644df4a32cSJunchao Zhang     auto exec = PetscGetKokkosExecutionSpace();
165e36ced11SJunchao Zhang     PetscCallCXX(aijkok->a_dual.sync_host(exec));
166e36ced11SJunchao Zhang     PetscCallCXX(exec.fence());
167076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
168076ba34aSJunchao Zhang   } else { /* Happens when calling MatSetValues on a newly created matrix */
169076ba34aSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
170076ba34aSJunchao Zhang   }
1713ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
172076ba34aSJunchao Zhang }
173076ba34aSJunchao Zhang 
174d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
175d71ae5a4SJacob Faibussowitsch {
176076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
177076ba34aSJunchao Zhang 
178076ba34aSJunchao Zhang   PetscFunctionBegin;
1795519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) aijkok->a_dual.modify_host();
1803ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
181076ba34aSJunchao Zhang }
182076ba34aSJunchao Zhang 
183d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
184d71ae5a4SJacob Faibussowitsch {
185076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
186076ba34aSJunchao Zhang 
187076ba34aSJunchao Zhang   PetscFunctionBegin;
1885519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
1894df4a32cSJunchao Zhang     auto exec = PetscGetKokkosExecutionSpace();
190e36ced11SJunchao Zhang     PetscCallCXX(aijkok->a_dual.sync_host(exec));
191e36ced11SJunchao Zhang     PetscCallCXX(exec.fence());
192076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1932328674fSJunchao Zhang   } else {
1942328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1952328674fSJunchao Zhang   }
1963ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
197076ba34aSJunchao Zhang }
198076ba34aSJunchao Zhang 
199d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
200d71ae5a4SJacob Faibussowitsch {
201076ba34aSJunchao Zhang   PetscFunctionBegin;
202076ba34aSJunchao Zhang   *array = NULL;
2033ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
204076ba34aSJunchao Zhang }
205076ba34aSJunchao Zhang 
206d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
207d71ae5a4SJacob Faibussowitsch {
208076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
209076ba34aSJunchao Zhang 
210076ba34aSJunchao Zhang   PetscFunctionBegin;
2115519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
212076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
2132328674fSJunchao Zhang   } else { /* Ex. happens with MatZeroEntries on a preallocated but not assembled matrix */
2142328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
2152328674fSJunchao Zhang   }
2163ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
217076ba34aSJunchao Zhang }
218076ba34aSJunchao Zhang 
219d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
220d71ae5a4SJacob Faibussowitsch {
221076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
222076ba34aSJunchao Zhang 
223076ba34aSJunchao Zhang   PetscFunctionBegin;
2245519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
225076ba34aSJunchao Zhang     aijkok->a_dual.clear_sync_state();
226076ba34aSJunchao Zhang     aijkok->a_dual.modify_host();
2272328674fSJunchao Zhang   }
2283ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
229f0cf5187SStefano Zampini }
230f0cf5187SStefano Zampini 
231d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJKokkos(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
232d71ae5a4SJacob Faibussowitsch {
2337ee59b9bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
2347ee59b9bSJunchao Zhang 
2357ee59b9bSJunchao Zhang   PetscFunctionBegin;
2367ee59b9bSJunchao Zhang   PetscCheck(aijkok != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "aijkok is NULL");
2377ee59b9bSJunchao Zhang 
2387ee59b9bSJunchao Zhang   if (i) *i = aijkok->i_device_data();
2397ee59b9bSJunchao Zhang   if (j) *j = aijkok->j_device_data();
2407ee59b9bSJunchao Zhang   if (a) {
2417ee59b9bSJunchao Zhang     aijkok->a_dual.sync_device();
2427ee59b9bSJunchao Zhang     *a = aijkok->a_device_data();
2437ee59b9bSJunchao Zhang   }
2447ee59b9bSJunchao Zhang   if (mtype) *mtype = PETSC_MEMTYPE_KOKKOS;
2453ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2467ee59b9bSJunchao Zhang }
2477ee59b9bSJunchao Zhang 
2480e3ece09SJunchao Zhang /*
2490e3ece09SJunchao Zhang   Generate the sparsity pattern of a MatSeqAIJKokkos matrix's transpose on device.
2500e3ece09SJunchao Zhang 
2510e3ece09SJunchao Zhang   Input Parameter:
2520e3ece09SJunchao Zhang .  A       - the MATSEQAIJKOKKOS matrix
2530e3ece09SJunchao Zhang 
2540e3ece09SJunchao Zhang   Output Parameters:
2550e3ece09SJunchao Zhang +  perm_d - the permutation array on device, which connects Ta(i) = Aa(perm(i))
256aaa8cc7dSPierre Jolivet -  T_d    - the transpose on device, whose value array is allocated but not initialized
2570e3ece09SJunchao Zhang */
2580e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateTransposeStructure(Mat A, MatRowMapKokkosView &perm_d, KokkosCsrMatrix &T_d)
259d71ae5a4SJacob Faibussowitsch {
2600e3ece09SJunchao Zhang   Mat_SeqAIJ             *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2610e3ece09SJunchao Zhang   PetscInt                nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
2620e3ece09SJunchao Zhang   const PetscInt         *Ai = aseq->i, *Aj = aseq->j;
2637b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost Ti_h(NoInit("Ti"), n + 1);
2640e3ece09SJunchao Zhang   MatRowMapType          *Ti = Ti_h.data();
2657b8d4ba6SJunchao Zhang   MatColIdxKokkosViewHost Tj_h(NoInit("Tj"), nz);
2667b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost perm_h(NoInit("permutation"), nz);
2670e3ece09SJunchao Zhang   PetscInt               *Tj   = Tj_h.data();
2680e3ece09SJunchao Zhang   PetscInt               *perm = perm_h.data();
2690e3ece09SJunchao Zhang   PetscInt               *offset;
270152b3e56SJunchao Zhang 
271152b3e56SJunchao Zhang   PetscFunctionBegin;
2720e3ece09SJunchao Zhang   // Populate Ti
2730e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::deep_copy(Ti_h, 0));
2740e3ece09SJunchao Zhang   Ti++;
2750e3ece09SJunchao Zhang   for (PetscInt i = 0; i < nz; i++) Ti[Aj[i]]++;
2760e3ece09SJunchao Zhang   Ti--;
2770e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) Ti[i + 1] += Ti[i];
2780e3ece09SJunchao Zhang 
2790e3ece09SJunchao Zhang   // Populate Tj and the permutation array
2800e3ece09SJunchao Zhang   PetscCall(PetscCalloc1(n, &offset)); // offset in each T row to fill in its column indices
2810e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
2820e3ece09SJunchao Zhang     for (PetscInt j = Ai[i]; j < Ai[i + 1]; j++) { // A's (i,j) is T's (j,i)
2830e3ece09SJunchao Zhang       PetscInt r    = Aj[j];                       // row r of T
2840e3ece09SJunchao Zhang       PetscInt disp = Ti[r] + offset[r];
2850e3ece09SJunchao Zhang 
2860e3ece09SJunchao Zhang       Tj[disp]   = i; // col i of T
2870e3ece09SJunchao Zhang       perm[disp] = j;
2880e3ece09SJunchao Zhang       offset[r]++;
289076ba34aSJunchao Zhang     }
2900e3ece09SJunchao Zhang   }
2910e3ece09SJunchao Zhang   PetscCall(PetscFree(offset));
2920e3ece09SJunchao Zhang 
2930e3ece09SJunchao Zhang   // Sort each row of T, along with the permutation array
2940e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) PetscCall(PetscSortIntWithArray(Ti[i + 1] - Ti[i], Tj + Ti[i], perm + Ti[i]));
2950e3ece09SJunchao Zhang 
2960e3ece09SJunchao Zhang   // Output perm and T on device
2970e3ece09SJunchao Zhang   auto Ti_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ti_h);
2980e3ece09SJunchao Zhang   auto Tj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Tj_h);
2990e3ece09SJunchao Zhang   PetscCallCXX(T_d = KokkosCsrMatrix("csrmatT", n, m, nz, MatScalarKokkosView("Ta", nz), Ti_d, Tj_d));
3000e3ece09SJunchao Zhang   PetscCallCXX(perm_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), perm_h));
3013ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
302152b3e56SJunchao Zhang }
303152b3e56SJunchao Zhang 
3040e3ece09SJunchao Zhang // Generate the transpose on device and cache it internally
3050e3ece09SJunchao Zhang // Note: KK transpose_matrix() does not have support symbolic/numeric transpose, so we do it on our own
3060e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGenerateTranspose_Private(Mat A, KokkosCsrMatrix *csrmatT)
307d71ae5a4SJacob Faibussowitsch {
3080e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
3090e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3100e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
3110e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatT;
312152b3e56SJunchao Zhang 
313152b3e56SJunchao Zhang   PetscFunctionBegin;
3140e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
315145b44c9SPierre Jolivet   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's values since we are going to access them on device
3160e3ece09SJunchao Zhang 
3170e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
3180e3ece09SJunchao Zhang 
3190e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE) {
3200e3ece09SJunchao Zhang     *csrmatT = akok->csrmat;
3210e3ece09SJunchao Zhang   } else {
3220e3ece09SJunchao Zhang     // See if we already have a cached transpose and its value is up to date
3230e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
3240e3ece09SJunchao Zhang       if (!akok->transpose_updated) {            // if the value is out of date, update the cached version
3250e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
3260e3ece09SJunchao Zhang         auto       &Ta   = T.values;
3270e3ece09SJunchao Zhang 
328d326c3f1SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = Aa(perm(i)); }));
329076ba34aSJunchao Zhang       }
3300e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
3310e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
3320e3ece09SJunchao Zhang 
3330e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
3340e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
335d326c3f1SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = Aa(perm(i)); }));
3360e3ece09SJunchao Zhang     }
3370e3ece09SJunchao Zhang     akok->transpose_updated = PETSC_TRUE;
3380e3ece09SJunchao Zhang     *csrmatT                = akok->csrmatT;
3390e3ece09SJunchao Zhang   }
3400e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3410e3ece09SJunchao Zhang }
3420e3ece09SJunchao Zhang 
3430e3ece09SJunchao Zhang // Generate the Hermitian on device and cache it internally
3440e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateHermitian_Private(Mat A, KokkosCsrMatrix *csrmatH)
3450e3ece09SJunchao Zhang {
3460e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
3470e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3480e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
3490e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatH;
3500e3ece09SJunchao Zhang 
3510e3ece09SJunchao Zhang   PetscFunctionBegin;
3520e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
3530e3ece09SJunchao Zhang   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's values since we are going to access them on device
3540e3ece09SJunchao Zhang 
3550e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
3560e3ece09SJunchao Zhang 
3570e3ece09SJunchao Zhang   if (A->hermitian == PETSC_BOOL3_TRUE) {
3580e3ece09SJunchao Zhang     *csrmatH = akok->csrmat;
3590e3ece09SJunchao Zhang   } else {
3600e3ece09SJunchao Zhang     // See if we already have a cached hermitian and its value is up to date
3610e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
3620e3ece09SJunchao Zhang       if (!akok->hermitian_updated) {            // if the value is out of date, update the cached version
3630e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
3640e3ece09SJunchao Zhang         auto       &Ta   = T.values;
3650e3ece09SJunchao Zhang 
366d326c3f1SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = PetscConj(Aa(perm(i))); }));
3670e3ece09SJunchao Zhang       }
3680e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
3690e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
3700e3ece09SJunchao Zhang 
3710e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
3720e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
373d326c3f1SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = PetscConj(Aa(perm(i))); }));
3740e3ece09SJunchao Zhang     }
3750e3ece09SJunchao Zhang     akok->hermitian_updated = PETSC_TRUE;
3760e3ece09SJunchao Zhang     *csrmatH                = akok->csrmatH;
3770e3ece09SJunchao Zhang   }
3783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
379152b3e56SJunchao Zhang }
380a587d139SMark 
3818c3ff71bSJunchao Zhang /* y = A x */
382d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMult_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
383d71ae5a4SJacob Faibussowitsch {
3848c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
385152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
386152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3878c3ff71bSJunchao Zhang 
3888c3ff71bSJunchao Zhang   PetscFunctionBegin;
3899566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3909566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3919566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3929566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
3938c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
394d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), "N", 1.0 /*alpha*/, aijkok->csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A x + beta y */
3959566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3969566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
397076ba34aSJunchao Zhang   /* 2.0*nnz - numRows seems more accurate here but assumes there are no zero-rows. So a little sloppy here. */
3989566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
3999566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4003ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4018c3ff71bSJunchao Zhang }
4028c3ff71bSJunchao Zhang 
4038c3ff71bSJunchao Zhang /* y = A^T x */
404d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
405d71ae5a4SJacob Faibussowitsch {
4068c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
407152b3e56SJunchao Zhang   const char                *mode;
408152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
409152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
4100e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4118c3ff71bSJunchao Zhang 
4128c3ff71bSJunchao Zhang   PetscFunctionBegin;
4139566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4149566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4159566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4169566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
417152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4189566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
419152b3e56SJunchao Zhang     mode = "N";
420152b3e56SJunchao Zhang   } else {
421076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4220e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
423152b3e56SJunchao Zhang     mode   = "T";
424152b3e56SJunchao Zhang   }
425d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^T x + beta y */
4269566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4279566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
4280e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4299566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4303ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4318c3ff71bSJunchao Zhang }
4328c3ff71bSJunchao Zhang 
4338c3ff71bSJunchao Zhang /* y = A^H x */
434d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
435d71ae5a4SJacob Faibussowitsch {
4368c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
437152b3e56SJunchao Zhang   const char                *mode;
438152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
439152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
4400e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4418c3ff71bSJunchao Zhang 
4428c3ff71bSJunchao Zhang   PetscFunctionBegin;
4439566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4449566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4459566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4469566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
447152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4489566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
449152b3e56SJunchao Zhang     mode = "N";
450152b3e56SJunchao Zhang   } else {
451076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4520e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
453152b3e56SJunchao Zhang     mode   = "C";
454152b3e56SJunchao Zhang   }
455d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^H x + beta y */
4569566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4579566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
4580e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4599566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4603ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4618c3ff71bSJunchao Zhang }
4628c3ff71bSJunchao Zhang 
4638c3ff71bSJunchao Zhang /* z = A x + y */
464d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
465d71ae5a4SJacob Faibussowitsch {
4668c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
46792896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
468152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4698c3ff71bSJunchao Zhang 
4708c3ff71bSJunchao Zhang   PetscFunctionBegin;
4719566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4729566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
47392896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz)); // depending on yy's sync flags, zz might get its latest data on host
4749566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
47592896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv)); // do after VecCopy(yy, zz) to get the latest data on device
4768c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
477d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), "N", 1.0 /*alpha*/, aijkok->csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A x + beta z */
4789566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
47992896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
4809566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
4819566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4823ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4838c3ff71bSJunchao Zhang }
4848c3ff71bSJunchao Zhang 
4858c3ff71bSJunchao Zhang /* z = A^T x + y */
486d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
487d71ae5a4SJacob Faibussowitsch {
4888c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
489152b3e56SJunchao Zhang   const char                *mode;
49092896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
491152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4920e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4938c3ff71bSJunchao Zhang 
4948c3ff71bSJunchao Zhang   PetscFunctionBegin;
4959566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4969566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
49792896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz));
4989566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
49992896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv));
500152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
5019566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
502152b3e56SJunchao Zhang     mode = "N";
503152b3e56SJunchao Zhang   } else {
504076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
5050e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
506152b3e56SJunchao Zhang     mode   = "T";
507152b3e56SJunchao Zhang   }
508d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^T x + beta z */
5099566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
51092896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
5110e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
5129566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
5133ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5148c3ff71bSJunchao Zhang }
5158c3ff71bSJunchao Zhang 
5168c3ff71bSJunchao Zhang /* z = A^H x + y */
517d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
518d71ae5a4SJacob Faibussowitsch {
5198c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
520152b3e56SJunchao Zhang   const char                *mode;
52192896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
522152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
5230e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
5248c3ff71bSJunchao Zhang 
5258c3ff71bSJunchao Zhang   PetscFunctionBegin;
5269566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
5279566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
52892896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz));
5299566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
53092896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv));
531152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
5329566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
533152b3e56SJunchao Zhang     mode = "N";
534152b3e56SJunchao Zhang   } else {
535076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
5360e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
537152b3e56SJunchao Zhang     mode   = "C";
538152b3e56SJunchao Zhang   }
539d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^H x + beta z */
5409566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
54192896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
5420e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
5439566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
5443ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
545152b3e56SJunchao Zhang }
546152b3e56SJunchao Zhang 
54766976f2fSJacob Faibussowitsch static PetscErrorCode MatSetOption_SeqAIJKokkos(Mat A, MatOption op, PetscBool flg)
548d71ae5a4SJacob Faibussowitsch {
549152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
550152b3e56SJunchao Zhang 
551152b3e56SJunchao Zhang   PetscFunctionBegin;
552152b3e56SJunchao Zhang   switch (op) {
553152b3e56SJunchao Zhang   case MAT_FORM_EXPLICIT_TRANSPOSE:
554152b3e56SJunchao Zhang     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
5559566063dSJacob Faibussowitsch     if (A->form_explicit_transpose && !flg && aijkok) PetscCall(aijkok->DestroyMatTranspose());
556152b3e56SJunchao Zhang     A->form_explicit_transpose = flg;
557152b3e56SJunchao Zhang     break;
558d71ae5a4SJacob Faibussowitsch   default:
559d71ae5a4SJacob Faibussowitsch     PetscCall(MatSetOption_SeqAIJ(A, op, flg));
560d71ae5a4SJacob Faibussowitsch     break;
561152b3e56SJunchao Zhang   }
5623ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5638c3ff71bSJunchao Zhang }
5648c3ff71bSJunchao Zhang 
565076ba34aSJunchao Zhang /* Depending on reuse, either build a new mat, or use the existing mat */
566d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
567d71ae5a4SJacob Faibussowitsch {
568076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
5698c3ff71bSJunchao Zhang 
5708c3ff71bSJunchao Zhang   PetscFunctionBegin;
5719566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
572076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) { /* Build a brand new mat */
57351ece73cSJunchao Zhang     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
57451ece73cSJunchao Zhang     PetscCall(MatSetType(*newmat, mtype));
5758c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {                 /* Reuse the mat created before */
5769566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); /* newmat is already a SeqAIJKokkos */
577076ba34aSJunchao Zhang   } else if (reuse == MAT_INPLACE_MATRIX) {               /* newmat is A */
5785f80ce2aSJacob Faibussowitsch     PetscCheck(A == *newmat, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "A != *newmat with MAT_INPLACE_MATRIX");
5799566063dSJacob Faibussowitsch     PetscCall(PetscFree(A->defaultvectype));
5809566063dSJacob Faibussowitsch     PetscCall(PetscStrallocpy(VECKOKKOS, &A->defaultvectype)); /* Allocate and copy the string */
5819566063dSJacob Faibussowitsch     PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSEQAIJKOKKOS));
5829566063dSJacob Faibussowitsch     PetscCall(MatSetOps_SeqAIJKokkos(A));
583076ba34aSJunchao Zhang     aseq = static_cast<Mat_SeqAIJ *>(A->data);
584394ed5ebSJunchao Zhang     if (A->assembled) { /* Copy i, j (but not values) to device for an assembled matrix if not yet */
5855f80ce2aSJacob Faibussowitsch       PetscCheck(!A->spptr, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Expect NULL (Mat_SeqAIJKokkos*)A->spptr");
586f4747e26SJunchao Zhang       A->spptr = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aseq, A->nonzerostate, PETSC_FALSE);
5878c3ff71bSJunchao Zhang     }
588076ba34aSJunchao Zhang   }
5893ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5908c3ff71bSJunchao Zhang }
5918c3ff71bSJunchao Zhang 
592076ba34aSJunchao Zhang /* MatDuplicate always creates a new matrix. MatDuplicate can be called either on an assembled matrix or
593076ba34aSJunchao Zhang    an unassembled matrix, even though MAT_COPY_VALUES is not allowed for unassembled matrix.
594076ba34aSJunchao Zhang  */
595d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDuplicate_SeqAIJKokkos(Mat A, MatDuplicateOption dupOption, Mat *B)
596d71ae5a4SJacob Faibussowitsch {
597076ba34aSJunchao Zhang   Mat_SeqAIJ       *bseq;
598076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok;
599076ba34aSJunchao Zhang   Mat               mat;
6008c3ff71bSJunchao Zhang 
6018c3ff71bSJunchao Zhang   PetscFunctionBegin;
602076ba34aSJunchao Zhang   /* Do not copy values on host as A's latest values might be on device. We don't want to do sync blindly */
6039566063dSJacob Faibussowitsch   PetscCall(MatDuplicate_SeqAIJ(A, MAT_DO_NOT_COPY_VALUES, B));
604076ba34aSJunchao Zhang   mat = *B;
605f4747e26SJunchao Zhang   if (A->assembled) {
606076ba34aSJunchao Zhang     bseq = static_cast<Mat_SeqAIJ *>(mat->data);
607f4747e26SJunchao Zhang     bkok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, bseq, mat->nonzerostate, PETSC_FALSE);
608076ba34aSJunchao Zhang     bkok->a_dual.clear_sync_state(); /* Clear B's sync state as it will be decided below */
609076ba34aSJunchao Zhang     /* Now copy values to B if needed */
610076ba34aSJunchao Zhang     if (dupOption == MAT_COPY_VALUES) {
611076ba34aSJunchao Zhang       if (akok->a_dual.need_sync_device()) {
612076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_host(), akok->a_dual.view_host());
613076ba34aSJunchao Zhang         bkok->a_dual.modify_host();
614076ba34aSJunchao Zhang       } else { /* If device has the latest data, we only copy data on device */
615076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_device(), akok->a_dual.view_device());
616076ba34aSJunchao Zhang         bkok->a_dual.modify_device();
617076ba34aSJunchao Zhang       }
618076ba34aSJunchao Zhang     } else { /* MAT_DO_NOT_COPY_VALUES or MAT_SHARE_NONZERO_PATTERN. B's values should be zeroed */
619076ba34aSJunchao Zhang       /* B's values on host should be already zeroed by MatDuplicate_SeqAIJ() */
620076ba34aSJunchao Zhang       bkok->a_dual.modify_host();
621076ba34aSJunchao Zhang     }
622076ba34aSJunchao Zhang     mat->spptr = bkok;
623076ba34aSJunchao Zhang   }
624076ba34aSJunchao Zhang 
6259566063dSJacob Faibussowitsch   PetscCall(PetscFree(mat->defaultvectype));
6269566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &mat->defaultvectype)); /* Allocate and copy the string */
6279566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)mat, MATSEQAIJKOKKOS));
6289566063dSJacob Faibussowitsch   PetscCall(MatSetOps_SeqAIJKokkos(mat));
6293ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6308c3ff71bSJunchao Zhang }
6318c3ff71bSJunchao Zhang 
632d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatTranspose_SeqAIJKokkos(Mat A, MatReuse reuse, Mat *B)
633d71ae5a4SJacob Faibussowitsch {
6340ecb592aSJunchao Zhang   Mat               At;
6350e3ece09SJunchao Zhang   KokkosCsrMatrix   internT;
6360ecb592aSJunchao Zhang   Mat_SeqAIJKokkos *atkok, *bkok;
6370ecb592aSJunchao Zhang 
6380ecb592aSJunchao Zhang   PetscFunctionBegin;
6397fb60732SBarry Smith   if (reuse == MAT_REUSE_MATRIX) PetscCall(MatTransposeCheckNonzeroState_Private(A, *B));
6409566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &internT)); /* Generate a transpose internally */
6410ecb592aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX || reuse == MAT_INPLACE_MATRIX) {
642ff751488SJunchao Zhang     /* Deep copy internT, as we want to isolate the internal transpose */
6430e3ece09SJunchao Zhang     PetscCallCXX(atkok = new Mat_SeqAIJKokkos(KokkosCsrMatrix("csrmat", internT)));
6449566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PetscObjectComm((PetscObject)A), atkok, &At));
6450ecb592aSJunchao Zhang     if (reuse == MAT_INITIAL_MATRIX) *B = At;
6469566063dSJacob Faibussowitsch     else PetscCall(MatHeaderReplace(A, &At)); /* Replace A with At inplace */
6470ecb592aSJunchao Zhang   } else {                                    /* MAT_REUSE_MATRIX, just need to copy values to B on device */
6480ecb592aSJunchao Zhang     if ((*B)->assembled) {
6490ecb592aSJunchao Zhang       bkok = static_cast<Mat_SeqAIJKokkos *>((*B)->spptr);
6500e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(bkok->a_dual.view_device(), internT.values));
6519566063dSJacob Faibussowitsch       PetscCall(MatSeqAIJKokkosModifyDevice(*B));
6520ecb592aSJunchao Zhang     } else if ((*B)->preallocated) { /* It is ok for B to be only preallocated, as needed in MatTranspose_MPIAIJ */
6530ecb592aSJunchao Zhang       Mat_SeqAIJ             *bseq = static_cast<Mat_SeqAIJ *>((*B)->data);
6540e3ece09SJunchao Zhang       MatScalarKokkosViewHost a_h(bseq->a, internT.nnz()); /* bseq->nz = 0 if unassembled */
6550e3ece09SJunchao Zhang       MatColIdxKokkosViewHost j_h(bseq->j, internT.nnz());
6560e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(a_h, internT.values));
6570e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(j_h, internT.graph.entries));
6580ecb592aSJunchao Zhang     } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "B must be assembled or preallocated");
6590ecb592aSJunchao Zhang   }
6603ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6610ecb592aSJunchao Zhang }
6620ecb592aSJunchao Zhang 
663d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDestroy_SeqAIJKokkos(Mat A)
664d71ae5a4SJacob Faibussowitsch {
66586a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
6668c3ff71bSJunchao Zhang 
6678c3ff71bSJunchao Zhang   PetscFunctionBegin;
66886a27549SJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
66986a27549SJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
6708c3ff71bSJunchao Zhang     delete aijkok;
67186a27549SJunchao Zhang   } else {
67286a27549SJunchao Zhang     delete static_cast<Mat_SeqAIJKokkosTriFactors *>(A->spptr);
67386a27549SJunchao Zhang   }
674cbc6b225SStefano Zampini   A->spptr = NULL;
6759566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
6769566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
6779566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
67857761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
67957761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijkokkos_hypre_C", NULL));
68057761e9aSJunchao Zhang #endif
6819566063dSJacob Faibussowitsch   PetscCall(MatDestroy_SeqAIJ(A));
6823ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6838c3ff71bSJunchao Zhang }
6848c3ff71bSJunchao Zhang 
6853f3ba80aSJunchao Zhang /*MC
6863f3ba80aSJunchao Zhang    MATSEQAIJKOKKOS - MATAIJKOKKOS = "(seq)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
6873f3ba80aSJunchao Zhang 
68815229ffcSPierre Jolivet    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
6893f3ba80aSJunchao Zhang 
6902ef1f0ffSBarry Smith    Options Database Key:
69111a5261eSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATSEQAIJKOKKOS` during a call to `MatSetFromOptions()`
6923f3ba80aSJunchao Zhang 
6933f3ba80aSJunchao Zhang   Level: beginner
6943f3ba80aSJunchao Zhang 
6951cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateSeqAIJKokkos()`, `MATMPIAIJKOKKOS`
6963f3ba80aSJunchao Zhang M*/
697d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJKokkos(Mat A)
698d71ae5a4SJacob Faibussowitsch {
69986a27549SJunchao Zhang   PetscFunctionBegin;
7009566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
7019566063dSJacob Faibussowitsch   PetscCall(MatCreate_SeqAIJ(A));
7029566063dSJacob Faibussowitsch   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
7033ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
70486a27549SJunchao Zhang }
70586a27549SJunchao Zhang 
706076ba34aSJunchao Zhang /* Merge A, B into a matrix C. A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n) */
707d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
708d71ae5a4SJacob Faibussowitsch {
709076ba34aSJunchao Zhang   Mat_SeqAIJ         *a, *b;
710076ba34aSJunchao Zhang   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
711076ba34aSJunchao Zhang   MatScalarKokkosView aa, ba, ca;
712076ba34aSJunchao Zhang   MatRowMapKokkosView ai, bi, ci;
713076ba34aSJunchao Zhang   MatColIdxKokkosView aj, bj, cj;
714076ba34aSJunchao Zhang   PetscInt            m, n, nnz, aN;
715a3f881fbSStefano Zampini 
716a3f881fbSStefano Zampini   PetscFunctionBegin;
717076ba34aSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
718076ba34aSJunchao Zhang   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
7194f572ea9SToby Isaac   PetscAssertPointer(C, 4);
720076ba34aSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
721076ba34aSJunchao Zhang   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
7225f80ce2aSJacob Faibussowitsch   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
7235f80ce2aSJacob Faibussowitsch   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
724076ba34aSJunchao Zhang 
7259566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
7269566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
727076ba34aSJunchao Zhang   a    = static_cast<Mat_SeqAIJ *>(A->data);
728076ba34aSJunchao Zhang   b    = static_cast<Mat_SeqAIJ *>(B->data);
729076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
730076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
731076ba34aSJunchao Zhang   aa   = akok->a_dual.view_device();
732076ba34aSJunchao Zhang   ai   = akok->i_dual.view_device();
733076ba34aSJunchao Zhang   ba   = bkok->a_dual.view_device();
734076ba34aSJunchao Zhang   bi   = bkok->i_dual.view_device();
735076ba34aSJunchao Zhang   m    = A->rmap->n; /* M, N and nnz of C */
736076ba34aSJunchao Zhang   n    = A->cmap->n + B->cmap->n;
737076ba34aSJunchao Zhang   nnz  = a->nz + b->nz;
738076ba34aSJunchao Zhang   aN   = A->cmap->n; /* N of A */
739076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
740076ba34aSJunchao Zhang     aj           = akok->j_dual.view_device();
741076ba34aSJunchao Zhang     bj           = bkok->j_dual.view_device();
742076ba34aSJunchao Zhang     auto ca_dual = MatScalarKokkosDualView("a", aa.extent(0) + ba.extent(0));
743076ba34aSJunchao Zhang     auto ci_dual = MatRowMapKokkosDualView("i", ai.extent(0));
744076ba34aSJunchao Zhang     auto cj_dual = MatColIdxKokkosDualView("j", aj.extent(0) + bj.extent(0));
745076ba34aSJunchao Zhang     ca           = ca_dual.view_device();
746076ba34aSJunchao Zhang     ci           = ci_dual.view_device();
747076ba34aSJunchao Zhang     cj           = cj_dual.view_device();
748076ba34aSJunchao Zhang 
749076ba34aSJunchao Zhang     /* Concatenate A and B in parallel using Kokkos hierarchical parallelism */
7509371c9d4SSatish Balay     Kokkos::parallel_for(
751d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
752076ba34aSJunchao Zhang         PetscInt i       = t.league_rank(); /* row i */
753076ba34aSJunchao Zhang         PetscInt coffset = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
754076ba34aSJunchao Zhang 
755076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* this side effect only happens once per whole team */
756076ba34aSJunchao Zhang                                                    ci(i) = coffset;
757076ba34aSJunchao Zhang                                                    if (i == m - 1) ci(m) = ai(m) + bi(m);
758076ba34aSJunchao Zhang         });
759076ba34aSJunchao Zhang 
760076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
761076ba34aSJunchao Zhang           if (k < alen) {
762076ba34aSJunchao Zhang             ca(coffset + k) = aa(ai(i) + k);
763076ba34aSJunchao Zhang             cj(coffset + k) = aj(ai(i) + k);
764076ba34aSJunchao Zhang           } else {
765076ba34aSJunchao Zhang             ca(coffset + k) = ba(bi(i) + k - alen);
766076ba34aSJunchao Zhang             cj(coffset + k) = bj(bi(i) + k - alen) + aN; /* Entries in B get new column indices in C */
767076ba34aSJunchao Zhang           }
768076ba34aSJunchao Zhang         });
769076ba34aSJunchao Zhang       });
770076ba34aSJunchao Zhang     ca_dual.modify_device();
771076ba34aSJunchao Zhang     ci_dual.modify_device();
772076ba34aSJunchao Zhang     cj_dual.modify_device();
7739566063dSJacob Faibussowitsch     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, n, nnz, ci_dual, cj_dual, ca_dual));
7749566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
775076ba34aSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
776076ba34aSJunchao Zhang     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
777076ba34aSJunchao Zhang     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
778076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
779076ba34aSJunchao Zhang     ca   = ckok->a_dual.view_device();
780076ba34aSJunchao Zhang     ci   = ckok->i_dual.view_device();
781076ba34aSJunchao Zhang 
7829371c9d4SSatish Balay     Kokkos::parallel_for(
783d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
784076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* row i */
785076ba34aSJunchao Zhang         PetscInt alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
786076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
787076ba34aSJunchao Zhang           if (k < alen) ca(ci(i) + k) = aa(ai(i) + k);
788076ba34aSJunchao Zhang           else ca(ci(i) + k) = ba(bi(i) + k - alen);
789076ba34aSJunchao Zhang         });
790076ba34aSJunchao Zhang       });
7919566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
792076ba34aSJunchao Zhang   }
7933ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
794076ba34aSJunchao Zhang }
795076ba34aSJunchao Zhang 
796d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_SeqAIJKokkos(void *pdata)
797d71ae5a4SJacob Faibussowitsch {
798076ba34aSJunchao Zhang   PetscFunctionBegin;
799076ba34aSJunchao Zhang   delete static_cast<MatProductData_SeqAIJKokkos *>(pdata);
8003ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
801a3f881fbSStefano Zampini }
802a3f881fbSStefano Zampini 
803d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos(Mat C)
804d71ae5a4SJacob Faibussowitsch {
805a3f881fbSStefano Zampini   Mat_Product                 *product = C->product;
806a3f881fbSStefano Zampini   Mat                          A, B;
807076ba34aSJunchao Zhang   bool                         transA, transB; /* use bool, since KK needs this type */
808a3f881fbSStefano Zampini   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
809a3f881fbSStefano Zampini   Mat_SeqAIJ                  *c;
810076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
8110e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB;
812a3f881fbSStefano Zampini 
813a3f881fbSStefano Zampini   PetscFunctionBegin;
814a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
8155f80ce2aSJacob Faibussowitsch   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data empty");
816076ba34aSJunchao Zhang   pdata = static_cast<MatProductData_SeqAIJKokkos *>(C->product->data);
817076ba34aSJunchao Zhang 
8180e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
8190e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
8200e3ece09SJunchao Zhang   // we still do numeric.
8210e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
8220e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
8233ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
824076ba34aSJunchao Zhang   }
825076ba34aSJunchao Zhang 
826076ba34aSJunchao Zhang   switch (product->type) {
8279371c9d4SSatish Balay   case MATPRODUCT_AB:
8289371c9d4SSatish Balay     transA = false;
8299371c9d4SSatish Balay     transB = false;
8309371c9d4SSatish Balay     break;
8319371c9d4SSatish Balay   case MATPRODUCT_AtB:
8329371c9d4SSatish Balay     transA = true;
8339371c9d4SSatish Balay     transB = false;
8349371c9d4SSatish Balay     break;
8359371c9d4SSatish Balay   case MATPRODUCT_ABt:
8369371c9d4SSatish Balay     transA = false;
8379371c9d4SSatish Balay     transB = true;
8389371c9d4SSatish Balay     break;
839d71ae5a4SJacob Faibussowitsch   default:
840d71ae5a4SJacob Faibussowitsch     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
841076ba34aSJunchao Zhang   }
842076ba34aSJunchao Zhang 
843a3f881fbSStefano Zampini   A = product->A;
844a3f881fbSStefano Zampini   B = product->B;
8459566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8469566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
847a3f881fbSStefano Zampini   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
848a3f881fbSStefano Zampini   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
849a3f881fbSStefano Zampini   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
850076ba34aSJunchao Zhang 
8515f80ce2aSJacob Faibussowitsch   PetscCheck(ckok, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Device data structure spptr is empty");
852076ba34aSJunchao Zhang 
8530e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
8540e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
855076ba34aSJunchao Zhang 
856076ba34aSJunchao Zhang   /* TODO: Once KK spgemm implements transpose, we can get rid of the explicit transpose here */
857076ba34aSJunchao Zhang   if (transA) {
8589566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
859076ba34aSJunchao Zhang     transA = false;
860a3f881fbSStefano Zampini   }
861a3f881fbSStefano Zampini 
862076ba34aSJunchao Zhang   if (transB) {
8639566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
864076ba34aSJunchao Zhang     transB = false;
865076ba34aSJunchao Zhang   }
8669566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
8670e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, ckok->csrmat));
8680e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
869866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
870866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(ckok->csrmat)); /* without sort, mat_tests-ex62_14_seqaijkokkos fails */
871e944a159SJunchao Zhang #endif
872866eb059SJunchao Zhang 
8739566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
8749566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(C));
875a3f881fbSStefano Zampini   /* shorter version of MatAssemblyEnd_SeqAIJ */
876a3f881fbSStefano Zampini   c = (Mat_SeqAIJ *)C->data;
8779566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded,%" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
8789566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
8799566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
880a3f881fbSStefano Zampini   c->reallocs         = 0;
881076ba34aSJunchao Zhang   C->info.mallocs     = 0;
882a3f881fbSStefano Zampini   C->info.nz_unneeded = 0;
883a3f881fbSStefano Zampini   C->assembled = C->was_assembled = PETSC_TRUE;
884a3f881fbSStefano Zampini   C->num_ass++;
8853ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
886a3f881fbSStefano Zampini }
887a3f881fbSStefano Zampini 
888d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos(Mat C)
889d71ae5a4SJacob Faibussowitsch {
890076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
891076ba34aSJunchao Zhang   MatProductType               ptype;
892076ba34aSJunchao Zhang   Mat                          A, B;
893076ba34aSJunchao Zhang   bool                         transA, transB;
894076ba34aSJunchao Zhang   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
895076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
896076ba34aSJunchao Zhang   MPI_Comm                     comm;
8970e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB, csrmatC;
898a3f881fbSStefano Zampini 
899a3f881fbSStefano Zampini   PetscFunctionBegin;
900a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
9019566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
9025f80ce2aSJacob Faibussowitsch   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
903a3f881fbSStefano Zampini   A = product->A;
904a3f881fbSStefano Zampini   B = product->B;
9059566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
9069566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
907a3f881fbSStefano Zampini   akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
908a3f881fbSStefano Zampini   bkok    = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
9090e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
9100e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
911076ba34aSJunchao Zhang 
912a3f881fbSStefano Zampini   ptype = product->type;
9130e3ece09SJunchao Zhang   // Take advantage of the symmetry if true
9140e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
9150e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
9160e3ece09SJunchao Zhang     product->symbolic_used_the_fact_A_is_symmetric = PETSC_TRUE;
9170e3ece09SJunchao Zhang   }
9180e3ece09SJunchao Zhang   if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
9190e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
9200e3ece09SJunchao Zhang     product->symbolic_used_the_fact_B_is_symmetric = PETSC_TRUE;
9210e3ece09SJunchao Zhang   }
9220e3ece09SJunchao Zhang 
923a3f881fbSStefano Zampini   switch (ptype) {
9249371c9d4SSatish Balay   case MATPRODUCT_AB:
9259371c9d4SSatish Balay     transA = false;
9269371c9d4SSatish Balay     transB = false;
9270e6a1e94SMark Adams     PetscCall(MatSetBlockSizesFromMats(C, A, B));
9289371c9d4SSatish Balay     break;
9299371c9d4SSatish Balay   case MATPRODUCT_AtB:
9309371c9d4SSatish Balay     transA = true;
9319371c9d4SSatish Balay     transB = false;
9320e6a1e94SMark Adams     if (A->cmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->rmap, A->cmap->bs));
9330e6a1e94SMark Adams     if (B->cmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->cmap, B->cmap->bs));
9349371c9d4SSatish Balay     break;
9359371c9d4SSatish Balay   case MATPRODUCT_ABt:
9369371c9d4SSatish Balay     transA = false;
9379371c9d4SSatish Balay     transB = true;
9380e6a1e94SMark Adams     if (A->rmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->rmap, A->rmap->bs));
9390e6a1e94SMark Adams     if (B->rmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->cmap, B->rmap->bs));
9409371c9d4SSatish Balay     break;
941d71ae5a4SJacob Faibussowitsch   default:
942d71ae5a4SJacob Faibussowitsch     SETERRQ(comm, PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
943a3f881fbSStefano Zampini   }
9440e3ece09SJunchao Zhang   PetscCallCXX(product->data = pdata = new MatProductData_SeqAIJKokkos());
945076ba34aSJunchao Zhang   pdata->reusesym = product->api_user;
946a3f881fbSStefano Zampini 
947076ba34aSJunchao Zhang   /* TODO: add command line options to select spgemm algorithms */
948866eb059SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; /* default alg is TPL if enabled, otherwise KK */
949866eb059SJunchao Zhang 
950866eb059SJunchao Zhang   /* CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 */
951866eb059SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
952866eb059SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
953866eb059SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
954866eb059SJunchao Zhang   #endif
955866eb059SJunchao Zhang #endif
9560e3ece09SJunchao Zhang   PetscCallCXX(pdata->kh.create_spgemm_handle(spgemm_alg));
957076ba34aSJunchao Zhang 
9589566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
959076ba34aSJunchao Zhang   /* TODO: Get rid of the explicit transpose once KK-spgemm implements the transpose option */
960076ba34aSJunchao Zhang   if (transA) {
9619566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
962076ba34aSJunchao Zhang     transA = false;
963076ba34aSJunchao Zhang   }
964076ba34aSJunchao Zhang 
965076ba34aSJunchao Zhang   if (transB) {
9669566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
967076ba34aSJunchao Zhang     transB = false;
968076ba34aSJunchao Zhang   }
969076ba34aSJunchao Zhang 
9700e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
971076ba34aSJunchao Zhang   /* spgemm_symbolic() only populates C's rowmap, but not C's column indices.
972076ba34aSJunchao Zhang     So we have to do a fake spgemm_numeric() here to get csrmatC.j_d setup, before
973076ba34aSJunchao Zhang     calling new Mat_SeqAIJKokkos().
974076ba34aSJunchao Zhang     TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
975076ba34aSJunchao Zhang   */
9760e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
9770e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
978866eb059SJunchao Zhang   /* Query if KK outputs a sorted matrix. If not, we need to sort it */
979866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
980866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(csrmatC)); /* sort_option defaults to -1 in KK!*/
981e944a159SJunchao Zhang #endif
9829566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
983076ba34aSJunchao Zhang 
9849566063dSJacob Faibussowitsch   PetscCallCXX(ckok = new Mat_SeqAIJKokkos(csrmatC));
9859566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(C, ckok));
986076ba34aSJunchao Zhang   C->product->destroy = MatProductDataDestroy_SeqAIJKokkos;
9873ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
988a3f881fbSStefano Zampini }
989a3f881fbSStefano Zampini 
990a3f881fbSStefano Zampini /* handles sparse matrix matrix ops */
991d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSetFromOptions_SeqAIJKokkos(Mat mat)
992d71ae5a4SJacob Faibussowitsch {
993076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
994a3f881fbSStefano Zampini   PetscBool    Biskok = PETSC_FALSE, Ciskok = PETSC_TRUE;
995a3f881fbSStefano Zampini 
996a3f881fbSStefano Zampini   PetscFunctionBegin;
997a3f881fbSStefano Zampini   MatCheckProduct(mat, 1);
9989566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJKOKKOS, &Biskok));
99948a46eb9SPierre Jolivet   if (product->type == MATPRODUCT_ABC) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJKOKKOS, &Ciskok));
1000a3f881fbSStefano Zampini   if (Biskok && Ciskok) {
1001a3f881fbSStefano Zampini     switch (product->type) {
1002a3f881fbSStefano Zampini     case MATPRODUCT_AB:
1003a3f881fbSStefano Zampini     case MATPRODUCT_AtB:
1004d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABt:
1005d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos;
1006d71ae5a4SJacob Faibussowitsch       break;
1007a3f881fbSStefano Zampini     case MATPRODUCT_PtAP:
1008a3f881fbSStefano Zampini     case MATPRODUCT_RARt:
1009d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABC:
1010d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
1011d71ae5a4SJacob Faibussowitsch       break;
1012d71ae5a4SJacob Faibussowitsch     default:
1013d71ae5a4SJacob Faibussowitsch       break;
1014a3f881fbSStefano Zampini     }
1015a3f881fbSStefano Zampini   } else { /* fallback for AIJ */
10169566063dSJacob Faibussowitsch     PetscCall(MatProductSetFromOptions_SeqAIJ(mat));
1017a3f881fbSStefano Zampini   }
10183ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1019a3f881fbSStefano Zampini }
1020a587d139SMark 
1021d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatScale_SeqAIJKokkos(Mat A, PetscScalar a)
1022d71ae5a4SJacob Faibussowitsch {
1023f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok;
1024f0cf5187SStefano Zampini 
1025f0cf5187SStefano Zampini   PetscFunctionBegin;
10269566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
10279566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1028f0cf5187SStefano Zampini   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1029d326c3f1SJunchao Zhang   KokkosBlas::scal(PetscGetKokkosExecutionSpace(), aijkok->a_dual.view_device(), a, aijkok->a_dual.view_device());
10309566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
10319566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(aijkok->a_dual.extent(0)));
10329566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
10333ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1034f0cf5187SStefano Zampini }
1035f0cf5187SStefano Zampini 
1036f4747e26SJunchao Zhang // add a to A's diagonal (if A is square) or main diagonal (if A is rectangular)
1037f4747e26SJunchao Zhang static PetscErrorCode MatShift_SeqAIJKokkos(Mat A, PetscScalar a)
1038f4747e26SJunchao Zhang {
1039f4747e26SJunchao Zhang   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1040f4747e26SJunchao Zhang 
1041f4747e26SJunchao Zhang   PetscFunctionBegin;
1042f4747e26SJunchao Zhang   if (A->assembled && aijseq->diagonaldense) { // no missing diagonals
1043f4747e26SJunchao Zhang     PetscInt n = PetscMin(A->rmap->n, A->cmap->n);
1044f4747e26SJunchao Zhang 
1045f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
1046f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosSyncDevice(A));
1047f4747e26SJunchao Zhang     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1048f4747e26SJunchao Zhang     const auto &Aa     = aijkok->a_dual.view_device();
1049f4747e26SJunchao Zhang     const auto &Adiag  = aijkok->diag_dual.view_device();
1050d326c3f1SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) { Aa(Adiag(i)) += a; }));
1051f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosModifyDevice(A));
1052f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(n));
1053f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
1054f4747e26SJunchao Zhang   } else { // need reassembly, very slow!
1055f4747e26SJunchao Zhang     PetscCall(MatShift_Basic(A, a));
1056f4747e26SJunchao Zhang   }
1057f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1058f4747e26SJunchao Zhang }
1059f4747e26SJunchao Zhang 
1060f4747e26SJunchao Zhang static PetscErrorCode MatDiagonalSet_SeqAIJKokkos(Mat Y, Vec D, InsertMode is)
1061f4747e26SJunchao Zhang {
1062f4747e26SJunchao Zhang   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(Y->data);
1063f4747e26SJunchao Zhang 
1064f4747e26SJunchao Zhang   PetscFunctionBegin;
1065f4747e26SJunchao Zhang   if (Y->assembled && aijseq->diagonaldense) { // no missing diagonals
1066f4747e26SJunchao Zhang     ConstPetscScalarKokkosView dv;
1067f4747e26SJunchao Zhang     PetscInt                   n, nv;
1068f4747e26SJunchao Zhang 
1069f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
1070f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosSyncDevice(Y));
1071f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(D, &dv));
1072f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(D, &nv));
1073f4747e26SJunchao Zhang     n = PetscMin(Y->rmap->n, Y->cmap->n);
1074f4747e26SJunchao Zhang     PetscCheck(n == nv, PetscObjectComm((PetscObject)Y), PETSC_ERR_ARG_SIZ, "Matrix size and vector size do not match");
1075f4747e26SJunchao Zhang 
1076f4747e26SJunchao Zhang     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1077f4747e26SJunchao Zhang     const auto &Aa     = aijkok->a_dual.view_device();
1078f4747e26SJunchao Zhang     const auto &Adiag  = aijkok->diag_dual.view_device();
1079f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(
1080d326c3f1SJunchao Zhang       Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) {
1081f4747e26SJunchao Zhang         if (is == INSERT_VALUES) Aa(Adiag(i)) = dv(i);
1082f4747e26SJunchao Zhang         else Aa(Adiag(i)) += dv(i);
1083f4747e26SJunchao Zhang       }));
1084f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(D, &dv));
1085f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1086f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(n));
1087f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
1088f4747e26SJunchao Zhang   } else { // need reassembly, very slow!
1089f4747e26SJunchao Zhang     PetscCall(MatDiagonalSet_Default(Y, D, is));
1090f4747e26SJunchao Zhang   }
1091f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1092f4747e26SJunchao Zhang }
1093f4747e26SJunchao Zhang 
1094f4747e26SJunchao Zhang static PetscErrorCode MatDiagonalScale_SeqAIJKokkos(Mat A, Vec ll, Vec rr)
1095f4747e26SJunchao Zhang {
1096f4747e26SJunchao Zhang   Mat_SeqAIJ                *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1097f4747e26SJunchao Zhang   PetscInt                   m = A->rmap->n, n = A->cmap->n, nz = aijseq->nz;
1098f4747e26SJunchao Zhang   ConstPetscScalarKokkosView lv, rv;
1099f4747e26SJunchao Zhang 
1100f4747e26SJunchao Zhang   PetscFunctionBegin;
1101f4747e26SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1102f4747e26SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1103f4747e26SJunchao Zhang   const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1104f4747e26SJunchao Zhang   const auto &Aa     = aijkok->a_dual.view_device();
1105f4747e26SJunchao Zhang   const auto &Ai     = aijkok->i_dual.view_device();
1106f4747e26SJunchao Zhang   const auto &Aj     = aijkok->j_dual.view_device();
1107f4747e26SJunchao Zhang   if (ll) {
1108f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(ll, &m));
1109f4747e26SJunchao Zhang     PetscCheck(m == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Left scaling vector wrong length");
1110f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(ll, &lv));
1111f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for( // for each row
1112d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1113f4747e26SJunchao Zhang         PetscInt i   = t.league_rank(); // row i
1114f4747e26SJunchao Zhang         PetscInt len = Ai(i + 1) - Ai(i);
1115f4747e26SJunchao Zhang         // scale entries on the row
1116f4747e26SJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt j) { Aa(Ai(i) + j) *= lv(i); });
1117f4747e26SJunchao Zhang       }));
1118f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(ll, &lv));
1119f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(nz));
1120f4747e26SJunchao Zhang   }
1121f4747e26SJunchao Zhang   if (rr) {
1122f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(rr, &n));
1123f4747e26SJunchao Zhang     PetscCheck(n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Right scaling vector wrong length");
1124f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(rr, &rv));
1125f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for( // for each nonzero
1126d326c3f1SJunchao Zhang       Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt k) { Aa(k) *= rv(Aj(k)); }));
1127f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(rr, &lv));
1128f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(nz));
1129f4747e26SJunchao Zhang   }
1130f4747e26SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(A));
1131f4747e26SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1132f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1133f4747e26SJunchao Zhang }
1134f4747e26SJunchao Zhang 
1135d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatZeroEntries_SeqAIJKokkos(Mat A)
1136d71ae5a4SJacob Faibussowitsch {
1137076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1138a587d139SMark 
1139a587d139SMark   PetscFunctionBegin;
1140076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
11412328674fSJunchao Zhang   if (aijkok) { /* Only zero the device if data is already there */
1142d326c3f1SJunchao Zhang     KokkosBlas::fill(PetscGetKokkosExecutionSpace(), aijkok->a_dual.view_device(), 0.0);
11439566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A));
11442328674fSJunchao Zhang   } else { /* Might be preallocated but not assembled */
11459566063dSJacob Faibussowitsch     PetscCall(MatZeroEntries_SeqAIJ(A));
11462328674fSJunchao Zhang   }
11473ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1148a587d139SMark }
1149a587d139SMark 
1150d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatGetDiagonal_SeqAIJKokkos(Mat A, Vec x)
1151d71ae5a4SJacob Faibussowitsch {
1152f78ce678SMark Adams   Mat_SeqAIJKokkos     *aijkok;
1153f78ce678SMark Adams   PetscInt              n;
1154f78ce678SMark Adams   PetscScalarKokkosView xv;
1155f78ce678SMark Adams 
1156f78ce678SMark Adams   PetscFunctionBegin;
1157f78ce678SMark Adams   PetscCall(VecGetLocalSize(x, &n));
1158f78ce678SMark Adams   PetscCheck(n == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Nonconforming matrix and vector");
1159f78ce678SMark Adams   PetscCheck(A->factortype == MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_SUP, "MatGetDiagonal_SeqAIJKokkos not supported on factored matrices");
1160f78ce678SMark Adams 
1161f78ce678SMark Adams   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1162f78ce678SMark Adams   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1163f78ce678SMark Adams 
1164f78ce678SMark Adams   const auto &Aa    = aijkok->a_dual.view_device();
1165f78ce678SMark Adams   const auto &Ai    = aijkok->i_dual.view_device();
1166f78ce678SMark Adams   const auto &Adiag = aijkok->diag_dual.view_device();
1167f78ce678SMark Adams 
1168f78ce678SMark Adams   PetscCall(VecGetKokkosViewWrite(x, &xv));
11699371c9d4SSatish Balay   Kokkos::parallel_for(
1170d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) {
1171f78ce678SMark Adams       if (Adiag(i) < Ai(i + 1)) xv(i) = Aa(Adiag(i));
1172f78ce678SMark Adams       else xv(i) = 0;
1173f78ce678SMark Adams     });
1174f78ce678SMark Adams   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
11753ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1176f78ce678SMark Adams }
1177f78ce678SMark Adams 
1178db78de30SJunchao Zhang /* Get a Kokkos View from a mat of type MatSeqAIJKokkos */
1179d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1180d71ae5a4SJacob Faibussowitsch {
1181db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1182db78de30SJunchao Zhang 
1183db78de30SJunchao Zhang   PetscFunctionBegin;
1184db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11854f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1186db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11879566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1188db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1189076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
11903ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1191db78de30SJunchao Zhang }
1192db78de30SJunchao Zhang 
1193d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1194d71ae5a4SJacob Faibussowitsch {
1195db78de30SJunchao Zhang   PetscFunctionBegin;
1196db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11974f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1198db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11993ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1200db78de30SJunchao Zhang }
1201db78de30SJunchao Zhang 
1202d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, MatScalarKokkosView *kv)
1203d71ae5a4SJacob Faibussowitsch {
1204db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1205db78de30SJunchao Zhang 
1206db78de30SJunchao Zhang   PetscFunctionBegin;
1207db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12084f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1209db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12109566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1211db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1212076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
12133ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1214db78de30SJunchao Zhang }
1215db78de30SJunchao Zhang 
1216d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, MatScalarKokkosView *kv)
1217d71ae5a4SJacob Faibussowitsch {
1218db78de30SJunchao Zhang   PetscFunctionBegin;
1219db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12204f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1221db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12229566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
12233ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1224db78de30SJunchao Zhang }
1225db78de30SJunchao Zhang 
1226d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1227d71ae5a4SJacob Faibussowitsch {
1228db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1229db78de30SJunchao Zhang 
1230db78de30SJunchao Zhang   PetscFunctionBegin;
1231db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12324f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1233db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1234db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1235076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
12363ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1237db78de30SJunchao Zhang }
1238db78de30SJunchao Zhang 
1239d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1240d71ae5a4SJacob Faibussowitsch {
1241db78de30SJunchao Zhang   PetscFunctionBegin;
1242db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12434f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1244db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12459566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
12463ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1247db78de30SJunchao Zhang }
1248db78de30SJunchao Zhang 
1249c0c276a7Ssdargavi PetscErrorCode MatCreateSeqAIJKokkosWithKokkosViews(MPI_Comm comm, PetscInt m, PetscInt n, Kokkos::View<PetscInt *> &i_d, Kokkos::View<PetscInt *> &j_d, Kokkos::View<PetscScalar *> &a_d, Mat *A)
1250c0c276a7Ssdargavi {
1251c0c276a7Ssdargavi   Mat_SeqAIJKokkos *akok;
1252c0c276a7Ssdargavi 
1253c0c276a7Ssdargavi   PetscFunctionBegin;
1254c0c276a7Ssdargavi   auto exec = PetscGetKokkosExecutionSpace();
1255c0c276a7Ssdargavi   // Create host copies of the input aij
1256c0c276a7Ssdargavi   auto i_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), i_d);
1257c0c276a7Ssdargavi   auto j_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), j_d);
1258c0c276a7Ssdargavi   // Don't copy the vals to the host now
1259c0c276a7Ssdargavi   auto a_h = Kokkos::create_mirror_view(HostMirrorMemorySpace(), a_d);
1260c0c276a7Ssdargavi 
1261c0c276a7Ssdargavi   MatScalarKokkosDualView a_dual = MatScalarKokkosDualView(a_d, a_h);
1262c0c276a7Ssdargavi   // Note we have modified device data so it will copy lazily
1263c0c276a7Ssdargavi   a_dual.modify_device();
1264c0c276a7Ssdargavi   MatRowMapKokkosDualView i_dual = MatRowMapKokkosDualView(i_d, i_h);
1265c0c276a7Ssdargavi   MatColIdxKokkosDualView j_dual = MatColIdxKokkosDualView(j_d, j_h);
1266c0c276a7Ssdargavi 
1267c0c276a7Ssdargavi   PetscCallCXX(akok = new Mat_SeqAIJKokkos(m, n, j_dual.extent(0), i_dual, j_dual, a_dual));
1268c0c276a7Ssdargavi   PetscCall(MatCreate(comm, A));
1269c0c276a7Ssdargavi   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
1270c0c276a7Ssdargavi   PetscFunctionReturn(PETSC_SUCCESS);
1271c0c276a7Ssdargavi }
1272c0c276a7Ssdargavi 
1273c17cf699SJunchao Zhang /* Computes Y += alpha X */
1274d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAXPY_SeqAIJKokkos(Mat Y, PetscScalar alpha, Mat X, MatStructure pattern)
1275d71ae5a4SJacob Faibussowitsch {
1276a587d139SMark   Mat_SeqAIJ              *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
1277c17cf699SJunchao Zhang   Mat_SeqAIJKokkos        *xkok, *ykok, *zkok;
1278c17cf699SJunchao Zhang   ConstMatScalarKokkosView Xa;
1279c17cf699SJunchao Zhang   MatScalarKokkosView      Ya;
12804df4a32cSJunchao Zhang   auto                     exec = PetscGetKokkosExecutionSpace();
1281a587d139SMark 
1282a587d139SMark   PetscFunctionBegin;
1283c17cf699SJunchao Zhang   PetscCheckTypeName(Y, MATSEQAIJKOKKOS);
1284c17cf699SJunchao Zhang   PetscCheckTypeName(X, MATSEQAIJKOKKOS);
12859566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(Y));
12869566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(X));
12879566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1288db78de30SJunchao Zhang 
1289c17cf699SJunchao Zhang   if (pattern != SAME_NONZERO_PATTERN && x->nz == y->nz) {
1290a587d139SMark     PetscBool e;
12919566063dSJacob Faibussowitsch     PetscCall(PetscArraycmp(x->i, y->i, Y->rmap->n + 1, &e));
1292a587d139SMark     if (e) {
12939566063dSJacob Faibussowitsch       PetscCall(PetscArraycmp(x->j, y->j, y->nz, &e));
1294c17cf699SJunchao Zhang       if (e) pattern = SAME_NONZERO_PATTERN;
1295a587d139SMark     }
1296a587d139SMark   }
1297db78de30SJunchao Zhang 
1298c17cf699SJunchao Zhang   /* cusparseDcsrgeam2() computes C = alpha A + beta B. If one knew sparsity pattern of C, one can skip
1299c17cf699SJunchao Zhang     cusparseScsrgeam2_bufferSizeExt() / cusparseXcsrgeam2Nnz(), and directly call cusparseScsrgeam2().
1300c17cf699SJunchao Zhang     If X is SUBSET_NONZERO_PATTERN of Y, we could take advantage of this cusparse feature. However,
1301c17cf699SJunchao Zhang     KokkosSparse::spadd(alpha,A,beta,B,C) has symbolic and numeric phases, MatAXPY does not.
1302c17cf699SJunchao Zhang   */
1303c17cf699SJunchao Zhang   ykok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1304c17cf699SJunchao Zhang   xkok = static_cast<Mat_SeqAIJKokkos *>(X->spptr);
1305c17cf699SJunchao Zhang   Xa   = xkok->a_dual.view_device();
1306c17cf699SJunchao Zhang   Ya   = ykok->a_dual.view_device();
1307c17cf699SJunchao Zhang 
1308c17cf699SJunchao Zhang   if (pattern == SAME_NONZERO_PATTERN) {
1309d326c3f1SJunchao Zhang     KokkosBlas::axpy(exec, alpha, Xa, Ya);
13109566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1311c17cf699SJunchao Zhang   } else if (pattern == SUBSET_NONZERO_PATTERN) {
1312c17cf699SJunchao Zhang     MatRowMapKokkosView Xi = xkok->i_dual.view_device(), Yi = ykok->i_dual.view_device();
1313c17cf699SJunchao Zhang     MatColIdxKokkosView Xj = xkok->j_dual.view_device(), Yj = ykok->j_dual.view_device();
1314c17cf699SJunchao Zhang 
13159371c9d4SSatish Balay     Kokkos::parallel_for(
1316d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(exec, Y->rmap->n, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
13170e3ece09SJunchao Zhang         PetscInt i = t.league_rank(); // row i
13180e3ece09SJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() {
13190e3ece09SJunchao Zhang           // Only one thread works in a team
1320c17cf699SJunchao Zhang           PetscInt p, q = Yi(i);
13210e3ece09SJunchao Zhang           for (p = Xi(i); p < Xi(i + 1); p++) {          // For each nonzero on row i of X,
13220e3ece09SJunchao Zhang             while (Xj(p) != Yj(q) && q < Yi(i + 1)) q++; // find the matching nonzero on row i of Y.
13230e3ece09SJunchao Zhang             if (Xj(p) == Yj(q)) {                        // Found it
1324c17cf699SJunchao Zhang               Ya(q) += alpha * Xa(p);
1325c17cf699SJunchao Zhang               q++;
1326a587d139SMark             } else {
13270e3ece09SJunchao Zhang             // If not found, it indicates the input is wrong (X is not a SUBSET_NONZERO_PATTERN of Y).
13280e3ece09SJunchao Zhang             // Just insert a NaN at the beginning of row i if it is not empty, to make the result wrong.
13290e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_VERSION_GE(3, 7, 0)
13300e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::ArithTraits<PetscScalar>::nan();
13318b8b16f9SJunchao Zhang #else
13320e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::Experimental::nan("1");
13338b8b16f9SJunchao Zhang #endif
1334a587d139SMark             }
1335c17cf699SJunchao Zhang           }
1336c17cf699SJunchao Zhang         });
1337c17cf699SJunchao Zhang       });
13389566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
13390e3ece09SJunchao Zhang   } else { // different nonzero patterns
1340c17cf699SJunchao Zhang     Mat             Z;
1341c17cf699SJunchao Zhang     KokkosCsrMatrix zcsr;
1342c17cf699SJunchao Zhang     KernelHandle    kh;
13430e3ece09SJunchao Zhang     kh.create_spadd_handle(true); // X, Y are sorted
1344c17cf699SJunchao Zhang     KokkosSparse::spadd_symbolic(&kh, xkok->csrmat, ykok->csrmat, zcsr);
1345c17cf699SJunchao Zhang     KokkosSparse::spadd_numeric(&kh, alpha, xkok->csrmat, (PetscScalar)1.0, ykok->csrmat, zcsr);
1346c17cf699SJunchao Zhang     zkok = new Mat_SeqAIJKokkos(zcsr);
13479566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, zkok, &Z));
13489566063dSJacob Faibussowitsch     PetscCall(MatHeaderReplace(Y, &Z));
1349c17cf699SJunchao Zhang     kh.destroy_spadd_handle();
1350c17cf699SJunchao Zhang   }
13519566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
13520e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(xkok->a_dual.extent(0) * 2)); // Because we scaled X and then added it to Y
13533ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1354a587d139SMark }
1355a587d139SMark 
13562c4ab24aSJunchao Zhang struct MatCOOStruct_SeqAIJKokkos {
13572c4ab24aSJunchao Zhang   PetscCount           n;
13582c4ab24aSJunchao Zhang   PetscCount           Atot;
13592c4ab24aSJunchao Zhang   PetscInt             nz;
13602c4ab24aSJunchao Zhang   PetscCountKokkosView jmap;
13612c4ab24aSJunchao Zhang   PetscCountKokkosView perm;
13622c4ab24aSJunchao Zhang 
13632c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos(const MatCOOStruct_SeqAIJ *coo_h)
13642c4ab24aSJunchao Zhang   {
13652c4ab24aSJunchao Zhang     nz   = coo_h->nz;
13662c4ab24aSJunchao Zhang     n    = coo_h->n;
13672c4ab24aSJunchao Zhang     Atot = coo_h->Atot;
13682c4ab24aSJunchao Zhang     jmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->jmap, nz + 1));
13692c4ab24aSJunchao Zhang     perm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->perm, Atot));
13702c4ab24aSJunchao Zhang   }
13712c4ab24aSJunchao Zhang };
13722c4ab24aSJunchao Zhang 
137349abdd8aSBarry Smith static PetscErrorCode MatCOOStructDestroy_SeqAIJKokkos(void **data)
13742c4ab24aSJunchao Zhang {
13752c4ab24aSJunchao Zhang   PetscFunctionBegin;
137649abdd8aSBarry Smith   PetscCallCXX(delete static_cast<MatCOOStruct_SeqAIJKokkos *>(*data));
13772c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
13782c4ab24aSJunchao Zhang }
13792c4ab24aSJunchao Zhang 
1380d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_SeqAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1381d71ae5a4SJacob Faibussowitsch {
138242550becSJunchao Zhang   Mat_SeqAIJKokkos          *akok;
138342550becSJunchao Zhang   Mat_SeqAIJ                *aseq;
138403e76207SPierre Jolivet   PetscContainer             container_h;
13852c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJ       *coo_h;
13862c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo_d;
138742550becSJunchao Zhang 
138842550becSJunchao Zhang   PetscFunctionBegin;
13899566063dSJacob Faibussowitsch   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, coo_i, coo_j));
1390394ed5ebSJunchao Zhang   aseq = static_cast<Mat_SeqAIJ *>(mat->data);
139142550becSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(mat->spptr);
1392cbc6b225SStefano Zampini   delete akok;
1393f4747e26SJunchao Zhang   mat->spptr = akok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, aseq, mat->nonzerostate + 1, PETSC_FALSE);
13949566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries_SeqAIJKokkos(mat));
13952c4ab24aSJunchao Zhang 
13962c4ab24aSJunchao Zhang   // Copy the COO struct to device
13972c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
13982c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
13992c4ab24aSJunchao Zhang   PetscCallCXX(coo_d = new MatCOOStruct_SeqAIJKokkos(coo_h));
14002c4ab24aSJunchao Zhang 
14012c4ab24aSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
140203e76207SPierre Jolivet   PetscCall(PetscObjectContainerCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", coo_d, MatCOOStructDestroy_SeqAIJKokkos));
14033ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
140442550becSJunchao Zhang }
140542550becSJunchao Zhang 
1406d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_SeqAIJKokkos(Mat A, const PetscScalar v[], InsertMode imode)
1407d71ae5a4SJacob Faibussowitsch {
140842550becSJunchao Zhang   MatScalarKokkosView        Aa;
140942550becSJunchao Zhang   ConstMatScalarKokkosView   kv;
141042550becSJunchao Zhang   PetscMemType               memtype;
14112c4ab24aSJunchao Zhang   PetscContainer             container;
14122c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo;
141342550becSJunchao Zhang 
141442550becSJunchao Zhang   PetscFunctionBegin;
14152c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)A, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
14162c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
14172c4ab24aSJunchao Zhang 
14182c4ab24aSJunchao Zhang   const auto &n    = coo->n;
14192c4ab24aSJunchao Zhang   const auto &Annz = coo->nz;
14202c4ab24aSJunchao Zhang   const auto &jmap = coo->jmap;
14212c4ab24aSJunchao Zhang   const auto &perm = coo->perm;
14222c4ab24aSJunchao Zhang 
14239566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype));
142442550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
14252c4ab24aSJunchao Zhang     kv = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatScalarKokkosViewHost(v, n));
142642550becSJunchao Zhang   } else {
14272c4ab24aSJunchao Zhang     kv = ConstMatScalarKokkosView(v, n); /* Directly use v[]'s memory */
142842550becSJunchao Zhang   }
142942550becSJunchao Zhang 
1430c7b718f4SJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1431c7b718f4SJunchao Zhang   else PetscCall(MatSeqAIJGetKokkosView(A, &Aa));                             /* read & write matrix values */
143242550becSJunchao Zhang 
143308bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
14349371c9d4SSatish Balay   Kokkos::parallel_for(
1435d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz), KOKKOS_LAMBDA(const PetscCount i) {
1436c7b718f4SJunchao Zhang       PetscScalar sum = 0.0;
1437c7b718f4SJunchao Zhang       for (PetscCount k = jmap(i); k < jmap(i + 1); k++) sum += kv(perm(k));
1438c7b718f4SJunchao Zhang       Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1439c7b718f4SJunchao Zhang     });
144008bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1441394ed5ebSJunchao Zhang 
14429566063dSJacob Faibussowitsch   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
14439566063dSJacob Faibussowitsch   else PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
14443ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
144542550becSJunchao Zhang }
144642550becSJunchao Zhang 
1447d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat A)
1448d71ae5a4SJacob Faibussowitsch {
1449076ba34aSJunchao Zhang   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
1450076ba34aSJunchao Zhang 
14518c3ff71bSJunchao Zhang   PetscFunctionBegin;
1452076ba34aSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_KOKKOS; /* We do not really use this flag */
14536f3d89d0SStefano Zampini   A->boundtocpu  = PETSC_FALSE;
14546f3d89d0SStefano Zampini 
14558c3ff71bSJunchao Zhang   A->ops->assemblyend               = MatAssemblyEnd_SeqAIJKokkos;
14568c3ff71bSJunchao Zhang   A->ops->destroy                   = MatDestroy_SeqAIJKokkos;
14578c3ff71bSJunchao Zhang   A->ops->duplicate                 = MatDuplicate_SeqAIJKokkos;
1458a587d139SMark   A->ops->axpy                      = MatAXPY_SeqAIJKokkos;
1459f0cf5187SStefano Zampini   A->ops->scale                     = MatScale_SeqAIJKokkos;
1460a587d139SMark   A->ops->zeroentries               = MatZeroEntries_SeqAIJKokkos;
1461076ba34aSJunchao Zhang   A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJKokkos;
14628c3ff71bSJunchao Zhang   A->ops->mult                      = MatMult_SeqAIJKokkos;
14638c3ff71bSJunchao Zhang   A->ops->multadd                   = MatMultAdd_SeqAIJKokkos;
14648c3ff71bSJunchao Zhang   A->ops->multtranspose             = MatMultTranspose_SeqAIJKokkos;
14658c3ff71bSJunchao Zhang   A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJKokkos;
14668c3ff71bSJunchao Zhang   A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJKokkos;
14678c3ff71bSJunchao Zhang   A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJKokkos;
1468076ba34aSJunchao Zhang   A->ops->productnumeric            = MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos;
14690ecb592aSJunchao Zhang   A->ops->transpose                 = MatTranspose_SeqAIJKokkos;
1470152b3e56SJunchao Zhang   A->ops->setoption                 = MatSetOption_SeqAIJKokkos;
1471f78ce678SMark Adams   A->ops->getdiagonal               = MatGetDiagonal_SeqAIJKokkos;
1472f4747e26SJunchao Zhang   A->ops->shift                     = MatShift_SeqAIJKokkos;
1473f4747e26SJunchao Zhang   A->ops->diagonalset               = MatDiagonalSet_SeqAIJKokkos;
1474f4747e26SJunchao Zhang   A->ops->diagonalscale             = MatDiagonalScale_SeqAIJKokkos;
1475076ba34aSJunchao Zhang   a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJKokkos;
1476076ba34aSJunchao Zhang   a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJKokkos;
1477076ba34aSJunchao Zhang   a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJKokkos;
1478076ba34aSJunchao Zhang   a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJKokkos;
1479076ba34aSJunchao Zhang   a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJKokkos;
1480076ba34aSJunchao Zhang   a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJKokkos;
14817ee59b9bSJunchao Zhang   a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJKokkos;
148242550becSJunchao Zhang 
14839566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJKokkos));
14849566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJKokkos));
148557761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
148657761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijkokkos_hypre_C", MatConvert_AIJ_HYPRE));
148757761e9aSJunchao Zhang #endif
14883ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1489076ba34aSJunchao Zhang }
1490076ba34aSJunchao Zhang 
14919d13fa56SJunchao Zhang /*
14929d13fa56SJunchao Zhang    Extract the (prescribled) diagonal blocks of the matrix and then invert them
14939d13fa56SJunchao Zhang 
14949d13fa56SJunchao Zhang   Input Parameters:
14959d13fa56SJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
14969d13fa56SJunchao Zhang .  bs      - block sizes in 'csr' format, i.e., the i-th block has size bs(i+1) - bs(i)
14979d13fa56SJunchao Zhang .  bs2     - square of block sizes in 'csr' format, i.e., the i-th block should be stored at offset bs2(i) in diagVal[]
14989d13fa56SJunchao Zhang .  blkMap  - map row ids to block ids, i.e., row i belongs to the block blkMap(i)
14999d13fa56SJunchao Zhang -  work    - a pre-allocated work buffer (as big as diagVal) for use by this routine
15009d13fa56SJunchao Zhang 
15019d13fa56SJunchao Zhang   Output Parameter:
15029d13fa56SJunchao Zhang .  diagVal - the (pre-allocated) buffer to store the inverted blocks (each block is stored in column-major order)
15039d13fa56SJunchao Zhang */
15049d13fa56SJunchao Zhang PETSC_INTERN PetscErrorCode MatInvertVariableBlockDiagonal_SeqAIJKokkos(Mat A, const PetscIntKokkosView &bs, const PetscIntKokkosView &bs2, const PetscIntKokkosView &blkMap, PetscScalarKokkosView &work, PetscScalarKokkosView &diagVal)
15059d13fa56SJunchao Zhang {
15069d13fa56SJunchao Zhang   Mat_SeqAIJKokkos *akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
15079d13fa56SJunchao Zhang   PetscInt          nblocks = bs.extent(0) - 1;
15089d13fa56SJunchao Zhang 
15099d13fa56SJunchao Zhang   PetscFunctionBegin;
15109d13fa56SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A)); // Since we'll access A's value on device
15119d13fa56SJunchao Zhang 
15129d13fa56SJunchao Zhang   // Pull out the diagonal blocks of the matrix and then invert the blocks
15139d13fa56SJunchao Zhang   auto Aa    = akok->a_dual.view_device();
15149d13fa56SJunchao Zhang   auto Ai    = akok->i_dual.view_device();
15159d13fa56SJunchao Zhang   auto Aj    = akok->j_dual.view_device();
15169d13fa56SJunchao Zhang   auto Adiag = akok->diag_dual.view_device();
15179d13fa56SJunchao Zhang   // TODO: how to tune the team size?
151845402d8aSJunchao Zhang #if defined(KOKKOS_ENABLE_UNIFIED_MEMORY)
15199d13fa56SJunchao Zhang   auto ts = Kokkos::AUTO();
15209d13fa56SJunchao Zhang #else
15219d13fa56SJunchao Zhang   auto ts = 16; // improved performance 30% over Kokkos::AUTO() with CUDA, but failed with "Kokkos::abort: Requested Team Size is too large!" on CPUs
15229d13fa56SJunchao Zhang #endif
15239d13fa56SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1524d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), nblocks, ts), KOKKOS_LAMBDA(const KokkosTeamMemberType &teamMember) {
15259d13fa56SJunchao Zhang       const PetscInt bid    = teamMember.league_rank();                                                   // block id
15269d13fa56SJunchao Zhang       const PetscInt rstart = bs(bid);                                                                    // this block starts from this row
15279d13fa56SJunchao Zhang       const PetscInt m      = bs(bid + 1) - bs(bid);                                                      // size of this block
15289d13fa56SJunchao Zhang       const auto    &B      = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft>(&diagVal(bs2(bid)), m, m); // column-major order
15299d13fa56SJunchao Zhang       const auto    &W      = PetscScalarKokkosView(&work(bs2(bid)), m * m);
15309d13fa56SJunchao Zhang 
15319d13fa56SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(teamMember, m), [=](const PetscInt &r) { // r-th row in B
15329d13fa56SJunchao Zhang         PetscInt i = rstart + r;                                                            // i-th row in A
15339d13fa56SJunchao Zhang 
15349d13fa56SJunchao Zhang         if (Ai(i) <= Adiag(i) && Adiag(i) < Ai(i + 1)) { // if the diagonal exists (common case)
15359d13fa56SJunchao Zhang           PetscInt first = Adiag(i) - r;                 // we start to check nonzeros from here along this row
15369d13fa56SJunchao Zhang 
15379d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) {                   // walk n steps to see what column indices we will meet
15389d13fa56SJunchao Zhang             if (first + c < Ai(i) || first + c >= Ai(i + 1)) { // this entry (first+c) is out of range of this row, in other words, its value is zero
15399d13fa56SJunchao Zhang               B(r, c) = 0.0;
15409d13fa56SJunchao Zhang             } else if (Aj(first + c) == rstart + c) { // this entry is right on the (rstart+c) column
15419d13fa56SJunchao Zhang               B(r, c) = Aa(first + c);
15429d13fa56SJunchao Zhang             } else { // this entry does not show up in the CSR
15439d13fa56SJunchao Zhang               B(r, c) = 0.0;
15449d13fa56SJunchao Zhang             }
15459d13fa56SJunchao Zhang           }
15469d13fa56SJunchao Zhang         } else { // rare case that the diagonal does not exist
15479d13fa56SJunchao Zhang           const PetscInt begin = Ai(i);
15489d13fa56SJunchao Zhang           const PetscInt end   = Ai(i + 1);
15499d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) B(r, c) = 0.0;
15509d13fa56SJunchao Zhang           for (PetscInt j = begin; j < end; j++) { // scan the whole row; could use binary search but this is a rare case so we did not.
15519d13fa56SJunchao Zhang             if (rstart <= Aj(j) && Aj(j) < rstart + m) B(r, Aj(j) - rstart) = Aa(j);
15529d13fa56SJunchao Zhang             else if (Aj(j) >= rstart + m) break;
15539d13fa56SJunchao Zhang           }
15549d13fa56SJunchao Zhang         }
15559d13fa56SJunchao Zhang       });
15569d13fa56SJunchao Zhang 
15579d13fa56SJunchao Zhang       // LU-decompose B (w/o pivoting) and then invert B
15589d13fa56SJunchao Zhang       KokkosBatched::TeamLU<KokkosTeamMemberType, KokkosBatched::Algo::LU::Unblocked>::invoke(teamMember, B, 0.0);
15599d13fa56SJunchao Zhang       KokkosBatched::TeamInverseLU<KokkosTeamMemberType, KokkosBatched::Algo::InverseLU::Unblocked>::invoke(teamMember, B, W);
15609d13fa56SJunchao Zhang     }));
15619d13fa56SJunchao Zhang   // PetscLogGpuFlops() is done in the caller PCSetUp_VPBJacobi_Kokkos as we don't want to compute the flops in kernels
15629d13fa56SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15639d13fa56SJunchao Zhang }
15649d13fa56SJunchao Zhang 
1565d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSetSeqAIJKokkosWithCSRMatrix(Mat A, Mat_SeqAIJKokkos *akok)
1566d71ae5a4SJacob Faibussowitsch {
1567076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
1568076ba34aSJunchao Zhang   PetscInt    i, m, n;
15694df4a32cSJunchao Zhang   auto        exec = PetscGetKokkosExecutionSpace();
1570076ba34aSJunchao Zhang 
1571076ba34aSJunchao Zhang   PetscFunctionBegin;
15725f80ce2aSJacob Faibussowitsch   PetscCheck(!A->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A->spptr is supposed to be empty");
1573076ba34aSJunchao Zhang 
1574076ba34aSJunchao Zhang   m = akok->nrows();
1575076ba34aSJunchao Zhang   n = akok->ncols();
15769566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(A, m, n, m, n));
15779566063dSJacob Faibussowitsch   PetscCall(MatSetType(A, MATSEQAIJKOKKOS));
1578076ba34aSJunchao Zhang 
1579076ba34aSJunchao Zhang   /* Set up data structures of A as a MATSEQAIJ */
15809566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(A, MAT_SKIP_ALLOCATION, NULL));
158157508eceSPierre Jolivet   aseq = (Mat_SeqAIJ *)A->data;
1582076ba34aSJunchao Zhang 
1583e36ced11SJunchao Zhang   PetscCallCXX(akok->i_dual.sync_host(exec)); /* We always need sync'ed i, j on host */
1584e36ced11SJunchao Zhang   PetscCallCXX(akok->j_dual.sync_host(exec));
1585e36ced11SJunchao Zhang   PetscCallCXX(exec.fence());
1586076ba34aSJunchao Zhang 
1587076ba34aSJunchao Zhang   aseq->i       = akok->i_host_data();
1588076ba34aSJunchao Zhang   aseq->j       = akok->j_host_data();
1589076ba34aSJunchao Zhang   aseq->a       = akok->a_host_data();
1590076ba34aSJunchao Zhang   aseq->nonew   = -1; /*this indicates that inserting a new value in the matrix that generates a new nonzero is an error*/
1591076ba34aSJunchao Zhang   aseq->free_a  = PETSC_FALSE;
1592076ba34aSJunchao Zhang   aseq->free_ij = PETSC_FALSE;
1593076ba34aSJunchao Zhang   aseq->nz      = akok->nnz();
1594076ba34aSJunchao Zhang   aseq->maxnz   = aseq->nz;
1595076ba34aSJunchao Zhang 
15969566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->imax));
15979566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->ilen));
1598ad540459SPierre Jolivet   for (i = 0; i < m; i++) aseq->ilen[i] = aseq->imax[i] = aseq->i[i + 1] - aseq->i[i];
1599076ba34aSJunchao Zhang 
1600076ba34aSJunchao Zhang   /* It is critical to set the nonzerostate, as we use it to check if sparsity pattern (hence data) has changed on host in MatAssemblyEnd */
1601076ba34aSJunchao Zhang   akok->nonzerostate = A->nonzerostate;
1602ff751488SJunchao Zhang   A->spptr           = akok; /* Set A->spptr before MatAssembly so that A->spptr won't be allocated again there */
16039566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
16049566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
16053ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1606076ba34aSJunchao Zhang }
1607076ba34aSJunchao Zhang 
16080e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetKokkosCsrMatrix(Mat A, KokkosCsrMatrix *csr)
16090e3ece09SJunchao Zhang {
16100e3ece09SJunchao Zhang   PetscFunctionBegin;
16110e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
16120e3ece09SJunchao Zhang   *csr = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
16130e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
16140e3ece09SJunchao Zhang }
16150e3ece09SJunchao Zhang 
16160e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithKokkosCsrMatrix(MPI_Comm comm, KokkosCsrMatrix csr, Mat *A)
16170e3ece09SJunchao Zhang {
16180e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok;
16194d86920dSPierre Jolivet 
16200e3ece09SJunchao Zhang   PetscFunctionBegin;
16210e3ece09SJunchao Zhang   PetscCallCXX(akok = new Mat_SeqAIJKokkos(csr));
16220e3ece09SJunchao Zhang   PetscCall(MatCreate(comm, A));
16230e3ece09SJunchao Zhang   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
16240e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
16250e3ece09SJunchao Zhang }
16260e3ece09SJunchao Zhang 
1627076ba34aSJunchao Zhang /* Crete a SEQAIJKOKKOS matrix with a Mat_SeqAIJKokkos data structure
1628076ba34aSJunchao Zhang 
1629076ba34aSJunchao Zhang    Note we have names like MatSeqAIJSetPreallocationCSR, so I use capitalized CSR
1630076ba34aSJunchao Zhang  */
1631d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithCSRMatrix(MPI_Comm comm, Mat_SeqAIJKokkos *akok, Mat *A)
1632d71ae5a4SJacob Faibussowitsch {
1633076ba34aSJunchao Zhang   PetscFunctionBegin;
16349566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
16359566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
16363ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16378c3ff71bSJunchao Zhang }
16388c3ff71bSJunchao Zhang 
1639152b3e56SJunchao Zhang /*@C
164011a5261eSBarry Smith   MatCreateSeqAIJKokkos - Creates a sparse matrix in `MATSEQAIJKOKKOS` (compressed row) format
16418c3ff71bSJunchao Zhang   (the default parallel PETSc format). This matrix will ultimately be handled by
164220f4b53cSBarry Smith   Kokkos for calculations.
16438c3ff71bSJunchao Zhang 
16448c3ff71bSJunchao Zhang   Collective
16458c3ff71bSJunchao Zhang 
16468c3ff71bSJunchao Zhang   Input Parameters:
164711a5261eSBarry Smith + comm - MPI communicator, set to `PETSC_COMM_SELF`
16488c3ff71bSJunchao Zhang . m    - number of rows
16498c3ff71bSJunchao Zhang . n    - number of columns
165020f4b53cSBarry Smith . nz   - number of nonzeros per row (same for all rows), ignored if `nnz` is provided
165120f4b53cSBarry Smith - nnz  - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL`
16528c3ff71bSJunchao Zhang 
16538c3ff71bSJunchao Zhang   Output Parameter:
16548c3ff71bSJunchao Zhang . A - the matrix
16558c3ff71bSJunchao Zhang 
16562ef1f0ffSBarry Smith   Level: intermediate
16572ef1f0ffSBarry Smith 
16582ef1f0ffSBarry Smith   Notes:
165911a5261eSBarry Smith   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
16608c3ff71bSJunchao Zhang   MatXXXXSetPreallocation() paradgm instead of this routine directly.
166111a5261eSBarry Smith   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
16628c3ff71bSJunchao Zhang 
166311a5261eSBarry Smith   The AIJ format, also called
16642ef1f0ffSBarry Smith   compressed row storage, is fully compatible with standard Fortran
16658c3ff71bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
166620f4b53cSBarry Smith   either one (as in Fortran) or zero.
16678c3ff71bSJunchao Zhang 
16682ef1f0ffSBarry Smith   Specify the preallocated storage with either `nz` or `nnz` (not both).
16692ef1f0ffSBarry Smith   Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
16702ef1f0ffSBarry Smith   allocation.
16718c3ff71bSJunchao Zhang 
1672fe59aa6dSJacob Faibussowitsch .seealso: [](ch_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
16738c3ff71bSJunchao Zhang @*/
1674d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateSeqAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
1675d71ae5a4SJacob Faibussowitsch {
16768c3ff71bSJunchao Zhang   PetscFunctionBegin;
16779566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
16789566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
16799566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, m, n));
16809566063dSJacob Faibussowitsch   PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
16819566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
16823ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16838c3ff71bSJunchao Zhang }
1684930e68a5SMark Adams 
1685aac854edSJunchao Zhang // After matrix numeric factorization, there are still steps to do before triangular solve can be called.
1686aac854edSJunchao Zhang // For example, for transpose solve, we might need to compute the transpose matrices if the solver does not support it (such as KK, while cusparse does).
1687aac854edSJunchao Zhang // In cusparse, one has to call cusparseSpSV_analysis() with updated triangular matrix values before calling cusparseSpSV_solve().
1688aac854edSJunchao Zhang // Simiarily, in KK sptrsv_symbolic() has to be called before sptrsv_solve(). We put these steps in MatSeqAIJKokkos{Transpose}SolveCheck.
1689aac854edSJunchao Zhang static PetscErrorCode MatSeqAIJKokkosSolveCheck(Mat A)
1690d71ae5a4SJacob Faibussowitsch {
169186a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors   = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1692aac854edSJunchao Zhang   const PetscBool             has_lower = factors->iL_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // false with Choleksy
1693aac854edSJunchao Zhang   const PetscBool             has_upper = factors->iU_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // true with LU and Choleksy
169486a27549SJunchao Zhang 
169586a27549SJunchao Zhang   PetscFunctionBegin;
1696aac854edSJunchao Zhang   if (!factors->sptrsv_symbolic_completed) { // If sptrsv_symbolic was not called yet
1697aac854edSJunchao Zhang     if (has_upper) PetscCallCXX(sptrsv_symbolic(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d));
1698aac854edSJunchao Zhang     if (has_lower) PetscCallCXX(sptrsv_symbolic(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d));
169986a27549SJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_TRUE;
170086a27549SJunchao Zhang   }
17013ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
170286a27549SJunchao Zhang }
170386a27549SJunchao Zhang 
1704d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosTransposeSolveCheck(Mat A)
1705d71ae5a4SJacob Faibussowitsch {
1706aac854edSJunchao Zhang   const PetscInt              n         = A->rmap->n;
170786a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors   = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1708aac854edSJunchao Zhang   const PetscBool             has_lower = factors->iL_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // false with Choleksy
1709aac854edSJunchao Zhang   const PetscBool             has_upper = factors->iU_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // true with LU or Choleksy
171086a27549SJunchao Zhang 
171186a27549SJunchao Zhang   PetscFunctionBegin;
1712aac854edSJunchao Zhang   if (!factors->transpose_updated) {
1713aac854edSJunchao Zhang     if (has_upper) {
1714aac854edSJunchao Zhang       if (!factors->iUt_d.extent(0)) {                                 // Allocate Ut on device if not yet
1715aac854edSJunchao Zhang         factors->iUt_d = MatRowMapKokkosView("factors->iUt_d", n + 1); // KK requires this view to be initialized to 0 to call transpose_matrix
17167b8d4ba6SJunchao Zhang         factors->jUt_d = MatColIdxKokkosView(NoInit("factors->jUt_d"), factors->jU_d.extent(0));
17177b8d4ba6SJunchao Zhang         factors->aUt_d = MatScalarKokkosView(NoInit("factors->aUt_d"), factors->aU_d.extent(0));
1718aac854edSJunchao Zhang       }
171986a27549SJunchao Zhang 
1720aac854edSJunchao Zhang       if (factors->iU_h.extent(0)) { // If U is on host (factorization was done on host), we also compute the transpose on host
1721aac854edSJunchao Zhang         if (!factors->U) {
1722aac854edSJunchao Zhang           Mat_SeqAIJ *seq;
172386a27549SJunchao Zhang 
1724aac854edSJunchao Zhang           PetscCall(MatCreateSeqAIJWithArrays(PETSC_COMM_SELF, n, n, factors->iU_h.data(), factors->jU_h.data(), factors->aU_h.data(), &factors->U));
1725aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->U, MAT_INITIAL_MATRIX, &factors->Ut));
172686a27549SJunchao Zhang 
1727aac854edSJunchao Zhang           seq            = static_cast<Mat_SeqAIJ *>(factors->Ut->data);
1728aac854edSJunchao Zhang           factors->iUt_h = MatRowMapKokkosViewHost(seq->i, n + 1);
1729aac854edSJunchao Zhang           factors->jUt_h = MatColIdxKokkosViewHost(seq->j, seq->nz);
1730aac854edSJunchao Zhang           factors->aUt_h = MatScalarKokkosViewHost(seq->a, seq->nz);
1731aac854edSJunchao Zhang         } else {
1732aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->U, MAT_REUSE_MATRIX, &factors->Ut)); // Matrix Ut' data is aliased with {i, j, a}Ut_h
1733aac854edSJunchao Zhang         }
1734aac854edSJunchao Zhang         // Copy Ut from host to device
1735aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->iUt_d, factors->iUt_h));
1736aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->jUt_d, factors->jUt_h));
1737aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->aUt_d, factors->aUt_h));
1738aac854edSJunchao Zhang       } else { // If U was computed on device, we also compute the transpose there
1739aac854edSJunchao Zhang         // TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices. We have to sort the indices, until KK provides finer control options.
1740aac854edSJunchao Zhang         PetscCallCXX(transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iU_d,
1741aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->jU_d, factors->aU_d,
1742aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->iUt_d, factors->jUt_d,
1743aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->aUt_d));
1744aac854edSJunchao Zhang         PetscCallCXX(sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iUt_d, factors->jUt_d, factors->aUt_d));
1745aac854edSJunchao Zhang       }
1746aac854edSJunchao Zhang       PetscCallCXX(sptrsv_symbolic(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d));
1747aac854edSJunchao Zhang     }
1748aac854edSJunchao Zhang 
1749aac854edSJunchao Zhang     // do the same for L with LU
1750aac854edSJunchao Zhang     if (has_lower) {
1751aac854edSJunchao Zhang       if (!factors->iLt_d.extent(0)) {                                 // Allocate Lt on device if not yet
1752aac854edSJunchao Zhang         factors->iLt_d = MatRowMapKokkosView("factors->iLt_d", n + 1); // KK requires this view to be initialized to 0 to call transpose_matrix
1753aac854edSJunchao Zhang         factors->jLt_d = MatColIdxKokkosView(NoInit("factors->jLt_d"), factors->jL_d.extent(0));
1754aac854edSJunchao Zhang         factors->aLt_d = MatScalarKokkosView(NoInit("factors->aLt_d"), factors->aL_d.extent(0));
1755aac854edSJunchao Zhang       }
1756aac854edSJunchao Zhang 
1757aac854edSJunchao Zhang       if (factors->iL_h.extent(0)) { // If L is on host, we also compute the transpose on host
1758aac854edSJunchao Zhang         if (!factors->L) {
1759aac854edSJunchao Zhang           Mat_SeqAIJ *seq;
1760aac854edSJunchao Zhang 
1761aac854edSJunchao Zhang           PetscCall(MatCreateSeqAIJWithArrays(PETSC_COMM_SELF, n, n, factors->iL_h.data(), factors->jL_h.data(), factors->aL_h.data(), &factors->L));
1762aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->L, MAT_INITIAL_MATRIX, &factors->Lt));
1763aac854edSJunchao Zhang 
1764aac854edSJunchao Zhang           seq            = static_cast<Mat_SeqAIJ *>(factors->Lt->data);
1765aac854edSJunchao Zhang           factors->iLt_h = MatRowMapKokkosViewHost(seq->i, n + 1);
1766aac854edSJunchao Zhang           factors->jLt_h = MatColIdxKokkosViewHost(seq->j, seq->nz);
1767aac854edSJunchao Zhang           factors->aLt_h = MatScalarKokkosViewHost(seq->a, seq->nz);
1768aac854edSJunchao Zhang         } else {
1769aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->L, MAT_REUSE_MATRIX, &factors->Lt)); // Matrix Lt' data is aliased with {i, j, a}Lt_h
1770aac854edSJunchao Zhang         }
1771aac854edSJunchao Zhang         // Copy Lt from host to device
1772aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->iLt_d, factors->iLt_h));
1773aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->jLt_d, factors->jLt_h));
1774aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->aLt_d, factors->aLt_h));
1775aac854edSJunchao Zhang       } else { // If L was computed on device, we also compute the transpose there
1776aac854edSJunchao Zhang         // TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices. We have to sort the indices, until KK provides finer control options.
1777aac854edSJunchao Zhang         PetscCallCXX(transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iL_d,
1778aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->jL_d, factors->aL_d,
1779aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->iLt_d, factors->jLt_d,
1780aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->aLt_d));
1781aac854edSJunchao Zhang         PetscCallCXX(sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iLt_d, factors->jLt_d, factors->aLt_d));
1782aac854edSJunchao Zhang       }
1783aac854edSJunchao Zhang       PetscCallCXX(sptrsv_symbolic(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d));
1784aac854edSJunchao Zhang     }
1785aac854edSJunchao Zhang 
178686a27549SJunchao Zhang     factors->transpose_updated = PETSC_TRUE;
178786a27549SJunchao Zhang   }
17883ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
178986a27549SJunchao Zhang }
179086a27549SJunchao Zhang 
1791aac854edSJunchao Zhang // Solve Ax = b, with RAR = U^T D U, where R is the row (and col) permutation matrix on A.
1792aac854edSJunchao Zhang // R is represented by rowperm in factors. If R is identity (i.e, no reordering), then rowperm is empty.
1793aac854edSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJKokkos_Cholesky(Mat A, Vec bb, Vec xx)
1794d71ae5a4SJacob Faibussowitsch {
1795aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
179686a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1797aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1798aac854edSJunchao Zhang   PetscScalarKokkosView       D       = factors->D_d;
1799aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1800aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1801aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1802aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm  = factors->rowperm;
1803aac854edSJunchao Zhang   PetscBool                   identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
180486a27549SJunchao Zhang 
180586a27549SJunchao Zhang   PetscFunctionBegin;
18069566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1807aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSolveCheck(A));          // for UX = T
1808aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A)); // for U^T Y = B
1809aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1810aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
1811aac854edSJunchao Zhang 
1812aac854edSJunchao Zhang   // Solve U^T Y = B
1813aac854edSJunchao Zhang   if (identity) { // Reorder b with the row permutation
1814aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1815aac854edSJunchao Zhang     Y = factors->workVector;
1816aac854edSJunchao Zhang   } else {
1817aac854edSJunchao Zhang     B = factors->workVector;
1818aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(rowperm(i)); }));
1819aac854edSJunchao Zhang     Y = x;
1820aac854edSJunchao Zhang   }
1821aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, B, Y));
1822aac854edSJunchao Zhang 
1823aac854edSJunchao Zhang   // Solve diag(D) Y' = Y.
1824aac854edSJunchao Zhang   // Actually just do Y' = Y*D since D is already inverted in MatCholeskyFactorNumeric_SeqAIJ(). It is basically a vector element-wise multiplication.
1825aac854edSJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { Y(i) = Y(i) * D(i); }));
1826aac854edSJunchao Zhang 
1827aac854edSJunchao Zhang   // Solve UX = Y
1828aac854edSJunchao Zhang   if (identity) {
1829aac854edSJunchao Zhang     X = x;
1830aac854edSJunchao Zhang   } else {
1831aac854edSJunchao Zhang     X = factors->workVector; // B is not needed anymore
1832aac854edSJunchao Zhang   }
1833aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, Y, X));
1834aac854edSJunchao Zhang 
1835aac854edSJunchao Zhang   // Reorder X with the inverse column (row) permutation
1836aac854edSJunchao Zhang   if (!identity) {
1837aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(rowperm(i)) = X(i); }));
1838aac854edSJunchao Zhang   }
1839aac854edSJunchao Zhang 
1840aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1841aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
18429566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
18433ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
184486a27549SJunchao Zhang }
184586a27549SJunchao Zhang 
1846aac854edSJunchao Zhang // Solve Ax = b, with RAC = LU, where R and C are row and col permutation matrices on A respectively.
1847aac854edSJunchao Zhang // R and C are represented by rowperm and colperm in factors.
1848aac854edSJunchao Zhang // If R or C is identity (i.e, no reordering), then rowperm or colperm is empty.
1849aac854edSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJKokkos_LU(Mat A, Vec bb, Vec xx)
1850d71ae5a4SJacob Faibussowitsch {
1851aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
185286a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1853aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1854aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1855aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1856aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1857aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm      = factors->rowperm;
1858aac854edSJunchao Zhang   PetscIntKokkosView         &colperm      = factors->colperm;
1859aac854edSJunchao Zhang   PetscBool                   row_identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1860aac854edSJunchao Zhang   PetscBool                   col_identity = colperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
186186a27549SJunchao Zhang 
186286a27549SJunchao Zhang   PetscFunctionBegin;
18639566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1864aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSolveCheck(A));
1865aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1866aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
186786a27549SJunchao Zhang 
1868aac854edSJunchao Zhang   // Solve L Y = B (i.e., L (U C^- x) = R b).  R b indicates applying the row permutation on b.
1869aac854edSJunchao Zhang   if (row_identity) {
1870aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1871aac854edSJunchao Zhang     Y = factors->workVector;
1872aac854edSJunchao Zhang   } else {
1873aac854edSJunchao Zhang     B = factors->workVector;
1874aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(rowperm(i)); }));
1875aac854edSJunchao Zhang     Y = x;
1876aac854edSJunchao Zhang   }
1877aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khL, factors->iL_d, factors->jL_d, factors->aL_d, B, Y));
1878aac854edSJunchao Zhang 
1879aac854edSJunchao Zhang   // Solve U C^- x = Y
1880aac854edSJunchao Zhang   if (col_identity) {
1881aac854edSJunchao Zhang     X = x;
1882aac854edSJunchao Zhang   } else {
1883aac854edSJunchao Zhang     X = factors->workVector;
1884aac854edSJunchao Zhang   }
1885aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, Y, X));
1886aac854edSJunchao Zhang 
1887aac854edSJunchao Zhang   // x = C X; Reorder X with the inverse col permutation
1888aac854edSJunchao Zhang   if (!col_identity) {
1889aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(colperm(i)) = X(i); }));
1890aac854edSJunchao Zhang   }
1891aac854edSJunchao Zhang 
1892aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1893aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
18949566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
18953ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
189686a27549SJunchao Zhang }
189786a27549SJunchao Zhang 
1898aac854edSJunchao Zhang // Solve A^T x = b, with RAC = LU, where R and C are row and col permutation matrices on A respectively.
1899aac854edSJunchao Zhang // R and C are represented by rowperm and colperm in factors.
1900aac854edSJunchao Zhang // If R or C is identity (i.e, no reordering), then rowperm or colperm is empty.
1901aac854edSJunchao Zhang // A = R^-1 L U C^-1, so A^T = C^-T U^T L^T R^-T. But since C^- = C^T, R^- = R^T, we have A^T = C U^T L^T R.
1902aac854edSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJKokkos_LU(Mat A, Vec bb, Vec xx)
1903aac854edSJunchao Zhang {
1904aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
1905aac854edSJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1906aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1907aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1908aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1909aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1910aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm      = factors->rowperm;
1911aac854edSJunchao Zhang   PetscIntKokkosView         &colperm      = factors->colperm;
1912aac854edSJunchao Zhang   PetscBool                   row_identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1913aac854edSJunchao Zhang   PetscBool                   col_identity = colperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1914aac854edSJunchao Zhang 
1915aac854edSJunchao Zhang   PetscFunctionBegin;
1916aac854edSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1917aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A)); // Update L^T, U^T if needed, and do sptrsv symbolic for L^T, U^T
1918aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1919aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
1920aac854edSJunchao Zhang 
1921aac854edSJunchao Zhang   // Solve U^T Y = B (i.e., U^T (L^T R x) = C^- b).  Note C^- b = C^T b, which means applying the column permutation on b.
1922aac854edSJunchao Zhang   if (col_identity) { // Reorder b with the col permutation
1923aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1924aac854edSJunchao Zhang     Y = factors->workVector;
1925aac854edSJunchao Zhang   } else {
1926aac854edSJunchao Zhang     B = factors->workVector;
1927aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(colperm(i)); }));
1928aac854edSJunchao Zhang     Y = x;
1929aac854edSJunchao Zhang   }
1930aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, B, Y));
1931aac854edSJunchao Zhang 
1932aac854edSJunchao Zhang   // Solve L^T X = Y
1933aac854edSJunchao Zhang   if (row_identity) {
1934aac854edSJunchao Zhang     X = x;
1935aac854edSJunchao Zhang   } else {
1936aac854edSJunchao Zhang     X = factors->workVector;
1937aac854edSJunchao Zhang   }
1938aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d, Y, X));
1939aac854edSJunchao Zhang 
1940aac854edSJunchao Zhang   // x = R^- X = R^T X; Reorder X with the inverse row permutation
1941aac854edSJunchao Zhang   if (!row_identity) {
1942aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(rowperm(i)) = X(i); }));
1943aac854edSJunchao Zhang   }
1944aac854edSJunchao Zhang 
1945aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1946aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
1947aac854edSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1948aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1949aac854edSJunchao Zhang }
1950aac854edSJunchao Zhang 
1951aac854edSJunchao Zhang static PetscErrorCode MatLUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1952aac854edSJunchao Zhang {
1953aac854edSJunchao Zhang   PetscFunctionBegin;
1954aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncHost(A));
1955aac854edSJunchao Zhang   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
1956aac854edSJunchao Zhang 
1957aac854edSJunchao Zhang   if (!info->solveonhost) { // if solve on host, then we don't need to copy L, U to device
1958aac854edSJunchao Zhang     Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
1959aac854edSJunchao Zhang     Mat_SeqAIJ                 *b       = static_cast<Mat_SeqAIJ *>(B->data);
1960aac854edSJunchao Zhang     const PetscInt             *Bi = b->i, *Bj = b->j, *Bdiag = b->diag;
1961aac854edSJunchao Zhang     const MatScalar            *Ba = b->a;
1962aac854edSJunchao Zhang     PetscInt                    m = B->rmap->n, n = B->cmap->n;
1963aac854edSJunchao Zhang 
1964aac854edSJunchao Zhang     if (factors->iL_h.extent(0) == 0) { // Allocate memory and copy the L, U structure for the first time
1965aac854edSJunchao Zhang       // Allocate memory and copy the structure
1966aac854edSJunchao Zhang       factors->iL_h = MatRowMapKokkosViewHost(NoInit("iL_h"), m + 1);
1967aac854edSJunchao Zhang       factors->jL_h = MatColIdxKokkosViewHost(NoInit("jL_h"), (Bi[m] - Bi[0]) + m); // + the diagonal entries
1968aac854edSJunchao Zhang       factors->aL_h = MatScalarKokkosViewHost(NoInit("aL_h"), (Bi[m] - Bi[0]) + m);
1969aac854edSJunchao Zhang       factors->iU_h = MatRowMapKokkosViewHost(NoInit("iU_h"), m + 1);
1970aac854edSJunchao Zhang       factors->jU_h = MatColIdxKokkosViewHost(NoInit("jU_h"), (Bdiag[0] - Bdiag[m]));
1971aac854edSJunchao Zhang       factors->aU_h = MatScalarKokkosViewHost(NoInit("aU_h"), (Bdiag[0] - Bdiag[m]));
1972aac854edSJunchao Zhang 
1973aac854edSJunchao Zhang       PetscInt *Li = factors->iL_h.data();
1974aac854edSJunchao Zhang       PetscInt *Lj = factors->jL_h.data();
1975aac854edSJunchao Zhang       PetscInt *Ui = factors->iU_h.data();
1976aac854edSJunchao Zhang       PetscInt *Uj = factors->jU_h.data();
1977aac854edSJunchao Zhang 
1978aac854edSJunchao Zhang       Li[0] = Ui[0] = 0;
1979aac854edSJunchao Zhang       for (PetscInt i = 0; i < m; i++) {
1980aac854edSJunchao Zhang         PetscInt llen = Bi[i + 1] - Bi[i];       // exclusive of the diagonal entry
1981aac854edSJunchao Zhang         PetscInt ulen = Bdiag[i] - Bdiag[i + 1]; // inclusive of the diagonal entry
1982aac854edSJunchao Zhang 
1983aac854edSJunchao Zhang         PetscArraycpy(Lj + Li[i], Bj + Bi[i], llen); // entries of L on the left of the diagonal
1984aac854edSJunchao Zhang         Lj[Li[i] + llen] = i;                        // diagonal entry of L
1985aac854edSJunchao Zhang 
1986aac854edSJunchao Zhang         Uj[Ui[i]] = i;                                                  // diagonal entry of U
1987aac854edSJunchao Zhang         PetscArraycpy(Uj + Ui[i] + 1, Bj + Bdiag[i + 1] + 1, ulen - 1); // entries of U on  the right of the diagonal
1988aac854edSJunchao Zhang 
1989aac854edSJunchao Zhang         Li[i + 1] = Li[i] + llen + 1;
1990aac854edSJunchao Zhang         Ui[i + 1] = Ui[i] + ulen;
1991aac854edSJunchao Zhang       }
1992aac854edSJunchao Zhang 
1993aac854edSJunchao Zhang       factors->iL_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iL_h);
1994aac854edSJunchao Zhang       factors->jL_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jL_h);
1995aac854edSJunchao Zhang       factors->iU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iU_h);
1996aac854edSJunchao Zhang       factors->jU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jU_h);
1997aac854edSJunchao Zhang       factors->aL_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aL_h);
1998aac854edSJunchao Zhang       factors->aU_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aU_h);
1999aac854edSJunchao Zhang 
2000aac854edSJunchao Zhang       // Copy row/col permutation to device
2001aac854edSJunchao Zhang       IS        rowperm = ((Mat_SeqAIJ *)B->data)->row;
2002aac854edSJunchao Zhang       PetscBool row_identity;
2003aac854edSJunchao Zhang       PetscCall(ISIdentity(rowperm, &row_identity));
2004aac854edSJunchao Zhang       if (!row_identity) {
2005aac854edSJunchao Zhang         const PetscInt *ip;
2006aac854edSJunchao Zhang 
2007aac854edSJunchao Zhang         PetscCall(ISGetIndices(rowperm, &ip));
2008aac854edSJunchao Zhang         factors->rowperm = PetscIntKokkosView(NoInit("rowperm"), m);
2009aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->rowperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), m)));
2010aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(rowperm, &ip));
2011aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(m * sizeof(PetscInt)));
2012aac854edSJunchao Zhang       }
2013aac854edSJunchao Zhang 
2014aac854edSJunchao Zhang       IS        colperm = ((Mat_SeqAIJ *)B->data)->col;
2015aac854edSJunchao Zhang       PetscBool col_identity;
2016aac854edSJunchao Zhang       PetscCall(ISIdentity(colperm, &col_identity));
2017aac854edSJunchao Zhang       if (!col_identity) {
2018aac854edSJunchao Zhang         const PetscInt *ip;
2019aac854edSJunchao Zhang 
2020aac854edSJunchao Zhang         PetscCall(ISGetIndices(colperm, &ip));
2021aac854edSJunchao Zhang         factors->colperm = PetscIntKokkosView(NoInit("colperm"), n);
2022aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->colperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), n)));
2023aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(colperm, &ip));
2024aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));
2025aac854edSJunchao Zhang       }
2026aac854edSJunchao Zhang 
2027aac854edSJunchao Zhang       /* Create sptrsv handles for L, U and their transpose */
2028aac854edSJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2029aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
2030aac854edSJunchao Zhang #else
2031aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
2032aac854edSJunchao Zhang #endif
2033aac854edSJunchao Zhang       factors->khL.create_sptrsv_handle(sptrsv_alg, m, true /* L is lower tri */);
2034aac854edSJunchao Zhang       factors->khU.create_sptrsv_handle(sptrsv_alg, m, false /* U is not lower tri */);
2035aac854edSJunchao Zhang       factors->khLt.create_sptrsv_handle(sptrsv_alg, m, false /* L^T is not lower tri */);
2036aac854edSJunchao Zhang       factors->khUt.create_sptrsv_handle(sptrsv_alg, m, true /* U^T is lower tri */);
2037aac854edSJunchao Zhang     }
2038aac854edSJunchao Zhang 
2039aac854edSJunchao Zhang     // Copy the value
2040aac854edSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
2041aac854edSJunchao Zhang       PetscInt        llen = Bi[i + 1] - Bi[i];
2042aac854edSJunchao Zhang       PetscInt        ulen = Bdiag[i] - Bdiag[i + 1];
2043aac854edSJunchao Zhang       const PetscInt *Li   = factors->iL_h.data();
2044aac854edSJunchao Zhang       const PetscInt *Ui   = factors->iU_h.data();
2045aac854edSJunchao Zhang 
2046aac854edSJunchao Zhang       PetscScalar *La = factors->aL_h.data();
2047aac854edSJunchao Zhang       PetscScalar *Ua = factors->aU_h.data();
2048aac854edSJunchao Zhang 
2049aac854edSJunchao Zhang       PetscArraycpy(La + Li[i], Ba + Bi[i], llen); // entries of L
2050aac854edSJunchao Zhang       La[Li[i] + llen] = 1.0;                      // diagonal entry
2051aac854edSJunchao Zhang 
2052aac854edSJunchao Zhang       Ua[Ui[i]] = 1.0 / Ba[Bdiag[i]];                                 // diagonal entry
2053aac854edSJunchao Zhang       PetscArraycpy(Ua + Ui[i] + 1, Ba + Bdiag[i + 1] + 1, ulen - 1); // entries of U
2054aac854edSJunchao Zhang     }
2055aac854edSJunchao Zhang 
2056aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aL_d, factors->aL_h));
2057aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aU_d, factors->aU_h));
2058aac854edSJunchao Zhang     // Once the factors' values have changed, we need to update their transpose and redo sptrsv symbolic
2059aac854edSJunchao Zhang     factors->transpose_updated         = PETSC_FALSE;
2060aac854edSJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_FALSE;
2061aac854edSJunchao Zhang 
2062aac854edSJunchao Zhang     B->ops->solve          = MatSolve_SeqAIJKokkos_LU;
2063aac854edSJunchao Zhang     B->ops->solvetranspose = MatSolveTranspose_SeqAIJKokkos_LU;
2064aac854edSJunchao Zhang   }
2065aac854edSJunchao Zhang 
2066aac854edSJunchao Zhang   B->ops->matsolve          = NULL;
2067aac854edSJunchao Zhang   B->ops->matsolvetranspose = NULL;
2068aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2069aac854edSJunchao Zhang }
2070aac854edSJunchao Zhang 
2071aac854edSJunchao Zhang static PetscErrorCode MatILUFactorNumeric_SeqAIJKokkos_ILU0(Mat B, Mat A, const MatFactorInfo *info)
2072d71ae5a4SJacob Faibussowitsch {
207386a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
207486a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
207586a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
207686a27549SJunchao Zhang 
207786a27549SJunchao Zhang   PetscFunctionBegin;
20789566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
2079aac854edSJunchao Zhang   PetscCheck(!info->factoronhost, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "MatFactorInfo.factoronhost should be false");
20809566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
2081076ba34aSJunchao Zhang 
2082076ba34aSJunchao Zhang   auto a_d = aijkok->a_dual.view_device();
2083076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
2084076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
2085076ba34aSJunchao Zhang 
2086aac854edSJunchao Zhang   PetscCallCXX(spiluk_numeric(&factors->kh, fill_lev, i_d, j_d, a_d, factors->iL_d, factors->jL_d, factors->aL_d, factors->iU_d, factors->jU_d, factors->aU_d));
208786a27549SJunchao Zhang 
208886a27549SJunchao Zhang   B->assembled              = PETSC_TRUE;
208986a27549SJunchao Zhang   B->preallocated           = PETSC_TRUE;
2090aac854edSJunchao Zhang   B->ops->solve             = MatSolve_SeqAIJKokkos_LU;
2091aac854edSJunchao Zhang   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJKokkos_LU;
209286a27549SJunchao Zhang   B->ops->matsolve          = NULL;
209386a27549SJunchao Zhang   B->ops->matsolvetranspose = NULL;
209486a27549SJunchao Zhang 
209586a27549SJunchao Zhang   /* Once the factors' value changed, we need to update their transpose and sptrsv handle */
209686a27549SJunchao Zhang   factors->transpose_updated         = PETSC_FALSE;
209786a27549SJunchao Zhang   factors->sptrsv_symbolic_completed = PETSC_FALSE;
2098eeadb341SJunchao Zhang   /* TODO: log flops, but how to know that? */
20999566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
21003ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
210186a27549SJunchao Zhang }
210286a27549SJunchao Zhang 
2103aac854edSJunchao Zhang // Use KK's spiluk_symbolic() to do ILU0 symbolic factorization, with no row/col reordering
2104aac854edSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos_ILU0(Mat B, Mat A, IS, IS, const MatFactorInfo *info)
2105d71ae5a4SJacob Faibussowitsch {
210686a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok;
210786a27549SJunchao Zhang   Mat_SeqAIJ                 *b;
210886a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
210986a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
211086a27549SJunchao Zhang   PetscInt                    nnzA     = ((Mat_SeqAIJ *)A->data)->nz, nnzL, nnzU;
211186a27549SJunchao Zhang   PetscInt                    n        = A->rmap->n;
211286a27549SJunchao Zhang 
211386a27549SJunchao Zhang   PetscFunctionBegin;
2114aac854edSJunchao Zhang   PetscCheck(!info->factoronhost, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "MatFactorInfo's factoronhost should be false as we are doing it on device right now");
21159566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
211686a27549SJunchao Zhang 
211786a27549SJunchao Zhang   /* Create a spiluk handle and then do symbolic factorization */
211886a27549SJunchao Zhang   nnzL = nnzU = PetscRealIntMultTruncate(info->fill, nnzA);
2119aac854edSJunchao Zhang   factors->kh.create_spiluk_handle(SPILUKAlgorithm::SEQLVLSCHD_TP1, n, nnzL, nnzU);
212086a27549SJunchao Zhang 
212186a27549SJunchao Zhang   auto spiluk_handle = factors->kh.get_spiluk_handle();
212286a27549SJunchao Zhang 
212386a27549SJunchao Zhang   Kokkos::realloc(factors->iL_d, n + 1); /* Free old arrays and realloc */
212486a27549SJunchao Zhang   Kokkos::realloc(factors->jL_d, spiluk_handle->get_nnzL());
212586a27549SJunchao Zhang   Kokkos::realloc(factors->iU_d, n + 1);
212686a27549SJunchao Zhang   Kokkos::realloc(factors->jU_d, spiluk_handle->get_nnzU());
212786a27549SJunchao Zhang 
212886a27549SJunchao Zhang   aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
2129076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
2130076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
2131aac854edSJunchao Zhang   PetscCallCXX(spiluk_symbolic(&factors->kh, fill_lev, i_d, j_d, factors->iL_d, factors->jL_d, factors->iU_d, factors->jU_d));
213286a27549SJunchao Zhang   /* TODO: if spiluk_symbolic is asynchronous, do we need to sync before calling get_nnzL()? */
213386a27549SJunchao Zhang 
213486a27549SJunchao Zhang   Kokkos::resize(factors->jL_d, spiluk_handle->get_nnzL()); /* Shrink or expand, and retain old value */
213586a27549SJunchao Zhang   Kokkos::resize(factors->jU_d, spiluk_handle->get_nnzU());
213686a27549SJunchao Zhang   Kokkos::realloc(factors->aL_d, spiluk_handle->get_nnzL()); /* No need to retain old value */
213786a27549SJunchao Zhang   Kokkos::realloc(factors->aU_d, spiluk_handle->get_nnzU());
213886a27549SJunchao Zhang 
213986a27549SJunchao Zhang   /* TODO: add options to select sptrsv algorithms */
214086a27549SJunchao Zhang   /* Create sptrsv handles for L, U and their transpose */
214186a27549SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2142aac854edSJunchao Zhang   auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
214386a27549SJunchao Zhang #else
2144aac854edSJunchao Zhang   auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
214586a27549SJunchao Zhang #endif
214686a27549SJunchao Zhang 
214786a27549SJunchao Zhang   factors->khL.create_sptrsv_handle(sptrsv_alg, n, true /* L is lower tri */);
214886a27549SJunchao Zhang   factors->khU.create_sptrsv_handle(sptrsv_alg, n, false /* U is not lower tri */);
214986a27549SJunchao Zhang   factors->khLt.create_sptrsv_handle(sptrsv_alg, n, false /* L^T is not lower tri */);
215086a27549SJunchao Zhang   factors->khUt.create_sptrsv_handle(sptrsv_alg, n, true /* U^T is lower tri */);
215186a27549SJunchao Zhang 
215286a27549SJunchao Zhang   /* Fill fields of the factor matrix B */
21539566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(B, MAT_SKIP_ALLOCATION, NULL));
215486a27549SJunchao Zhang   b     = (Mat_SeqAIJ *)B->data;
215586a27549SJunchao Zhang   b->nz = b->maxnz          = spiluk_handle->get_nnzL() + spiluk_handle->get_nnzU();
215686a27549SJunchao Zhang   B->info.fill_ratio_given  = info->fill;
2157a1e4e190SStefano Zampini   B->info.fill_ratio_needed = nnzA > 0 ? ((PetscReal)b->nz) / ((PetscReal)nnzA) : 1.0;
215886a27549SJunchao Zhang 
2159aac854edSJunchao Zhang   B->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJKokkos_ILU0;
21603ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2161930e68a5SMark Adams }
2162930e68a5SMark Adams 
2163aac854edSJunchao Zhang static PetscErrorCode MatLUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
2164aac854edSJunchao Zhang {
2165aac854edSJunchao Zhang   PetscFunctionBegin;
2166aac854edSJunchao Zhang   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
2167aac854edSJunchao Zhang   PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2168aac854edSJunchao Zhang   PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2169aac854edSJunchao Zhang   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
2170aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2171aac854edSJunchao Zhang }
2172aac854edSJunchao Zhang 
2173aac854edSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
2174aac854edSJunchao Zhang {
2175aac854edSJunchao Zhang   PetscBool row_identity = PETSC_FALSE, col_identity = PETSC_FALSE;
2176aac854edSJunchao Zhang 
2177aac854edSJunchao Zhang   PetscFunctionBegin;
2178aac854edSJunchao Zhang   if (!info->factoronhost) {
2179aac854edSJunchao Zhang     PetscCall(ISIdentity(isrow, &row_identity));
2180aac854edSJunchao Zhang     PetscCall(ISIdentity(iscol, &col_identity));
2181aac854edSJunchao Zhang   }
2182aac854edSJunchao Zhang 
2183aac854edSJunchao Zhang   PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2184aac854edSJunchao Zhang   PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2185aac854edSJunchao Zhang 
2186aac854edSJunchao Zhang   if (!info->factoronhost && !info->levels && row_identity && col_identity) { // if level 0 and no reordering
2187aac854edSJunchao Zhang     PetscCall(MatILUFactorSymbolic_SeqAIJKokkos_ILU0(B, A, isrow, iscol, info));
2188aac854edSJunchao Zhang   } else {
2189aac854edSJunchao Zhang     PetscCall(MatILUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info)); // otherwise, use PETSc's ILU on host
2190aac854edSJunchao Zhang     B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
2191aac854edSJunchao Zhang   }
2192aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2193aac854edSJunchao Zhang }
2194aac854edSJunchao Zhang 
2195aac854edSJunchao Zhang static PetscErrorCode MatCholeskyFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
2196aac854edSJunchao Zhang {
2197aac854edSJunchao Zhang   PetscFunctionBegin;
2198aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncHost(A));
2199aac854edSJunchao Zhang   PetscCall(MatCholeskyFactorNumeric_SeqAIJ(B, A, info));
2200aac854edSJunchao Zhang 
2201aac854edSJunchao Zhang   if (!info->solveonhost) { // if solve on host, then we don't need to copy L, U to device
2202aac854edSJunchao Zhang     Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
2203aac854edSJunchao Zhang     Mat_SeqAIJ                 *b       = static_cast<Mat_SeqAIJ *>(B->data);
2204aac854edSJunchao Zhang     const PetscInt             *Bi = b->i, *Bj = b->j, *Bdiag = b->diag;
2205aac854edSJunchao Zhang     const MatScalar            *Ba = b->a;
2206aac854edSJunchao Zhang     PetscInt                    m  = B->rmap->n;
2207aac854edSJunchao Zhang 
2208aac854edSJunchao Zhang     if (factors->iU_h.extent(0) == 0) { // First time of numeric factorization
2209aac854edSJunchao Zhang       // Allocate memory and copy the structure
2210aac854edSJunchao Zhang       factors->iU_h = PetscIntKokkosViewHost(const_cast<PetscInt *>(Bi), m + 1); // wrap Bi as iU_h
2211aac854edSJunchao Zhang       factors->jU_h = MatColIdxKokkosViewHost(NoInit("jU_h"), Bi[m]);
2212aac854edSJunchao Zhang       factors->aU_h = MatScalarKokkosViewHost(NoInit("aU_h"), Bi[m]);
2213aac854edSJunchao Zhang       factors->D_h  = MatScalarKokkosViewHost(NoInit("D_h"), m);
2214aac854edSJunchao Zhang       factors->aU_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aU_h);
2215aac854edSJunchao Zhang       factors->D_d  = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->D_h);
2216aac854edSJunchao Zhang 
2217aac854edSJunchao Zhang       // Build jU_h from the skewed Aj
2218aac854edSJunchao Zhang       PetscInt *Uj = factors->jU_h.data();
2219aac854edSJunchao Zhang       for (PetscInt i = 0; i < m; i++) {
2220aac854edSJunchao Zhang         PetscInt ulen = Bi[i + 1] - Bi[i];
2221aac854edSJunchao Zhang         Uj[Bi[i]]     = i;                                              // diagonal entry
2222aac854edSJunchao Zhang         PetscCall(PetscArraycpy(Uj + Bi[i] + 1, Bj + Bi[i], ulen - 1)); // entries of U on the right of the diagonal
2223aac854edSJunchao Zhang       }
2224aac854edSJunchao Zhang 
2225aac854edSJunchao Zhang       // Copy iU, jU to device
2226aac854edSJunchao Zhang       PetscCallCXX(factors->iU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iU_h));
2227aac854edSJunchao Zhang       PetscCallCXX(factors->jU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jU_h));
2228aac854edSJunchao Zhang 
2229aac854edSJunchao Zhang       // Copy row/col permutation to device
2230aac854edSJunchao Zhang       IS        rowperm = ((Mat_SeqAIJ *)B->data)->row;
2231aac854edSJunchao Zhang       PetscBool row_identity;
2232aac854edSJunchao Zhang       PetscCall(ISIdentity(rowperm, &row_identity));
2233aac854edSJunchao Zhang       if (!row_identity) {
2234aac854edSJunchao Zhang         const PetscInt *ip;
2235aac854edSJunchao Zhang 
2236aac854edSJunchao Zhang         PetscCall(ISGetIndices(rowperm, &ip));
2237aac854edSJunchao Zhang         PetscCallCXX(factors->rowperm = PetscIntKokkosView(NoInit("rowperm"), m));
2238aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->rowperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), m)));
2239aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(rowperm, &ip));
2240aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(m * sizeof(PetscInt)));
2241aac854edSJunchao Zhang       }
2242aac854edSJunchao Zhang 
2243aac854edSJunchao Zhang       // Create sptrsv handles for U and U^T
2244aac854edSJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2245aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
2246aac854edSJunchao Zhang #else
2247aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
2248aac854edSJunchao Zhang #endif
2249aac854edSJunchao Zhang       factors->khU.create_sptrsv_handle(sptrsv_alg, m, false /* U is not lower tri */);
2250aac854edSJunchao Zhang       factors->khUt.create_sptrsv_handle(sptrsv_alg, m, true /* U^T is lower tri */);
2251aac854edSJunchao Zhang     }
2252aac854edSJunchao Zhang     // These pointers were set MatCholeskyFactorNumeric_SeqAIJ(), so we always need to update them
2253aac854edSJunchao Zhang     B->ops->solve          = MatSolve_SeqAIJKokkos_Cholesky;
2254aac854edSJunchao Zhang     B->ops->solvetranspose = MatSolve_SeqAIJKokkos_Cholesky;
2255aac854edSJunchao Zhang 
2256aac854edSJunchao Zhang     // Copy the value
2257aac854edSJunchao Zhang     PetscScalar *Ua = factors->aU_h.data();
2258aac854edSJunchao Zhang     PetscScalar *D  = factors->D_h.data();
2259aac854edSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
2260aac854edSJunchao Zhang       D[i]      = Ba[Bdiag[i]];     // actually Aa[Adiag[i]] is the inverse of the diagonal
2261aac854edSJunchao Zhang       Ua[Bi[i]] = (PetscScalar)1.0; // set the unit diagonal for U
2262aac854edSJunchao Zhang       for (PetscInt k = 0; k < Bi[i + 1] - Bi[i] - 1; k++) Ua[Bi[i] + 1 + k] = -Ba[Bi[i] + k];
2263aac854edSJunchao Zhang     }
2264aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aU_d, factors->aU_h));
2265aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->D_d, factors->D_h));
2266aac854edSJunchao Zhang 
2267aac854edSJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_FALSE; // When numeric value changed, we must do these again
2268aac854edSJunchao Zhang     factors->transpose_updated         = PETSC_FALSE;
2269aac854edSJunchao Zhang   }
2270aac854edSJunchao Zhang 
2271aac854edSJunchao Zhang   B->ops->matsolve          = NULL;
2272aac854edSJunchao Zhang   B->ops->matsolvetranspose = NULL;
2273aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2274aac854edSJunchao Zhang }
2275aac854edSJunchao Zhang 
2276aac854edSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS perm, const MatFactorInfo *info)
2277aac854edSJunchao Zhang {
2278aac854edSJunchao Zhang   PetscFunctionBegin;
2279aac854edSJunchao Zhang   if (info->solveonhost) {
2280aac854edSJunchao Zhang     // If solve on host, we have to change the type, as eventually we need to call MatSolve_SeqSBAIJ_1_NaturalOrdering() etc.
2281aac854edSJunchao Zhang     PetscCall(MatSetType(B, MATSEQSBAIJ));
2282aac854edSJunchao Zhang     PetscCall(MatSeqSBAIJSetPreallocation(B, 1, MAT_SKIP_ALLOCATION, NULL));
2283aac854edSJunchao Zhang   }
2284aac854edSJunchao Zhang 
2285aac854edSJunchao Zhang   PetscCall(MatICCFactorSymbolic_SeqAIJ(B, A, perm, info));
2286aac854edSJunchao Zhang 
2287aac854edSJunchao Zhang   if (!info->solveonhost) {
2288bfe80ac4SPierre Jolivet     // If solve on device, B is still a MATSEQAIJKOKKOS, so we are good to allocate B->spptr
2289aac854edSJunchao Zhang     PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2290aac854edSJunchao Zhang     PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2291aac854edSJunchao Zhang     B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJKokkos;
2292aac854edSJunchao Zhang   }
2293aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2294aac854edSJunchao Zhang }
2295aac854edSJunchao Zhang 
2296aac854edSJunchao Zhang static PetscErrorCode MatCholeskyFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS perm, const MatFactorInfo *info)
2297aac854edSJunchao Zhang {
2298aac854edSJunchao Zhang   PetscFunctionBegin;
2299aac854edSJunchao Zhang   if (info->solveonhost) {
2300aac854edSJunchao Zhang     // If solve on host, we have to change the type, as eventually we need to call MatSolve_SeqSBAIJ_1_NaturalOrdering() etc.
2301aac854edSJunchao Zhang     PetscCall(MatSetType(B, MATSEQSBAIJ));
2302aac854edSJunchao Zhang     PetscCall(MatSeqSBAIJSetPreallocation(B, 1, MAT_SKIP_ALLOCATION, NULL));
2303aac854edSJunchao Zhang   }
2304aac854edSJunchao Zhang 
2305aac854edSJunchao Zhang   PetscCall(MatCholeskyFactorSymbolic_SeqAIJ(B, A, perm, info)); // it sets B's two ISes ((Mat_SeqAIJ*)B->data)->{row, col} to perm
2306aac854edSJunchao Zhang 
2307aac854edSJunchao Zhang   if (!info->solveonhost) {
2308bfe80ac4SPierre Jolivet     // If solve on device, B is still a MATSEQAIJKOKKOS, so we are good to allocate B->spptr
2309aac854edSJunchao Zhang     PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2310aac854edSJunchao Zhang     PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2311aac854edSJunchao Zhang     B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJKokkos;
2312aac854edSJunchao Zhang   }
2313aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2314aac854edSJunchao Zhang }
2315aac854edSJunchao Zhang 
2316aac854edSJunchao Zhang // The _Kokkos suffix means we will use Kokkos as a solver for the SeqAIJKokkos matrix
2317aac854edSJunchao Zhang static PetscErrorCode MatFactorGetSolverType_SeqAIJKokkos_Kokkos(Mat A, MatSolverType *type)
2318d71ae5a4SJacob Faibussowitsch {
2319930e68a5SMark Adams   PetscFunctionBegin;
2320930e68a5SMark Adams   *type = MATSOLVERKOKKOS;
23213ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2322930e68a5SMark Adams }
2323930e68a5SMark Adams 
2324930e68a5SMark Adams /*MC
232586a27549SJunchao Zhang   MATSOLVERKOKKOS = "Kokkos" - A matrix solver type providing triangular solvers for sequential matrices
232611a5261eSBarry Smith   on a single GPU of type, `MATSEQAIJKOKKOS`, `MATAIJKOKKOS`.
2327930e68a5SMark Adams 
2328930e68a5SMark Adams   Level: beginner
2329930e68a5SMark Adams 
23301cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJKokkos()`, `MATAIJKOKKOS`, `MatKokkosSetFormat()`, `MatKokkosStorageFormat`, `MatKokkosFormatOperation`
2331930e68a5SMark Adams M*/
233286a27549SJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_SeqAIJKokkos_Kokkos(Mat A, MatFactorType ftype, Mat *B) /* MatGetFactor_<MatType>_<MatSolverType> */
2333930e68a5SMark Adams {
2334930e68a5SMark Adams   PetscInt n = A->rmap->n;
2335aac854edSJunchao Zhang   MPI_Comm comm;
2336930e68a5SMark Adams 
2337930e68a5SMark Adams   PetscFunctionBegin;
2338aac854edSJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
2339aac854edSJunchao Zhang   PetscCall(MatCreate(comm, B));
23409566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*B, n, n, n, n));
2341aac854edSJunchao Zhang   PetscCall(MatSetBlockSizesFromMats(*B, A, A));
2342930e68a5SMark Adams   (*B)->factortype = ftype;
23439566063dSJacob Faibussowitsch   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
23449566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
2345aac854edSJunchao Zhang   PetscCheck(!(*B)->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2346aac854edSJunchao Zhang 
2347aac854edSJunchao Zhang   if (ftype == MAT_FACTOR_LU || ftype == MAT_FACTOR_ILU || ftype == MAT_FACTOR_ILUDT) {
2348aac854edSJunchao Zhang     (*B)->ops->lufactorsymbolic  = MatLUFactorSymbolic_SeqAIJKokkos;
2349aac854edSJunchao Zhang     (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJKokkos;
2350aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
2351aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILU]));
2352aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILUDT]));
2353aac854edSJunchao Zhang   } else if (ftype == MAT_FACTOR_CHOLESKY || ftype == MAT_FACTOR_ICC) {
2354aac854edSJunchao Zhang     (*B)->ops->iccfactorsymbolic      = MatICCFactorSymbolic_SeqAIJKokkos;
2355aac854edSJunchao Zhang     (*B)->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqAIJKokkos;
2356aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_CHOLESKY]));
2357aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ICC]));
2358aac854edSJunchao Zhang   } else SETERRQ(comm, PETSC_ERR_SUP, "MatFactorType %s is not supported by MatType SeqAIJKokkos", MatFactorTypes[ftype]);
2359aac854edSJunchao Zhang 
2360aac854edSJunchao Zhang   // The factorization can use the ordering provided in MatLUFactorSymbolic(), MatCholeskyFactorSymbolic() etc, though we do it on host
2361aac854edSJunchao Zhang   (*B)->canuseordering = PETSC_TRUE;
2362aac854edSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)*B, "MatFactorGetSolverType_C", MatFactorGetSolverType_SeqAIJKokkos_Kokkos));
23633ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2364930e68a5SMark Adams }
23658f7e8f9dSMark Adams 
2366aac854edSJunchao Zhang PETSC_INTERN PetscErrorCode MatSolverTypeRegister_Kokkos(void)
2367d71ae5a4SJacob Faibussowitsch {
236886a27549SJunchao Zhang   PetscFunctionBegin;
23699566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_SeqAIJKokkos_Kokkos));
2370aac854edSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_CHOLESKY, MatGetFactor_SeqAIJKokkos_Kokkos));
23719566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ILU, MatGetFactor_SeqAIJKokkos_Kokkos));
2372aac854edSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ICC, MatGetFactor_SeqAIJKokkos_Kokkos));
23733ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
237486a27549SJunchao Zhang }
237586a27549SJunchao Zhang 
2376076ba34aSJunchao Zhang /* Utility to print out a KokkosCsrMatrix for debugging */
2377d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PrintCsrMatrix(const KokkosCsrMatrix &csrmat)
2378d71ae5a4SJacob Faibussowitsch {
237945402d8aSJunchao Zhang   const auto        &iv = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.graph.row_map);
238045402d8aSJunchao Zhang   const auto        &jv = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.graph.entries);
238145402d8aSJunchao Zhang   const auto        &av = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.values);
2382076ba34aSJunchao Zhang   const PetscInt    *i  = iv.data();
2383076ba34aSJunchao Zhang   const PetscInt    *j  = jv.data();
2384076ba34aSJunchao Zhang   const PetscScalar *a  = av.data();
2385076ba34aSJunchao Zhang   PetscInt           m = csrmat.numRows(), n = csrmat.numCols(), nnz = csrmat.nnz();
2386076ba34aSJunchao Zhang 
2387076ba34aSJunchao Zhang   PetscFunctionBegin;
23889566063dSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT " x %" PetscInt_FMT " SeqAIJKokkos, with %" PetscInt_FMT " nonzeros\n", m, n, nnz));
2389076ba34aSJunchao Zhang   for (PetscInt k = 0; k < m; k++) {
23909566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT ": ", k));
239148a46eb9SPierre Jolivet     for (PetscInt p = i[k]; p < i[k + 1]; p++) PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT "(%.1f), ", j[p], (double)PetscRealPart(a[p])));
23929566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "\n"));
2393076ba34aSJunchao Zhang   }
23943ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2395076ba34aSJunchao Zhang }
2396