xref: /petsc/src/mat/impls/aij/seq/kokkos/aijkok.kokkos.cxx (revision 03db1824571927542203d3572ba93afdd6bab0ec)
1e36ced11SJunchao Zhang #include <petsc_kokkos.hpp>
211d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
3c0c276a7Ssdargavi #include <petscmat_kokkos.hpp>
4076ba34aSJunchao Zhang #include <petscpkg_version.h>
5152b3e56SJunchao Zhang #include <petsc/private/petscimpl.h>
642550becSJunchao Zhang #include <petsc/private/sfimpl.h>
7aac854edSJunchao Zhang #include <petsc/private/kokkosimpl.hpp>
87233ce55SJed Brown #include <petscsys.h>
98c3ff71bSJunchao Zhang 
108c3ff71bSJunchao Zhang #include <Kokkos_Core.hpp>
11f0cf5187SStefano Zampini #include <KokkosBlas.hpp>
128c3ff71bSJunchao Zhang #include <KokkosSparse_CrsMatrix.hpp>
13cc6e31f1SJunchao Zhang 
14cc6e31f1SJunchao Zhang // To suppress compiler warnings:
15cc6e31f1SJunchao Zhang // /path/include/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp:434:63:
16cc6e31f1SJunchao Zhang // warning: 'cusparseStatus_t cusparseDbsrmm(cusparseHandle_t, cusparseDirection_t, cusparseOperation_t,
17cc6e31f1SJunchao Zhang // cusparseOperation_t, int, int, int, int, const double*, cusparseMatDescr_t, const double*, const int*, const int*,
18cc6e31f1SJunchao Zhang // int, const double*, int, const double*, double*, int)' is deprecated: please use cusparseSpMM instead [-Wdeprecated-declarations]
19cc6e31f1SJunchao Zhang PETSC_PRAGMA_DIAGNOSTIC_IGNORED_BEGIN("-Wdeprecated-declarations")
208c3ff71bSJunchao Zhang #include <KokkosSparse_spmv.hpp>
21cc6e31f1SJunchao Zhang PETSC_PRAGMA_DIAGNOSTIC_IGNORED_END()
22cc6e31f1SJunchao Zhang 
2386a27549SJunchao Zhang #include <KokkosSparse_spiluk.hpp>
2486a27549SJunchao Zhang #include <KokkosSparse_sptrsv.hpp>
25076ba34aSJunchao Zhang #include <KokkosSparse_spgemm.hpp>
26076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
279d13fa56SJunchao Zhang #include <KokkosBatched_LU_Decl.hpp>
289d13fa56SJunchao Zhang #include <KokkosBatched_InverseLU_Decl.hpp>
2986a27549SJunchao Zhang 
3042550becSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
318c3ff71bSJunchao Zhang 
320e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 7, 0)
33f98996d3SJunchao Zhang   #include <KokkosSparse_Utils.hpp>
34f98996d3SJunchao Zhang using KokkosSparse::sort_crs_matrix;
359371c9d4SSatish Balay using KokkosSparse::Impl::transpose_matrix;
36f98996d3SJunchao Zhang #else
37f98996d3SJunchao Zhang   #include <KokkosKernels_Sorting.hpp>
38f98996d3SJunchao Zhang using KokkosKernels::sort_crs_matrix;
399371c9d4SSatish Balay using KokkosKernels::Impl::transpose_matrix;
40f98996d3SJunchao Zhang #endif
41f98996d3SJunchao Zhang 
42aac854edSJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(4, 6, 0)
43aac854edSJunchao Zhang using KokkosSparse::spiluk_symbolic;
44aac854edSJunchao Zhang using KokkosSparse::spiluk_numeric;
45aac854edSJunchao Zhang using KokkosSparse::sptrsv_symbolic;
46aac854edSJunchao Zhang using KokkosSparse::sptrsv_solve;
47aac854edSJunchao Zhang using KokkosSparse::Experimental::SPTRSVAlgorithm;
48aac854edSJunchao Zhang using KokkosSparse::Experimental::SPILUKAlgorithm;
49aac854edSJunchao Zhang #else
50aac854edSJunchao Zhang using KokkosSparse::Experimental::spiluk_symbolic;
51aac854edSJunchao Zhang using KokkosSparse::Experimental::spiluk_numeric;
52aac854edSJunchao Zhang using KokkosSparse::Experimental::sptrsv_symbolic;
53aac854edSJunchao Zhang using KokkosSparse::Experimental::sptrsv_solve;
54aac854edSJunchao Zhang using KokkosSparse::Experimental::SPTRSVAlgorithm;
55aac854edSJunchao Zhang using KokkosSparse::Experimental::SPILUKAlgorithm;
56aac854edSJunchao Zhang #endif
57aac854edSJunchao Zhang 
588c3ff71bSJunchao Zhang static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat); /* Forward declaration */
598c3ff71bSJunchao Zhang 
60076ba34aSJunchao Zhang /* MatAssemblyEnd_SeqAIJKokkos() happens when we finalized nonzeros of the matrix, either after
61076ba34aSJunchao Zhang    we assembled the matrix on host, or after we directly produced the matrix data on device (ex., through MatMatMult).
62076ba34aSJunchao Zhang    In the latter case, it is important to set a_dual's sync state correctly.
63076ba34aSJunchao Zhang  */
64d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_SeqAIJKokkos(Mat A, MatAssemblyType mode)
65d71ae5a4SJacob Faibussowitsch {
66076ba34aSJunchao Zhang   Mat_SeqAIJ       *aijseq;
67076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
688c3ff71bSJunchao Zhang 
698c3ff71bSJunchao Zhang   PetscFunctionBegin;
703ba16761SJacob Faibussowitsch   if (mode == MAT_FLUSH_ASSEMBLY) PetscFunctionReturn(PETSC_SUCCESS);
719566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
72076ba34aSJunchao Zhang 
73076ba34aSJunchao Zhang   aijseq = static_cast<Mat_SeqAIJ *>(A->data);
74076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
75076ba34aSJunchao Zhang 
76076ba34aSJunchao Zhang   /* If aijkok does not exist, we just copy i, j to device.
77076ba34aSJunchao Zhang      If aijkok already exists, but the device's nonzero pattern does not match with the host's, we assume the latest data is on host.
78076ba34aSJunchao Zhang      In both cases, we build a new aijkok structure.
79076ba34aSJunchao Zhang   */
80076ba34aSJunchao Zhang   if (!aijkok || aijkok->nonzerostate != A->nonzerostate) { /* aijkok might not exist yet or nonzero pattern has changed */
81d1c799ffSJunchao Zhang     if (aijkok && aijkok->host_aij_allocated_by_kokkos) {   /* Avoid accidently freeing much needed a,i,j on host when deleting aijkok */
82d1c799ffSJunchao Zhang       PetscCall(PetscShmgetAllocateArray(aijkok->nrows() + 1, sizeof(PetscInt), (void **)&aijseq->i));
83d1c799ffSJunchao Zhang       PetscCall(PetscShmgetAllocateArray(aijkok->nnz(), sizeof(PetscInt), (void **)&aijseq->j));
84d1c799ffSJunchao Zhang       PetscCall(PetscShmgetAllocateArray(aijkok->nnz(), sizeof(PetscInt), (void **)&aijseq->a));
85d1c799ffSJunchao Zhang       PetscCall(PetscArraycpy(aijseq->i, aijkok->i_host_data(), aijkok->nrows() + 1));
86d1c799ffSJunchao Zhang       PetscCall(PetscArraycpy(aijseq->j, aijkok->j_host_data(), aijkok->nnz()));
87d1c799ffSJunchao Zhang       PetscCall(PetscArraycpy(aijseq->a, aijkok->a_host_data(), aijkok->nnz()));
88d1c799ffSJunchao Zhang       aijseq->free_a  = PETSC_TRUE;
89d1c799ffSJunchao Zhang       aijseq->free_ij = PETSC_TRUE;
90d1c799ffSJunchao Zhang       /* This arises from MatCreateSeqAIJKokkosWithKokkosCsrMatrix() used in MatMatMult, where
91d1c799ffSJunchao Zhang          we have the CsrMatrix on device first and then copy to host, followed by
92d1c799ffSJunchao Zhang          MatSetMPIAIJWithSplitSeqAIJ() with garray = NULL.
93d1c799ffSJunchao Zhang          One could improve it by not using NULL garray.
94d1c799ffSJunchao Zhang       */
95d1c799ffSJunchao Zhang     }
96076ba34aSJunchao Zhang     delete aijkok;
97f4747e26SJunchao Zhang     aijkok   = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aijseq, A->nonzerostate, PETSC_FALSE /*don't copy mat values to device*/);
98076ba34aSJunchao Zhang     A->spptr = aijkok;
99f4747e26SJunchao Zhang   } else if (A->rmap->n && aijkok->diag_dual.extent(0) == 0) { // MatProduct might directly produce AIJ on device, but not the diag.
100f4747e26SJunchao Zhang     MatRowMapKokkosViewHost diag_h(aijseq->diag, A->rmap->n);
101f4747e26SJunchao Zhang     auto                    diag_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), diag_h);
102f4747e26SJunchao Zhang     aijkok->diag_dual              = MatRowMapKokkosDualView(diag_d, diag_h);
103076ba34aSJunchao Zhang   }
1043ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1058c3ff71bSJunchao Zhang }
1068c3ff71bSJunchao Zhang 
10786a27549SJunchao Zhang /* Sync CSR data to device if not yet */
108d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosSyncDevice(Mat A)
109d71ae5a4SJacob Faibussowitsch {
1108c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1118c3ff71bSJunchao Zhang 
1128c3ff71bSJunchao Zhang   PetscFunctionBegin;
113aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from host to device");
1145f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
115076ba34aSJunchao Zhang   if (aijkok->a_dual.need_sync_device()) {
116f3d3cd90SZach Atkins     PetscCall(KokkosDualViewSyncDevice(aijkok->a_dual, PetscGetKokkosExecutionSpace()));
117580c7c76SPierre Jolivet     aijkok->transpose_updated = PETSC_FALSE; /* values of the transpose is out-of-date */
11886a27549SJunchao Zhang     aijkok->hermitian_updated = PETSC_FALSE;
1198c3ff71bSJunchao Zhang   }
1203ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1218c3ff71bSJunchao Zhang }
1228c3ff71bSJunchao Zhang 
123076ba34aSJunchao Zhang /* Mark the CSR data on device as modified */
124d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosModifyDevice(Mat A)
125d71ae5a4SJacob Faibussowitsch {
12686a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
12786a27549SJunchao Zhang 
12886a27549SJunchao Zhang   PetscFunctionBegin;
1295f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Not supported for factorized matries");
13086a27549SJunchao Zhang   aijkok->a_dual.clear_sync_state();
13186a27549SJunchao Zhang   aijkok->a_dual.modify_device();
13286a27549SJunchao Zhang   aijkok->transpose_updated = PETSC_FALSE;
13386a27549SJunchao Zhang   aijkok->hermitian_updated = PETSC_FALSE;
1349566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJInvalidateDiagonal(A));
1359566063dSJacob Faibussowitsch   PetscCall(PetscObjectStateIncrease((PetscObject)A));
1363ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
13786a27549SJunchao Zhang }
13886a27549SJunchao Zhang 
139d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSyncHost(Mat A)
140d71ae5a4SJacob Faibussowitsch {
141f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1424df4a32cSJunchao Zhang   auto              exec   = PetscGetKokkosExecutionSpace();
143f0cf5187SStefano Zampini 
144f0cf5187SStefano Zampini   PetscFunctionBegin;
145f0cf5187SStefano Zampini   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
14686a27549SJunchao Zhang   /* We do not expect one needs factors on host  */
147aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from device to host");
1485f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing AIJKOK");
149f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncHost(aijkok->a_dual, exec));
1503ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
151f0cf5187SStefano Zampini }
152f0cf5187SStefano Zampini 
153d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
154d71ae5a4SJacob Faibussowitsch {
155076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
156f0cf5187SStefano Zampini 
157f0cf5187SStefano Zampini   PetscFunctionBegin;
1585519a089SJose E. Roman   /* aijkok contains valid pointers only if the host's nonzerostate matches with the device's.
1595519a089SJose E. Roman     Calling MatSeqAIJSetPreallocation() or MatSetValues() on host, where aijseq->{i,j,a} might be
1605519a089SJose E. Roman     reallocated, will lead to stale {i,j,a}_dual in aijkok. In both operations, the hosts's nonzerostate
1615519a089SJose E. Roman     must have been updated. The stale aijkok will be rebuilt during MatAssemblyEnd.
1625519a089SJose E. Roman   */
1635519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
164f3d3cd90SZach Atkins     PetscCall(KokkosDualViewSyncHost(aijkok->a_dual, PetscGetKokkosExecutionSpace()));
165076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
166076ba34aSJunchao Zhang   } else { /* Happens when calling MatSetValues on a newly created matrix */
167076ba34aSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
168076ba34aSJunchao Zhang   }
1693ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
170076ba34aSJunchao Zhang }
171076ba34aSJunchao Zhang 
172d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
173d71ae5a4SJacob Faibussowitsch {
174076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
175076ba34aSJunchao Zhang 
176076ba34aSJunchao Zhang   PetscFunctionBegin;
1775519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) aijkok->a_dual.modify_host();
1783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
179076ba34aSJunchao Zhang }
180076ba34aSJunchao Zhang 
181d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
182d71ae5a4SJacob Faibussowitsch {
183076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
184076ba34aSJunchao Zhang 
185076ba34aSJunchao Zhang   PetscFunctionBegin;
1865519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
187f3d3cd90SZach Atkins     PetscCall(KokkosDualViewSyncHost(aijkok->a_dual, PetscGetKokkosExecutionSpace()));
188076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1892328674fSJunchao Zhang   } else {
1902328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1912328674fSJunchao Zhang   }
1923ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
193076ba34aSJunchao Zhang }
194076ba34aSJunchao Zhang 
195d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
196d71ae5a4SJacob Faibussowitsch {
197076ba34aSJunchao Zhang   PetscFunctionBegin;
198076ba34aSJunchao Zhang   *array = NULL;
1993ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
200076ba34aSJunchao Zhang }
201076ba34aSJunchao Zhang 
202d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
203d71ae5a4SJacob Faibussowitsch {
204076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
205076ba34aSJunchao Zhang 
206076ba34aSJunchao Zhang   PetscFunctionBegin;
2075519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
208076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
2092328674fSJunchao Zhang   } else { /* Ex. happens with MatZeroEntries on a preallocated but not assembled matrix */
2102328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
2112328674fSJunchao Zhang   }
2123ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
213076ba34aSJunchao Zhang }
214076ba34aSJunchao Zhang 
215d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
216d71ae5a4SJacob Faibussowitsch {
217076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
218076ba34aSJunchao Zhang 
219076ba34aSJunchao Zhang   PetscFunctionBegin;
2205519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
221076ba34aSJunchao Zhang     aijkok->a_dual.clear_sync_state();
222076ba34aSJunchao Zhang     aijkok->a_dual.modify_host();
2232328674fSJunchao Zhang   }
2243ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
225f0cf5187SStefano Zampini }
226f0cf5187SStefano Zampini 
227d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJKokkos(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
228d71ae5a4SJacob Faibussowitsch {
2297ee59b9bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
2307ee59b9bSJunchao Zhang 
2317ee59b9bSJunchao Zhang   PetscFunctionBegin;
2327ee59b9bSJunchao Zhang   PetscCheck(aijkok != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "aijkok is NULL");
2337ee59b9bSJunchao Zhang 
2347ee59b9bSJunchao Zhang   if (i) *i = aijkok->i_device_data();
2357ee59b9bSJunchao Zhang   if (j) *j = aijkok->j_device_data();
2367ee59b9bSJunchao Zhang   if (a) {
237f3d3cd90SZach Atkins     PetscCall(KokkosDualViewSyncDevice(aijkok->a_dual, PetscGetKokkosExecutionSpace()));
2387ee59b9bSJunchao Zhang     *a = aijkok->a_device_data();
2397ee59b9bSJunchao Zhang   }
2407ee59b9bSJunchao Zhang   if (mtype) *mtype = PETSC_MEMTYPE_KOKKOS;
2413ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2427ee59b9bSJunchao Zhang }
2437ee59b9bSJunchao Zhang 
244*03db1824SAlex Lindsay static PetscErrorCode MatGetCurrentMemType_SeqAIJKokkos(PETSC_UNUSED Mat A, PetscMemType *mtype)
245*03db1824SAlex Lindsay {
246*03db1824SAlex Lindsay   PetscFunctionBegin;
247*03db1824SAlex Lindsay   *mtype = PETSC_MEMTYPE_KOKKOS;
248*03db1824SAlex Lindsay   PetscFunctionReturn(PETSC_SUCCESS);
249*03db1824SAlex Lindsay }
250*03db1824SAlex Lindsay 
2510e3ece09SJunchao Zhang /*
2520e3ece09SJunchao Zhang   Generate the sparsity pattern of a MatSeqAIJKokkos matrix's transpose on device.
2530e3ece09SJunchao Zhang 
2540e3ece09SJunchao Zhang   Input Parameter:
2550e3ece09SJunchao Zhang .  A       - the MATSEQAIJKOKKOS matrix
2560e3ece09SJunchao Zhang 
2570e3ece09SJunchao Zhang   Output Parameters:
2580e3ece09SJunchao Zhang +  perm_d - the permutation array on device, which connects Ta(i) = Aa(perm(i))
259aaa8cc7dSPierre Jolivet -  T_d    - the transpose on device, whose value array is allocated but not initialized
2600e3ece09SJunchao Zhang */
2610e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateTransposeStructure(Mat A, MatRowMapKokkosView &perm_d, KokkosCsrMatrix &T_d)
262d71ae5a4SJacob Faibussowitsch {
2630e3ece09SJunchao Zhang   Mat_SeqAIJ             *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2640e3ece09SJunchao Zhang   PetscInt                nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
2650e3ece09SJunchao Zhang   const PetscInt         *Ai = aseq->i, *Aj = aseq->j;
2667b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost Ti_h(NoInit("Ti"), n + 1);
2670e3ece09SJunchao Zhang   MatRowMapType          *Ti = Ti_h.data();
2687b8d4ba6SJunchao Zhang   MatColIdxKokkosViewHost Tj_h(NoInit("Tj"), nz);
2697b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost perm_h(NoInit("permutation"), nz);
2700e3ece09SJunchao Zhang   PetscInt               *Tj   = Tj_h.data();
2710e3ece09SJunchao Zhang   PetscInt               *perm = perm_h.data();
2720e3ece09SJunchao Zhang   PetscInt               *offset;
273152b3e56SJunchao Zhang 
274152b3e56SJunchao Zhang   PetscFunctionBegin;
2750e3ece09SJunchao Zhang   // Populate Ti
2760e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::deep_copy(Ti_h, 0));
2770e3ece09SJunchao Zhang   Ti++;
2780e3ece09SJunchao Zhang   for (PetscInt i = 0; i < nz; i++) Ti[Aj[i]]++;
2790e3ece09SJunchao Zhang   Ti--;
2800e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) Ti[i + 1] += Ti[i];
2810e3ece09SJunchao Zhang 
2820e3ece09SJunchao Zhang   // Populate Tj and the permutation array
2830e3ece09SJunchao Zhang   PetscCall(PetscCalloc1(n, &offset)); // offset in each T row to fill in its column indices
2840e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
2850e3ece09SJunchao Zhang     for (PetscInt j = Ai[i]; j < Ai[i + 1]; j++) { // A's (i,j) is T's (j,i)
2860e3ece09SJunchao Zhang       PetscInt r    = Aj[j];                       // row r of T
2870e3ece09SJunchao Zhang       PetscInt disp = Ti[r] + offset[r];
2880e3ece09SJunchao Zhang 
2890e3ece09SJunchao Zhang       Tj[disp]   = i; // col i of T
2900e3ece09SJunchao Zhang       perm[disp] = j;
2910e3ece09SJunchao Zhang       offset[r]++;
292076ba34aSJunchao Zhang     }
2930e3ece09SJunchao Zhang   }
2940e3ece09SJunchao Zhang   PetscCall(PetscFree(offset));
2950e3ece09SJunchao Zhang 
2960e3ece09SJunchao Zhang   // Sort each row of T, along with the permutation array
2970e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) PetscCall(PetscSortIntWithArray(Ti[i + 1] - Ti[i], Tj + Ti[i], perm + Ti[i]));
2980e3ece09SJunchao Zhang 
2990e3ece09SJunchao Zhang   // Output perm and T on device
3000e3ece09SJunchao Zhang   auto Ti_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ti_h);
3010e3ece09SJunchao Zhang   auto Tj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Tj_h);
3020e3ece09SJunchao Zhang   PetscCallCXX(T_d = KokkosCsrMatrix("csrmatT", n, m, nz, MatScalarKokkosView("Ta", nz), Ti_d, Tj_d));
3030e3ece09SJunchao Zhang   PetscCallCXX(perm_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), perm_h));
3043ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
305152b3e56SJunchao Zhang }
306152b3e56SJunchao Zhang 
3070e3ece09SJunchao Zhang // Generate the transpose on device and cache it internally
3080e3ece09SJunchao Zhang // Note: KK transpose_matrix() does not have support symbolic/numeric transpose, so we do it on our own
3090e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGenerateTranspose_Private(Mat A, KokkosCsrMatrix *csrmatT)
310d71ae5a4SJacob Faibussowitsch {
3110e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
3120e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3130e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
3140e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatT;
315152b3e56SJunchao Zhang 
316152b3e56SJunchao Zhang   PetscFunctionBegin;
3170e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
318f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncDevice(akok->a_dual, PetscGetKokkosExecutionSpace())); // Sync A's values since we are going to access them on device
3190e3ece09SJunchao Zhang 
3200e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
3210e3ece09SJunchao Zhang 
3220e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE) {
3230e3ece09SJunchao Zhang     *csrmatT = akok->csrmat;
3240e3ece09SJunchao Zhang   } else {
3250e3ece09SJunchao Zhang     // See if we already have a cached transpose and its value is up to date
3260e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
3270e3ece09SJunchao Zhang       if (!akok->transpose_updated) {            // if the value is out of date, update the cached version
3280e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
3290e3ece09SJunchao Zhang         auto       &Ta   = T.values;
3300e3ece09SJunchao Zhang 
331d326c3f1SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = Aa(perm(i)); }));
332076ba34aSJunchao Zhang       }
3330e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
3340e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
3350e3ece09SJunchao Zhang 
3360e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
3370e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
338d326c3f1SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = Aa(perm(i)); }));
3390e3ece09SJunchao Zhang     }
3400e3ece09SJunchao Zhang     akok->transpose_updated = PETSC_TRUE;
3410e3ece09SJunchao Zhang     *csrmatT                = akok->csrmatT;
3420e3ece09SJunchao Zhang   }
3430e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3440e3ece09SJunchao Zhang }
3450e3ece09SJunchao Zhang 
3460e3ece09SJunchao Zhang // Generate the Hermitian on device and cache it internally
3470e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateHermitian_Private(Mat A, KokkosCsrMatrix *csrmatH)
3480e3ece09SJunchao Zhang {
3490e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
3500e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3510e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
3520e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatH;
3530e3ece09SJunchao Zhang 
3540e3ece09SJunchao Zhang   PetscFunctionBegin;
3550e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
356f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncDevice(akok->a_dual, PetscGetKokkosExecutionSpace())); // Sync A's values since we are going to access them on device
3570e3ece09SJunchao Zhang 
3580e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
3590e3ece09SJunchao Zhang 
3600e3ece09SJunchao Zhang   if (A->hermitian == PETSC_BOOL3_TRUE) {
3610e3ece09SJunchao Zhang     *csrmatH = akok->csrmat;
3620e3ece09SJunchao Zhang   } else {
3630e3ece09SJunchao Zhang     // See if we already have a cached hermitian and its value is up to date
3640e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
3650e3ece09SJunchao Zhang       if (!akok->hermitian_updated) {            // if the value is out of date, update the cached version
3660e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
3670e3ece09SJunchao Zhang         auto       &Ta   = T.values;
3680e3ece09SJunchao Zhang 
369d326c3f1SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = PetscConj(Aa(perm(i))); }));
3700e3ece09SJunchao Zhang       }
3710e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
3720e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
3730e3ece09SJunchao Zhang 
3740e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
3750e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
376d326c3f1SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = PetscConj(Aa(perm(i))); }));
3770e3ece09SJunchao Zhang     }
3780e3ece09SJunchao Zhang     akok->hermitian_updated = PETSC_TRUE;
3790e3ece09SJunchao Zhang     *csrmatH                = akok->csrmatH;
3800e3ece09SJunchao Zhang   }
3813ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
382152b3e56SJunchao Zhang }
383a587d139SMark 
3848c3ff71bSJunchao Zhang /* y = A x */
385d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMult_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
386d71ae5a4SJacob Faibussowitsch {
3878c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
388152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
389152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3908c3ff71bSJunchao Zhang 
3918c3ff71bSJunchao Zhang   PetscFunctionBegin;
3929566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3939566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3949566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3959566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
3968c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
397d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), "N", 1.0 /*alpha*/, aijkok->csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A x + beta y */
3989566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3999566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
400076ba34aSJunchao Zhang   /* 2.0*nnz - numRows seems more accurate here but assumes there are no zero-rows. So a little sloppy here. */
4019566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
4029566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4033ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4048c3ff71bSJunchao Zhang }
4058c3ff71bSJunchao Zhang 
4068c3ff71bSJunchao Zhang /* y = A^T x */
407d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
408d71ae5a4SJacob Faibussowitsch {
4098c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
410152b3e56SJunchao Zhang   const char                *mode;
411152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
412152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
4130e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4148c3ff71bSJunchao Zhang 
4158c3ff71bSJunchao Zhang   PetscFunctionBegin;
4169566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4179566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4189566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4199566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
420152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4219566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
422152b3e56SJunchao Zhang     mode = "N";
423152b3e56SJunchao Zhang   } else {
424076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4250e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
426152b3e56SJunchao Zhang     mode   = "T";
427152b3e56SJunchao Zhang   }
428d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^T x + beta y */
4299566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4309566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
4310e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4329566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4333ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4348c3ff71bSJunchao Zhang }
4358c3ff71bSJunchao Zhang 
4368c3ff71bSJunchao Zhang /* y = A^H x */
437d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
438d71ae5a4SJacob Faibussowitsch {
4398c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
440152b3e56SJunchao Zhang   const char                *mode;
441152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
442152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
4430e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4448c3ff71bSJunchao Zhang 
4458c3ff71bSJunchao Zhang   PetscFunctionBegin;
4469566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4479566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4489566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4499566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
450152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4519566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
452152b3e56SJunchao Zhang     mode = "N";
453152b3e56SJunchao Zhang   } else {
454076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4550e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
456152b3e56SJunchao Zhang     mode   = "C";
457152b3e56SJunchao Zhang   }
458d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^H x + beta y */
4599566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4609566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
4610e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4629566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4633ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4648c3ff71bSJunchao Zhang }
4658c3ff71bSJunchao Zhang 
4668c3ff71bSJunchao Zhang /* z = A x + y */
467d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
468d71ae5a4SJacob Faibussowitsch {
4698c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
47092896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
471152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4728c3ff71bSJunchao Zhang 
4738c3ff71bSJunchao Zhang   PetscFunctionBegin;
4749566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4759566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
47692896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz)); // depending on yy's sync flags, zz might get its latest data on host
4779566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
47892896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv)); // do after VecCopy(yy, zz) to get the latest data on device
4798c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
480d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), "N", 1.0 /*alpha*/, aijkok->csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A x + beta z */
4819566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
48292896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
4839566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
4849566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4853ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4868c3ff71bSJunchao Zhang }
4878c3ff71bSJunchao Zhang 
4888c3ff71bSJunchao Zhang /* z = A^T x + y */
489d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
490d71ae5a4SJacob Faibussowitsch {
4918c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
492152b3e56SJunchao Zhang   const char                *mode;
49392896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
494152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4950e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4968c3ff71bSJunchao Zhang 
4978c3ff71bSJunchao Zhang   PetscFunctionBegin;
4989566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4999566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
50092896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz));
5019566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
50292896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv));
503152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
5049566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
505152b3e56SJunchao Zhang     mode = "N";
506152b3e56SJunchao Zhang   } else {
507076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
5080e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
509152b3e56SJunchao Zhang     mode   = "T";
510152b3e56SJunchao Zhang   }
511d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^T x + beta z */
5129566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
51392896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
5140e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
5159566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
5163ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5178c3ff71bSJunchao Zhang }
5188c3ff71bSJunchao Zhang 
5198c3ff71bSJunchao Zhang /* z = A^H x + y */
520d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
521d71ae5a4SJacob Faibussowitsch {
5228c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
523152b3e56SJunchao Zhang   const char                *mode;
52492896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
525152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
5260e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
5278c3ff71bSJunchao Zhang 
5288c3ff71bSJunchao Zhang   PetscFunctionBegin;
5299566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
5309566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
53192896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz));
5329566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
53392896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv));
534152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
5359566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
536152b3e56SJunchao Zhang     mode = "N";
537152b3e56SJunchao Zhang   } else {
538076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
5390e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
540152b3e56SJunchao Zhang     mode   = "C";
541152b3e56SJunchao Zhang   }
542d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^H x + beta z */
5439566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
54492896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
5450e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
5469566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
5473ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
548152b3e56SJunchao Zhang }
549152b3e56SJunchao Zhang 
55066976f2fSJacob Faibussowitsch static PetscErrorCode MatSetOption_SeqAIJKokkos(Mat A, MatOption op, PetscBool flg)
551d71ae5a4SJacob Faibussowitsch {
552152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
553152b3e56SJunchao Zhang 
554152b3e56SJunchao Zhang   PetscFunctionBegin;
555152b3e56SJunchao Zhang   switch (op) {
556152b3e56SJunchao Zhang   case MAT_FORM_EXPLICIT_TRANSPOSE:
557152b3e56SJunchao Zhang     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
5589566063dSJacob Faibussowitsch     if (A->form_explicit_transpose && !flg && aijkok) PetscCall(aijkok->DestroyMatTranspose());
559152b3e56SJunchao Zhang     A->form_explicit_transpose = flg;
560152b3e56SJunchao Zhang     break;
561d71ae5a4SJacob Faibussowitsch   default:
562d71ae5a4SJacob Faibussowitsch     PetscCall(MatSetOption_SeqAIJ(A, op, flg));
563d71ae5a4SJacob Faibussowitsch     break;
564152b3e56SJunchao Zhang   }
5653ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5668c3ff71bSJunchao Zhang }
5678c3ff71bSJunchao Zhang 
568076ba34aSJunchao Zhang /* Depending on reuse, either build a new mat, or use the existing mat */
569d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
570d71ae5a4SJacob Faibussowitsch {
571076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
5728c3ff71bSJunchao Zhang 
5738c3ff71bSJunchao Zhang   PetscFunctionBegin;
5749566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
575076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) { /* Build a brand new mat */
57651ece73cSJunchao Zhang     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
57751ece73cSJunchao Zhang     PetscCall(MatSetType(*newmat, mtype));
5788c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {                 /* Reuse the mat created before */
5799566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); /* newmat is already a SeqAIJKokkos */
580076ba34aSJunchao Zhang   } else if (reuse == MAT_INPLACE_MATRIX) {               /* newmat is A */
5815f80ce2aSJacob Faibussowitsch     PetscCheck(A == *newmat, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "A != *newmat with MAT_INPLACE_MATRIX");
5829566063dSJacob Faibussowitsch     PetscCall(PetscFree(A->defaultvectype));
5839566063dSJacob Faibussowitsch     PetscCall(PetscStrallocpy(VECKOKKOS, &A->defaultvectype)); /* Allocate and copy the string */
5849566063dSJacob Faibussowitsch     PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSEQAIJKOKKOS));
5859566063dSJacob Faibussowitsch     PetscCall(MatSetOps_SeqAIJKokkos(A));
586076ba34aSJunchao Zhang     aseq = static_cast<Mat_SeqAIJ *>(A->data);
587394ed5ebSJunchao Zhang     if (A->assembled) { /* Copy i, j (but not values) to device for an assembled matrix if not yet */
5885f80ce2aSJacob Faibussowitsch       PetscCheck(!A->spptr, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Expect NULL (Mat_SeqAIJKokkos*)A->spptr");
589f4747e26SJunchao Zhang       A->spptr = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aseq, A->nonzerostate, PETSC_FALSE);
5908c3ff71bSJunchao Zhang     }
591076ba34aSJunchao Zhang   }
5923ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5938c3ff71bSJunchao Zhang }
5948c3ff71bSJunchao Zhang 
595076ba34aSJunchao Zhang /* MatDuplicate always creates a new matrix. MatDuplicate can be called either on an assembled matrix or
596076ba34aSJunchao Zhang    an unassembled matrix, even though MAT_COPY_VALUES is not allowed for unassembled matrix.
597076ba34aSJunchao Zhang  */
598d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDuplicate_SeqAIJKokkos(Mat A, MatDuplicateOption dupOption, Mat *B)
599d71ae5a4SJacob Faibussowitsch {
600076ba34aSJunchao Zhang   Mat_SeqAIJ       *bseq;
601076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok;
602076ba34aSJunchao Zhang   Mat               mat;
6038c3ff71bSJunchao Zhang 
6048c3ff71bSJunchao Zhang   PetscFunctionBegin;
605076ba34aSJunchao Zhang   /* Do not copy values on host as A's latest values might be on device. We don't want to do sync blindly */
6069566063dSJacob Faibussowitsch   PetscCall(MatDuplicate_SeqAIJ(A, MAT_DO_NOT_COPY_VALUES, B));
607076ba34aSJunchao Zhang   mat = *B;
608f4747e26SJunchao Zhang   if (A->assembled) {
609076ba34aSJunchao Zhang     bseq = static_cast<Mat_SeqAIJ *>(mat->data);
610f4747e26SJunchao Zhang     bkok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, bseq, mat->nonzerostate, PETSC_FALSE);
611076ba34aSJunchao Zhang     bkok->a_dual.clear_sync_state(); /* Clear B's sync state as it will be decided below */
612076ba34aSJunchao Zhang     /* Now copy values to B if needed */
613076ba34aSJunchao Zhang     if (dupOption == MAT_COPY_VALUES) {
614076ba34aSJunchao Zhang       if (akok->a_dual.need_sync_device()) {
615076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_host(), akok->a_dual.view_host());
616076ba34aSJunchao Zhang         bkok->a_dual.modify_host();
617076ba34aSJunchao Zhang       } else { /* If device has the latest data, we only copy data on device */
618076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_device(), akok->a_dual.view_device());
619076ba34aSJunchao Zhang         bkok->a_dual.modify_device();
620076ba34aSJunchao Zhang       }
621076ba34aSJunchao Zhang     } else { /* MAT_DO_NOT_COPY_VALUES or MAT_SHARE_NONZERO_PATTERN. B's values should be zeroed */
622076ba34aSJunchao Zhang       /* B's values on host should be already zeroed by MatDuplicate_SeqAIJ() */
623076ba34aSJunchao Zhang       bkok->a_dual.modify_host();
624076ba34aSJunchao Zhang     }
625076ba34aSJunchao Zhang     mat->spptr = bkok;
626076ba34aSJunchao Zhang   }
627076ba34aSJunchao Zhang 
6289566063dSJacob Faibussowitsch   PetscCall(PetscFree(mat->defaultvectype));
6299566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &mat->defaultvectype)); /* Allocate and copy the string */
6309566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)mat, MATSEQAIJKOKKOS));
6319566063dSJacob Faibussowitsch   PetscCall(MatSetOps_SeqAIJKokkos(mat));
6323ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6338c3ff71bSJunchao Zhang }
6348c3ff71bSJunchao Zhang 
635d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatTranspose_SeqAIJKokkos(Mat A, MatReuse reuse, Mat *B)
636d71ae5a4SJacob Faibussowitsch {
6370ecb592aSJunchao Zhang   Mat               At;
6380e3ece09SJunchao Zhang   KokkosCsrMatrix   internT;
6390ecb592aSJunchao Zhang   Mat_SeqAIJKokkos *atkok, *bkok;
6400ecb592aSJunchao Zhang 
6410ecb592aSJunchao Zhang   PetscFunctionBegin;
6427fb60732SBarry Smith   if (reuse == MAT_REUSE_MATRIX) PetscCall(MatTransposeCheckNonzeroState_Private(A, *B));
6439566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &internT)); /* Generate a transpose internally */
6440ecb592aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX || reuse == MAT_INPLACE_MATRIX) {
645ff751488SJunchao Zhang     /* Deep copy internT, as we want to isolate the internal transpose */
6460e3ece09SJunchao Zhang     PetscCallCXX(atkok = new Mat_SeqAIJKokkos(KokkosCsrMatrix("csrmat", internT)));
6479566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PetscObjectComm((PetscObject)A), atkok, &At));
6480ecb592aSJunchao Zhang     if (reuse == MAT_INITIAL_MATRIX) *B = At;
6499566063dSJacob Faibussowitsch     else PetscCall(MatHeaderReplace(A, &At)); /* Replace A with At inplace */
6500ecb592aSJunchao Zhang   } else {                                    /* MAT_REUSE_MATRIX, just need to copy values to B on device */
6510ecb592aSJunchao Zhang     if ((*B)->assembled) {
6520ecb592aSJunchao Zhang       bkok = static_cast<Mat_SeqAIJKokkos *>((*B)->spptr);
6530e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(bkok->a_dual.view_device(), internT.values));
6549566063dSJacob Faibussowitsch       PetscCall(MatSeqAIJKokkosModifyDevice(*B));
6550ecb592aSJunchao Zhang     } else if ((*B)->preallocated) { /* It is ok for B to be only preallocated, as needed in MatTranspose_MPIAIJ */
6560ecb592aSJunchao Zhang       Mat_SeqAIJ             *bseq = static_cast<Mat_SeqAIJ *>((*B)->data);
6570e3ece09SJunchao Zhang       MatScalarKokkosViewHost a_h(bseq->a, internT.nnz()); /* bseq->nz = 0 if unassembled */
6580e3ece09SJunchao Zhang       MatColIdxKokkosViewHost j_h(bseq->j, internT.nnz());
6590e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(a_h, internT.values));
6600e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(j_h, internT.graph.entries));
6610ecb592aSJunchao Zhang     } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "B must be assembled or preallocated");
6620ecb592aSJunchao Zhang   }
6633ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6640ecb592aSJunchao Zhang }
6650ecb592aSJunchao Zhang 
666d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDestroy_SeqAIJKokkos(Mat A)
667d71ae5a4SJacob Faibussowitsch {
66886a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
6698c3ff71bSJunchao Zhang 
6708c3ff71bSJunchao Zhang   PetscFunctionBegin;
67186a27549SJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
67286a27549SJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
6738c3ff71bSJunchao Zhang     delete aijkok;
67486a27549SJunchao Zhang   } else {
67586a27549SJunchao Zhang     delete static_cast<Mat_SeqAIJKokkosTriFactors *>(A->spptr);
67686a27549SJunchao Zhang   }
677cbc6b225SStefano Zampini   A->spptr = NULL;
6789566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
6799566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
6809566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
68157761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
68257761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijkokkos_hypre_C", NULL));
68357761e9aSJunchao Zhang #endif
6849566063dSJacob Faibussowitsch   PetscCall(MatDestroy_SeqAIJ(A));
6853ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6868c3ff71bSJunchao Zhang }
6878c3ff71bSJunchao Zhang 
6883f3ba80aSJunchao Zhang /*MC
6893f3ba80aSJunchao Zhang    MATSEQAIJKOKKOS - MATAIJKOKKOS = "(seq)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
6903f3ba80aSJunchao Zhang 
69115229ffcSPierre Jolivet    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
6923f3ba80aSJunchao Zhang 
6932ef1f0ffSBarry Smith    Options Database Key:
69411a5261eSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATSEQAIJKOKKOS` during a call to `MatSetFromOptions()`
6953f3ba80aSJunchao Zhang 
6963f3ba80aSJunchao Zhang   Level: beginner
6973f3ba80aSJunchao Zhang 
6981cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateSeqAIJKokkos()`, `MATMPIAIJKOKKOS`
6993f3ba80aSJunchao Zhang M*/
700d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJKokkos(Mat A)
701d71ae5a4SJacob Faibussowitsch {
70286a27549SJunchao Zhang   PetscFunctionBegin;
7039566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
7049566063dSJacob Faibussowitsch   PetscCall(MatCreate_SeqAIJ(A));
7059566063dSJacob Faibussowitsch   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
7063ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
70786a27549SJunchao Zhang }
70886a27549SJunchao Zhang 
709076ba34aSJunchao Zhang /* Merge A, B into a matrix C. A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n) */
710d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
711d71ae5a4SJacob Faibussowitsch {
712076ba34aSJunchao Zhang   Mat_SeqAIJ         *a, *b;
713076ba34aSJunchao Zhang   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
714076ba34aSJunchao Zhang   MatScalarKokkosView aa, ba, ca;
715076ba34aSJunchao Zhang   MatRowMapKokkosView ai, bi, ci;
716076ba34aSJunchao Zhang   MatColIdxKokkosView aj, bj, cj;
717076ba34aSJunchao Zhang   PetscInt            m, n, nnz, aN;
718a3f881fbSStefano Zampini 
719a3f881fbSStefano Zampini   PetscFunctionBegin;
720076ba34aSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
721076ba34aSJunchao Zhang   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
7224f572ea9SToby Isaac   PetscAssertPointer(C, 4);
723076ba34aSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
724076ba34aSJunchao Zhang   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
7255f80ce2aSJacob Faibussowitsch   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
7265f80ce2aSJacob Faibussowitsch   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
727076ba34aSJunchao Zhang 
7289566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
7299566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
730076ba34aSJunchao Zhang   a    = static_cast<Mat_SeqAIJ *>(A->data);
731076ba34aSJunchao Zhang   b    = static_cast<Mat_SeqAIJ *>(B->data);
732076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
733076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
734076ba34aSJunchao Zhang   aa   = akok->a_dual.view_device();
735076ba34aSJunchao Zhang   ai   = akok->i_dual.view_device();
736076ba34aSJunchao Zhang   ba   = bkok->a_dual.view_device();
737076ba34aSJunchao Zhang   bi   = bkok->i_dual.view_device();
738076ba34aSJunchao Zhang   m    = A->rmap->n; /* M, N and nnz of C */
739076ba34aSJunchao Zhang   n    = A->cmap->n + B->cmap->n;
740076ba34aSJunchao Zhang   nnz  = a->nz + b->nz;
741076ba34aSJunchao Zhang   aN   = A->cmap->n; /* N of A */
742076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
743076ba34aSJunchao Zhang     aj           = akok->j_dual.view_device();
744076ba34aSJunchao Zhang     bj           = bkok->j_dual.view_device();
745076ba34aSJunchao Zhang     auto ca_dual = MatScalarKokkosDualView("a", aa.extent(0) + ba.extent(0));
746076ba34aSJunchao Zhang     auto ci_dual = MatRowMapKokkosDualView("i", ai.extent(0));
747076ba34aSJunchao Zhang     auto cj_dual = MatColIdxKokkosDualView("j", aj.extent(0) + bj.extent(0));
748076ba34aSJunchao Zhang     ca           = ca_dual.view_device();
749076ba34aSJunchao Zhang     ci           = ci_dual.view_device();
750076ba34aSJunchao Zhang     cj           = cj_dual.view_device();
751076ba34aSJunchao Zhang 
752076ba34aSJunchao Zhang     /* Concatenate A and B in parallel using Kokkos hierarchical parallelism */
7539371c9d4SSatish Balay     Kokkos::parallel_for(
754d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
755076ba34aSJunchao Zhang         PetscInt i       = t.league_rank(); /* row i */
756076ba34aSJunchao Zhang         PetscInt coffset = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
757076ba34aSJunchao Zhang 
758076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* this side effect only happens once per whole team */
759076ba34aSJunchao Zhang                                                    ci(i) = coffset;
760076ba34aSJunchao Zhang                                                    if (i == m - 1) ci(m) = ai(m) + bi(m);
761076ba34aSJunchao Zhang         });
762076ba34aSJunchao Zhang 
763076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
764076ba34aSJunchao Zhang           if (k < alen) {
765076ba34aSJunchao Zhang             ca(coffset + k) = aa(ai(i) + k);
766076ba34aSJunchao Zhang             cj(coffset + k) = aj(ai(i) + k);
767076ba34aSJunchao Zhang           } else {
768076ba34aSJunchao Zhang             ca(coffset + k) = ba(bi(i) + k - alen);
769076ba34aSJunchao Zhang             cj(coffset + k) = bj(bi(i) + k - alen) + aN; /* Entries in B get new column indices in C */
770076ba34aSJunchao Zhang           }
771076ba34aSJunchao Zhang         });
772076ba34aSJunchao Zhang       });
773076ba34aSJunchao Zhang     ca_dual.modify_device();
774076ba34aSJunchao Zhang     ci_dual.modify_device();
775076ba34aSJunchao Zhang     cj_dual.modify_device();
7769566063dSJacob Faibussowitsch     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, n, nnz, ci_dual, cj_dual, ca_dual));
7779566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
778076ba34aSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
779076ba34aSJunchao Zhang     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
780076ba34aSJunchao Zhang     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
781076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
782076ba34aSJunchao Zhang     ca   = ckok->a_dual.view_device();
783076ba34aSJunchao Zhang     ci   = ckok->i_dual.view_device();
784076ba34aSJunchao Zhang 
7859371c9d4SSatish Balay     Kokkos::parallel_for(
786d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
787076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* row i */
788076ba34aSJunchao Zhang         PetscInt alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
789076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
790076ba34aSJunchao Zhang           if (k < alen) ca(ci(i) + k) = aa(ai(i) + k);
791076ba34aSJunchao Zhang           else ca(ci(i) + k) = ba(bi(i) + k - alen);
792076ba34aSJunchao Zhang         });
793076ba34aSJunchao Zhang       });
7949566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
795076ba34aSJunchao Zhang   }
7963ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
797076ba34aSJunchao Zhang }
798076ba34aSJunchao Zhang 
799d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_SeqAIJKokkos(void *pdata)
800d71ae5a4SJacob Faibussowitsch {
801076ba34aSJunchao Zhang   PetscFunctionBegin;
802076ba34aSJunchao Zhang   delete static_cast<MatProductData_SeqAIJKokkos *>(pdata);
8033ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
804a3f881fbSStefano Zampini }
805a3f881fbSStefano Zampini 
806d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos(Mat C)
807d71ae5a4SJacob Faibussowitsch {
808a3f881fbSStefano Zampini   Mat_Product                 *product = C->product;
809a3f881fbSStefano Zampini   Mat                          A, B;
810076ba34aSJunchao Zhang   bool                         transA, transB; /* use bool, since KK needs this type */
811a3f881fbSStefano Zampini   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
812a3f881fbSStefano Zampini   Mat_SeqAIJ                  *c;
813076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
8140e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB;
815a3f881fbSStefano Zampini 
816a3f881fbSStefano Zampini   PetscFunctionBegin;
817a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
8185f80ce2aSJacob Faibussowitsch   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data empty");
819076ba34aSJunchao Zhang   pdata = static_cast<MatProductData_SeqAIJKokkos *>(C->product->data);
820076ba34aSJunchao Zhang 
8210e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
8220e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
8230e3ece09SJunchao Zhang   // we still do numeric.
8240e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
8250e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
8263ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
827076ba34aSJunchao Zhang   }
828076ba34aSJunchao Zhang 
829076ba34aSJunchao Zhang   switch (product->type) {
8309371c9d4SSatish Balay   case MATPRODUCT_AB:
8319371c9d4SSatish Balay     transA = false;
8329371c9d4SSatish Balay     transB = false;
8339371c9d4SSatish Balay     break;
8349371c9d4SSatish Balay   case MATPRODUCT_AtB:
8359371c9d4SSatish Balay     transA = true;
8369371c9d4SSatish Balay     transB = false;
8379371c9d4SSatish Balay     break;
8389371c9d4SSatish Balay   case MATPRODUCT_ABt:
8399371c9d4SSatish Balay     transA = false;
8409371c9d4SSatish Balay     transB = true;
8419371c9d4SSatish Balay     break;
842d71ae5a4SJacob Faibussowitsch   default:
843d71ae5a4SJacob Faibussowitsch     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
844076ba34aSJunchao Zhang   }
845076ba34aSJunchao Zhang 
846a3f881fbSStefano Zampini   A = product->A;
847a3f881fbSStefano Zampini   B = product->B;
8489566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8499566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
850a3f881fbSStefano Zampini   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
851a3f881fbSStefano Zampini   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
852a3f881fbSStefano Zampini   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
853076ba34aSJunchao Zhang 
8545f80ce2aSJacob Faibussowitsch   PetscCheck(ckok, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Device data structure spptr is empty");
855076ba34aSJunchao Zhang 
8560e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
8570e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
858076ba34aSJunchao Zhang 
859076ba34aSJunchao Zhang   /* TODO: Once KK spgemm implements transpose, we can get rid of the explicit transpose here */
860076ba34aSJunchao Zhang   if (transA) {
8619566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
862076ba34aSJunchao Zhang     transA = false;
863a3f881fbSStefano Zampini   }
864a3f881fbSStefano Zampini 
865076ba34aSJunchao Zhang   if (transB) {
8669566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
867076ba34aSJunchao Zhang     transB = false;
868076ba34aSJunchao Zhang   }
8699566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
8700e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, ckok->csrmat));
8710e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
872866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
873866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(ckok->csrmat)); /* without sort, mat_tests-ex62_14_seqaijkokkos fails */
874e944a159SJunchao Zhang #endif
875866eb059SJunchao Zhang 
8769566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
8779566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(C));
878a3f881fbSStefano Zampini   /* shorter version of MatAssemblyEnd_SeqAIJ */
879a3f881fbSStefano Zampini   c = (Mat_SeqAIJ *)C->data;
8809566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded,%" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
8819566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
8829566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
883a3f881fbSStefano Zampini   c->reallocs         = 0;
884076ba34aSJunchao Zhang   C->info.mallocs     = 0;
885a3f881fbSStefano Zampini   C->info.nz_unneeded = 0;
886a3f881fbSStefano Zampini   C->assembled = C->was_assembled = PETSC_TRUE;
887a3f881fbSStefano Zampini   C->num_ass++;
8883ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
889a3f881fbSStefano Zampini }
890a3f881fbSStefano Zampini 
891d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos(Mat C)
892d71ae5a4SJacob Faibussowitsch {
893076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
894076ba34aSJunchao Zhang   MatProductType               ptype;
895076ba34aSJunchao Zhang   Mat                          A, B;
896076ba34aSJunchao Zhang   bool                         transA, transB;
897076ba34aSJunchao Zhang   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
898076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
899076ba34aSJunchao Zhang   MPI_Comm                     comm;
9000e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB, csrmatC;
901a3f881fbSStefano Zampini 
902a3f881fbSStefano Zampini   PetscFunctionBegin;
903a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
9049566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
9055f80ce2aSJacob Faibussowitsch   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
906a3f881fbSStefano Zampini   A = product->A;
907a3f881fbSStefano Zampini   B = product->B;
9089566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
9099566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
910a3f881fbSStefano Zampini   akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
911a3f881fbSStefano Zampini   bkok    = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
9120e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
9130e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
914076ba34aSJunchao Zhang 
915a3f881fbSStefano Zampini   ptype = product->type;
9160e3ece09SJunchao Zhang   // Take advantage of the symmetry if true
9170e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
9180e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
9190e3ece09SJunchao Zhang     product->symbolic_used_the_fact_A_is_symmetric = PETSC_TRUE;
9200e3ece09SJunchao Zhang   }
9210e3ece09SJunchao Zhang   if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
9220e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
9230e3ece09SJunchao Zhang     product->symbolic_used_the_fact_B_is_symmetric = PETSC_TRUE;
9240e3ece09SJunchao Zhang   }
9250e3ece09SJunchao Zhang 
926a3f881fbSStefano Zampini   switch (ptype) {
9279371c9d4SSatish Balay   case MATPRODUCT_AB:
9289371c9d4SSatish Balay     transA = false;
9299371c9d4SSatish Balay     transB = false;
9300e6a1e94SMark Adams     PetscCall(MatSetBlockSizesFromMats(C, A, B));
9319371c9d4SSatish Balay     break;
9329371c9d4SSatish Balay   case MATPRODUCT_AtB:
9339371c9d4SSatish Balay     transA = true;
9349371c9d4SSatish Balay     transB = false;
9350e6a1e94SMark Adams     if (A->cmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->rmap, A->cmap->bs));
9360e6a1e94SMark Adams     if (B->cmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->cmap, B->cmap->bs));
9379371c9d4SSatish Balay     break;
9389371c9d4SSatish Balay   case MATPRODUCT_ABt:
9399371c9d4SSatish Balay     transA = false;
9409371c9d4SSatish Balay     transB = true;
9410e6a1e94SMark Adams     if (A->rmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->rmap, A->rmap->bs));
9420e6a1e94SMark Adams     if (B->rmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->cmap, B->rmap->bs));
9439371c9d4SSatish Balay     break;
944d71ae5a4SJacob Faibussowitsch   default:
945d71ae5a4SJacob Faibussowitsch     SETERRQ(comm, PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
946a3f881fbSStefano Zampini   }
9470e3ece09SJunchao Zhang   PetscCallCXX(product->data = pdata = new MatProductData_SeqAIJKokkos());
948076ba34aSJunchao Zhang   pdata->reusesym = product->api_user;
949a3f881fbSStefano Zampini 
950076ba34aSJunchao Zhang   /* TODO: add command line options to select spgemm algorithms */
951866eb059SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; /* default alg is TPL if enabled, otherwise KK */
952866eb059SJunchao Zhang 
953866eb059SJunchao Zhang   /* CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 */
954866eb059SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
955866eb059SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
956866eb059SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
957866eb059SJunchao Zhang   #endif
958866eb059SJunchao Zhang #endif
9590e3ece09SJunchao Zhang   PetscCallCXX(pdata->kh.create_spgemm_handle(spgemm_alg));
960076ba34aSJunchao Zhang 
9619566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
962076ba34aSJunchao Zhang   /* TODO: Get rid of the explicit transpose once KK-spgemm implements the transpose option */
963076ba34aSJunchao Zhang   if (transA) {
9649566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
965076ba34aSJunchao Zhang     transA = false;
966076ba34aSJunchao Zhang   }
967076ba34aSJunchao Zhang 
968076ba34aSJunchao Zhang   if (transB) {
9699566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
970076ba34aSJunchao Zhang     transB = false;
971076ba34aSJunchao Zhang   }
972076ba34aSJunchao Zhang 
9730e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
974076ba34aSJunchao Zhang   /* spgemm_symbolic() only populates C's rowmap, but not C's column indices.
975076ba34aSJunchao Zhang     So we have to do a fake spgemm_numeric() here to get csrmatC.j_d setup, before
976076ba34aSJunchao Zhang     calling new Mat_SeqAIJKokkos().
977076ba34aSJunchao Zhang     TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
978076ba34aSJunchao Zhang   */
9790e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
9800e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
981866eb059SJunchao Zhang   /* Query if KK outputs a sorted matrix. If not, we need to sort it */
982866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
983866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(csrmatC)); /* sort_option defaults to -1 in KK!*/
984e944a159SJunchao Zhang #endif
9859566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
986076ba34aSJunchao Zhang 
9879566063dSJacob Faibussowitsch   PetscCallCXX(ckok = new Mat_SeqAIJKokkos(csrmatC));
9889566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(C, ckok));
989076ba34aSJunchao Zhang   C->product->destroy = MatProductDataDestroy_SeqAIJKokkos;
9903ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
991a3f881fbSStefano Zampini }
992a3f881fbSStefano Zampini 
993a3f881fbSStefano Zampini /* handles sparse matrix matrix ops */
994d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSetFromOptions_SeqAIJKokkos(Mat mat)
995d71ae5a4SJacob Faibussowitsch {
996076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
997a3f881fbSStefano Zampini   PetscBool    Biskok = PETSC_FALSE, Ciskok = PETSC_TRUE;
998a3f881fbSStefano Zampini 
999a3f881fbSStefano Zampini   PetscFunctionBegin;
1000a3f881fbSStefano Zampini   MatCheckProduct(mat, 1);
10019566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJKOKKOS, &Biskok));
100248a46eb9SPierre Jolivet   if (product->type == MATPRODUCT_ABC) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJKOKKOS, &Ciskok));
1003a3f881fbSStefano Zampini   if (Biskok && Ciskok) {
1004a3f881fbSStefano Zampini     switch (product->type) {
1005a3f881fbSStefano Zampini     case MATPRODUCT_AB:
1006a3f881fbSStefano Zampini     case MATPRODUCT_AtB:
1007d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABt:
1008d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos;
1009d71ae5a4SJacob Faibussowitsch       break;
1010a3f881fbSStefano Zampini     case MATPRODUCT_PtAP:
1011a3f881fbSStefano Zampini     case MATPRODUCT_RARt:
1012d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABC:
1013d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
1014d71ae5a4SJacob Faibussowitsch       break;
1015d71ae5a4SJacob Faibussowitsch     default:
1016d71ae5a4SJacob Faibussowitsch       break;
1017a3f881fbSStefano Zampini     }
1018a3f881fbSStefano Zampini   } else { /* fallback for AIJ */
10199566063dSJacob Faibussowitsch     PetscCall(MatProductSetFromOptions_SeqAIJ(mat));
1020a3f881fbSStefano Zampini   }
10213ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1022a3f881fbSStefano Zampini }
1023a587d139SMark 
1024d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatScale_SeqAIJKokkos(Mat A, PetscScalar a)
1025d71ae5a4SJacob Faibussowitsch {
1026f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok;
1027f0cf5187SStefano Zampini 
1028f0cf5187SStefano Zampini   PetscFunctionBegin;
10299566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
10309566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1031f0cf5187SStefano Zampini   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1032d326c3f1SJunchao Zhang   KokkosBlas::scal(PetscGetKokkosExecutionSpace(), aijkok->a_dual.view_device(), a, aijkok->a_dual.view_device());
10339566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
10349566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(aijkok->a_dual.extent(0)));
10359566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
10363ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1037f0cf5187SStefano Zampini }
1038f0cf5187SStefano Zampini 
1039f4747e26SJunchao Zhang // add a to A's diagonal (if A is square) or main diagonal (if A is rectangular)
1040f4747e26SJunchao Zhang static PetscErrorCode MatShift_SeqAIJKokkos(Mat A, PetscScalar a)
1041f4747e26SJunchao Zhang {
1042f4747e26SJunchao Zhang   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1043f4747e26SJunchao Zhang 
1044f4747e26SJunchao Zhang   PetscFunctionBegin;
1045f4747e26SJunchao Zhang   if (A->assembled && aijseq->diagonaldense) { // no missing diagonals
1046f4747e26SJunchao Zhang     PetscInt n = PetscMin(A->rmap->n, A->cmap->n);
1047f4747e26SJunchao Zhang 
1048f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
1049f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosSyncDevice(A));
1050f4747e26SJunchao Zhang     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1051f4747e26SJunchao Zhang     const auto &Aa     = aijkok->a_dual.view_device();
1052f4747e26SJunchao Zhang     const auto &Adiag  = aijkok->diag_dual.view_device();
1053d326c3f1SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) { Aa(Adiag(i)) += a; }));
1054f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosModifyDevice(A));
1055f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(n));
1056f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
1057f4747e26SJunchao Zhang   } else { // need reassembly, very slow!
1058f4747e26SJunchao Zhang     PetscCall(MatShift_Basic(A, a));
1059f4747e26SJunchao Zhang   }
1060f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1061f4747e26SJunchao Zhang }
1062f4747e26SJunchao Zhang 
1063f4747e26SJunchao Zhang static PetscErrorCode MatDiagonalSet_SeqAIJKokkos(Mat Y, Vec D, InsertMode is)
1064f4747e26SJunchao Zhang {
1065f4747e26SJunchao Zhang   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(Y->data);
1066f4747e26SJunchao Zhang 
1067f4747e26SJunchao Zhang   PetscFunctionBegin;
1068f4747e26SJunchao Zhang   if (Y->assembled && aijseq->diagonaldense) { // no missing diagonals
1069f4747e26SJunchao Zhang     ConstPetscScalarKokkosView dv;
1070f4747e26SJunchao Zhang     PetscInt                   n, nv;
1071f4747e26SJunchao Zhang 
1072f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
1073f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosSyncDevice(Y));
1074f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(D, &dv));
1075f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(D, &nv));
1076f4747e26SJunchao Zhang     n = PetscMin(Y->rmap->n, Y->cmap->n);
1077f4747e26SJunchao Zhang     PetscCheck(n == nv, PetscObjectComm((PetscObject)Y), PETSC_ERR_ARG_SIZ, "Matrix size and vector size do not match");
1078f4747e26SJunchao Zhang 
1079f4747e26SJunchao Zhang     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1080f4747e26SJunchao Zhang     const auto &Aa     = aijkok->a_dual.view_device();
1081f4747e26SJunchao Zhang     const auto &Adiag  = aijkok->diag_dual.view_device();
1082f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(
1083d326c3f1SJunchao Zhang       Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) {
1084f4747e26SJunchao Zhang         if (is == INSERT_VALUES) Aa(Adiag(i)) = dv(i);
1085f4747e26SJunchao Zhang         else Aa(Adiag(i)) += dv(i);
1086f4747e26SJunchao Zhang       }));
1087f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(D, &dv));
1088f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1089f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(n));
1090f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
1091f4747e26SJunchao Zhang   } else { // need reassembly, very slow!
1092f4747e26SJunchao Zhang     PetscCall(MatDiagonalSet_Default(Y, D, is));
1093f4747e26SJunchao Zhang   }
1094f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1095f4747e26SJunchao Zhang }
1096f4747e26SJunchao Zhang 
1097f4747e26SJunchao Zhang static PetscErrorCode MatDiagonalScale_SeqAIJKokkos(Mat A, Vec ll, Vec rr)
1098f4747e26SJunchao Zhang {
1099f4747e26SJunchao Zhang   Mat_SeqAIJ                *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1100f4747e26SJunchao Zhang   PetscInt                   m = A->rmap->n, n = A->cmap->n, nz = aijseq->nz;
1101f4747e26SJunchao Zhang   ConstPetscScalarKokkosView lv, rv;
1102f4747e26SJunchao Zhang 
1103f4747e26SJunchao Zhang   PetscFunctionBegin;
1104f4747e26SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1105f4747e26SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1106f4747e26SJunchao Zhang   const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1107f4747e26SJunchao Zhang   const auto &Aa     = aijkok->a_dual.view_device();
1108f4747e26SJunchao Zhang   const auto &Ai     = aijkok->i_dual.view_device();
1109f4747e26SJunchao Zhang   const auto &Aj     = aijkok->j_dual.view_device();
1110f4747e26SJunchao Zhang   if (ll) {
1111f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(ll, &m));
1112f4747e26SJunchao Zhang     PetscCheck(m == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Left scaling vector wrong length");
1113f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(ll, &lv));
1114f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for( // for each row
1115d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1116f4747e26SJunchao Zhang         PetscInt i   = t.league_rank(); // row i
1117f4747e26SJunchao Zhang         PetscInt len = Ai(i + 1) - Ai(i);
1118f4747e26SJunchao Zhang         // scale entries on the row
1119f4747e26SJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt j) { Aa(Ai(i) + j) *= lv(i); });
1120f4747e26SJunchao Zhang       }));
1121f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(ll, &lv));
1122f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(nz));
1123f4747e26SJunchao Zhang   }
1124f4747e26SJunchao Zhang   if (rr) {
1125f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(rr, &n));
1126f4747e26SJunchao Zhang     PetscCheck(n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Right scaling vector wrong length");
1127f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(rr, &rv));
1128f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for( // for each nonzero
1129d326c3f1SJunchao Zhang       Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt k) { Aa(k) *= rv(Aj(k)); }));
1130f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(rr, &lv));
1131f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(nz));
1132f4747e26SJunchao Zhang   }
1133f4747e26SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(A));
1134f4747e26SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1135f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1136f4747e26SJunchao Zhang }
1137f4747e26SJunchao Zhang 
1138d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatZeroEntries_SeqAIJKokkos(Mat A)
1139d71ae5a4SJacob Faibussowitsch {
1140076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1141a587d139SMark 
1142a587d139SMark   PetscFunctionBegin;
1143076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
11442328674fSJunchao Zhang   if (aijkok) { /* Only zero the device if data is already there */
1145d326c3f1SJunchao Zhang     KokkosBlas::fill(PetscGetKokkosExecutionSpace(), aijkok->a_dual.view_device(), 0.0);
11469566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A));
11472328674fSJunchao Zhang   } else { /* Might be preallocated but not assembled */
11489566063dSJacob Faibussowitsch     PetscCall(MatZeroEntries_SeqAIJ(A));
11492328674fSJunchao Zhang   }
11503ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1151a587d139SMark }
1152a587d139SMark 
1153d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatGetDiagonal_SeqAIJKokkos(Mat A, Vec x)
1154d71ae5a4SJacob Faibussowitsch {
1155f78ce678SMark Adams   Mat_SeqAIJKokkos     *aijkok;
1156f78ce678SMark Adams   PetscInt              n;
1157f78ce678SMark Adams   PetscScalarKokkosView xv;
1158f78ce678SMark Adams 
1159f78ce678SMark Adams   PetscFunctionBegin;
1160f78ce678SMark Adams   PetscCall(VecGetLocalSize(x, &n));
1161f78ce678SMark Adams   PetscCheck(n == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Nonconforming matrix and vector");
1162f78ce678SMark Adams   PetscCheck(A->factortype == MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_SUP, "MatGetDiagonal_SeqAIJKokkos not supported on factored matrices");
1163f78ce678SMark Adams 
1164f78ce678SMark Adams   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1165f78ce678SMark Adams   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1166f78ce678SMark Adams 
1167f78ce678SMark Adams   const auto &Aa    = aijkok->a_dual.view_device();
1168f78ce678SMark Adams   const auto &Ai    = aijkok->i_dual.view_device();
1169f78ce678SMark Adams   const auto &Adiag = aijkok->diag_dual.view_device();
1170f78ce678SMark Adams 
1171f78ce678SMark Adams   PetscCall(VecGetKokkosViewWrite(x, &xv));
11729371c9d4SSatish Balay   Kokkos::parallel_for(
1173d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) {
1174f78ce678SMark Adams       if (Adiag(i) < Ai(i + 1)) xv(i) = Aa(Adiag(i));
1175f78ce678SMark Adams       else xv(i) = 0;
1176f78ce678SMark Adams     });
1177f78ce678SMark Adams   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
11783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1179f78ce678SMark Adams }
1180f78ce678SMark Adams 
1181db78de30SJunchao Zhang /* Get a Kokkos View from a mat of type MatSeqAIJKokkos */
1182d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1183d71ae5a4SJacob Faibussowitsch {
1184db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1185db78de30SJunchao Zhang 
1186db78de30SJunchao Zhang   PetscFunctionBegin;
1187db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11884f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1189db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11909566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1191db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1192076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
11933ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1194db78de30SJunchao Zhang }
1195db78de30SJunchao Zhang 
1196d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1197d71ae5a4SJacob Faibussowitsch {
1198db78de30SJunchao Zhang   PetscFunctionBegin;
1199db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12004f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1201db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12023ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1203db78de30SJunchao Zhang }
1204db78de30SJunchao Zhang 
1205d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, MatScalarKokkosView *kv)
1206d71ae5a4SJacob Faibussowitsch {
1207db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1208db78de30SJunchao Zhang 
1209db78de30SJunchao Zhang   PetscFunctionBegin;
1210db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12114f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1212db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12139566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1214db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1215076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
12163ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1217db78de30SJunchao Zhang }
1218db78de30SJunchao Zhang 
1219d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, MatScalarKokkosView *kv)
1220d71ae5a4SJacob Faibussowitsch {
1221db78de30SJunchao Zhang   PetscFunctionBegin;
1222db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12234f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1224db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12259566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
12263ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1227db78de30SJunchao Zhang }
1228db78de30SJunchao Zhang 
1229d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1230d71ae5a4SJacob Faibussowitsch {
1231db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1232db78de30SJunchao Zhang 
1233db78de30SJunchao Zhang   PetscFunctionBegin;
1234db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12354f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1236db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1237db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1238076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
12393ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1240db78de30SJunchao Zhang }
1241db78de30SJunchao Zhang 
1242d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1243d71ae5a4SJacob Faibussowitsch {
1244db78de30SJunchao Zhang   PetscFunctionBegin;
1245db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12464f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1247db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12489566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
12493ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1250db78de30SJunchao Zhang }
1251db78de30SJunchao Zhang 
1252c0c276a7Ssdargavi PetscErrorCode MatCreateSeqAIJKokkosWithKokkosViews(MPI_Comm comm, PetscInt m, PetscInt n, Kokkos::View<PetscInt *> &i_d, Kokkos::View<PetscInt *> &j_d, Kokkos::View<PetscScalar *> &a_d, Mat *A)
1253c0c276a7Ssdargavi {
1254c0c276a7Ssdargavi   Mat_SeqAIJKokkos *akok;
1255c0c276a7Ssdargavi 
1256c0c276a7Ssdargavi   PetscFunctionBegin;
1257c0c276a7Ssdargavi   // Create host copies of the input aij
1258c0c276a7Ssdargavi   auto i_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), i_d);
1259c0c276a7Ssdargavi   auto j_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), j_d);
1260c0c276a7Ssdargavi   // Don't copy the vals to the host now
1261c0c276a7Ssdargavi   auto a_h = Kokkos::create_mirror_view(HostMirrorMemorySpace(), a_d);
1262c0c276a7Ssdargavi 
1263c0c276a7Ssdargavi   MatScalarKokkosDualView a_dual = MatScalarKokkosDualView(a_d, a_h);
1264c0c276a7Ssdargavi   // Note we have modified device data so it will copy lazily
1265c0c276a7Ssdargavi   a_dual.modify_device();
1266c0c276a7Ssdargavi   MatRowMapKokkosDualView i_dual = MatRowMapKokkosDualView(i_d, i_h);
1267c0c276a7Ssdargavi   MatColIdxKokkosDualView j_dual = MatColIdxKokkosDualView(j_d, j_h);
1268c0c276a7Ssdargavi 
1269c0c276a7Ssdargavi   PetscCallCXX(akok = new Mat_SeqAIJKokkos(m, n, j_dual.extent(0), i_dual, j_dual, a_dual));
1270c0c276a7Ssdargavi   PetscCall(MatCreate(comm, A));
1271c0c276a7Ssdargavi   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
1272c0c276a7Ssdargavi   PetscFunctionReturn(PETSC_SUCCESS);
1273c0c276a7Ssdargavi }
1274c0c276a7Ssdargavi 
1275c17cf699SJunchao Zhang /* Computes Y += alpha X */
1276d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAXPY_SeqAIJKokkos(Mat Y, PetscScalar alpha, Mat X, MatStructure pattern)
1277d71ae5a4SJacob Faibussowitsch {
1278a587d139SMark   Mat_SeqAIJ              *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
1279c17cf699SJunchao Zhang   Mat_SeqAIJKokkos        *xkok, *ykok, *zkok;
1280c17cf699SJunchao Zhang   ConstMatScalarKokkosView Xa;
1281c17cf699SJunchao Zhang   MatScalarKokkosView      Ya;
12824df4a32cSJunchao Zhang   auto                     exec = PetscGetKokkosExecutionSpace();
1283a587d139SMark 
1284a587d139SMark   PetscFunctionBegin;
1285c17cf699SJunchao Zhang   PetscCheckTypeName(Y, MATSEQAIJKOKKOS);
1286c17cf699SJunchao Zhang   PetscCheckTypeName(X, MATSEQAIJKOKKOS);
12879566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(Y));
12889566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(X));
12899566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1290db78de30SJunchao Zhang 
1291c17cf699SJunchao Zhang   if (pattern != SAME_NONZERO_PATTERN && x->nz == y->nz) {
1292a587d139SMark     PetscBool e;
12939566063dSJacob Faibussowitsch     PetscCall(PetscArraycmp(x->i, y->i, Y->rmap->n + 1, &e));
1294a587d139SMark     if (e) {
12959566063dSJacob Faibussowitsch       PetscCall(PetscArraycmp(x->j, y->j, y->nz, &e));
1296c17cf699SJunchao Zhang       if (e) pattern = SAME_NONZERO_PATTERN;
1297a587d139SMark     }
1298a587d139SMark   }
1299db78de30SJunchao Zhang 
1300c17cf699SJunchao Zhang   /* cusparseDcsrgeam2() computes C = alpha A + beta B. If one knew sparsity pattern of C, one can skip
1301c17cf699SJunchao Zhang     cusparseScsrgeam2_bufferSizeExt() / cusparseXcsrgeam2Nnz(), and directly call cusparseScsrgeam2().
1302c17cf699SJunchao Zhang     If X is SUBSET_NONZERO_PATTERN of Y, we could take advantage of this cusparse feature. However,
1303c17cf699SJunchao Zhang     KokkosSparse::spadd(alpha,A,beta,B,C) has symbolic and numeric phases, MatAXPY does not.
1304c17cf699SJunchao Zhang   */
1305c17cf699SJunchao Zhang   ykok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1306c17cf699SJunchao Zhang   xkok = static_cast<Mat_SeqAIJKokkos *>(X->spptr);
1307c17cf699SJunchao Zhang   Xa   = xkok->a_dual.view_device();
1308c17cf699SJunchao Zhang   Ya   = ykok->a_dual.view_device();
1309c17cf699SJunchao Zhang 
1310c17cf699SJunchao Zhang   if (pattern == SAME_NONZERO_PATTERN) {
1311d326c3f1SJunchao Zhang     KokkosBlas::axpy(exec, alpha, Xa, Ya);
13129566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1313c17cf699SJunchao Zhang   } else if (pattern == SUBSET_NONZERO_PATTERN) {
1314c17cf699SJunchao Zhang     MatRowMapKokkosView Xi = xkok->i_dual.view_device(), Yi = ykok->i_dual.view_device();
1315c17cf699SJunchao Zhang     MatColIdxKokkosView Xj = xkok->j_dual.view_device(), Yj = ykok->j_dual.view_device();
1316c17cf699SJunchao Zhang 
13179371c9d4SSatish Balay     Kokkos::parallel_for(
1318d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(exec, Y->rmap->n, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
13190e3ece09SJunchao Zhang         PetscInt i = t.league_rank(); // row i
13200e3ece09SJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() {
13210e3ece09SJunchao Zhang           // Only one thread works in a team
1322c17cf699SJunchao Zhang           PetscInt p, q = Yi(i);
13230e3ece09SJunchao Zhang           for (p = Xi(i); p < Xi(i + 1); p++) {          // For each nonzero on row i of X,
13240e3ece09SJunchao Zhang             while (Xj(p) != Yj(q) && q < Yi(i + 1)) q++; // find the matching nonzero on row i of Y.
13250e3ece09SJunchao Zhang             if (Xj(p) == Yj(q)) {                        // Found it
1326c17cf699SJunchao Zhang               Ya(q) += alpha * Xa(p);
1327c17cf699SJunchao Zhang               q++;
1328a587d139SMark             } else {
13290e3ece09SJunchao Zhang             // If not found, it indicates the input is wrong (X is not a SUBSET_NONZERO_PATTERN of Y).
13300e3ece09SJunchao Zhang             // Just insert a NaN at the beginning of row i if it is not empty, to make the result wrong.
13310e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_VERSION_GE(3, 7, 0)
13320e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::ArithTraits<PetscScalar>::nan();
13338b8b16f9SJunchao Zhang #else
13340e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::Experimental::nan("1");
13358b8b16f9SJunchao Zhang #endif
1336a587d139SMark             }
1337c17cf699SJunchao Zhang           }
1338c17cf699SJunchao Zhang         });
1339c17cf699SJunchao Zhang       });
13409566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
13410e3ece09SJunchao Zhang   } else { // different nonzero patterns
1342c17cf699SJunchao Zhang     Mat             Z;
1343c17cf699SJunchao Zhang     KokkosCsrMatrix zcsr;
1344c17cf699SJunchao Zhang     KernelHandle    kh;
13450e3ece09SJunchao Zhang     kh.create_spadd_handle(true); // X, Y are sorted
1346c17cf699SJunchao Zhang     KokkosSparse::spadd_symbolic(&kh, xkok->csrmat, ykok->csrmat, zcsr);
1347c17cf699SJunchao Zhang     KokkosSparse::spadd_numeric(&kh, alpha, xkok->csrmat, (PetscScalar)1.0, ykok->csrmat, zcsr);
1348c17cf699SJunchao Zhang     zkok = new Mat_SeqAIJKokkos(zcsr);
13499566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, zkok, &Z));
13509566063dSJacob Faibussowitsch     PetscCall(MatHeaderReplace(Y, &Z));
1351c17cf699SJunchao Zhang     kh.destroy_spadd_handle();
1352c17cf699SJunchao Zhang   }
13539566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
13540e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(xkok->a_dual.extent(0) * 2)); // Because we scaled X and then added it to Y
13553ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1356a587d139SMark }
1357a587d139SMark 
13582c4ab24aSJunchao Zhang struct MatCOOStruct_SeqAIJKokkos {
13592c4ab24aSJunchao Zhang   PetscCount           n;
13602c4ab24aSJunchao Zhang   PetscCount           Atot;
13612c4ab24aSJunchao Zhang   PetscInt             nz;
13622c4ab24aSJunchao Zhang   PetscCountKokkosView jmap;
13632c4ab24aSJunchao Zhang   PetscCountKokkosView perm;
13642c4ab24aSJunchao Zhang 
13652c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos(const MatCOOStruct_SeqAIJ *coo_h)
13662c4ab24aSJunchao Zhang   {
13672c4ab24aSJunchao Zhang     nz   = coo_h->nz;
13682c4ab24aSJunchao Zhang     n    = coo_h->n;
13692c4ab24aSJunchao Zhang     Atot = coo_h->Atot;
13702c4ab24aSJunchao Zhang     jmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->jmap, nz + 1));
13712c4ab24aSJunchao Zhang     perm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->perm, Atot));
13722c4ab24aSJunchao Zhang   }
13732c4ab24aSJunchao Zhang };
13742c4ab24aSJunchao Zhang 
137549abdd8aSBarry Smith static PetscErrorCode MatCOOStructDestroy_SeqAIJKokkos(void **data)
13762c4ab24aSJunchao Zhang {
13772c4ab24aSJunchao Zhang   PetscFunctionBegin;
137849abdd8aSBarry Smith   PetscCallCXX(delete static_cast<MatCOOStruct_SeqAIJKokkos *>(*data));
13792c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
13802c4ab24aSJunchao Zhang }
13812c4ab24aSJunchao Zhang 
1382d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_SeqAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1383d71ae5a4SJacob Faibussowitsch {
138442550becSJunchao Zhang   Mat_SeqAIJKokkos          *akok;
138542550becSJunchao Zhang   Mat_SeqAIJ                *aseq;
138603e76207SPierre Jolivet   PetscContainer             container_h;
13872c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJ       *coo_h;
13882c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo_d;
138942550becSJunchao Zhang 
139042550becSJunchao Zhang   PetscFunctionBegin;
13919566063dSJacob Faibussowitsch   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, coo_i, coo_j));
1392394ed5ebSJunchao Zhang   aseq = static_cast<Mat_SeqAIJ *>(mat->data);
139342550becSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(mat->spptr);
1394cbc6b225SStefano Zampini   delete akok;
1395f4747e26SJunchao Zhang   mat->spptr = akok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, aseq, mat->nonzerostate + 1, PETSC_FALSE);
13969566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries_SeqAIJKokkos(mat));
13972c4ab24aSJunchao Zhang 
13982c4ab24aSJunchao Zhang   // Copy the COO struct to device
13992c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
14002c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
14012c4ab24aSJunchao Zhang   PetscCallCXX(coo_d = new MatCOOStruct_SeqAIJKokkos(coo_h));
14022c4ab24aSJunchao Zhang 
14032c4ab24aSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
140403e76207SPierre Jolivet   PetscCall(PetscObjectContainerCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", coo_d, MatCOOStructDestroy_SeqAIJKokkos));
14053ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
140642550becSJunchao Zhang }
140742550becSJunchao Zhang 
1408d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_SeqAIJKokkos(Mat A, const PetscScalar v[], InsertMode imode)
1409d71ae5a4SJacob Faibussowitsch {
141042550becSJunchao Zhang   MatScalarKokkosView        Aa;
141142550becSJunchao Zhang   ConstMatScalarKokkosView   kv;
141242550becSJunchao Zhang   PetscMemType               memtype;
14132c4ab24aSJunchao Zhang   PetscContainer             container;
14142c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo;
141542550becSJunchao Zhang 
141642550becSJunchao Zhang   PetscFunctionBegin;
14172c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)A, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
14182c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
14192c4ab24aSJunchao Zhang 
14202c4ab24aSJunchao Zhang   const auto &n    = coo->n;
14212c4ab24aSJunchao Zhang   const auto &Annz = coo->nz;
14222c4ab24aSJunchao Zhang   const auto &jmap = coo->jmap;
14232c4ab24aSJunchao Zhang   const auto &perm = coo->perm;
14242c4ab24aSJunchao Zhang 
14259566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype));
142642550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
14272c4ab24aSJunchao Zhang     kv = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatScalarKokkosViewHost(v, n));
142842550becSJunchao Zhang   } else {
14292c4ab24aSJunchao Zhang     kv = ConstMatScalarKokkosView(v, n); /* Directly use v[]'s memory */
143042550becSJunchao Zhang   }
143142550becSJunchao Zhang 
1432c7b718f4SJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1433c7b718f4SJunchao Zhang   else PetscCall(MatSeqAIJGetKokkosView(A, &Aa));                             /* read & write matrix values */
143442550becSJunchao Zhang 
143508bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
14369371c9d4SSatish Balay   Kokkos::parallel_for(
1437d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz), KOKKOS_LAMBDA(const PetscCount i) {
1438c7b718f4SJunchao Zhang       PetscScalar sum = 0.0;
1439c7b718f4SJunchao Zhang       for (PetscCount k = jmap(i); k < jmap(i + 1); k++) sum += kv(perm(k));
1440c7b718f4SJunchao Zhang       Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1441c7b718f4SJunchao Zhang     });
144208bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1443394ed5ebSJunchao Zhang 
14449566063dSJacob Faibussowitsch   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
14459566063dSJacob Faibussowitsch   else PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
14463ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
144742550becSJunchao Zhang }
144842550becSJunchao Zhang 
1449d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat A)
1450d71ae5a4SJacob Faibussowitsch {
1451076ba34aSJunchao Zhang   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
1452076ba34aSJunchao Zhang 
14538c3ff71bSJunchao Zhang   PetscFunctionBegin;
1454076ba34aSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_KOKKOS; /* We do not really use this flag */
14556f3d89d0SStefano Zampini   A->boundtocpu  = PETSC_FALSE;
14566f3d89d0SStefano Zampini 
14578c3ff71bSJunchao Zhang   A->ops->assemblyend               = MatAssemblyEnd_SeqAIJKokkos;
14588c3ff71bSJunchao Zhang   A->ops->destroy                   = MatDestroy_SeqAIJKokkos;
14598c3ff71bSJunchao Zhang   A->ops->duplicate                 = MatDuplicate_SeqAIJKokkos;
1460a587d139SMark   A->ops->axpy                      = MatAXPY_SeqAIJKokkos;
1461f0cf5187SStefano Zampini   A->ops->scale                     = MatScale_SeqAIJKokkos;
1462a587d139SMark   A->ops->zeroentries               = MatZeroEntries_SeqAIJKokkos;
1463076ba34aSJunchao Zhang   A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJKokkos;
14648c3ff71bSJunchao Zhang   A->ops->mult                      = MatMult_SeqAIJKokkos;
14658c3ff71bSJunchao Zhang   A->ops->multadd                   = MatMultAdd_SeqAIJKokkos;
14668c3ff71bSJunchao Zhang   A->ops->multtranspose             = MatMultTranspose_SeqAIJKokkos;
14678c3ff71bSJunchao Zhang   A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJKokkos;
14688c3ff71bSJunchao Zhang   A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJKokkos;
14698c3ff71bSJunchao Zhang   A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJKokkos;
1470076ba34aSJunchao Zhang   A->ops->productnumeric            = MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos;
14710ecb592aSJunchao Zhang   A->ops->transpose                 = MatTranspose_SeqAIJKokkos;
1472152b3e56SJunchao Zhang   A->ops->setoption                 = MatSetOption_SeqAIJKokkos;
1473f78ce678SMark Adams   A->ops->getdiagonal               = MatGetDiagonal_SeqAIJKokkos;
1474f4747e26SJunchao Zhang   A->ops->shift                     = MatShift_SeqAIJKokkos;
1475f4747e26SJunchao Zhang   A->ops->diagonalset               = MatDiagonalSet_SeqAIJKokkos;
1476f4747e26SJunchao Zhang   A->ops->diagonalscale             = MatDiagonalScale_SeqAIJKokkos;
1477*03db1824SAlex Lindsay   A->ops->getcurrentmemtype         = MatGetCurrentMemType_SeqAIJKokkos;
1478076ba34aSJunchao Zhang   a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJKokkos;
1479076ba34aSJunchao Zhang   a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJKokkos;
1480076ba34aSJunchao Zhang   a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJKokkos;
1481076ba34aSJunchao Zhang   a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJKokkos;
1482076ba34aSJunchao Zhang   a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJKokkos;
1483076ba34aSJunchao Zhang   a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJKokkos;
14847ee59b9bSJunchao Zhang   a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJKokkos;
148542550becSJunchao Zhang 
14869566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJKokkos));
14879566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJKokkos));
148857761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
148957761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijkokkos_hypre_C", MatConvert_AIJ_HYPRE));
149057761e9aSJunchao Zhang #endif
14913ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1492076ba34aSJunchao Zhang }
1493076ba34aSJunchao Zhang 
14949d13fa56SJunchao Zhang /*
14959d13fa56SJunchao Zhang    Extract the (prescribled) diagonal blocks of the matrix and then invert them
14969d13fa56SJunchao Zhang 
14979d13fa56SJunchao Zhang   Input Parameters:
14989d13fa56SJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
14999d13fa56SJunchao Zhang .  bs      - block sizes in 'csr' format, i.e., the i-th block has size bs(i+1) - bs(i)
15009d13fa56SJunchao Zhang .  bs2     - square of block sizes in 'csr' format, i.e., the i-th block should be stored at offset bs2(i) in diagVal[]
15019d13fa56SJunchao Zhang .  blkMap  - map row ids to block ids, i.e., row i belongs to the block blkMap(i)
15029d13fa56SJunchao Zhang -  work    - a pre-allocated work buffer (as big as diagVal) for use by this routine
15039d13fa56SJunchao Zhang 
15049d13fa56SJunchao Zhang   Output Parameter:
15059d13fa56SJunchao Zhang .  diagVal - the (pre-allocated) buffer to store the inverted blocks (each block is stored in column-major order)
15069d13fa56SJunchao Zhang */
15079d13fa56SJunchao Zhang PETSC_INTERN PetscErrorCode MatInvertVariableBlockDiagonal_SeqAIJKokkos(Mat A, const PetscIntKokkosView &bs, const PetscIntKokkosView &bs2, const PetscIntKokkosView &blkMap, PetscScalarKokkosView &work, PetscScalarKokkosView &diagVal)
15089d13fa56SJunchao Zhang {
15099d13fa56SJunchao Zhang   Mat_SeqAIJKokkos *akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
15109d13fa56SJunchao Zhang   PetscInt          nblocks = bs.extent(0) - 1;
15119d13fa56SJunchao Zhang 
15129d13fa56SJunchao Zhang   PetscFunctionBegin;
15139d13fa56SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A)); // Since we'll access A's value on device
15149d13fa56SJunchao Zhang 
15159d13fa56SJunchao Zhang   // Pull out the diagonal blocks of the matrix and then invert the blocks
15169d13fa56SJunchao Zhang   auto Aa    = akok->a_dual.view_device();
15179d13fa56SJunchao Zhang   auto Ai    = akok->i_dual.view_device();
15189d13fa56SJunchao Zhang   auto Aj    = akok->j_dual.view_device();
15199d13fa56SJunchao Zhang   auto Adiag = akok->diag_dual.view_device();
15209d13fa56SJunchao Zhang   // TODO: how to tune the team size?
152145402d8aSJunchao Zhang #if defined(KOKKOS_ENABLE_UNIFIED_MEMORY)
15229d13fa56SJunchao Zhang   auto ts = Kokkos::AUTO();
15239d13fa56SJunchao Zhang #else
15249d13fa56SJunchao Zhang   auto ts = 16; // improved performance 30% over Kokkos::AUTO() with CUDA, but failed with "Kokkos::abort: Requested Team Size is too large!" on CPUs
15259d13fa56SJunchao Zhang #endif
15269d13fa56SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1527d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), nblocks, ts), KOKKOS_LAMBDA(const KokkosTeamMemberType &teamMember) {
15289d13fa56SJunchao Zhang       const PetscInt bid    = teamMember.league_rank();                                                   // block id
15299d13fa56SJunchao Zhang       const PetscInt rstart = bs(bid);                                                                    // this block starts from this row
15309d13fa56SJunchao Zhang       const PetscInt m      = bs(bid + 1) - bs(bid);                                                      // size of this block
15319d13fa56SJunchao Zhang       const auto    &B      = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft>(&diagVal(bs2(bid)), m, m); // column-major order
15329d13fa56SJunchao Zhang       const auto    &W      = PetscScalarKokkosView(&work(bs2(bid)), m * m);
15339d13fa56SJunchao Zhang 
15349d13fa56SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(teamMember, m), [=](const PetscInt &r) { // r-th row in B
15359d13fa56SJunchao Zhang         PetscInt i = rstart + r;                                                            // i-th row in A
15369d13fa56SJunchao Zhang 
15379d13fa56SJunchao Zhang         if (Ai(i) <= Adiag(i) && Adiag(i) < Ai(i + 1)) { // if the diagonal exists (common case)
15389d13fa56SJunchao Zhang           PetscInt first = Adiag(i) - r;                 // we start to check nonzeros from here along this row
15399d13fa56SJunchao Zhang 
15409d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) {                   // walk n steps to see what column indices we will meet
15419d13fa56SJunchao Zhang             if (first + c < Ai(i) || first + c >= Ai(i + 1)) { // this entry (first+c) is out of range of this row, in other words, its value is zero
15429d13fa56SJunchao Zhang               B(r, c) = 0.0;
15439d13fa56SJunchao Zhang             } else if (Aj(first + c) == rstart + c) { // this entry is right on the (rstart+c) column
15449d13fa56SJunchao Zhang               B(r, c) = Aa(first + c);
15459d13fa56SJunchao Zhang             } else { // this entry does not show up in the CSR
15469d13fa56SJunchao Zhang               B(r, c) = 0.0;
15479d13fa56SJunchao Zhang             }
15489d13fa56SJunchao Zhang           }
15499d13fa56SJunchao Zhang         } else { // rare case that the diagonal does not exist
15509d13fa56SJunchao Zhang           const PetscInt begin = Ai(i);
15519d13fa56SJunchao Zhang           const PetscInt end   = Ai(i + 1);
15529d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) B(r, c) = 0.0;
15539d13fa56SJunchao Zhang           for (PetscInt j = begin; j < end; j++) { // scan the whole row; could use binary search but this is a rare case so we did not.
15549d13fa56SJunchao Zhang             if (rstart <= Aj(j) && Aj(j) < rstart + m) B(r, Aj(j) - rstart) = Aa(j);
15559d13fa56SJunchao Zhang             else if (Aj(j) >= rstart + m) break;
15569d13fa56SJunchao Zhang           }
15579d13fa56SJunchao Zhang         }
15589d13fa56SJunchao Zhang       });
15599d13fa56SJunchao Zhang 
15609d13fa56SJunchao Zhang       // LU-decompose B (w/o pivoting) and then invert B
15619d13fa56SJunchao Zhang       KokkosBatched::TeamLU<KokkosTeamMemberType, KokkosBatched::Algo::LU::Unblocked>::invoke(teamMember, B, 0.0);
15629d13fa56SJunchao Zhang       KokkosBatched::TeamInverseLU<KokkosTeamMemberType, KokkosBatched::Algo::InverseLU::Unblocked>::invoke(teamMember, B, W);
15639d13fa56SJunchao Zhang     }));
15649d13fa56SJunchao Zhang   // PetscLogGpuFlops() is done in the caller PCSetUp_VPBJacobi_Kokkos as we don't want to compute the flops in kernels
15659d13fa56SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15669d13fa56SJunchao Zhang }
15679d13fa56SJunchao Zhang 
1568d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSetSeqAIJKokkosWithCSRMatrix(Mat A, Mat_SeqAIJKokkos *akok)
1569d71ae5a4SJacob Faibussowitsch {
1570076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
1571076ba34aSJunchao Zhang   PetscInt    i, m, n;
15724df4a32cSJunchao Zhang   auto        exec = PetscGetKokkosExecutionSpace();
1573076ba34aSJunchao Zhang 
1574076ba34aSJunchao Zhang   PetscFunctionBegin;
15755f80ce2aSJacob Faibussowitsch   PetscCheck(!A->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A->spptr is supposed to be empty");
1576076ba34aSJunchao Zhang 
1577076ba34aSJunchao Zhang   m = akok->nrows();
1578076ba34aSJunchao Zhang   n = akok->ncols();
15799566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(A, m, n, m, n));
15809566063dSJacob Faibussowitsch   PetscCall(MatSetType(A, MATSEQAIJKOKKOS));
1581076ba34aSJunchao Zhang 
1582076ba34aSJunchao Zhang   /* Set up data structures of A as a MATSEQAIJ */
15839566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(A, MAT_SKIP_ALLOCATION, NULL));
158457508eceSPierre Jolivet   aseq = (Mat_SeqAIJ *)A->data;
1585076ba34aSJunchao Zhang 
1586f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncHost(akok->i_dual, exec)); /* We always need sync'ed i, j on host */
1587f3d3cd90SZach Atkins   PetscCall(KokkosDualViewSyncHost(akok->j_dual, exec));
1588076ba34aSJunchao Zhang 
1589076ba34aSJunchao Zhang   aseq->i       = akok->i_host_data();
1590076ba34aSJunchao Zhang   aseq->j       = akok->j_host_data();
1591076ba34aSJunchao Zhang   aseq->a       = akok->a_host_data();
1592076ba34aSJunchao Zhang   aseq->nonew   = -1; /*this indicates that inserting a new value in the matrix that generates a new nonzero is an error*/
1593076ba34aSJunchao Zhang   aseq->free_a  = PETSC_FALSE;
1594076ba34aSJunchao Zhang   aseq->free_ij = PETSC_FALSE;
1595076ba34aSJunchao Zhang   aseq->nz      = akok->nnz();
1596076ba34aSJunchao Zhang   aseq->maxnz   = aseq->nz;
1597076ba34aSJunchao Zhang 
15989566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->imax));
15999566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->ilen));
1600ad540459SPierre Jolivet   for (i = 0; i < m; i++) aseq->ilen[i] = aseq->imax[i] = aseq->i[i + 1] - aseq->i[i];
1601076ba34aSJunchao Zhang 
1602076ba34aSJunchao Zhang   /* It is critical to set the nonzerostate, as we use it to check if sparsity pattern (hence data) has changed on host in MatAssemblyEnd */
1603076ba34aSJunchao Zhang   akok->nonzerostate = A->nonzerostate;
1604ff751488SJunchao Zhang   A->spptr           = akok; /* Set A->spptr before MatAssembly so that A->spptr won't be allocated again there */
16059566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
16069566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
16073ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1608076ba34aSJunchao Zhang }
1609076ba34aSJunchao Zhang 
16100e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetKokkosCsrMatrix(Mat A, KokkosCsrMatrix *csr)
16110e3ece09SJunchao Zhang {
16120e3ece09SJunchao Zhang   PetscFunctionBegin;
16130e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
16140e3ece09SJunchao Zhang   *csr = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
16150e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
16160e3ece09SJunchao Zhang }
16170e3ece09SJunchao Zhang 
16180e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithKokkosCsrMatrix(MPI_Comm comm, KokkosCsrMatrix csr, Mat *A)
16190e3ece09SJunchao Zhang {
16200e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok;
16214d86920dSPierre Jolivet 
16220e3ece09SJunchao Zhang   PetscFunctionBegin;
16230e3ece09SJunchao Zhang   PetscCallCXX(akok = new Mat_SeqAIJKokkos(csr));
16240e3ece09SJunchao Zhang   PetscCall(MatCreate(comm, A));
16250e3ece09SJunchao Zhang   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
16260e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
16270e3ece09SJunchao Zhang }
16280e3ece09SJunchao Zhang 
1629076ba34aSJunchao Zhang /* Crete a SEQAIJKOKKOS matrix with a Mat_SeqAIJKokkos data structure
1630076ba34aSJunchao Zhang 
1631076ba34aSJunchao Zhang    Note we have names like MatSeqAIJSetPreallocationCSR, so I use capitalized CSR
1632076ba34aSJunchao Zhang  */
1633d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithCSRMatrix(MPI_Comm comm, Mat_SeqAIJKokkos *akok, Mat *A)
1634d71ae5a4SJacob Faibussowitsch {
1635076ba34aSJunchao Zhang   PetscFunctionBegin;
16369566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
16379566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
16383ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16398c3ff71bSJunchao Zhang }
16408c3ff71bSJunchao Zhang 
1641152b3e56SJunchao Zhang /*@C
164211a5261eSBarry Smith   MatCreateSeqAIJKokkos - Creates a sparse matrix in `MATSEQAIJKOKKOS` (compressed row) format
16438c3ff71bSJunchao Zhang   (the default parallel PETSc format). This matrix will ultimately be handled by
164420f4b53cSBarry Smith   Kokkos for calculations.
16458c3ff71bSJunchao Zhang 
16468c3ff71bSJunchao Zhang   Collective
16478c3ff71bSJunchao Zhang 
16488c3ff71bSJunchao Zhang   Input Parameters:
164911a5261eSBarry Smith + comm - MPI communicator, set to `PETSC_COMM_SELF`
16508c3ff71bSJunchao Zhang . m    - number of rows
16518c3ff71bSJunchao Zhang . n    - number of columns
165220f4b53cSBarry Smith . nz   - number of nonzeros per row (same for all rows), ignored if `nnz` is provided
165320f4b53cSBarry Smith - nnz  - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL`
16548c3ff71bSJunchao Zhang 
16558c3ff71bSJunchao Zhang   Output Parameter:
16568c3ff71bSJunchao Zhang . A - the matrix
16578c3ff71bSJunchao Zhang 
16582ef1f0ffSBarry Smith   Level: intermediate
16592ef1f0ffSBarry Smith 
16602ef1f0ffSBarry Smith   Notes:
166111a5261eSBarry Smith   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
16628c3ff71bSJunchao Zhang   MatXXXXSetPreallocation() paradgm instead of this routine directly.
166311a5261eSBarry Smith   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
16648c3ff71bSJunchao Zhang 
166511a5261eSBarry Smith   The AIJ format, also called
16662ef1f0ffSBarry Smith   compressed row storage, is fully compatible with standard Fortran
16678c3ff71bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
166820f4b53cSBarry Smith   either one (as in Fortran) or zero.
16698c3ff71bSJunchao Zhang 
16702ef1f0ffSBarry Smith   Specify the preallocated storage with either `nz` or `nnz` (not both).
16712ef1f0ffSBarry Smith   Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
16722ef1f0ffSBarry Smith   allocation.
16738c3ff71bSJunchao Zhang 
1674fe59aa6dSJacob Faibussowitsch .seealso: [](ch_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
16758c3ff71bSJunchao Zhang @*/
1676d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateSeqAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
1677d71ae5a4SJacob Faibussowitsch {
16788c3ff71bSJunchao Zhang   PetscFunctionBegin;
16799566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
16809566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
16819566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, m, n));
16829566063dSJacob Faibussowitsch   PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
16839566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
16843ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16858c3ff71bSJunchao Zhang }
1686930e68a5SMark Adams 
1687aac854edSJunchao Zhang // After matrix numeric factorization, there are still steps to do before triangular solve can be called.
1688aac854edSJunchao Zhang // For example, for transpose solve, we might need to compute the transpose matrices if the solver does not support it (such as KK, while cusparse does).
1689aac854edSJunchao Zhang // In cusparse, one has to call cusparseSpSV_analysis() with updated triangular matrix values before calling cusparseSpSV_solve().
1690aac854edSJunchao Zhang // Simiarily, in KK sptrsv_symbolic() has to be called before sptrsv_solve(). We put these steps in MatSeqAIJKokkos{Transpose}SolveCheck.
1691aac854edSJunchao Zhang static PetscErrorCode MatSeqAIJKokkosSolveCheck(Mat A)
1692d71ae5a4SJacob Faibussowitsch {
169386a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors   = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1694aac854edSJunchao Zhang   const PetscBool             has_lower = factors->iL_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // false with Choleksy
1695aac854edSJunchao Zhang   const PetscBool             has_upper = factors->iU_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // true with LU and Choleksy
169686a27549SJunchao Zhang 
169786a27549SJunchao Zhang   PetscFunctionBegin;
1698aac854edSJunchao Zhang   if (!factors->sptrsv_symbolic_completed) { // If sptrsv_symbolic was not called yet
1699aac854edSJunchao Zhang     if (has_upper) PetscCallCXX(sptrsv_symbolic(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d));
1700aac854edSJunchao Zhang     if (has_lower) PetscCallCXX(sptrsv_symbolic(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d));
170186a27549SJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_TRUE;
170286a27549SJunchao Zhang   }
17033ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
170486a27549SJunchao Zhang }
170586a27549SJunchao Zhang 
1706d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosTransposeSolveCheck(Mat A)
1707d71ae5a4SJacob Faibussowitsch {
1708aac854edSJunchao Zhang   const PetscInt              n         = A->rmap->n;
170986a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors   = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1710aac854edSJunchao Zhang   const PetscBool             has_lower = factors->iL_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // false with Choleksy
1711aac854edSJunchao Zhang   const PetscBool             has_upper = factors->iU_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // true with LU or Choleksy
171286a27549SJunchao Zhang 
171386a27549SJunchao Zhang   PetscFunctionBegin;
1714aac854edSJunchao Zhang   if (!factors->transpose_updated) {
1715aac854edSJunchao Zhang     if (has_upper) {
1716aac854edSJunchao Zhang       if (!factors->iUt_d.extent(0)) {                                 // Allocate Ut on device if not yet
1717aac854edSJunchao Zhang         factors->iUt_d = MatRowMapKokkosView("factors->iUt_d", n + 1); // KK requires this view to be initialized to 0 to call transpose_matrix
17187b8d4ba6SJunchao Zhang         factors->jUt_d = MatColIdxKokkosView(NoInit("factors->jUt_d"), factors->jU_d.extent(0));
17197b8d4ba6SJunchao Zhang         factors->aUt_d = MatScalarKokkosView(NoInit("factors->aUt_d"), factors->aU_d.extent(0));
1720aac854edSJunchao Zhang       }
172186a27549SJunchao Zhang 
1722aac854edSJunchao Zhang       if (factors->iU_h.extent(0)) { // If U is on host (factorization was done on host), we also compute the transpose on host
1723aac854edSJunchao Zhang         if (!factors->U) {
1724aac854edSJunchao Zhang           Mat_SeqAIJ *seq;
172586a27549SJunchao Zhang 
1726aac854edSJunchao Zhang           PetscCall(MatCreateSeqAIJWithArrays(PETSC_COMM_SELF, n, n, factors->iU_h.data(), factors->jU_h.data(), factors->aU_h.data(), &factors->U));
1727aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->U, MAT_INITIAL_MATRIX, &factors->Ut));
172886a27549SJunchao Zhang 
1729aac854edSJunchao Zhang           seq            = static_cast<Mat_SeqAIJ *>(factors->Ut->data);
1730aac854edSJunchao Zhang           factors->iUt_h = MatRowMapKokkosViewHost(seq->i, n + 1);
1731aac854edSJunchao Zhang           factors->jUt_h = MatColIdxKokkosViewHost(seq->j, seq->nz);
1732aac854edSJunchao Zhang           factors->aUt_h = MatScalarKokkosViewHost(seq->a, seq->nz);
1733aac854edSJunchao Zhang         } else {
1734aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->U, MAT_REUSE_MATRIX, &factors->Ut)); // Matrix Ut' data is aliased with {i, j, a}Ut_h
1735aac854edSJunchao Zhang         }
1736aac854edSJunchao Zhang         // Copy Ut from host to device
1737aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->iUt_d, factors->iUt_h));
1738aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->jUt_d, factors->jUt_h));
1739aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->aUt_d, factors->aUt_h));
1740aac854edSJunchao Zhang       } else { // If U was computed on device, we also compute the transpose there
1741aac854edSJunchao Zhang         // TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices. We have to sort the indices, until KK provides finer control options.
1742aac854edSJunchao Zhang         PetscCallCXX(transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iU_d,
1743aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->jU_d, factors->aU_d,
1744aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->iUt_d, factors->jUt_d,
1745aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->aUt_d));
1746aac854edSJunchao Zhang         PetscCallCXX(sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iUt_d, factors->jUt_d, factors->aUt_d));
1747aac854edSJunchao Zhang       }
1748aac854edSJunchao Zhang       PetscCallCXX(sptrsv_symbolic(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d));
1749aac854edSJunchao Zhang     }
1750aac854edSJunchao Zhang 
1751aac854edSJunchao Zhang     // do the same for L with LU
1752aac854edSJunchao Zhang     if (has_lower) {
1753aac854edSJunchao Zhang       if (!factors->iLt_d.extent(0)) {                                 // Allocate Lt on device if not yet
1754aac854edSJunchao Zhang         factors->iLt_d = MatRowMapKokkosView("factors->iLt_d", n + 1); // KK requires this view to be initialized to 0 to call transpose_matrix
1755aac854edSJunchao Zhang         factors->jLt_d = MatColIdxKokkosView(NoInit("factors->jLt_d"), factors->jL_d.extent(0));
1756aac854edSJunchao Zhang         factors->aLt_d = MatScalarKokkosView(NoInit("factors->aLt_d"), factors->aL_d.extent(0));
1757aac854edSJunchao Zhang       }
1758aac854edSJunchao Zhang 
1759aac854edSJunchao Zhang       if (factors->iL_h.extent(0)) { // If L is on host, we also compute the transpose on host
1760aac854edSJunchao Zhang         if (!factors->L) {
1761aac854edSJunchao Zhang           Mat_SeqAIJ *seq;
1762aac854edSJunchao Zhang 
1763aac854edSJunchao Zhang           PetscCall(MatCreateSeqAIJWithArrays(PETSC_COMM_SELF, n, n, factors->iL_h.data(), factors->jL_h.data(), factors->aL_h.data(), &factors->L));
1764aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->L, MAT_INITIAL_MATRIX, &factors->Lt));
1765aac854edSJunchao Zhang 
1766aac854edSJunchao Zhang           seq            = static_cast<Mat_SeqAIJ *>(factors->Lt->data);
1767aac854edSJunchao Zhang           factors->iLt_h = MatRowMapKokkosViewHost(seq->i, n + 1);
1768aac854edSJunchao Zhang           factors->jLt_h = MatColIdxKokkosViewHost(seq->j, seq->nz);
1769aac854edSJunchao Zhang           factors->aLt_h = MatScalarKokkosViewHost(seq->a, seq->nz);
1770aac854edSJunchao Zhang         } else {
1771aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->L, MAT_REUSE_MATRIX, &factors->Lt)); // Matrix Lt' data is aliased with {i, j, a}Lt_h
1772aac854edSJunchao Zhang         }
1773aac854edSJunchao Zhang         // Copy Lt from host to device
1774aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->iLt_d, factors->iLt_h));
1775aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->jLt_d, factors->jLt_h));
1776aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->aLt_d, factors->aLt_h));
1777aac854edSJunchao Zhang       } else { // If L was computed on device, we also compute the transpose there
1778aac854edSJunchao Zhang         // TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices. We have to sort the indices, until KK provides finer control options.
1779aac854edSJunchao Zhang         PetscCallCXX(transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iL_d,
1780aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->jL_d, factors->aL_d,
1781aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->iLt_d, factors->jLt_d,
1782aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->aLt_d));
1783aac854edSJunchao Zhang         PetscCallCXX(sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iLt_d, factors->jLt_d, factors->aLt_d));
1784aac854edSJunchao Zhang       }
1785aac854edSJunchao Zhang       PetscCallCXX(sptrsv_symbolic(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d));
1786aac854edSJunchao Zhang     }
1787aac854edSJunchao Zhang 
178886a27549SJunchao Zhang     factors->transpose_updated = PETSC_TRUE;
178986a27549SJunchao Zhang   }
17903ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
179186a27549SJunchao Zhang }
179286a27549SJunchao Zhang 
1793aac854edSJunchao Zhang // Solve Ax = b, with RAR = U^T D U, where R is the row (and col) permutation matrix on A.
1794aac854edSJunchao Zhang // R is represented by rowperm in factors. If R is identity (i.e, no reordering), then rowperm is empty.
1795aac854edSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJKokkos_Cholesky(Mat A, Vec bb, Vec xx)
1796d71ae5a4SJacob Faibussowitsch {
1797aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
179886a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1799aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1800aac854edSJunchao Zhang   PetscScalarKokkosView       D       = factors->D_d;
1801aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1802aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1803aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1804aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm  = factors->rowperm;
1805aac854edSJunchao Zhang   PetscBool                   identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
180686a27549SJunchao Zhang 
180786a27549SJunchao Zhang   PetscFunctionBegin;
18089566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1809aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSolveCheck(A));          // for UX = T
1810aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A)); // for U^T Y = B
1811aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1812aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
1813aac854edSJunchao Zhang 
1814aac854edSJunchao Zhang   // Solve U^T Y = B
1815aac854edSJunchao Zhang   if (identity) { // Reorder b with the row permutation
1816aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1817aac854edSJunchao Zhang     Y = factors->workVector;
1818aac854edSJunchao Zhang   } else {
1819aac854edSJunchao Zhang     B = factors->workVector;
1820aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(rowperm(i)); }));
1821aac854edSJunchao Zhang     Y = x;
1822aac854edSJunchao Zhang   }
1823aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, B, Y));
1824aac854edSJunchao Zhang 
1825aac854edSJunchao Zhang   // Solve diag(D) Y' = Y.
1826aac854edSJunchao Zhang   // Actually just do Y' = Y*D since D is already inverted in MatCholeskyFactorNumeric_SeqAIJ(). It is basically a vector element-wise multiplication.
1827aac854edSJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { Y(i) = Y(i) * D(i); }));
1828aac854edSJunchao Zhang 
1829aac854edSJunchao Zhang   // Solve UX = Y
1830aac854edSJunchao Zhang   if (identity) {
1831aac854edSJunchao Zhang     X = x;
1832aac854edSJunchao Zhang   } else {
1833aac854edSJunchao Zhang     X = factors->workVector; // B is not needed anymore
1834aac854edSJunchao Zhang   }
1835aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, Y, X));
1836aac854edSJunchao Zhang 
1837aac854edSJunchao Zhang   // Reorder X with the inverse column (row) permutation
1838aac854edSJunchao Zhang   if (!identity) {
1839aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(rowperm(i)) = X(i); }));
1840aac854edSJunchao Zhang   }
1841aac854edSJunchao Zhang 
1842aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1843aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
18449566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
18453ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
184686a27549SJunchao Zhang }
184786a27549SJunchao Zhang 
1848aac854edSJunchao Zhang // Solve Ax = b, with RAC = LU, where R and C are row and col permutation matrices on A respectively.
1849aac854edSJunchao Zhang // R and C are represented by rowperm and colperm in factors.
1850aac854edSJunchao Zhang // If R or C is identity (i.e, no reordering), then rowperm or colperm is empty.
1851aac854edSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJKokkos_LU(Mat A, Vec bb, Vec xx)
1852d71ae5a4SJacob Faibussowitsch {
1853aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
185486a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1855aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1856aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1857aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1858aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1859aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm      = factors->rowperm;
1860aac854edSJunchao Zhang   PetscIntKokkosView         &colperm      = factors->colperm;
1861aac854edSJunchao Zhang   PetscBool                   row_identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1862aac854edSJunchao Zhang   PetscBool                   col_identity = colperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
186386a27549SJunchao Zhang 
186486a27549SJunchao Zhang   PetscFunctionBegin;
18659566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1866aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSolveCheck(A));
1867aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1868aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
186986a27549SJunchao Zhang 
1870aac854edSJunchao Zhang   // Solve L Y = B (i.e., L (U C^- x) = R b).  R b indicates applying the row permutation on b.
1871aac854edSJunchao Zhang   if (row_identity) {
1872aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1873aac854edSJunchao Zhang     Y = factors->workVector;
1874aac854edSJunchao Zhang   } else {
1875aac854edSJunchao Zhang     B = factors->workVector;
1876aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(rowperm(i)); }));
1877aac854edSJunchao Zhang     Y = x;
1878aac854edSJunchao Zhang   }
1879aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khL, factors->iL_d, factors->jL_d, factors->aL_d, B, Y));
1880aac854edSJunchao Zhang 
1881aac854edSJunchao Zhang   // Solve U C^- x = Y
1882aac854edSJunchao Zhang   if (col_identity) {
1883aac854edSJunchao Zhang     X = x;
1884aac854edSJunchao Zhang   } else {
1885aac854edSJunchao Zhang     X = factors->workVector;
1886aac854edSJunchao Zhang   }
1887aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, Y, X));
1888aac854edSJunchao Zhang 
1889aac854edSJunchao Zhang   // x = C X; Reorder X with the inverse col permutation
1890aac854edSJunchao Zhang   if (!col_identity) {
1891aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(colperm(i)) = X(i); }));
1892aac854edSJunchao Zhang   }
1893aac854edSJunchao Zhang 
1894aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1895aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
18969566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
18973ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
189886a27549SJunchao Zhang }
189986a27549SJunchao Zhang 
1900aac854edSJunchao Zhang // Solve A^T x = b, with RAC = LU, where R and C are row and col permutation matrices on A respectively.
1901aac854edSJunchao Zhang // R and C are represented by rowperm and colperm in factors.
1902aac854edSJunchao Zhang // If R or C is identity (i.e, no reordering), then rowperm or colperm is empty.
1903aac854edSJunchao Zhang // A = R^-1 L U C^-1, so A^T = C^-T U^T L^T R^-T. But since C^- = C^T, R^- = R^T, we have A^T = C U^T L^T R.
1904aac854edSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJKokkos_LU(Mat A, Vec bb, Vec xx)
1905aac854edSJunchao Zhang {
1906aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
1907aac854edSJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1908aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1909aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1910aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1911aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1912aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm      = factors->rowperm;
1913aac854edSJunchao Zhang   PetscIntKokkosView         &colperm      = factors->colperm;
1914aac854edSJunchao Zhang   PetscBool                   row_identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1915aac854edSJunchao Zhang   PetscBool                   col_identity = colperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1916aac854edSJunchao Zhang 
1917aac854edSJunchao Zhang   PetscFunctionBegin;
1918aac854edSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1919aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A)); // Update L^T, U^T if needed, and do sptrsv symbolic for L^T, U^T
1920aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1921aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
1922aac854edSJunchao Zhang 
1923aac854edSJunchao Zhang   // Solve U^T Y = B (i.e., U^T (L^T R x) = C^- b).  Note C^- b = C^T b, which means applying the column permutation on b.
1924aac854edSJunchao Zhang   if (col_identity) { // Reorder b with the col permutation
1925aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1926aac854edSJunchao Zhang     Y = factors->workVector;
1927aac854edSJunchao Zhang   } else {
1928aac854edSJunchao Zhang     B = factors->workVector;
1929aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(colperm(i)); }));
1930aac854edSJunchao Zhang     Y = x;
1931aac854edSJunchao Zhang   }
1932aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, B, Y));
1933aac854edSJunchao Zhang 
1934aac854edSJunchao Zhang   // Solve L^T X = Y
1935aac854edSJunchao Zhang   if (row_identity) {
1936aac854edSJunchao Zhang     X = x;
1937aac854edSJunchao Zhang   } else {
1938aac854edSJunchao Zhang     X = factors->workVector;
1939aac854edSJunchao Zhang   }
1940aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d, Y, X));
1941aac854edSJunchao Zhang 
1942aac854edSJunchao Zhang   // x = R^- X = R^T X; Reorder X with the inverse row permutation
1943aac854edSJunchao Zhang   if (!row_identity) {
1944aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(rowperm(i)) = X(i); }));
1945aac854edSJunchao Zhang   }
1946aac854edSJunchao Zhang 
1947aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1948aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
1949aac854edSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1950aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1951aac854edSJunchao Zhang }
1952aac854edSJunchao Zhang 
1953aac854edSJunchao Zhang static PetscErrorCode MatLUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1954aac854edSJunchao Zhang {
1955aac854edSJunchao Zhang   PetscFunctionBegin;
1956aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncHost(A));
1957aac854edSJunchao Zhang   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
1958aac854edSJunchao Zhang 
1959aac854edSJunchao Zhang   if (!info->solveonhost) { // if solve on host, then we don't need to copy L, U to device
1960aac854edSJunchao Zhang     Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
1961aac854edSJunchao Zhang     Mat_SeqAIJ                 *b       = static_cast<Mat_SeqAIJ *>(B->data);
1962aac854edSJunchao Zhang     const PetscInt             *Bi = b->i, *Bj = b->j, *Bdiag = b->diag;
1963aac854edSJunchao Zhang     const MatScalar            *Ba = b->a;
1964aac854edSJunchao Zhang     PetscInt                    m = B->rmap->n, n = B->cmap->n;
1965aac854edSJunchao Zhang 
1966aac854edSJunchao Zhang     if (factors->iL_h.extent(0) == 0) { // Allocate memory and copy the L, U structure for the first time
1967aac854edSJunchao Zhang       // Allocate memory and copy the structure
1968aac854edSJunchao Zhang       factors->iL_h = MatRowMapKokkosViewHost(NoInit("iL_h"), m + 1);
1969aac854edSJunchao Zhang       factors->jL_h = MatColIdxKokkosViewHost(NoInit("jL_h"), (Bi[m] - Bi[0]) + m); // + the diagonal entries
1970aac854edSJunchao Zhang       factors->aL_h = MatScalarKokkosViewHost(NoInit("aL_h"), (Bi[m] - Bi[0]) + m);
1971aac854edSJunchao Zhang       factors->iU_h = MatRowMapKokkosViewHost(NoInit("iU_h"), m + 1);
1972aac854edSJunchao Zhang       factors->jU_h = MatColIdxKokkosViewHost(NoInit("jU_h"), (Bdiag[0] - Bdiag[m]));
1973aac854edSJunchao Zhang       factors->aU_h = MatScalarKokkosViewHost(NoInit("aU_h"), (Bdiag[0] - Bdiag[m]));
1974aac854edSJunchao Zhang 
1975aac854edSJunchao Zhang       PetscInt *Li = factors->iL_h.data();
1976aac854edSJunchao Zhang       PetscInt *Lj = factors->jL_h.data();
1977aac854edSJunchao Zhang       PetscInt *Ui = factors->iU_h.data();
1978aac854edSJunchao Zhang       PetscInt *Uj = factors->jU_h.data();
1979aac854edSJunchao Zhang 
1980aac854edSJunchao Zhang       Li[0] = Ui[0] = 0;
1981aac854edSJunchao Zhang       for (PetscInt i = 0; i < m; i++) {
1982aac854edSJunchao Zhang         PetscInt llen = Bi[i + 1] - Bi[i];       // exclusive of the diagonal entry
1983aac854edSJunchao Zhang         PetscInt ulen = Bdiag[i] - Bdiag[i + 1]; // inclusive of the diagonal entry
1984aac854edSJunchao Zhang 
1985aac854edSJunchao Zhang         PetscArraycpy(Lj + Li[i], Bj + Bi[i], llen); // entries of L on the left of the diagonal
1986aac854edSJunchao Zhang         Lj[Li[i] + llen] = i;                        // diagonal entry of L
1987aac854edSJunchao Zhang 
1988aac854edSJunchao Zhang         Uj[Ui[i]] = i;                                                  // diagonal entry of U
1989aac854edSJunchao Zhang         PetscArraycpy(Uj + Ui[i] + 1, Bj + Bdiag[i + 1] + 1, ulen - 1); // entries of U on  the right of the diagonal
1990aac854edSJunchao Zhang 
1991aac854edSJunchao Zhang         Li[i + 1] = Li[i] + llen + 1;
1992aac854edSJunchao Zhang         Ui[i + 1] = Ui[i] + ulen;
1993aac854edSJunchao Zhang       }
1994aac854edSJunchao Zhang 
1995aac854edSJunchao Zhang       factors->iL_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iL_h);
1996aac854edSJunchao Zhang       factors->jL_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jL_h);
1997aac854edSJunchao Zhang       factors->iU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iU_h);
1998aac854edSJunchao Zhang       factors->jU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jU_h);
1999aac854edSJunchao Zhang       factors->aL_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aL_h);
2000aac854edSJunchao Zhang       factors->aU_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aU_h);
2001aac854edSJunchao Zhang 
2002aac854edSJunchao Zhang       // Copy row/col permutation to device
2003aac854edSJunchao Zhang       IS        rowperm = ((Mat_SeqAIJ *)B->data)->row;
2004aac854edSJunchao Zhang       PetscBool row_identity;
2005aac854edSJunchao Zhang       PetscCall(ISIdentity(rowperm, &row_identity));
2006aac854edSJunchao Zhang       if (!row_identity) {
2007aac854edSJunchao Zhang         const PetscInt *ip;
2008aac854edSJunchao Zhang 
2009aac854edSJunchao Zhang         PetscCall(ISGetIndices(rowperm, &ip));
2010aac854edSJunchao Zhang         factors->rowperm = PetscIntKokkosView(NoInit("rowperm"), m);
2011aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->rowperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), m)));
2012aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(rowperm, &ip));
2013aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(m * sizeof(PetscInt)));
2014aac854edSJunchao Zhang       }
2015aac854edSJunchao Zhang 
2016aac854edSJunchao Zhang       IS        colperm = ((Mat_SeqAIJ *)B->data)->col;
2017aac854edSJunchao Zhang       PetscBool col_identity;
2018aac854edSJunchao Zhang       PetscCall(ISIdentity(colperm, &col_identity));
2019aac854edSJunchao Zhang       if (!col_identity) {
2020aac854edSJunchao Zhang         const PetscInt *ip;
2021aac854edSJunchao Zhang 
2022aac854edSJunchao Zhang         PetscCall(ISGetIndices(colperm, &ip));
2023aac854edSJunchao Zhang         factors->colperm = PetscIntKokkosView(NoInit("colperm"), n);
2024aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->colperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), n)));
2025aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(colperm, &ip));
2026aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));
2027aac854edSJunchao Zhang       }
2028aac854edSJunchao Zhang 
2029aac854edSJunchao Zhang       /* Create sptrsv handles for L, U and their transpose */
2030aac854edSJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2031aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
2032aac854edSJunchao Zhang #else
2033aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
2034aac854edSJunchao Zhang #endif
2035aac854edSJunchao Zhang       factors->khL.create_sptrsv_handle(sptrsv_alg, m, true /* L is lower tri */);
2036aac854edSJunchao Zhang       factors->khU.create_sptrsv_handle(sptrsv_alg, m, false /* U is not lower tri */);
2037aac854edSJunchao Zhang       factors->khLt.create_sptrsv_handle(sptrsv_alg, m, false /* L^T is not lower tri */);
2038aac854edSJunchao Zhang       factors->khUt.create_sptrsv_handle(sptrsv_alg, m, true /* U^T is lower tri */);
2039aac854edSJunchao Zhang     }
2040aac854edSJunchao Zhang 
2041aac854edSJunchao Zhang     // Copy the value
2042aac854edSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
2043aac854edSJunchao Zhang       PetscInt        llen = Bi[i + 1] - Bi[i];
2044aac854edSJunchao Zhang       PetscInt        ulen = Bdiag[i] - Bdiag[i + 1];
2045aac854edSJunchao Zhang       const PetscInt *Li   = factors->iL_h.data();
2046aac854edSJunchao Zhang       const PetscInt *Ui   = factors->iU_h.data();
2047aac854edSJunchao Zhang 
2048aac854edSJunchao Zhang       PetscScalar *La = factors->aL_h.data();
2049aac854edSJunchao Zhang       PetscScalar *Ua = factors->aU_h.data();
2050aac854edSJunchao Zhang 
2051aac854edSJunchao Zhang       PetscArraycpy(La + Li[i], Ba + Bi[i], llen); // entries of L
2052aac854edSJunchao Zhang       La[Li[i] + llen] = 1.0;                      // diagonal entry
2053aac854edSJunchao Zhang 
2054aac854edSJunchao Zhang       Ua[Ui[i]] = 1.0 / Ba[Bdiag[i]];                                 // diagonal entry
2055aac854edSJunchao Zhang       PetscArraycpy(Ua + Ui[i] + 1, Ba + Bdiag[i + 1] + 1, ulen - 1); // entries of U
2056aac854edSJunchao Zhang     }
2057aac854edSJunchao Zhang 
2058aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aL_d, factors->aL_h));
2059aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aU_d, factors->aU_h));
2060aac854edSJunchao Zhang     // Once the factors' values have changed, we need to update their transpose and redo sptrsv symbolic
2061aac854edSJunchao Zhang     factors->transpose_updated         = PETSC_FALSE;
2062aac854edSJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_FALSE;
2063aac854edSJunchao Zhang 
2064aac854edSJunchao Zhang     B->ops->solve          = MatSolve_SeqAIJKokkos_LU;
2065aac854edSJunchao Zhang     B->ops->solvetranspose = MatSolveTranspose_SeqAIJKokkos_LU;
2066aac854edSJunchao Zhang   }
2067aac854edSJunchao Zhang 
2068aac854edSJunchao Zhang   B->ops->matsolve          = NULL;
2069aac854edSJunchao Zhang   B->ops->matsolvetranspose = NULL;
2070aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2071aac854edSJunchao Zhang }
2072aac854edSJunchao Zhang 
2073aac854edSJunchao Zhang static PetscErrorCode MatILUFactorNumeric_SeqAIJKokkos_ILU0(Mat B, Mat A, const MatFactorInfo *info)
2074d71ae5a4SJacob Faibussowitsch {
207586a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
207686a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
207786a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
207886a27549SJunchao Zhang 
207986a27549SJunchao Zhang   PetscFunctionBegin;
20809566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
2081aac854edSJunchao Zhang   PetscCheck(!info->factoronhost, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "MatFactorInfo.factoronhost should be false");
20829566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
2083076ba34aSJunchao Zhang 
2084076ba34aSJunchao Zhang   auto a_d = aijkok->a_dual.view_device();
2085076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
2086076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
2087076ba34aSJunchao Zhang 
2088aac854edSJunchao Zhang   PetscCallCXX(spiluk_numeric(&factors->kh, fill_lev, i_d, j_d, a_d, factors->iL_d, factors->jL_d, factors->aL_d, factors->iU_d, factors->jU_d, factors->aU_d));
208986a27549SJunchao Zhang 
209086a27549SJunchao Zhang   B->assembled              = PETSC_TRUE;
209186a27549SJunchao Zhang   B->preallocated           = PETSC_TRUE;
2092aac854edSJunchao Zhang   B->ops->solve             = MatSolve_SeqAIJKokkos_LU;
2093aac854edSJunchao Zhang   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJKokkos_LU;
209486a27549SJunchao Zhang   B->ops->matsolve          = NULL;
209586a27549SJunchao Zhang   B->ops->matsolvetranspose = NULL;
209686a27549SJunchao Zhang 
209786a27549SJunchao Zhang   /* Once the factors' value changed, we need to update their transpose and sptrsv handle */
209886a27549SJunchao Zhang   factors->transpose_updated         = PETSC_FALSE;
209986a27549SJunchao Zhang   factors->sptrsv_symbolic_completed = PETSC_FALSE;
2100eeadb341SJunchao Zhang   /* TODO: log flops, but how to know that? */
21019566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
21023ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
210386a27549SJunchao Zhang }
210486a27549SJunchao Zhang 
2105aac854edSJunchao Zhang // Use KK's spiluk_symbolic() to do ILU0 symbolic factorization, with no row/col reordering
2106aac854edSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos_ILU0(Mat B, Mat A, IS, IS, const MatFactorInfo *info)
2107d71ae5a4SJacob Faibussowitsch {
210886a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok;
210986a27549SJunchao Zhang   Mat_SeqAIJ                 *b;
211086a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
211186a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
211286a27549SJunchao Zhang   PetscInt                    nnzA     = ((Mat_SeqAIJ *)A->data)->nz, nnzL, nnzU;
211386a27549SJunchao Zhang   PetscInt                    n        = A->rmap->n;
211486a27549SJunchao Zhang 
211586a27549SJunchao Zhang   PetscFunctionBegin;
2116aac854edSJunchao Zhang   PetscCheck(!info->factoronhost, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "MatFactorInfo's factoronhost should be false as we are doing it on device right now");
21179566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
211886a27549SJunchao Zhang 
211986a27549SJunchao Zhang   /* Create a spiluk handle and then do symbolic factorization */
212086a27549SJunchao Zhang   nnzL = nnzU = PetscRealIntMultTruncate(info->fill, nnzA);
2121aac854edSJunchao Zhang   factors->kh.create_spiluk_handle(SPILUKAlgorithm::SEQLVLSCHD_TP1, n, nnzL, nnzU);
212286a27549SJunchao Zhang 
212386a27549SJunchao Zhang   auto spiluk_handle = factors->kh.get_spiluk_handle();
212486a27549SJunchao Zhang 
212586a27549SJunchao Zhang   Kokkos::realloc(factors->iL_d, n + 1); /* Free old arrays and realloc */
212686a27549SJunchao Zhang   Kokkos::realloc(factors->jL_d, spiluk_handle->get_nnzL());
212786a27549SJunchao Zhang   Kokkos::realloc(factors->iU_d, n + 1);
212886a27549SJunchao Zhang   Kokkos::realloc(factors->jU_d, spiluk_handle->get_nnzU());
212986a27549SJunchao Zhang 
213086a27549SJunchao Zhang   aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
2131076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
2132076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
2133aac854edSJunchao Zhang   PetscCallCXX(spiluk_symbolic(&factors->kh, fill_lev, i_d, j_d, factors->iL_d, factors->jL_d, factors->iU_d, factors->jU_d));
213486a27549SJunchao Zhang   /* TODO: if spiluk_symbolic is asynchronous, do we need to sync before calling get_nnzL()? */
213586a27549SJunchao Zhang 
213686a27549SJunchao Zhang   Kokkos::resize(factors->jL_d, spiluk_handle->get_nnzL()); /* Shrink or expand, and retain old value */
213786a27549SJunchao Zhang   Kokkos::resize(factors->jU_d, spiluk_handle->get_nnzU());
213886a27549SJunchao Zhang   Kokkos::realloc(factors->aL_d, spiluk_handle->get_nnzL()); /* No need to retain old value */
213986a27549SJunchao Zhang   Kokkos::realloc(factors->aU_d, spiluk_handle->get_nnzU());
214086a27549SJunchao Zhang 
214186a27549SJunchao Zhang   /* TODO: add options to select sptrsv algorithms */
214286a27549SJunchao Zhang   /* Create sptrsv handles for L, U and their transpose */
214386a27549SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2144aac854edSJunchao Zhang   auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
214586a27549SJunchao Zhang #else
2146aac854edSJunchao Zhang   auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
214786a27549SJunchao Zhang #endif
214886a27549SJunchao Zhang 
214986a27549SJunchao Zhang   factors->khL.create_sptrsv_handle(sptrsv_alg, n, true /* L is lower tri */);
215086a27549SJunchao Zhang   factors->khU.create_sptrsv_handle(sptrsv_alg, n, false /* U is not lower tri */);
215186a27549SJunchao Zhang   factors->khLt.create_sptrsv_handle(sptrsv_alg, n, false /* L^T is not lower tri */);
215286a27549SJunchao Zhang   factors->khUt.create_sptrsv_handle(sptrsv_alg, n, true /* U^T is lower tri */);
215386a27549SJunchao Zhang 
215486a27549SJunchao Zhang   /* Fill fields of the factor matrix B */
21559566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(B, MAT_SKIP_ALLOCATION, NULL));
215686a27549SJunchao Zhang   b     = (Mat_SeqAIJ *)B->data;
215786a27549SJunchao Zhang   b->nz = b->maxnz          = spiluk_handle->get_nnzL() + spiluk_handle->get_nnzU();
215886a27549SJunchao Zhang   B->info.fill_ratio_given  = info->fill;
2159a1e4e190SStefano Zampini   B->info.fill_ratio_needed = nnzA > 0 ? ((PetscReal)b->nz) / ((PetscReal)nnzA) : 1.0;
216086a27549SJunchao Zhang 
2161aac854edSJunchao Zhang   B->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJKokkos_ILU0;
21623ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2163930e68a5SMark Adams }
2164930e68a5SMark Adams 
2165aac854edSJunchao Zhang static PetscErrorCode MatLUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
2166aac854edSJunchao Zhang {
2167aac854edSJunchao Zhang   PetscFunctionBegin;
2168aac854edSJunchao Zhang   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
2169aac854edSJunchao Zhang   PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2170aac854edSJunchao Zhang   PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2171aac854edSJunchao Zhang   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
2172aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2173aac854edSJunchao Zhang }
2174aac854edSJunchao Zhang 
2175aac854edSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
2176aac854edSJunchao Zhang {
2177aac854edSJunchao Zhang   PetscBool row_identity = PETSC_FALSE, col_identity = PETSC_FALSE;
2178aac854edSJunchao Zhang 
2179aac854edSJunchao Zhang   PetscFunctionBegin;
2180aac854edSJunchao Zhang   if (!info->factoronhost) {
2181aac854edSJunchao Zhang     PetscCall(ISIdentity(isrow, &row_identity));
2182aac854edSJunchao Zhang     PetscCall(ISIdentity(iscol, &col_identity));
2183aac854edSJunchao Zhang   }
2184aac854edSJunchao Zhang 
2185aac854edSJunchao Zhang   PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2186aac854edSJunchao Zhang   PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2187aac854edSJunchao Zhang 
2188aac854edSJunchao Zhang   if (!info->factoronhost && !info->levels && row_identity && col_identity) { // if level 0 and no reordering
2189aac854edSJunchao Zhang     PetscCall(MatILUFactorSymbolic_SeqAIJKokkos_ILU0(B, A, isrow, iscol, info));
2190aac854edSJunchao Zhang   } else {
2191aac854edSJunchao Zhang     PetscCall(MatILUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info)); // otherwise, use PETSc's ILU on host
2192aac854edSJunchao Zhang     B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
2193aac854edSJunchao Zhang   }
2194aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2195aac854edSJunchao Zhang }
2196aac854edSJunchao Zhang 
2197aac854edSJunchao Zhang static PetscErrorCode MatCholeskyFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
2198aac854edSJunchao Zhang {
2199aac854edSJunchao Zhang   PetscFunctionBegin;
2200aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncHost(A));
2201aac854edSJunchao Zhang   PetscCall(MatCholeskyFactorNumeric_SeqAIJ(B, A, info));
2202aac854edSJunchao Zhang 
2203aac854edSJunchao Zhang   if (!info->solveonhost) { // if solve on host, then we don't need to copy L, U to device
2204aac854edSJunchao Zhang     Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
2205aac854edSJunchao Zhang     Mat_SeqAIJ                 *b       = static_cast<Mat_SeqAIJ *>(B->data);
2206aac854edSJunchao Zhang     const PetscInt             *Bi = b->i, *Bj = b->j, *Bdiag = b->diag;
2207aac854edSJunchao Zhang     const MatScalar            *Ba = b->a;
2208aac854edSJunchao Zhang     PetscInt                    m  = B->rmap->n;
2209aac854edSJunchao Zhang 
2210aac854edSJunchao Zhang     if (factors->iU_h.extent(0) == 0) { // First time of numeric factorization
2211aac854edSJunchao Zhang       // Allocate memory and copy the structure
2212aac854edSJunchao Zhang       factors->iU_h = PetscIntKokkosViewHost(const_cast<PetscInt *>(Bi), m + 1); // wrap Bi as iU_h
2213aac854edSJunchao Zhang       factors->jU_h = MatColIdxKokkosViewHost(NoInit("jU_h"), Bi[m]);
2214aac854edSJunchao Zhang       factors->aU_h = MatScalarKokkosViewHost(NoInit("aU_h"), Bi[m]);
2215aac854edSJunchao Zhang       factors->D_h  = MatScalarKokkosViewHost(NoInit("D_h"), m);
2216aac854edSJunchao Zhang       factors->aU_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aU_h);
2217aac854edSJunchao Zhang       factors->D_d  = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->D_h);
2218aac854edSJunchao Zhang 
2219aac854edSJunchao Zhang       // Build jU_h from the skewed Aj
2220aac854edSJunchao Zhang       PetscInt *Uj = factors->jU_h.data();
2221aac854edSJunchao Zhang       for (PetscInt i = 0; i < m; i++) {
2222aac854edSJunchao Zhang         PetscInt ulen = Bi[i + 1] - Bi[i];
2223aac854edSJunchao Zhang         Uj[Bi[i]]     = i;                                              // diagonal entry
2224aac854edSJunchao Zhang         PetscCall(PetscArraycpy(Uj + Bi[i] + 1, Bj + Bi[i], ulen - 1)); // entries of U on the right of the diagonal
2225aac854edSJunchao Zhang       }
2226aac854edSJunchao Zhang 
2227aac854edSJunchao Zhang       // Copy iU, jU to device
2228aac854edSJunchao Zhang       PetscCallCXX(factors->iU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iU_h));
2229aac854edSJunchao Zhang       PetscCallCXX(factors->jU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jU_h));
2230aac854edSJunchao Zhang 
2231aac854edSJunchao Zhang       // Copy row/col permutation to device
2232aac854edSJunchao Zhang       IS        rowperm = ((Mat_SeqAIJ *)B->data)->row;
2233aac854edSJunchao Zhang       PetscBool row_identity;
2234aac854edSJunchao Zhang       PetscCall(ISIdentity(rowperm, &row_identity));
2235aac854edSJunchao Zhang       if (!row_identity) {
2236aac854edSJunchao Zhang         const PetscInt *ip;
2237aac854edSJunchao Zhang 
2238aac854edSJunchao Zhang         PetscCall(ISGetIndices(rowperm, &ip));
2239aac854edSJunchao Zhang         PetscCallCXX(factors->rowperm = PetscIntKokkosView(NoInit("rowperm"), m));
2240aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->rowperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), m)));
2241aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(rowperm, &ip));
2242aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(m * sizeof(PetscInt)));
2243aac854edSJunchao Zhang       }
2244aac854edSJunchao Zhang 
2245aac854edSJunchao Zhang       // Create sptrsv handles for U and U^T
2246aac854edSJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2247aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
2248aac854edSJunchao Zhang #else
2249aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
2250aac854edSJunchao Zhang #endif
2251aac854edSJunchao Zhang       factors->khU.create_sptrsv_handle(sptrsv_alg, m, false /* U is not lower tri */);
2252aac854edSJunchao Zhang       factors->khUt.create_sptrsv_handle(sptrsv_alg, m, true /* U^T is lower tri */);
2253aac854edSJunchao Zhang     }
2254aac854edSJunchao Zhang     // These pointers were set MatCholeskyFactorNumeric_SeqAIJ(), so we always need to update them
2255aac854edSJunchao Zhang     B->ops->solve          = MatSolve_SeqAIJKokkos_Cholesky;
2256aac854edSJunchao Zhang     B->ops->solvetranspose = MatSolve_SeqAIJKokkos_Cholesky;
2257aac854edSJunchao Zhang 
2258aac854edSJunchao Zhang     // Copy the value
2259aac854edSJunchao Zhang     PetscScalar *Ua = factors->aU_h.data();
2260aac854edSJunchao Zhang     PetscScalar *D  = factors->D_h.data();
2261aac854edSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
2262aac854edSJunchao Zhang       D[i]      = Ba[Bdiag[i]];     // actually Aa[Adiag[i]] is the inverse of the diagonal
2263aac854edSJunchao Zhang       Ua[Bi[i]] = (PetscScalar)1.0; // set the unit diagonal for U
2264aac854edSJunchao Zhang       for (PetscInt k = 0; k < Bi[i + 1] - Bi[i] - 1; k++) Ua[Bi[i] + 1 + k] = -Ba[Bi[i] + k];
2265aac854edSJunchao Zhang     }
2266aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aU_d, factors->aU_h));
2267aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->D_d, factors->D_h));
2268aac854edSJunchao Zhang 
2269aac854edSJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_FALSE; // When numeric value changed, we must do these again
2270aac854edSJunchao Zhang     factors->transpose_updated         = PETSC_FALSE;
2271aac854edSJunchao Zhang   }
2272aac854edSJunchao Zhang 
2273aac854edSJunchao Zhang   B->ops->matsolve          = NULL;
2274aac854edSJunchao Zhang   B->ops->matsolvetranspose = NULL;
2275aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2276aac854edSJunchao Zhang }
2277aac854edSJunchao Zhang 
2278aac854edSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS perm, const MatFactorInfo *info)
2279aac854edSJunchao Zhang {
2280aac854edSJunchao Zhang   PetscFunctionBegin;
2281aac854edSJunchao Zhang   if (info->solveonhost) {
2282aac854edSJunchao Zhang     // If solve on host, we have to change the type, as eventually we need to call MatSolve_SeqSBAIJ_1_NaturalOrdering() etc.
2283aac854edSJunchao Zhang     PetscCall(MatSetType(B, MATSEQSBAIJ));
2284aac854edSJunchao Zhang     PetscCall(MatSeqSBAIJSetPreallocation(B, 1, MAT_SKIP_ALLOCATION, NULL));
2285aac854edSJunchao Zhang   }
2286aac854edSJunchao Zhang 
2287aac854edSJunchao Zhang   PetscCall(MatICCFactorSymbolic_SeqAIJ(B, A, perm, info));
2288aac854edSJunchao Zhang 
2289aac854edSJunchao Zhang   if (!info->solveonhost) {
2290bfe80ac4SPierre Jolivet     // If solve on device, B is still a MATSEQAIJKOKKOS, so we are good to allocate B->spptr
2291aac854edSJunchao Zhang     PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2292aac854edSJunchao Zhang     PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2293aac854edSJunchao Zhang     B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJKokkos;
2294aac854edSJunchao Zhang   }
2295aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2296aac854edSJunchao Zhang }
2297aac854edSJunchao Zhang 
2298aac854edSJunchao Zhang static PetscErrorCode MatCholeskyFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS perm, const MatFactorInfo *info)
2299aac854edSJunchao Zhang {
2300aac854edSJunchao Zhang   PetscFunctionBegin;
2301aac854edSJunchao Zhang   if (info->solveonhost) {
2302aac854edSJunchao Zhang     // If solve on host, we have to change the type, as eventually we need to call MatSolve_SeqSBAIJ_1_NaturalOrdering() etc.
2303aac854edSJunchao Zhang     PetscCall(MatSetType(B, MATSEQSBAIJ));
2304aac854edSJunchao Zhang     PetscCall(MatSeqSBAIJSetPreallocation(B, 1, MAT_SKIP_ALLOCATION, NULL));
2305aac854edSJunchao Zhang   }
2306aac854edSJunchao Zhang 
2307aac854edSJunchao Zhang   PetscCall(MatCholeskyFactorSymbolic_SeqAIJ(B, A, perm, info)); // it sets B's two ISes ((Mat_SeqAIJ*)B->data)->{row, col} to perm
2308aac854edSJunchao Zhang 
2309aac854edSJunchao Zhang   if (!info->solveonhost) {
2310bfe80ac4SPierre Jolivet     // If solve on device, B is still a MATSEQAIJKOKKOS, so we are good to allocate B->spptr
2311aac854edSJunchao Zhang     PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2312aac854edSJunchao Zhang     PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2313aac854edSJunchao Zhang     B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJKokkos;
2314aac854edSJunchao Zhang   }
2315aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2316aac854edSJunchao Zhang }
2317aac854edSJunchao Zhang 
2318aac854edSJunchao Zhang // The _Kokkos suffix means we will use Kokkos as a solver for the SeqAIJKokkos matrix
2319aac854edSJunchao Zhang static PetscErrorCode MatFactorGetSolverType_SeqAIJKokkos_Kokkos(Mat A, MatSolverType *type)
2320d71ae5a4SJacob Faibussowitsch {
2321930e68a5SMark Adams   PetscFunctionBegin;
2322930e68a5SMark Adams   *type = MATSOLVERKOKKOS;
23233ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2324930e68a5SMark Adams }
2325930e68a5SMark Adams 
2326930e68a5SMark Adams /*MC
232786a27549SJunchao Zhang   MATSOLVERKOKKOS = "Kokkos" - A matrix solver type providing triangular solvers for sequential matrices
232811a5261eSBarry Smith   on a single GPU of type, `MATSEQAIJKOKKOS`, `MATAIJKOKKOS`.
2329930e68a5SMark Adams 
2330930e68a5SMark Adams   Level: beginner
2331930e68a5SMark Adams 
23321cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJKokkos()`, `MATAIJKOKKOS`, `MatKokkosSetFormat()`, `MatKokkosStorageFormat`, `MatKokkosFormatOperation`
2333930e68a5SMark Adams M*/
233486a27549SJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_SeqAIJKokkos_Kokkos(Mat A, MatFactorType ftype, Mat *B) /* MatGetFactor_<MatType>_<MatSolverType> */
2335930e68a5SMark Adams {
2336930e68a5SMark Adams   PetscInt n = A->rmap->n;
2337aac854edSJunchao Zhang   MPI_Comm comm;
2338930e68a5SMark Adams 
2339930e68a5SMark Adams   PetscFunctionBegin;
2340aac854edSJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
2341aac854edSJunchao Zhang   PetscCall(MatCreate(comm, B));
23429566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*B, n, n, n, n));
2343aac854edSJunchao Zhang   PetscCall(MatSetBlockSizesFromMats(*B, A, A));
2344930e68a5SMark Adams   (*B)->factortype = ftype;
23459566063dSJacob Faibussowitsch   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
23469566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
2347aac854edSJunchao Zhang   PetscCheck(!(*B)->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2348aac854edSJunchao Zhang 
2349aac854edSJunchao Zhang   if (ftype == MAT_FACTOR_LU || ftype == MAT_FACTOR_ILU || ftype == MAT_FACTOR_ILUDT) {
2350aac854edSJunchao Zhang     (*B)->ops->lufactorsymbolic  = MatLUFactorSymbolic_SeqAIJKokkos;
2351aac854edSJunchao Zhang     (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJKokkos;
2352aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
2353aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILU]));
2354aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILUDT]));
2355aac854edSJunchao Zhang   } else if (ftype == MAT_FACTOR_CHOLESKY || ftype == MAT_FACTOR_ICC) {
2356aac854edSJunchao Zhang     (*B)->ops->iccfactorsymbolic      = MatICCFactorSymbolic_SeqAIJKokkos;
2357aac854edSJunchao Zhang     (*B)->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqAIJKokkos;
2358aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_CHOLESKY]));
2359aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ICC]));
2360aac854edSJunchao Zhang   } else SETERRQ(comm, PETSC_ERR_SUP, "MatFactorType %s is not supported by MatType SeqAIJKokkos", MatFactorTypes[ftype]);
2361aac854edSJunchao Zhang 
2362aac854edSJunchao Zhang   // The factorization can use the ordering provided in MatLUFactorSymbolic(), MatCholeskyFactorSymbolic() etc, though we do it on host
2363aac854edSJunchao Zhang   (*B)->canuseordering = PETSC_TRUE;
2364aac854edSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)*B, "MatFactorGetSolverType_C", MatFactorGetSolverType_SeqAIJKokkos_Kokkos));
23653ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2366930e68a5SMark Adams }
23678f7e8f9dSMark Adams 
2368aac854edSJunchao Zhang PETSC_INTERN PetscErrorCode MatSolverTypeRegister_Kokkos(void)
2369d71ae5a4SJacob Faibussowitsch {
237086a27549SJunchao Zhang   PetscFunctionBegin;
23719566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_SeqAIJKokkos_Kokkos));
2372aac854edSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_CHOLESKY, MatGetFactor_SeqAIJKokkos_Kokkos));
23739566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ILU, MatGetFactor_SeqAIJKokkos_Kokkos));
2374aac854edSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ICC, MatGetFactor_SeqAIJKokkos_Kokkos));
23753ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
237686a27549SJunchao Zhang }
237786a27549SJunchao Zhang 
2378076ba34aSJunchao Zhang /* Utility to print out a KokkosCsrMatrix for debugging */
2379d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PrintCsrMatrix(const KokkosCsrMatrix &csrmat)
2380d71ae5a4SJacob Faibussowitsch {
238145402d8aSJunchao Zhang   const auto        &iv = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.graph.row_map);
238245402d8aSJunchao Zhang   const auto        &jv = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.graph.entries);
238345402d8aSJunchao Zhang   const auto        &av = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.values);
2384076ba34aSJunchao Zhang   const PetscInt    *i  = iv.data();
2385076ba34aSJunchao Zhang   const PetscInt    *j  = jv.data();
2386076ba34aSJunchao Zhang   const PetscScalar *a  = av.data();
2387076ba34aSJunchao Zhang   PetscInt           m = csrmat.numRows(), n = csrmat.numCols(), nnz = csrmat.nnz();
2388076ba34aSJunchao Zhang 
2389076ba34aSJunchao Zhang   PetscFunctionBegin;
23909566063dSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT " x %" PetscInt_FMT " SeqAIJKokkos, with %" PetscInt_FMT " nonzeros\n", m, n, nnz));
2391076ba34aSJunchao Zhang   for (PetscInt k = 0; k < m; k++) {
23929566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT ": ", k));
239348a46eb9SPierre Jolivet     for (PetscInt p = i[k]; p < i[k + 1]; p++) PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT "(%.1f), ", j[p], (double)PetscRealPart(a[p])));
23949566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "\n"));
2395076ba34aSJunchao Zhang   }
23963ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2397076ba34aSJunchao Zhang }
2398