xref: /petsc/src/mat/impls/aij/seq/kokkos/aijkok.kokkos.cxx (revision d71ae5a4db6382e7f06317b8d368875286fe9008)
111d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
2076ba34aSJunchao Zhang #include <petscpkg_version.h>
3152b3e56SJunchao Zhang #include <petsc/private/petscimpl.h>
442550becSJunchao Zhang #include <petsc/private/sfimpl.h>
58c3ff71bSJunchao Zhang #include <petscsystypes.h>
68c3ff71bSJunchao Zhang #include <petscerror.h>
78c3ff71bSJunchao Zhang 
88c3ff71bSJunchao Zhang #include <Kokkos_Core.hpp>
9f0cf5187SStefano Zampini #include <KokkosBlas.hpp>
108c3ff71bSJunchao Zhang #include <KokkosSparse_CrsMatrix.hpp>
118c3ff71bSJunchao Zhang #include <KokkosSparse_spmv.hpp>
1286a27549SJunchao Zhang #include <KokkosSparse_spiluk.hpp>
1386a27549SJunchao Zhang #include <KokkosSparse_sptrsv.hpp>
14076ba34aSJunchao Zhang #include <KokkosSparse_spgemm.hpp>
15076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
1686a27549SJunchao Zhang 
1742550becSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
188c3ff71bSJunchao Zhang 
19f98996d3SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 6, 99)
20f98996d3SJunchao Zhang   #include <KokkosSparse_Utils.hpp>
21f98996d3SJunchao Zhang using KokkosSparse::sort_crs_matrix;
229371c9d4SSatish Balay using KokkosSparse::Impl::transpose_matrix;
23f98996d3SJunchao Zhang #else
24f98996d3SJunchao Zhang   #include <KokkosKernels_Sorting.hpp>
25f98996d3SJunchao Zhang using KokkosKernels::sort_crs_matrix;
269371c9d4SSatish Balay using KokkosKernels::Impl::transpose_matrix;
27f98996d3SJunchao Zhang #endif
28f98996d3SJunchao Zhang 
298c3ff71bSJunchao Zhang static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat); /* Forward declaration */
308c3ff71bSJunchao Zhang 
31076ba34aSJunchao Zhang /* MatAssemblyEnd_SeqAIJKokkos() happens when we finalized nonzeros of the matrix, either after
32076ba34aSJunchao Zhang    we assembled the matrix on host, or after we directly produced the matrix data on device (ex., through MatMatMult).
33076ba34aSJunchao Zhang    In the latter case, it is important to set a_dual's sync state correctly.
34076ba34aSJunchao Zhang  */
35*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_SeqAIJKokkos(Mat A, MatAssemblyType mode)
36*d71ae5a4SJacob Faibussowitsch {
37076ba34aSJunchao Zhang   Mat_SeqAIJ       *aijseq;
38076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
398c3ff71bSJunchao Zhang 
408c3ff71bSJunchao Zhang   PetscFunctionBegin;
41076ba34aSJunchao Zhang   if (mode == MAT_FLUSH_ASSEMBLY) PetscFunctionReturn(0);
429566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
43076ba34aSJunchao Zhang 
44076ba34aSJunchao Zhang   aijseq = static_cast<Mat_SeqAIJ *>(A->data);
45076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
46076ba34aSJunchao Zhang 
47076ba34aSJunchao Zhang   /* If aijkok does not exist, we just copy i, j to device.
48076ba34aSJunchao Zhang      If aijkok already exists, but the device's nonzero pattern does not match with the host's, we assume the latest data is on host.
49076ba34aSJunchao Zhang      In both cases, we build a new aijkok structure.
50076ba34aSJunchao Zhang   */
51076ba34aSJunchao Zhang   if (!aijkok || aijkok->nonzerostate != A->nonzerostate) { /* aijkok might not exist yet or nonzero pattern has changed */
52076ba34aSJunchao Zhang     delete aijkok;
53076ba34aSJunchao Zhang     aijkok   = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aijseq->nz, aijseq->i, aijseq->j, aijseq->a, A->nonzerostate, PETSC_FALSE /*don't copy mat values to device*/);
54076ba34aSJunchao Zhang     A->spptr = aijkok;
55076ba34aSJunchao Zhang   }
56076ba34aSJunchao Zhang 
575519a089SJose E. Roman   if (aijkok->device_mat_d.data()) {
58a587d139SMark     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
59a587d139SMark   }
608c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
618c3ff71bSJunchao Zhang }
628c3ff71bSJunchao Zhang 
6386a27549SJunchao Zhang /* Sync CSR data to device if not yet */
64*d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosSyncDevice(Mat A)
65*d71ae5a4SJacob Faibussowitsch {
668c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
678c3ff71bSJunchao Zhang 
688c3ff71bSJunchao Zhang   PetscFunctionBegin;
695f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Cann't sync factorized matrix from host to device");
705f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
71076ba34aSJunchao Zhang   if (aijkok->a_dual.need_sync_device()) {
72076ba34aSJunchao Zhang     aijkok->a_dual.sync_device();
73580c7c76SPierre Jolivet     aijkok->transpose_updated = PETSC_FALSE; /* values of the transpose is out-of-date */
7486a27549SJunchao Zhang     aijkok->hermitian_updated = PETSC_FALSE;
758c3ff71bSJunchao Zhang   }
768c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
778c3ff71bSJunchao Zhang }
788c3ff71bSJunchao Zhang 
79076ba34aSJunchao Zhang /* Mark the CSR data on device as modified */
80*d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosModifyDevice(Mat A)
81*d71ae5a4SJacob Faibussowitsch {
8286a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
8386a27549SJunchao Zhang 
8486a27549SJunchao Zhang   PetscFunctionBegin;
855f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Not supported for factorized matries");
8686a27549SJunchao Zhang   aijkok->a_dual.clear_sync_state();
8786a27549SJunchao Zhang   aijkok->a_dual.modify_device();
8886a27549SJunchao Zhang   aijkok->transpose_updated = PETSC_FALSE;
8986a27549SJunchao Zhang   aijkok->hermitian_updated = PETSC_FALSE;
909566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJInvalidateDiagonal(A));
919566063dSJacob Faibussowitsch   PetscCall(PetscObjectStateIncrease((PetscObject)A));
9286a27549SJunchao Zhang   PetscFunctionReturn(0);
9386a27549SJunchao Zhang }
9486a27549SJunchao Zhang 
95*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSyncHost(Mat A)
96*d71ae5a4SJacob Faibussowitsch {
97f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
98f0cf5187SStefano Zampini 
99f0cf5187SStefano Zampini   PetscFunctionBegin;
100f0cf5187SStefano Zampini   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10186a27549SJunchao Zhang   /* We do not expect one needs factors on host  */
1025f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Cann't sync factorized matrix from device to host");
1035f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing AIJKOK");
104076ba34aSJunchao Zhang   aijkok->a_dual.sync_host();
105f0cf5187SStefano Zampini   PetscFunctionReturn(0);
106f0cf5187SStefano Zampini }
107f0cf5187SStefano Zampini 
108*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
109*d71ae5a4SJacob Faibussowitsch {
110076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
111f0cf5187SStefano Zampini 
112f0cf5187SStefano Zampini   PetscFunctionBegin;
1135519a089SJose E. Roman   /* aijkok contains valid pointers only if the host's nonzerostate matches with the device's.
1145519a089SJose E. Roman     Calling MatSeqAIJSetPreallocation() or MatSetValues() on host, where aijseq->{i,j,a} might be
1155519a089SJose E. Roman     reallocated, will lead to stale {i,j,a}_dual in aijkok. In both operations, the hosts's nonzerostate
1165519a089SJose E. Roman     must have been updated. The stale aijkok will be rebuilt during MatAssemblyEnd.
1175519a089SJose E. Roman   */
1185519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
119076ba34aSJunchao Zhang     aijkok->a_dual.sync_host();
120076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
121076ba34aSJunchao Zhang   } else { /* Happens when calling MatSetValues on a newly created matrix */
122076ba34aSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
123076ba34aSJunchao Zhang   }
124076ba34aSJunchao Zhang   PetscFunctionReturn(0);
125076ba34aSJunchao Zhang }
126076ba34aSJunchao Zhang 
127*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
128*d71ae5a4SJacob Faibussowitsch {
129076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
130076ba34aSJunchao Zhang 
131076ba34aSJunchao Zhang   PetscFunctionBegin;
1325519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) aijkok->a_dual.modify_host();
133076ba34aSJunchao Zhang   PetscFunctionReturn(0);
134076ba34aSJunchao Zhang }
135076ba34aSJunchao Zhang 
136*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
137*d71ae5a4SJacob Faibussowitsch {
138076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
139076ba34aSJunchao Zhang 
140076ba34aSJunchao Zhang   PetscFunctionBegin;
1415519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
142076ba34aSJunchao Zhang     aijkok->a_dual.sync_host();
143076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1442328674fSJunchao Zhang   } else {
1452328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1462328674fSJunchao Zhang   }
147076ba34aSJunchao Zhang   PetscFunctionReturn(0);
148076ba34aSJunchao Zhang }
149076ba34aSJunchao Zhang 
150*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
151*d71ae5a4SJacob Faibussowitsch {
152076ba34aSJunchao Zhang   PetscFunctionBegin;
153076ba34aSJunchao Zhang   *array = NULL;
154076ba34aSJunchao Zhang   PetscFunctionReturn(0);
155076ba34aSJunchao Zhang }
156076ba34aSJunchao Zhang 
157*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
158*d71ae5a4SJacob Faibussowitsch {
159076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
160076ba34aSJunchao Zhang 
161076ba34aSJunchao Zhang   PetscFunctionBegin;
1625519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
163076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1642328674fSJunchao Zhang   } else { /* Ex. happens with MatZeroEntries on a preallocated but not assembled matrix */
1652328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1662328674fSJunchao Zhang   }
167076ba34aSJunchao Zhang   PetscFunctionReturn(0);
168076ba34aSJunchao Zhang }
169076ba34aSJunchao Zhang 
170*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
171*d71ae5a4SJacob Faibussowitsch {
172076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
173076ba34aSJunchao Zhang 
174076ba34aSJunchao Zhang   PetscFunctionBegin;
1755519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
176076ba34aSJunchao Zhang     aijkok->a_dual.clear_sync_state();
177076ba34aSJunchao Zhang     aijkok->a_dual.modify_host();
1782328674fSJunchao Zhang   }
179f0cf5187SStefano Zampini   PetscFunctionReturn(0);
180f0cf5187SStefano Zampini }
181f0cf5187SStefano Zampini 
182*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJKokkos(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
183*d71ae5a4SJacob Faibussowitsch {
1847ee59b9bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1857ee59b9bSJunchao Zhang 
1867ee59b9bSJunchao Zhang   PetscFunctionBegin;
1877ee59b9bSJunchao Zhang   PetscCheck(aijkok != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "aijkok is NULL");
1887ee59b9bSJunchao Zhang 
1897ee59b9bSJunchao Zhang   if (i) *i = aijkok->i_device_data();
1907ee59b9bSJunchao Zhang   if (j) *j = aijkok->j_device_data();
1917ee59b9bSJunchao Zhang   if (a) {
1927ee59b9bSJunchao Zhang     aijkok->a_dual.sync_device();
1937ee59b9bSJunchao Zhang     *a = aijkok->a_device_data();
1947ee59b9bSJunchao Zhang   }
1957ee59b9bSJunchao Zhang   if (mtype) *mtype = PETSC_MEMTYPE_KOKKOS;
1967ee59b9bSJunchao Zhang   PetscFunctionReturn(0);
1977ee59b9bSJunchao Zhang }
1987ee59b9bSJunchao Zhang 
199a587d139SMark // MatSeqAIJKokkosSetDeviceMat takes a PetscSplitCSRDataStructure with device data and copies it to the device. Note, "deep_copy" here is really a shallow copy
200*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosSetDeviceMat(Mat A, PetscSplitCSRDataStructure h_mat)
201*d71ae5a4SJacob Faibussowitsch {
202042217e8SBarry Smith   Mat_SeqAIJKokkos                            *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
203042217e8SBarry Smith   Kokkos::View<SplitCSRMat, Kokkos::HostSpace> h_mat_k(h_mat);
204a587d139SMark 
205a587d139SMark   PetscFunctionBegin;
2065f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
207152b3e56SJunchao Zhang   aijkok->device_mat_d = create_mirror(DefaultMemorySpace(), h_mat_k);
208a587d139SMark   Kokkos::deep_copy(aijkok->device_mat_d, h_mat_k);
209a587d139SMark   PetscFunctionReturn(0);
210a587d139SMark }
211a587d139SMark 
212a587d139SMark // MatSeqAIJKokkosGetDeviceMat gets the device if it is here, otherwise it creates a place for it and returns NULL
213*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosGetDeviceMat(Mat A, PetscSplitCSRDataStructure *d_mat)
214*d71ae5a4SJacob Faibussowitsch {
215042217e8SBarry Smith   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
216a587d139SMark 
217a587d139SMark   PetscFunctionBegin;
218a587d139SMark   if (aijkok && aijkok->device_mat_d.data()) {
219a587d139SMark     *d_mat = aijkok->device_mat_d.data();
220a587d139SMark   } else {
2219566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosSyncDevice(A)); // create aijkok (we are making d_mat now so make a place for it)
222a587d139SMark     *d_mat = NULL;
223a587d139SMark   }
224a587d139SMark   PetscFunctionReturn(0);
225a587d139SMark }
226076ba34aSJunchao Zhang 
227076ba34aSJunchao Zhang /* Generate the transpose on device and cache it internally */
228*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosGenerateTranspose_Private(Mat A, KokkosCsrMatrix **csrmatT)
229*d71ae5a4SJacob Faibussowitsch {
230152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
231152b3e56SJunchao Zhang 
232152b3e56SJunchao Zhang   PetscFunctionBegin;
2335f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
234076ba34aSJunchao Zhang   if (!aijkok->csrmatT.nnz() || !aijkok->transpose_updated) { /* Generate At for the first time OR just update its values */
235076ba34aSJunchao Zhang     /* FIXME: KK does not separate symbolic/numeric transpose. We could have a permutation array to help value-only update */
2369566063dSJacob Faibussowitsch     PetscCallCXX(aijkok->a_dual.sync_device());
237f98996d3SJunchao Zhang     PetscCallCXX(aijkok->csrmatT = transpose_matrix(aijkok->csrmat));
238f98996d3SJunchao Zhang     PetscCallCXX(sort_crs_matrix(aijkok->csrmatT));
23986a27549SJunchao Zhang     aijkok->transpose_updated = PETSC_TRUE;
240076ba34aSJunchao Zhang   }
241076ba34aSJunchao Zhang   *csrmatT = &aijkok->csrmatT;
242152b3e56SJunchao Zhang   PetscFunctionReturn(0);
243152b3e56SJunchao Zhang }
244152b3e56SJunchao Zhang 
245076ba34aSJunchao Zhang /* Generate the Hermitian on device and cache it internally */
246*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosGenerateHermitian_Private(Mat A, KokkosCsrMatrix **csrmatH)
247*d71ae5a4SJacob Faibussowitsch {
248152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
249152b3e56SJunchao Zhang 
250152b3e56SJunchao Zhang   PetscFunctionBegin;
2519566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
2525f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
253076ba34aSJunchao Zhang   if (!aijkok->csrmatH.nnz() || !aijkok->hermitian_updated) { /* Generate Ah for the first time OR just update its values */
2549566063dSJacob Faibussowitsch     PetscCallCXX(aijkok->a_dual.sync_device());
255f98996d3SJunchao Zhang     PetscCallCXX(aijkok->csrmatH = transpose_matrix(aijkok->csrmat));
256f98996d3SJunchao Zhang     PetscCallCXX(sort_crs_matrix(aijkok->csrmatH));
257076ba34aSJunchao Zhang #if defined(PETSC_USE_COMPLEX)
258076ba34aSJunchao Zhang     const auto &a = aijkok->csrmatH.values;
2599371c9d4SSatish Balay     Kokkos::parallel_for(
2609371c9d4SSatish Balay       a.extent(0), KOKKOS_LAMBDA(MatRowMapType i) { a(i) = PetscConj(a(i)); });
261076ba34aSJunchao Zhang #endif
26286a27549SJunchao Zhang     aijkok->hermitian_updated = PETSC_TRUE;
263076ba34aSJunchao Zhang   }
264076ba34aSJunchao Zhang   *csrmatH = &aijkok->csrmatH;
2659566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
266152b3e56SJunchao Zhang   PetscFunctionReturn(0);
267152b3e56SJunchao Zhang }
268a587d139SMark 
2698c3ff71bSJunchao Zhang /* y = A x */
270*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMult_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
271*d71ae5a4SJacob Faibussowitsch {
2728c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
273152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
274152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
2758c3ff71bSJunchao Zhang 
2768c3ff71bSJunchao Zhang   PetscFunctionBegin;
2779566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
2789566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
2799566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
2809566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
2818c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
282152b3e56SJunchao Zhang   KokkosSparse::spmv("N", 1.0 /*alpha*/, aijkok->csrmat, xv, 0.0 /*beta*/, yv); /* y = alpha A x + beta y */
2839566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
2849566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
285076ba34aSJunchao Zhang   /* 2.0*nnz - numRows seems more accurate here but assumes there are no zero-rows. So a little sloppy here. */
2869566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
2879566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
2888c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
2898c3ff71bSJunchao Zhang }
2908c3ff71bSJunchao Zhang 
2918c3ff71bSJunchao Zhang /* y = A^T x */
292*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
293*d71ae5a4SJacob Faibussowitsch {
2948c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
295152b3e56SJunchao Zhang   const char                *mode;
296152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
297152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
298076ba34aSJunchao Zhang   KokkosCsrMatrix           *csrmat;
2998c3ff71bSJunchao Zhang 
3008c3ff71bSJunchao Zhang   PetscFunctionBegin;
3019566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3029566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3039566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3049566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
305152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
3069566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
307152b3e56SJunchao Zhang     mode = "N";
308152b3e56SJunchao Zhang   } else {
309076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
310076ba34aSJunchao Zhang     csrmat = &aijkok->csrmat;
311152b3e56SJunchao Zhang     mode   = "T";
312152b3e56SJunchao Zhang   }
313076ba34aSJunchao Zhang   KokkosSparse::spmv(mode, 1.0 /*alpha*/, *csrmat, xv, 0.0 /*beta*/, yv); /* y = alpha A^T x + beta y */
3149566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3159566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
3169566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * csrmat->nnz()));
3179566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
3188c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
3198c3ff71bSJunchao Zhang }
3208c3ff71bSJunchao Zhang 
3218c3ff71bSJunchao Zhang /* y = A^H x */
322*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
323*d71ae5a4SJacob Faibussowitsch {
3248c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
325152b3e56SJunchao Zhang   const char                *mode;
326152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
327152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
328076ba34aSJunchao Zhang   KokkosCsrMatrix           *csrmat;
3298c3ff71bSJunchao Zhang 
3308c3ff71bSJunchao Zhang   PetscFunctionBegin;
3319566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3329566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3339566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3349566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
335152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
3369566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
337152b3e56SJunchao Zhang     mode = "N";
338152b3e56SJunchao Zhang   } else {
339076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
340076ba34aSJunchao Zhang     csrmat = &aijkok->csrmat;
341152b3e56SJunchao Zhang     mode   = "C";
342152b3e56SJunchao Zhang   }
343076ba34aSJunchao Zhang   KokkosSparse::spmv(mode, 1.0 /*alpha*/, *csrmat, xv, 0.0 /*beta*/, yv); /* y = alpha A^H x + beta y */
3449566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3459566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
3469566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * csrmat->nnz()));
3479566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
3488c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
3498c3ff71bSJunchao Zhang }
3508c3ff71bSJunchao Zhang 
3518c3ff71bSJunchao Zhang /* z = A x + y */
352*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
353*d71ae5a4SJacob Faibussowitsch {
3548c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
355152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
356152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
3578c3ff71bSJunchao Zhang 
3588c3ff71bSJunchao Zhang   PetscFunctionBegin;
3599566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3609566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3619566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3629566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
3639566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
3648c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
3658c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
366152b3e56SJunchao Zhang   KokkosSparse::spmv("N", 1.0 /*alpha*/, aijkok->csrmat, xv, 1.0 /*beta*/, zv); /* z = alpha A x + beta z */
3679566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3689566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
3699566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
3709566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
3719566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
3728c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
3738c3ff71bSJunchao Zhang }
3748c3ff71bSJunchao Zhang 
3758c3ff71bSJunchao Zhang /* z = A^T x + y */
376*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
377*d71ae5a4SJacob Faibussowitsch {
3788c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
379152b3e56SJunchao Zhang   const char                *mode;
380152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
381152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
382076ba34aSJunchao Zhang   KokkosCsrMatrix           *csrmat;
3838c3ff71bSJunchao Zhang 
3848c3ff71bSJunchao Zhang   PetscFunctionBegin;
3859566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3869566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3879566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3889566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
3899566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
3908c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
391152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
3929566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
393152b3e56SJunchao Zhang     mode = "N";
394152b3e56SJunchao Zhang   } else {
395076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
396076ba34aSJunchao Zhang     csrmat = &aijkok->csrmat;
397152b3e56SJunchao Zhang     mode   = "T";
398152b3e56SJunchao Zhang   }
399076ba34aSJunchao Zhang   KokkosSparse::spmv(mode, 1.0 /*alpha*/, *csrmat, xv, 1.0 /*beta*/, zv); /* z = alpha A^T x + beta z */
4009566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4019566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
4029566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
4039566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * csrmat->nnz()));
4049566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4058c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
4068c3ff71bSJunchao Zhang }
4078c3ff71bSJunchao Zhang 
4088c3ff71bSJunchao Zhang /* z = A^H x + y */
409*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
410*d71ae5a4SJacob Faibussowitsch {
4118c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
412152b3e56SJunchao Zhang   const char                *mode;
413152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
414152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
415076ba34aSJunchao Zhang   KokkosCsrMatrix           *csrmat;
4168c3ff71bSJunchao Zhang 
4178c3ff71bSJunchao Zhang   PetscFunctionBegin;
4189566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4199566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4209566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4219566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
4229566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
4238c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
424152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4259566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
426152b3e56SJunchao Zhang     mode = "N";
427152b3e56SJunchao Zhang   } else {
428076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
429076ba34aSJunchao Zhang     csrmat = &aijkok->csrmat;
430152b3e56SJunchao Zhang     mode   = "C";
431152b3e56SJunchao Zhang   }
432076ba34aSJunchao Zhang   KokkosSparse::spmv(mode, 1.0 /*alpha*/, *csrmat, xv, 1.0 /*beta*/, zv); /* z = alpha A^H x + beta z */
4339566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4349566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
4359566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
4369566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * csrmat->nnz()));
4379566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
438152b3e56SJunchao Zhang   PetscFunctionReturn(0);
439152b3e56SJunchao Zhang }
440152b3e56SJunchao Zhang 
441*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSetOption_SeqAIJKokkos(Mat A, MatOption op, PetscBool flg)
442*d71ae5a4SJacob Faibussowitsch {
443152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
444152b3e56SJunchao Zhang 
445152b3e56SJunchao Zhang   PetscFunctionBegin;
446152b3e56SJunchao Zhang   switch (op) {
447152b3e56SJunchao Zhang   case MAT_FORM_EXPLICIT_TRANSPOSE:
448152b3e56SJunchao Zhang     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
4499566063dSJacob Faibussowitsch     if (A->form_explicit_transpose && !flg && aijkok) PetscCall(aijkok->DestroyMatTranspose());
450152b3e56SJunchao Zhang     A->form_explicit_transpose = flg;
451152b3e56SJunchao Zhang     break;
452*d71ae5a4SJacob Faibussowitsch   default:
453*d71ae5a4SJacob Faibussowitsch     PetscCall(MatSetOption_SeqAIJ(A, op, flg));
454*d71ae5a4SJacob Faibussowitsch     break;
455152b3e56SJunchao Zhang   }
4568c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
4578c3ff71bSJunchao Zhang }
4588c3ff71bSJunchao Zhang 
459076ba34aSJunchao Zhang /* Depending on reuse, either build a new mat, or use the existing mat */
460*d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
461*d71ae5a4SJacob Faibussowitsch {
462076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
4638c3ff71bSJunchao Zhang 
4648c3ff71bSJunchao Zhang   PetscFunctionBegin;
4659566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
466076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {                      /* Build a brand new mat */
4679566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));  /* the returned newmat is a SeqAIJKokkos */
4688c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {                 /* Reuse the mat created before */
4699566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); /* newmat is already a SeqAIJKokkos */
470076ba34aSJunchao Zhang   } else if (reuse == MAT_INPLACE_MATRIX) {               /* newmat is A */
4715f80ce2aSJacob Faibussowitsch     PetscCheck(A == *newmat, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "A != *newmat with MAT_INPLACE_MATRIX");
4729566063dSJacob Faibussowitsch     PetscCall(PetscFree(A->defaultvectype));
4739566063dSJacob Faibussowitsch     PetscCall(PetscStrallocpy(VECKOKKOS, &A->defaultvectype)); /* Allocate and copy the string */
4749566063dSJacob Faibussowitsch     PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSEQAIJKOKKOS));
4759566063dSJacob Faibussowitsch     PetscCall(MatSetOps_SeqAIJKokkos(A));
476076ba34aSJunchao Zhang     aseq = static_cast<Mat_SeqAIJ *>(A->data);
477394ed5ebSJunchao Zhang     if (A->assembled) { /* Copy i, j (but not values) to device for an assembled matrix if not yet */
4785f80ce2aSJacob Faibussowitsch       PetscCheck(!A->spptr, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Expect NULL (Mat_SeqAIJKokkos*)A->spptr");
479076ba34aSJunchao Zhang       A->spptr = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aseq->nz, aseq->i, aseq->j, aseq->a, A->nonzerostate, PETSC_FALSE);
4808c3ff71bSJunchao Zhang     }
481076ba34aSJunchao Zhang   }
4828c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
4838c3ff71bSJunchao Zhang }
4848c3ff71bSJunchao Zhang 
485076ba34aSJunchao Zhang /* MatDuplicate always creates a new matrix. MatDuplicate can be called either on an assembled matrix or
486076ba34aSJunchao Zhang    an unassembled matrix, even though MAT_COPY_VALUES is not allowed for unassembled matrix.
487076ba34aSJunchao Zhang  */
488*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDuplicate_SeqAIJKokkos(Mat A, MatDuplicateOption dupOption, Mat *B)
489*d71ae5a4SJacob Faibussowitsch {
490076ba34aSJunchao Zhang   Mat_SeqAIJ       *bseq;
491076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok;
492076ba34aSJunchao Zhang   Mat               mat;
4938c3ff71bSJunchao Zhang 
4948c3ff71bSJunchao Zhang   PetscFunctionBegin;
495076ba34aSJunchao Zhang   /* Do not copy values on host as A's latest values might be on device. We don't want to do sync blindly */
4969566063dSJacob Faibussowitsch   PetscCall(MatDuplicate_SeqAIJ(A, MAT_DO_NOT_COPY_VALUES, B));
497076ba34aSJunchao Zhang   mat = *B;
498076ba34aSJunchao Zhang   if (A->assembled) {
499076ba34aSJunchao Zhang     bseq = static_cast<Mat_SeqAIJ *>(mat->data);
500076ba34aSJunchao Zhang     bkok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, bseq->nz, bseq->i, bseq->j, bseq->a, mat->nonzerostate, PETSC_FALSE);
501076ba34aSJunchao Zhang     bkok->a_dual.clear_sync_state(); /* Clear B's sync state as it will be decided below */
502076ba34aSJunchao Zhang     /* Now copy values to B if needed */
503076ba34aSJunchao Zhang     if (dupOption == MAT_COPY_VALUES) {
504076ba34aSJunchao Zhang       if (akok->a_dual.need_sync_device()) {
505076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_host(), akok->a_dual.view_host());
506076ba34aSJunchao Zhang         bkok->a_dual.modify_host();
507076ba34aSJunchao Zhang       } else { /* If device has the latest data, we only copy data on device */
508076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_device(), akok->a_dual.view_device());
509076ba34aSJunchao Zhang         bkok->a_dual.modify_device();
510076ba34aSJunchao Zhang       }
511076ba34aSJunchao Zhang     } else { /* MAT_DO_NOT_COPY_VALUES or MAT_SHARE_NONZERO_PATTERN. B's values should be zeroed */
512076ba34aSJunchao Zhang       /* B's values on host should be already zeroed by MatDuplicate_SeqAIJ() */
513076ba34aSJunchao Zhang       bkok->a_dual.modify_host();
514076ba34aSJunchao Zhang     }
515076ba34aSJunchao Zhang     mat->spptr = bkok;
516076ba34aSJunchao Zhang   }
517076ba34aSJunchao Zhang 
5189566063dSJacob Faibussowitsch   PetscCall(PetscFree(mat->defaultvectype));
5199566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &mat->defaultvectype)); /* Allocate and copy the string */
5209566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)mat, MATSEQAIJKOKKOS));
5219566063dSJacob Faibussowitsch   PetscCall(MatSetOps_SeqAIJKokkos(mat));
5228c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
5238c3ff71bSJunchao Zhang }
5248c3ff71bSJunchao Zhang 
525*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatTranspose_SeqAIJKokkos(Mat A, MatReuse reuse, Mat *B)
526*d71ae5a4SJacob Faibussowitsch {
5270ecb592aSJunchao Zhang   Mat               At;
528ff751488SJunchao Zhang   KokkosCsrMatrix  *internT;
5290ecb592aSJunchao Zhang   Mat_SeqAIJKokkos *atkok, *bkok;
5300ecb592aSJunchao Zhang 
5310ecb592aSJunchao Zhang   PetscFunctionBegin;
5327fb60732SBarry Smith   if (reuse == MAT_REUSE_MATRIX) PetscCall(MatTransposeCheckNonzeroState_Private(A, *B));
5339566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &internT)); /* Generate a transpose internally */
5340ecb592aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX || reuse == MAT_INPLACE_MATRIX) {
535ff751488SJunchao Zhang     /* Deep copy internT, as we want to isolate the internal transpose */
5369566063dSJacob Faibussowitsch     PetscCallCXX(atkok = new Mat_SeqAIJKokkos(KokkosCsrMatrix("csrmat", *internT)));
5379566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PetscObjectComm((PetscObject)A), atkok, &At));
5380ecb592aSJunchao Zhang     if (reuse == MAT_INITIAL_MATRIX) *B = At;
5399566063dSJacob Faibussowitsch     else PetscCall(MatHeaderReplace(A, &At)); /* Replace A with At inplace */
5400ecb592aSJunchao Zhang   } else {                                    /* MAT_REUSE_MATRIX, just need to copy values to B on device */
5410ecb592aSJunchao Zhang     if ((*B)->assembled) {
5420ecb592aSJunchao Zhang       bkok = static_cast<Mat_SeqAIJKokkos *>((*B)->spptr);
5439566063dSJacob Faibussowitsch       PetscCallCXX(Kokkos::deep_copy(bkok->a_dual.view_device(), internT->values));
5449566063dSJacob Faibussowitsch       PetscCall(MatSeqAIJKokkosModifyDevice(*B));
5450ecb592aSJunchao Zhang     } else if ((*B)->preallocated) { /* It is ok for B to be only preallocated, as needed in MatTranspose_MPIAIJ */
5460ecb592aSJunchao Zhang       Mat_SeqAIJ             *bseq = static_cast<Mat_SeqAIJ *>((*B)->data);
5470ecb592aSJunchao Zhang       MatScalarKokkosViewHost a_h(bseq->a, internT->nnz()); /* bseq->nz = 0 if unassembled */
5480ecb592aSJunchao Zhang       MatColIdxKokkosViewHost j_h(bseq->j, internT->nnz());
5499566063dSJacob Faibussowitsch       PetscCallCXX(Kokkos::deep_copy(a_h, internT->values));
5509566063dSJacob Faibussowitsch       PetscCallCXX(Kokkos::deep_copy(j_h, internT->graph.entries));
5510ecb592aSJunchao Zhang     } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "B must be assembled or preallocated");
5520ecb592aSJunchao Zhang   }
5530ecb592aSJunchao Zhang   PetscFunctionReturn(0);
5540ecb592aSJunchao Zhang }
5550ecb592aSJunchao Zhang 
556*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDestroy_SeqAIJKokkos(Mat A)
557*d71ae5a4SJacob Faibussowitsch {
55886a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
5598c3ff71bSJunchao Zhang 
5608c3ff71bSJunchao Zhang   PetscFunctionBegin;
56186a27549SJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
56286a27549SJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
5638c3ff71bSJunchao Zhang     delete aijkok;
56486a27549SJunchao Zhang   } else {
56586a27549SJunchao Zhang     delete static_cast<Mat_SeqAIJKokkosTriFactors *>(A->spptr);
56686a27549SJunchao Zhang   }
567cbc6b225SStefano Zampini   A->spptr = NULL;
5689566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
5699566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
5709566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
5719566063dSJacob Faibussowitsch   PetscCall(MatDestroy_SeqAIJ(A));
5728c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
5738c3ff71bSJunchao Zhang }
5748c3ff71bSJunchao Zhang 
5753f3ba80aSJunchao Zhang /*MC
5763f3ba80aSJunchao Zhang    MATSEQAIJKOKKOS - MATAIJKOKKOS = "(seq)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
5773f3ba80aSJunchao Zhang 
5783f3ba80aSJunchao Zhang    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
5793f3ba80aSJunchao Zhang 
5803f3ba80aSJunchao Zhang    Options Database Keys:
58111a5261eSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATSEQAIJKOKKOS` during a call to `MatSetFromOptions()`
5823f3ba80aSJunchao Zhang 
5833f3ba80aSJunchao Zhang   Level: beginner
5843f3ba80aSJunchao Zhang 
585db781477SPatrick Sanan .seealso: `MatCreateSeqAIJKokkos()`, `MATMPIAIJKOKKOS`
5863f3ba80aSJunchao Zhang M*/
587*d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJKokkos(Mat A)
588*d71ae5a4SJacob Faibussowitsch {
58986a27549SJunchao Zhang   PetscFunctionBegin;
5909566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
5919566063dSJacob Faibussowitsch   PetscCall(MatCreate_SeqAIJ(A));
5929566063dSJacob Faibussowitsch   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
59386a27549SJunchao Zhang   PetscFunctionReturn(0);
59486a27549SJunchao Zhang }
59586a27549SJunchao Zhang 
596076ba34aSJunchao Zhang /* Merge A, B into a matrix C. A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n) */
597*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
598*d71ae5a4SJacob Faibussowitsch {
599076ba34aSJunchao Zhang   Mat_SeqAIJ         *a, *b;
600076ba34aSJunchao Zhang   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
601076ba34aSJunchao Zhang   MatScalarKokkosView aa, ba, ca;
602076ba34aSJunchao Zhang   MatRowMapKokkosView ai, bi, ci;
603076ba34aSJunchao Zhang   MatColIdxKokkosView aj, bj, cj;
604076ba34aSJunchao Zhang   PetscInt            m, n, nnz, aN;
605a3f881fbSStefano Zampini 
606a3f881fbSStefano Zampini   PetscFunctionBegin;
607076ba34aSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
608076ba34aSJunchao Zhang   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
609076ba34aSJunchao Zhang   PetscValidPointer(C, 4);
610076ba34aSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
611076ba34aSJunchao Zhang   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
6125f80ce2aSJacob Faibussowitsch   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
6135f80ce2aSJacob Faibussowitsch   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
614076ba34aSJunchao Zhang 
6159566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
6169566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
617076ba34aSJunchao Zhang   a    = static_cast<Mat_SeqAIJ *>(A->data);
618076ba34aSJunchao Zhang   b    = static_cast<Mat_SeqAIJ *>(B->data);
619076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
620076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
621076ba34aSJunchao Zhang   aa   = akok->a_dual.view_device();
622076ba34aSJunchao Zhang   ai   = akok->i_dual.view_device();
623076ba34aSJunchao Zhang   ba   = bkok->a_dual.view_device();
624076ba34aSJunchao Zhang   bi   = bkok->i_dual.view_device();
625076ba34aSJunchao Zhang   m    = A->rmap->n; /* M, N and nnz of C */
626076ba34aSJunchao Zhang   n    = A->cmap->n + B->cmap->n;
627076ba34aSJunchao Zhang   nnz  = a->nz + b->nz;
628076ba34aSJunchao Zhang   aN   = A->cmap->n; /* N of A */
629076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
630076ba34aSJunchao Zhang     aj           = akok->j_dual.view_device();
631076ba34aSJunchao Zhang     bj           = bkok->j_dual.view_device();
632076ba34aSJunchao Zhang     auto ca_dual = MatScalarKokkosDualView("a", aa.extent(0) + ba.extent(0));
633076ba34aSJunchao Zhang     auto ci_dual = MatRowMapKokkosDualView("i", ai.extent(0));
634076ba34aSJunchao Zhang     auto cj_dual = MatColIdxKokkosDualView("j", aj.extent(0) + bj.extent(0));
635076ba34aSJunchao Zhang     ca           = ca_dual.view_device();
636076ba34aSJunchao Zhang     ci           = ci_dual.view_device();
637076ba34aSJunchao Zhang     cj           = cj_dual.view_device();
638076ba34aSJunchao Zhang 
639076ba34aSJunchao Zhang     /* Concatenate A and B in parallel using Kokkos hierarchical parallelism */
6409371c9d4SSatish Balay     Kokkos::parallel_for(
6419371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
642076ba34aSJunchao Zhang         PetscInt i       = t.league_rank(); /* row i */
643076ba34aSJunchao Zhang         PetscInt coffset = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
644076ba34aSJunchao Zhang 
645076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* this side effect only happens once per whole team */
646076ba34aSJunchao Zhang                                                    ci(i) = coffset;
647076ba34aSJunchao Zhang                                                    if (i == m - 1) ci(m) = ai(m) + bi(m);
648076ba34aSJunchao Zhang         });
649076ba34aSJunchao Zhang 
650076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
651076ba34aSJunchao Zhang           if (k < alen) {
652076ba34aSJunchao Zhang             ca(coffset + k) = aa(ai(i) + k);
653076ba34aSJunchao Zhang             cj(coffset + k) = aj(ai(i) + k);
654076ba34aSJunchao Zhang           } else {
655076ba34aSJunchao Zhang             ca(coffset + k) = ba(bi(i) + k - alen);
656076ba34aSJunchao Zhang             cj(coffset + k) = bj(bi(i) + k - alen) + aN; /* Entries in B get new column indices in C */
657076ba34aSJunchao Zhang           }
658076ba34aSJunchao Zhang         });
659076ba34aSJunchao Zhang       });
660076ba34aSJunchao Zhang     ca_dual.modify_device();
661076ba34aSJunchao Zhang     ci_dual.modify_device();
662076ba34aSJunchao Zhang     cj_dual.modify_device();
6639566063dSJacob Faibussowitsch     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, n, nnz, ci_dual, cj_dual, ca_dual));
6649566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
665076ba34aSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
666076ba34aSJunchao Zhang     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
667076ba34aSJunchao Zhang     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
668076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
669076ba34aSJunchao Zhang     ca   = ckok->a_dual.view_device();
670076ba34aSJunchao Zhang     ci   = ckok->i_dual.view_device();
671076ba34aSJunchao Zhang 
6729371c9d4SSatish Balay     Kokkos::parallel_for(
6739371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
674076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* row i */
675076ba34aSJunchao Zhang         PetscInt alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
676076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
677076ba34aSJunchao Zhang           if (k < alen) ca(ci(i) + k) = aa(ai(i) + k);
678076ba34aSJunchao Zhang           else ca(ci(i) + k) = ba(bi(i) + k - alen);
679076ba34aSJunchao Zhang         });
680076ba34aSJunchao Zhang       });
6819566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
682076ba34aSJunchao Zhang   }
683076ba34aSJunchao Zhang   PetscFunctionReturn(0);
684076ba34aSJunchao Zhang }
685076ba34aSJunchao Zhang 
686*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_SeqAIJKokkos(void *pdata)
687*d71ae5a4SJacob Faibussowitsch {
688076ba34aSJunchao Zhang   PetscFunctionBegin;
689076ba34aSJunchao Zhang   delete static_cast<MatProductData_SeqAIJKokkos *>(pdata);
690a3f881fbSStefano Zampini   PetscFunctionReturn(0);
691a3f881fbSStefano Zampini }
692a3f881fbSStefano Zampini 
693*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos(Mat C)
694*d71ae5a4SJacob Faibussowitsch {
695a3f881fbSStefano Zampini   Mat_Product                 *product = C->product;
696a3f881fbSStefano Zampini   Mat                          A, B;
697076ba34aSJunchao Zhang   bool                         transA, transB; /* use bool, since KK needs this type */
698a3f881fbSStefano Zampini   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
699a3f881fbSStefano Zampini   Mat_SeqAIJ                  *c;
700076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
701076ba34aSJunchao Zhang   KokkosCsrMatrix             *csrmatA, *csrmatB;
702a3f881fbSStefano Zampini 
703a3f881fbSStefano Zampini   PetscFunctionBegin;
704a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
7055f80ce2aSJacob Faibussowitsch   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data empty");
706076ba34aSJunchao Zhang   pdata = static_cast<MatProductData_SeqAIJKokkos *>(C->product->data);
707076ba34aSJunchao Zhang 
708076ba34aSJunchao Zhang   if (pdata->reusesym) {           /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
709076ba34aSJunchao Zhang     pdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
710076ba34aSJunchao Zhang     PetscFunctionReturn(0);
711076ba34aSJunchao Zhang   }
712076ba34aSJunchao Zhang 
713076ba34aSJunchao Zhang   switch (product->type) {
7149371c9d4SSatish Balay   case MATPRODUCT_AB:
7159371c9d4SSatish Balay     transA = false;
7169371c9d4SSatish Balay     transB = false;
7179371c9d4SSatish Balay     break;
7189371c9d4SSatish Balay   case MATPRODUCT_AtB:
7199371c9d4SSatish Balay     transA = true;
7209371c9d4SSatish Balay     transB = false;
7219371c9d4SSatish Balay     break;
7229371c9d4SSatish Balay   case MATPRODUCT_ABt:
7239371c9d4SSatish Balay     transA = false;
7249371c9d4SSatish Balay     transB = true;
7259371c9d4SSatish Balay     break;
726*d71ae5a4SJacob Faibussowitsch   default:
727*d71ae5a4SJacob Faibussowitsch     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
728076ba34aSJunchao Zhang   }
729076ba34aSJunchao Zhang 
730a3f881fbSStefano Zampini   A = product->A;
731a3f881fbSStefano Zampini   B = product->B;
7329566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
7339566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
734a3f881fbSStefano Zampini   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
735a3f881fbSStefano Zampini   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
736a3f881fbSStefano Zampini   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
737076ba34aSJunchao Zhang 
7385f80ce2aSJacob Faibussowitsch   PetscCheck(ckok, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Device data structure spptr is empty");
739076ba34aSJunchao Zhang 
740076ba34aSJunchao Zhang   csrmatA = &akok->csrmat;
741076ba34aSJunchao Zhang   csrmatB = &bkok->csrmat;
742076ba34aSJunchao Zhang 
743076ba34aSJunchao Zhang   /* TODO: Once KK spgemm implements transpose, we can get rid of the explicit transpose here */
744076ba34aSJunchao Zhang   if (transA) {
7459566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
746076ba34aSJunchao Zhang     transA = false;
747a3f881fbSStefano Zampini   }
748a3f881fbSStefano Zampini 
749076ba34aSJunchao Zhang   if (transB) {
7509566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
751076ba34aSJunchao Zhang     transB = false;
752076ba34aSJunchao Zhang   }
7539566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
7549566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, *csrmatA, transA, *csrmatB, transB, ckok->csrmat));
755866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
756866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(ckok->csrmat)); /* without sort, mat_tests-ex62_14_seqaijkokkos fails */
757866eb059SJunchao Zhang 
7589566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
7599566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(C));
760a3f881fbSStefano Zampini   /* shorter version of MatAssemblyEnd_SeqAIJ */
761a3f881fbSStefano Zampini   c = (Mat_SeqAIJ *)C->data;
7629566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded,%" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
7639566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
7649566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
765a3f881fbSStefano Zampini   c->reallocs         = 0;
766076ba34aSJunchao Zhang   C->info.mallocs     = 0;
767a3f881fbSStefano Zampini   C->info.nz_unneeded = 0;
768a3f881fbSStefano Zampini   C->assembled = C->was_assembled = PETSC_TRUE;
769a3f881fbSStefano Zampini   C->num_ass++;
770a3f881fbSStefano Zampini   PetscFunctionReturn(0);
771a3f881fbSStefano Zampini }
772a3f881fbSStefano Zampini 
773*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos(Mat C)
774*d71ae5a4SJacob Faibussowitsch {
775076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
776076ba34aSJunchao Zhang   MatProductType               ptype;
777076ba34aSJunchao Zhang   Mat                          A, B;
778076ba34aSJunchao Zhang   bool                         transA, transB;
779076ba34aSJunchao Zhang   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
780076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
781076ba34aSJunchao Zhang   MPI_Comm                     comm;
782076ba34aSJunchao Zhang   KokkosCsrMatrix             *csrmatA, *csrmatB, csrmatC;
783a3f881fbSStefano Zampini 
784a3f881fbSStefano Zampini   PetscFunctionBegin;
785a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
7869566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
7875f80ce2aSJacob Faibussowitsch   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
788a3f881fbSStefano Zampini   A = product->A;
789a3f881fbSStefano Zampini   B = product->B;
7909566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
7919566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
792a3f881fbSStefano Zampini   akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
793a3f881fbSStefano Zampini   bkok    = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
794076ba34aSJunchao Zhang   csrmatA = &akok->csrmat;
795076ba34aSJunchao Zhang   csrmatB = &bkok->csrmat;
796076ba34aSJunchao Zhang 
797a3f881fbSStefano Zampini   ptype = product->type;
798a3f881fbSStefano Zampini   switch (ptype) {
7999371c9d4SSatish Balay   case MATPRODUCT_AB:
8009371c9d4SSatish Balay     transA = false;
8019371c9d4SSatish Balay     transB = false;
8029371c9d4SSatish Balay     break;
8039371c9d4SSatish Balay   case MATPRODUCT_AtB:
8049371c9d4SSatish Balay     transA = true;
8059371c9d4SSatish Balay     transB = false;
8069371c9d4SSatish Balay     break;
8079371c9d4SSatish Balay   case MATPRODUCT_ABt:
8089371c9d4SSatish Balay     transA = false;
8099371c9d4SSatish Balay     transB = true;
8109371c9d4SSatish Balay     break;
811*d71ae5a4SJacob Faibussowitsch   default:
812*d71ae5a4SJacob Faibussowitsch     SETERRQ(comm, PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
813a3f881fbSStefano Zampini   }
814a3f881fbSStefano Zampini 
815076ba34aSJunchao Zhang   product->data = pdata = new MatProductData_SeqAIJKokkos();
816076ba34aSJunchao Zhang   pdata->kh.set_team_work_size(16);
817076ba34aSJunchao Zhang   pdata->kh.set_dynamic_scheduling(true);
818076ba34aSJunchao Zhang   pdata->reusesym = product->api_user;
819a3f881fbSStefano Zampini 
820076ba34aSJunchao Zhang   /* TODO: add command line options to select spgemm algorithms */
821866eb059SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; /* default alg is TPL if enabled, otherwise KK */
822866eb059SJunchao Zhang 
823866eb059SJunchao Zhang   /* CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 */
824866eb059SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
825866eb059SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
826866eb059SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
827866eb059SJunchao Zhang   #endif
828866eb059SJunchao Zhang #endif
829866eb059SJunchao Zhang 
830076ba34aSJunchao Zhang   pdata->kh.create_spgemm_handle(spgemm_alg);
831076ba34aSJunchao Zhang 
8329566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
833076ba34aSJunchao Zhang   /* TODO: Get rid of the explicit transpose once KK-spgemm implements the transpose option */
834076ba34aSJunchao Zhang   if (transA) {
8359566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
836076ba34aSJunchao Zhang     transA = false;
837076ba34aSJunchao Zhang   }
838076ba34aSJunchao Zhang 
839076ba34aSJunchao Zhang   if (transB) {
8409566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
841076ba34aSJunchao Zhang     transB = false;
842076ba34aSJunchao Zhang   }
843076ba34aSJunchao Zhang 
8449566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::spgemm_symbolic(pdata->kh, *csrmatA, transA, *csrmatB, transB, csrmatC));
845866eb059SJunchao Zhang 
846076ba34aSJunchao Zhang   /* spgemm_symbolic() only populates C's rowmap, but not C's column indices.
847076ba34aSJunchao Zhang     So we have to do a fake spgemm_numeric() here to get csrmatC.j_d setup, before
848076ba34aSJunchao Zhang     calling new Mat_SeqAIJKokkos().
849076ba34aSJunchao Zhang     TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
850076ba34aSJunchao Zhang   */
8519566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, *csrmatA, transA, *csrmatB, transB, csrmatC));
852866eb059SJunchao Zhang   /* Query if KK outputs a sorted matrix. If not, we need to sort it */
853866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
854866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(csrmatC)); /* sort_option defaults to -1 in KK!*/
8559566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
856076ba34aSJunchao Zhang 
8579566063dSJacob Faibussowitsch   PetscCallCXX(ckok = new Mat_SeqAIJKokkos(csrmatC));
8589566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(C, ckok));
859076ba34aSJunchao Zhang   C->product->destroy = MatProductDataDestroy_SeqAIJKokkos;
860a3f881fbSStefano Zampini   PetscFunctionReturn(0);
861a3f881fbSStefano Zampini }
862a3f881fbSStefano Zampini 
863a3f881fbSStefano Zampini /* handles sparse matrix matrix ops */
864*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSetFromOptions_SeqAIJKokkos(Mat mat)
865*d71ae5a4SJacob Faibussowitsch {
866076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
867a3f881fbSStefano Zampini   PetscBool    Biskok = PETSC_FALSE, Ciskok = PETSC_TRUE;
868a3f881fbSStefano Zampini 
869a3f881fbSStefano Zampini   PetscFunctionBegin;
870a3f881fbSStefano Zampini   MatCheckProduct(mat, 1);
8719566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJKOKKOS, &Biskok));
87248a46eb9SPierre Jolivet   if (product->type == MATPRODUCT_ABC) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJKOKKOS, &Ciskok));
873a3f881fbSStefano Zampini   if (Biskok && Ciskok) {
874a3f881fbSStefano Zampini     switch (product->type) {
875a3f881fbSStefano Zampini     case MATPRODUCT_AB:
876a3f881fbSStefano Zampini     case MATPRODUCT_AtB:
877*d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABt:
878*d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos;
879*d71ae5a4SJacob Faibussowitsch       break;
880a3f881fbSStefano Zampini     case MATPRODUCT_PtAP:
881a3f881fbSStefano Zampini     case MATPRODUCT_RARt:
882*d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABC:
883*d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
884*d71ae5a4SJacob Faibussowitsch       break;
885*d71ae5a4SJacob Faibussowitsch     default:
886*d71ae5a4SJacob Faibussowitsch       break;
887a3f881fbSStefano Zampini     }
888a3f881fbSStefano Zampini   } else { /* fallback for AIJ */
8899566063dSJacob Faibussowitsch     PetscCall(MatProductSetFromOptions_SeqAIJ(mat));
890a3f881fbSStefano Zampini   }
891a3f881fbSStefano Zampini   PetscFunctionReturn(0);
892a3f881fbSStefano Zampini }
893a587d139SMark 
894*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatScale_SeqAIJKokkos(Mat A, PetscScalar a)
895*d71ae5a4SJacob Faibussowitsch {
896f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok;
897f0cf5187SStefano Zampini 
898f0cf5187SStefano Zampini   PetscFunctionBegin;
8999566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
9009566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
901f0cf5187SStefano Zampini   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
902076ba34aSJunchao Zhang   KokkosBlas::scal(aijkok->a_dual.view_device(), a, aijkok->a_dual.view_device());
9039566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
9049566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(aijkok->a_dual.extent(0)));
9059566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
906f0cf5187SStefano Zampini   PetscFunctionReturn(0);
907f0cf5187SStefano Zampini }
908f0cf5187SStefano Zampini 
909*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatZeroEntries_SeqAIJKokkos(Mat A)
910*d71ae5a4SJacob Faibussowitsch {
911076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
912a587d139SMark 
913a587d139SMark   PetscFunctionBegin;
914076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
9152328674fSJunchao Zhang   if (aijkok) { /* Only zero the device if data is already there */
916076ba34aSJunchao Zhang     KokkosBlas::fill(aijkok->a_dual.view_device(), 0.0);
9179566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A));
9182328674fSJunchao Zhang   } else { /* Might be preallocated but not assembled */
9199566063dSJacob Faibussowitsch     PetscCall(MatZeroEntries_SeqAIJ(A));
9202328674fSJunchao Zhang   }
921a587d139SMark   PetscFunctionReturn(0);
922a587d139SMark }
923a587d139SMark 
924*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatGetDiagonal_SeqAIJKokkos(Mat A, Vec x)
925*d71ae5a4SJacob Faibussowitsch {
926f78ce678SMark Adams   Mat_SeqAIJ           *aijseq;
927f78ce678SMark Adams   Mat_SeqAIJKokkos     *aijkok;
928f78ce678SMark Adams   PetscInt              n;
929f78ce678SMark Adams   PetscScalarKokkosView xv;
930f78ce678SMark Adams 
931f78ce678SMark Adams   PetscFunctionBegin;
932f78ce678SMark Adams   PetscCall(VecGetLocalSize(x, &n));
933f78ce678SMark Adams   PetscCheck(n == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Nonconforming matrix and vector");
934f78ce678SMark Adams   PetscCheck(A->factortype == MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_SUP, "MatGetDiagonal_SeqAIJKokkos not supported on factored matrices");
935f78ce678SMark Adams 
936f78ce678SMark Adams   PetscCall(MatSeqAIJKokkosSyncDevice(A));
937f78ce678SMark Adams   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
938f78ce678SMark Adams 
939f78ce678SMark Adams   if (A->rmap->n && aijkok->diag_dual.extent(0) == 0) { /* Set the diagonal pointer if not already */
940f78ce678SMark Adams     PetscCall(MatMarkDiagonal_SeqAIJ(A));
941f78ce678SMark Adams     aijseq = static_cast<Mat_SeqAIJ *>(A->data);
942f78ce678SMark Adams     aijkok->SetDiagonal(aijseq->diag);
943f78ce678SMark Adams   }
944f78ce678SMark Adams 
945f78ce678SMark Adams   const auto &Aa    = aijkok->a_dual.view_device();
946f78ce678SMark Adams   const auto &Ai    = aijkok->i_dual.view_device();
947f78ce678SMark Adams   const auto &Adiag = aijkok->diag_dual.view_device();
948f78ce678SMark Adams 
949f78ce678SMark Adams   PetscCall(VecGetKokkosViewWrite(x, &xv));
9509371c9d4SSatish Balay   Kokkos::parallel_for(
9519371c9d4SSatish Balay     n, KOKKOS_LAMBDA(const PetscInt i) {
952f78ce678SMark Adams       if (Adiag(i) < Ai(i + 1)) xv(i) = Aa(Adiag(i));
953f78ce678SMark Adams       else xv(i) = 0;
954f78ce678SMark Adams     });
955f78ce678SMark Adams   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
956f78ce678SMark Adams   PetscFunctionReturn(0);
957f78ce678SMark Adams }
958f78ce678SMark Adams 
959db78de30SJunchao Zhang /* Get a Kokkos View from a mat of type MatSeqAIJKokkos */
960*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, ConstMatScalarKokkosView *kv)
961*d71ae5a4SJacob Faibussowitsch {
962db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
963db78de30SJunchao Zhang 
964db78de30SJunchao Zhang   PetscFunctionBegin;
965db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
966db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
967db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
9689566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
969db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
970076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
971db78de30SJunchao Zhang   PetscFunctionReturn(0);
972db78de30SJunchao Zhang }
973db78de30SJunchao Zhang 
974*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, ConstMatScalarKokkosView *kv)
975*d71ae5a4SJacob Faibussowitsch {
976db78de30SJunchao Zhang   PetscFunctionBegin;
977db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
978db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
979db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
980db78de30SJunchao Zhang   PetscFunctionReturn(0);
981db78de30SJunchao Zhang }
982db78de30SJunchao Zhang 
983*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, MatScalarKokkosView *kv)
984*d71ae5a4SJacob Faibussowitsch {
985db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
986db78de30SJunchao Zhang 
987db78de30SJunchao Zhang   PetscFunctionBegin;
988db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
989db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
990db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
9919566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
992db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
993076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
994db78de30SJunchao Zhang   PetscFunctionReturn(0);
995db78de30SJunchao Zhang }
996db78de30SJunchao Zhang 
997*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, MatScalarKokkosView *kv)
998*d71ae5a4SJacob Faibussowitsch {
999db78de30SJunchao Zhang   PetscFunctionBegin;
1000db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1001db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1002db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10039566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
1004db78de30SJunchao Zhang   PetscFunctionReturn(0);
1005db78de30SJunchao Zhang }
1006db78de30SJunchao Zhang 
1007*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1008*d71ae5a4SJacob Faibussowitsch {
1009db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1010db78de30SJunchao Zhang 
1011db78de30SJunchao Zhang   PetscFunctionBegin;
1012db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1013db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1014db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1015db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1016076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
1017db78de30SJunchao Zhang   PetscFunctionReturn(0);
1018db78de30SJunchao Zhang }
1019db78de30SJunchao Zhang 
1020*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1021*d71ae5a4SJacob Faibussowitsch {
1022db78de30SJunchao Zhang   PetscFunctionBegin;
1023db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1024db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1025db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10269566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
1027db78de30SJunchao Zhang   PetscFunctionReturn(0);
1028db78de30SJunchao Zhang }
1029db78de30SJunchao Zhang 
1030c17cf699SJunchao Zhang /* Computes Y += alpha X */
1031*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAXPY_SeqAIJKokkos(Mat Y, PetscScalar alpha, Mat X, MatStructure pattern)
1032*d71ae5a4SJacob Faibussowitsch {
1033a587d139SMark   Mat_SeqAIJ              *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
1034c17cf699SJunchao Zhang   Mat_SeqAIJKokkos        *xkok, *ykok, *zkok;
1035c17cf699SJunchao Zhang   ConstMatScalarKokkosView Xa;
1036c17cf699SJunchao Zhang   MatScalarKokkosView      Ya;
1037a587d139SMark 
1038a587d139SMark   PetscFunctionBegin;
1039c17cf699SJunchao Zhang   PetscCheckTypeName(Y, MATSEQAIJKOKKOS);
1040c17cf699SJunchao Zhang   PetscCheckTypeName(X, MATSEQAIJKOKKOS);
10419566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(Y));
10429566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(X));
10439566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1044db78de30SJunchao Zhang 
1045c17cf699SJunchao Zhang   if (pattern != SAME_NONZERO_PATTERN && x->nz == y->nz) {
1046c17cf699SJunchao Zhang     /* We could compare on device, but have to get the comparison result on host. So compare on host instead. */
1047a587d139SMark     PetscBool e;
10489566063dSJacob Faibussowitsch     PetscCall(PetscArraycmp(x->i, y->i, Y->rmap->n + 1, &e));
1049a587d139SMark     if (e) {
10509566063dSJacob Faibussowitsch       PetscCall(PetscArraycmp(x->j, y->j, y->nz, &e));
1051c17cf699SJunchao Zhang       if (e) pattern = SAME_NONZERO_PATTERN;
1052a587d139SMark     }
1053a587d139SMark   }
1054db78de30SJunchao Zhang 
1055c17cf699SJunchao Zhang   /* cusparseDcsrgeam2() computes C = alpha A + beta B. If one knew sparsity pattern of C, one can skip
1056c17cf699SJunchao Zhang     cusparseScsrgeam2_bufferSizeExt() / cusparseXcsrgeam2Nnz(), and directly call cusparseScsrgeam2().
1057c17cf699SJunchao Zhang     If X is SUBSET_NONZERO_PATTERN of Y, we could take advantage of this cusparse feature. However,
1058c17cf699SJunchao Zhang     KokkosSparse::spadd(alpha,A,beta,B,C) has symbolic and numeric phases, MatAXPY does not.
1059c17cf699SJunchao Zhang   */
1060c17cf699SJunchao Zhang   ykok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1061c17cf699SJunchao Zhang   xkok = static_cast<Mat_SeqAIJKokkos *>(X->spptr);
1062c17cf699SJunchao Zhang   Xa   = xkok->a_dual.view_device();
1063c17cf699SJunchao Zhang   Ya   = ykok->a_dual.view_device();
1064c17cf699SJunchao Zhang 
1065c17cf699SJunchao Zhang   if (pattern == SAME_NONZERO_PATTERN) {
1066c17cf699SJunchao Zhang     KokkosBlas::axpy(alpha, Xa, Ya);
10679566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1068c17cf699SJunchao Zhang   } else if (pattern == SUBSET_NONZERO_PATTERN) {
1069c17cf699SJunchao Zhang     MatRowMapKokkosView Xi = xkok->i_dual.view_device(), Yi = ykok->i_dual.view_device();
1070c17cf699SJunchao Zhang     MatColIdxKokkosView Xj = xkok->j_dual.view_device(), Yj = ykok->j_dual.view_device();
1071c17cf699SJunchao Zhang 
10729371c9d4SSatish Balay     Kokkos::parallel_for(
10739371c9d4SSatish Balay       Kokkos::TeamPolicy<>(Y->rmap->n, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1074c17cf699SJunchao Zhang         PetscInt i = t.league_rank();              /* row i */
1075c17cf699SJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* Only one thread works in a team */
1076c17cf699SJunchao Zhang                                                    PetscInt p, q = Yi(i);
1077c17cf699SJunchao Zhang                                                    for (p = Xi(i); p < Xi(i + 1); p++) {          /* For each nonzero on row i of X */
1078c17cf699SJunchao Zhang                                                      while (Xj(p) != Yj(q) && q < Yi(i + 1)) q++; /* find the matching nonzero on row i of Y */
1079c17cf699SJunchao Zhang                                                      if (Xj(p) == Yj(q)) {                        /* Found it */
1080c17cf699SJunchao Zhang                                                        Ya(q) += alpha * Xa(p);
1081c17cf699SJunchao Zhang                                                        q++;
1082a587d139SMark                                                      } else {
1083c17cf699SJunchao Zhang                                                        /* If not found, it indicates the input is wrong (X is not a SUBSET_NONZERO_PATTERN of Y).
1084c17cf699SJunchao Zhang                Just insert a NaN at the beginning of row i if it is not empty, to make the result wrong.
1085c17cf699SJunchao Zhang             */
10869371c9d4SSatish Balay                                                        if (Yi(i) != Yi(i + 1))
10879371c9d4SSatish Balay                                                          Ya(Yi(i)) =
10888b8b16f9SJunchao Zhang #if PETSC_PKG_KOKKOS_VERSION_GE(3, 6, 99)
10898b8b16f9SJunchao Zhang                                                            Kokkos::nan("1"); /* auto promote the double NaN if needed */
10908b8b16f9SJunchao Zhang #else
10918b8b16f9SJunchao Zhang               Kokkos::Experimental::nan("1");
10928b8b16f9SJunchao Zhang #endif
1093a587d139SMark                                                      }
1094c17cf699SJunchao Zhang                                                    }
1095c17cf699SJunchao Zhang         });
1096c17cf699SJunchao Zhang       });
10979566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1098c17cf699SJunchao Zhang   } else { /* different nonzero patterns */
1099c17cf699SJunchao Zhang     Mat             Z;
1100c17cf699SJunchao Zhang     KokkosCsrMatrix zcsr;
1101c17cf699SJunchao Zhang     KernelHandle    kh;
1102c17cf699SJunchao Zhang     kh.create_spadd_handle(false);
1103c17cf699SJunchao Zhang     KokkosSparse::spadd_symbolic(&kh, xkok->csrmat, ykok->csrmat, zcsr);
1104c17cf699SJunchao Zhang     KokkosSparse::spadd_numeric(&kh, alpha, xkok->csrmat, (PetscScalar)1.0, ykok->csrmat, zcsr);
1105c17cf699SJunchao Zhang     zkok = new Mat_SeqAIJKokkos(zcsr);
11069566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, zkok, &Z));
11079566063dSJacob Faibussowitsch     PetscCall(MatHeaderReplace(Y, &Z));
1108c17cf699SJunchao Zhang     kh.destroy_spadd_handle();
1109c17cf699SJunchao Zhang   }
11109566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
11119566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(xkok->a_dual.extent(0) * 2)); /* Because we scaled X and then added it to Y */
1112a587d139SMark   PetscFunctionReturn(0);
1113a587d139SMark }
1114a587d139SMark 
1115*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_SeqAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1116*d71ae5a4SJacob Faibussowitsch {
111742550becSJunchao Zhang   Mat_SeqAIJKokkos *akok;
111842550becSJunchao Zhang   Mat_SeqAIJ       *aseq;
111942550becSJunchao Zhang 
112042550becSJunchao Zhang   PetscFunctionBegin;
11219566063dSJacob Faibussowitsch   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, coo_i, coo_j));
1122394ed5ebSJunchao Zhang   aseq = static_cast<Mat_SeqAIJ *>(mat->data);
112342550becSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(mat->spptr);
1124cbc6b225SStefano Zampini   delete akok;
1125cbc6b225SStefano Zampini   mat->spptr = akok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, aseq->nz, aseq->i, aseq->j, aseq->a, mat->nonzerostate + 1, PETSC_FALSE);
11269566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries_SeqAIJKokkos(mat));
1127394ed5ebSJunchao Zhang   akok->SetUpCOO(aseq);
112842550becSJunchao Zhang   PetscFunctionReturn(0);
112942550becSJunchao Zhang }
113042550becSJunchao Zhang 
1131*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_SeqAIJKokkos(Mat A, const PetscScalar v[], InsertMode imode)
1132*d71ae5a4SJacob Faibussowitsch {
113342550becSJunchao Zhang   Mat_SeqAIJ                 *aseq = static_cast<Mat_SeqAIJ *>(A->data);
113442550becSJunchao Zhang   Mat_SeqAIJKokkos           *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1135394ed5ebSJunchao Zhang   PetscCount                  Annz = aseq->nz;
1136394ed5ebSJunchao Zhang   const PetscCountKokkosView &jmap = akok->jmap_d;
1137394ed5ebSJunchao Zhang   const PetscCountKokkosView &perm = akok->perm_d;
113842550becSJunchao Zhang   MatScalarKokkosView         Aa;
113942550becSJunchao Zhang   ConstMatScalarKokkosView    kv;
114042550becSJunchao Zhang   PetscMemType                memtype;
114142550becSJunchao Zhang 
114242550becSJunchao Zhang   PetscFunctionBegin;
11439566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype));
114442550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
1145394ed5ebSJunchao Zhang     kv = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatScalarKokkosViewHost(v, aseq->coo_n));
114642550becSJunchao Zhang   } else {
1147394ed5ebSJunchao Zhang     kv = ConstMatScalarKokkosView(v, aseq->coo_n); /* Directly use v[]'s memory */
114842550becSJunchao Zhang   }
114942550becSJunchao Zhang 
1150c7b718f4SJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1151c7b718f4SJunchao Zhang   else PetscCall(MatSeqAIJGetKokkosView(A, &Aa));                             /* read & write matrix values */
115242550becSJunchao Zhang 
11539371c9d4SSatish Balay   Kokkos::parallel_for(
11549371c9d4SSatish Balay     Annz, KOKKOS_LAMBDA(const PetscCount i) {
1155c7b718f4SJunchao Zhang       PetscScalar sum = 0.0;
1156c7b718f4SJunchao Zhang       for (PetscCount k = jmap(i); k < jmap(i + 1); k++) sum += kv(perm(k));
1157c7b718f4SJunchao Zhang       Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1158c7b718f4SJunchao Zhang     });
1159394ed5ebSJunchao Zhang 
11609566063dSJacob Faibussowitsch   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
11619566063dSJacob Faibussowitsch   else PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
116242550becSJunchao Zhang   PetscFunctionReturn(0);
116342550becSJunchao Zhang }
116442550becSJunchao Zhang 
1165*d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJMoveDiagonalValuesFront_SeqAIJKokkos(Mat A, const PetscInt *diag)
1166*d71ae5a4SJacob Faibussowitsch {
11675fbaff96SJunchao Zhang   Mat_SeqAIJKokkos          *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
11685fbaff96SJunchao Zhang   MatScalarKokkosView        Aa;
11695fbaff96SJunchao Zhang   const MatRowMapKokkosView &Ai = akok->i_dual.view_device();
11705fbaff96SJunchao Zhang   PetscInt                   m  = A->rmap->n;
11715fbaff96SJunchao Zhang   ConstMatRowMapKokkosView   Adiag(diag, m); /* diag is a device pointer */
11725fbaff96SJunchao Zhang 
11735fbaff96SJunchao Zhang   PetscFunctionBegin;
11745fbaff96SJunchao Zhang   PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa));
11759371c9d4SSatish Balay   Kokkos::parallel_for(
11769371c9d4SSatish Balay     m, KOKKOS_LAMBDA(const PetscInt i) {
11775fbaff96SJunchao Zhang       PetscScalar tmp;
11785fbaff96SJunchao Zhang       if (Adiag(i) >= Ai(i) && Adiag(i) < Ai(i + 1)) { /* The diagonal element exists */
11795fbaff96SJunchao Zhang         tmp          = Aa(Ai(i));
11805fbaff96SJunchao Zhang         Aa(Ai(i))    = Aa(Adiag(i));
11815fbaff96SJunchao Zhang         Aa(Adiag(i)) = tmp;
11825fbaff96SJunchao Zhang       }
11835fbaff96SJunchao Zhang     });
11845fbaff96SJunchao Zhang   PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
11855fbaff96SJunchao Zhang   PetscFunctionReturn(0);
11865fbaff96SJunchao Zhang }
11875fbaff96SJunchao Zhang 
1188*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1189*d71ae5a4SJacob Faibussowitsch {
11908f7e8f9dSMark Adams   PetscFunctionBegin;
11919566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncHost(A));
11929566063dSJacob Faibussowitsch   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
11938f7e8f9dSMark Adams   B->offloadmask = PETSC_OFFLOAD_CPU;
11948f7e8f9dSMark Adams   PetscFunctionReturn(0);
11958f7e8f9dSMark Adams }
11968f7e8f9dSMark Adams 
1197*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat A)
1198*d71ae5a4SJacob Faibussowitsch {
1199076ba34aSJunchao Zhang   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
1200076ba34aSJunchao Zhang 
12018c3ff71bSJunchao Zhang   PetscFunctionBegin;
1202076ba34aSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_KOKKOS; /* We do not really use this flag */
12036f3d89d0SStefano Zampini   A->boundtocpu  = PETSC_FALSE;
12046f3d89d0SStefano Zampini 
12058c3ff71bSJunchao Zhang   A->ops->assemblyend               = MatAssemblyEnd_SeqAIJKokkos;
12068c3ff71bSJunchao Zhang   A->ops->destroy                   = MatDestroy_SeqAIJKokkos;
12078c3ff71bSJunchao Zhang   A->ops->duplicate                 = MatDuplicate_SeqAIJKokkos;
1208a587d139SMark   A->ops->axpy                      = MatAXPY_SeqAIJKokkos;
1209f0cf5187SStefano Zampini   A->ops->scale                     = MatScale_SeqAIJKokkos;
1210a587d139SMark   A->ops->zeroentries               = MatZeroEntries_SeqAIJKokkos;
1211076ba34aSJunchao Zhang   A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJKokkos;
12128c3ff71bSJunchao Zhang   A->ops->mult                      = MatMult_SeqAIJKokkos;
12138c3ff71bSJunchao Zhang   A->ops->multadd                   = MatMultAdd_SeqAIJKokkos;
12148c3ff71bSJunchao Zhang   A->ops->multtranspose             = MatMultTranspose_SeqAIJKokkos;
12158c3ff71bSJunchao Zhang   A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJKokkos;
12168c3ff71bSJunchao Zhang   A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJKokkos;
12178c3ff71bSJunchao Zhang   A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJKokkos;
1218076ba34aSJunchao Zhang   A->ops->productnumeric            = MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos;
12190ecb592aSJunchao Zhang   A->ops->transpose                 = MatTranspose_SeqAIJKokkos;
1220152b3e56SJunchao Zhang   A->ops->setoption                 = MatSetOption_SeqAIJKokkos;
1221f78ce678SMark Adams   A->ops->getdiagonal               = MatGetDiagonal_SeqAIJKokkos;
1222076ba34aSJunchao Zhang   a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJKokkos;
1223076ba34aSJunchao Zhang   a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJKokkos;
1224076ba34aSJunchao Zhang   a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJKokkos;
1225076ba34aSJunchao Zhang   a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJKokkos;
1226076ba34aSJunchao Zhang   a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJKokkos;
1227076ba34aSJunchao Zhang   a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJKokkos;
12287ee59b9bSJunchao Zhang   a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJKokkos;
122942550becSJunchao Zhang 
12309566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJKokkos));
12319566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJKokkos));
1232076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1233076ba34aSJunchao Zhang }
1234076ba34aSJunchao Zhang 
1235*d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSetSeqAIJKokkosWithCSRMatrix(Mat A, Mat_SeqAIJKokkos *akok)
1236*d71ae5a4SJacob Faibussowitsch {
1237076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
1238076ba34aSJunchao Zhang   PetscInt    i, m, n;
1239076ba34aSJunchao Zhang 
1240076ba34aSJunchao Zhang   PetscFunctionBegin;
12415f80ce2aSJacob Faibussowitsch   PetscCheck(!A->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A->spptr is supposed to be empty");
1242076ba34aSJunchao Zhang 
1243076ba34aSJunchao Zhang   m = akok->nrows();
1244076ba34aSJunchao Zhang   n = akok->ncols();
12459566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(A, m, n, m, n));
12469566063dSJacob Faibussowitsch   PetscCall(MatSetType(A, MATSEQAIJKOKKOS));
1247076ba34aSJunchao Zhang 
1248076ba34aSJunchao Zhang   /* Set up data structures of A as a MATSEQAIJ */
12499566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(A, MAT_SKIP_ALLOCATION, NULL));
1250076ba34aSJunchao Zhang   aseq = (Mat_SeqAIJ *)(A)->data;
1251076ba34aSJunchao Zhang 
1252076ba34aSJunchao Zhang   akok->i_dual.sync_host(); /* We always need sync'ed i, j on host */
1253076ba34aSJunchao Zhang   akok->j_dual.sync_host();
1254076ba34aSJunchao Zhang 
1255076ba34aSJunchao Zhang   aseq->i            = akok->i_host_data();
1256076ba34aSJunchao Zhang   aseq->j            = akok->j_host_data();
1257076ba34aSJunchao Zhang   aseq->a            = akok->a_host_data();
1258076ba34aSJunchao Zhang   aseq->nonew        = -1; /*this indicates that inserting a new value in the matrix that generates a new nonzero is an error*/
1259076ba34aSJunchao Zhang   aseq->singlemalloc = PETSC_FALSE;
1260076ba34aSJunchao Zhang   aseq->free_a       = PETSC_FALSE;
1261076ba34aSJunchao Zhang   aseq->free_ij      = PETSC_FALSE;
1262076ba34aSJunchao Zhang   aseq->nz           = akok->nnz();
1263076ba34aSJunchao Zhang   aseq->maxnz        = aseq->nz;
1264076ba34aSJunchao Zhang 
12659566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->imax));
12669566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->ilen));
1267ad540459SPierre Jolivet   for (i = 0; i < m; i++) aseq->ilen[i] = aseq->imax[i] = aseq->i[i + 1] - aseq->i[i];
1268076ba34aSJunchao Zhang 
1269076ba34aSJunchao Zhang   /* It is critical to set the nonzerostate, as we use it to check if sparsity pattern (hence data) has changed on host in MatAssemblyEnd */
1270076ba34aSJunchao Zhang   akok->nonzerostate = A->nonzerostate;
1271ff751488SJunchao Zhang   A->spptr           = akok; /* Set A->spptr before MatAssembly so that A->spptr won't be allocated again there */
12729566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
12739566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
1274076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1275076ba34aSJunchao Zhang }
1276076ba34aSJunchao Zhang 
1277076ba34aSJunchao Zhang /* Crete a SEQAIJKOKKOS matrix with a Mat_SeqAIJKokkos data structure
1278076ba34aSJunchao Zhang 
1279076ba34aSJunchao Zhang    Note we have names like MatSeqAIJSetPreallocationCSR, so I use capitalized CSR
1280076ba34aSJunchao Zhang  */
1281*d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithCSRMatrix(MPI_Comm comm, Mat_SeqAIJKokkos *akok, Mat *A)
1282*d71ae5a4SJacob Faibussowitsch {
1283076ba34aSJunchao Zhang   PetscFunctionBegin;
12849566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
12859566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
12868c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
12878c3ff71bSJunchao Zhang }
12888c3ff71bSJunchao Zhang 
12898c3ff71bSJunchao Zhang /* --------------------------------------------------------------------------------*/
1290152b3e56SJunchao Zhang /*@C
129111a5261eSBarry Smith    MatCreateSeqAIJKokkos - Creates a sparse matrix in `MATSEQAIJKOKKOS` (compressed row) format
12928c3ff71bSJunchao Zhang    (the default parallel PETSc format). This matrix will ultimately be handled by
12938c3ff71bSJunchao Zhang    Kokkos for calculations. For good matrix
12948c3ff71bSJunchao Zhang    assembly performance the user should preallocate the matrix storage by setting
12958c3ff71bSJunchao Zhang    the parameter nz (or the array nnz).  By setting these parameters accurately,
12968c3ff71bSJunchao Zhang    performance during matrix assembly can be increased by more than a factor of 50.
12978c3ff71bSJunchao Zhang 
12988c3ff71bSJunchao Zhang    Collective
12998c3ff71bSJunchao Zhang 
13008c3ff71bSJunchao Zhang    Input Parameters:
130111a5261eSBarry Smith +  comm - MPI communicator, set to `PETSC_COMM_SELF`
13028c3ff71bSJunchao Zhang .  m - number of rows
13038c3ff71bSJunchao Zhang .  n - number of columns
13048c3ff71bSJunchao Zhang .  nz - number of nonzeros per row (same for all rows)
13058c3ff71bSJunchao Zhang -  nnz - array containing the number of nonzeros in the various rows
13068c3ff71bSJunchao Zhang          (possibly different for each row) or NULL
13078c3ff71bSJunchao Zhang 
13088c3ff71bSJunchao Zhang    Output Parameter:
13098c3ff71bSJunchao Zhang .  A - the matrix
13108c3ff71bSJunchao Zhang 
131111a5261eSBarry Smith    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
13128c3ff71bSJunchao Zhang    MatXXXXSetPreallocation() paradgm instead of this routine directly.
131311a5261eSBarry Smith    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
13148c3ff71bSJunchao Zhang 
13158c3ff71bSJunchao Zhang    Notes:
13168c3ff71bSJunchao Zhang    If nnz is given then nz is ignored
13178c3ff71bSJunchao Zhang 
131811a5261eSBarry Smith    The AIJ format, also called
131911a5261eSBarry Smith    compressed row storage, is fully compatible with standard Fortran 77
13208c3ff71bSJunchao Zhang    storage.  That is, the stored row and column indices can begin at
13218c3ff71bSJunchao Zhang    either one (as in Fortran) or zero.  See the users' manual for details.
13228c3ff71bSJunchao Zhang 
13238c3ff71bSJunchao Zhang    Specify the preallocated storage with either nz or nnz (not both).
132411a5261eSBarry Smith    Set nz = `PETSC_DEFAULT` and nnz = NULL for PETSc to control dynamic memory
13258c3ff71bSJunchao Zhang    allocation.  For large problems you MUST preallocate memory or you
13268c3ff71bSJunchao Zhang    will get TERRIBLE performance, see the users' manual chapter on matrices.
13278c3ff71bSJunchao Zhang 
13288c3ff71bSJunchao Zhang    By default, this format uses inodes (identical nodes) when possible, to
13298c3ff71bSJunchao Zhang    improve numerical efficiency of matrix-vector products and solves. We
13308c3ff71bSJunchao Zhang    search for consecutive rows with the same nonzero structure, thereby
13318c3ff71bSJunchao Zhang    reusing matrix information to achieve increased efficiency.
13328c3ff71bSJunchao Zhang 
13338c3ff71bSJunchao Zhang    Level: intermediate
13348c3ff71bSJunchao Zhang 
1335db781477SPatrick Sanan .seealso: `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`
13368c3ff71bSJunchao Zhang @*/
1337*d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateSeqAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
1338*d71ae5a4SJacob Faibussowitsch {
13398c3ff71bSJunchao Zhang   PetscFunctionBegin;
13409566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
13419566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
13429566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, m, n));
13439566063dSJacob Faibussowitsch   PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
13449566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
13458c3ff71bSJunchao Zhang   PetscFunctionReturn(0);
13468c3ff71bSJunchao Zhang }
1347930e68a5SMark Adams 
13488f7e8f9dSMark Adams typedef Kokkos::TeamPolicy<>::member_type team_member;
13498f7e8f9dSMark Adams //
135046804e07SMark Adams // This factorization exploits block diagonal matrices with "Nf" (not used).
13518f7e8f9dSMark Adams // Use -pc_factor_mat_ordering_type rcm to order decouple blocks of size N/Nf for this optimization
13528f7e8f9dSMark Adams //
1353*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorNumeric_SeqAIJKOKKOSDEVICE(Mat B, Mat A, const MatFactorInfo *info)
1354*d71ae5a4SJacob Faibussowitsch {
13558f7e8f9dSMark Adams   Mat_SeqAIJ       *b      = (Mat_SeqAIJ *)B->data;
13568f7e8f9dSMark Adams   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *baijkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
13578f7e8f9dSMark Adams   IS                isrow = b->row, isicol = b->icol;
13588f7e8f9dSMark Adams   const PetscInt   *r_h, *ic_h;
1359300d22a6SJunchao Zhang   const PetscInt n = A->rmap->n, *ai_d = aijkok->i_dual.view_device().data(), *aj_d = aijkok->j_dual.view_device().data(), *bi_d = baijkok->i_dual.view_device().data(), *bj_d = baijkok->j_dual.view_device().data(), *bdiag_d = baijkok->diag_d.data();
1360076ba34aSJunchao Zhang   const PetscScalar *aa_d = aijkok->a_dual.view_device().data();
1361076ba34aSJunchao Zhang   PetscScalar       *ba_d = baijkok->a_dual.view_device().data();
13628f7e8f9dSMark Adams   PetscBool          row_identity, col_identity;
136346804e07SMark Adams   PetscInt           nc, Nf = 1, nVec = 32; // should be a parameter, Nf is batch size - not used
1364930e68a5SMark Adams 
1365930e68a5SMark Adams   PetscFunctionBegin;
13662c71b3e2SJacob Faibussowitsch   PetscCheck(A->rmap->n == n, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "square matrices only supported %" PetscInt_FMT " %" PetscInt_FMT, A->rmap->n, n);
1367b94d7dedSBarry Smith   PetscCall(MatIsStructurallySymmetric(A, &row_identity));
13682c71b3e2SJacob Faibussowitsch   PetscCheck(row_identity, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "structurally symmetric matrices only supported");
13699566063dSJacob Faibussowitsch   PetscCall(ISGetIndices(isrow, &r_h));
13709566063dSJacob Faibussowitsch   PetscCall(ISGetIndices(isicol, &ic_h));
13719566063dSJacob Faibussowitsch   PetscCall(ISGetSize(isicol, &nc));
13729566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
13739566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
13748f7e8f9dSMark Adams   {
13758f7e8f9dSMark Adams #define KOKKOS_SHARED_LEVEL 1
13768f7e8f9dSMark Adams     using scr_mem_t    = Kokkos::DefaultExecutionSpace::scratch_memory_space;
13778f7e8f9dSMark Adams     using sizet_scr_t  = Kokkos::View<size_t, scr_mem_t>;
13788f7e8f9dSMark Adams     using scalar_scr_t = Kokkos::View<PetscScalar, scr_mem_t>;
13798f7e8f9dSMark Adams     const Kokkos::View<const PetscInt *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_r_k(r_h, n);
13808f7e8f9dSMark Adams     Kokkos::View<PetscInt *, Kokkos::LayoutLeft>                                                                         d_r_k("r", n);
13818f7e8f9dSMark Adams     const Kokkos::View<const PetscInt *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_ic_k(ic_h, nc);
13828f7e8f9dSMark Adams     Kokkos::View<PetscInt *, Kokkos::LayoutLeft>                                                                         d_ic_k("ic", nc);
13838f7e8f9dSMark Adams     size_t                                                                                                               flops_h = 0.0;
13848f7e8f9dSMark Adams     Kokkos::View<size_t, Kokkos::HostSpace>                                                                              h_flops_k(&flops_h);
13858f7e8f9dSMark Adams     Kokkos::View<size_t>                                                                                                 d_flops_k("flops");
13868f7e8f9dSMark Adams     const int                                                                                                            conc = Kokkos::DefaultExecutionSpace().concurrency(), team_size = conc > 1 ? 16 : 1; // 8*32 = 256
13878f7e8f9dSMark Adams     const int                                                                                                            nloc = n / Nf, Ni = (conc > 8) ? 1 /* some intelegent number of SMs -- but need league_barrier */ : 1;
13888f7e8f9dSMark Adams     Kokkos::deep_copy(d_flops_k, h_flops_k);
13898f7e8f9dSMark Adams     Kokkos::deep_copy(d_r_k, h_r_k);
13908f7e8f9dSMark Adams     Kokkos::deep_copy(d_ic_k, h_ic_k);
13918f7e8f9dSMark Adams     // Fill A --> fact
13929371c9d4SSatish Balay     Kokkos::parallel_for(
13939371c9d4SSatish Balay       Kokkos::TeamPolicy<>(Nf * Ni, team_size, nVec), KOKKOS_LAMBDA(const team_member team) {
1394042217e8SBarry Smith         const PetscInt  field = team.league_rank() / Ni, field_block = team.league_rank() % Ni; // use grid.x/y in CUDA
13958f7e8f9dSMark Adams         const PetscInt  nloc_i = (nloc / Ni + !!(nloc % Ni)), start_i = field * nloc + field_block * nloc_i, end_i = (start_i + nloc_i) > (field + 1) * nloc ? (field + 1) * nloc : (start_i + nloc_i);
13968f7e8f9dSMark Adams         const PetscInt *ic = d_ic_k.data(), *r = d_r_k.data();
13978f7e8f9dSMark Adams         // zero rows of B
13988f7e8f9dSMark Adams         Kokkos::parallel_for(Kokkos::TeamVectorRange(team, start_i, end_i), [=](const int &rowb) {
13998f7e8f9dSMark Adams           PetscInt     nzbL = bi_d[rowb + 1] - bi_d[rowb], nzbU = bdiag_d[rowb] - bdiag_d[rowb + 1]; // with diag
14008f7e8f9dSMark Adams           PetscScalar *baL = ba_d + bi_d[rowb];
14018f7e8f9dSMark Adams           PetscScalar *baU = ba_d + bdiag_d[rowb + 1] + 1;
14028f7e8f9dSMark Adams           /* zero (unfactored row) */
14038f7e8f9dSMark Adams           for (int j = 0; j < nzbL; j++) baL[j] = 0;
14048f7e8f9dSMark Adams           for (int j = 0; j < nzbU; j++) baU[j] = 0;
14058f7e8f9dSMark Adams         });
14068f7e8f9dSMark Adams         // copy A into B
14078f7e8f9dSMark Adams         Kokkos::parallel_for(Kokkos::TeamVectorRange(team, start_i, end_i), [=](const int &rowb) {
14088f7e8f9dSMark Adams           PetscInt           rowa = r[rowb], nza = ai_d[rowa + 1] - ai_d[rowa];
14098f7e8f9dSMark Adams           const PetscScalar *av    = aa_d + ai_d[rowa];
14108f7e8f9dSMark Adams           const PetscInt    *ajtmp = aj_d + ai_d[rowa];
14118f7e8f9dSMark Adams           /* load in initial (unfactored row) */
14128f7e8f9dSMark Adams           for (int j = 0; j < nza; j++) {
14138f7e8f9dSMark Adams             PetscInt    colb = ic[ajtmp[j]];
14148f7e8f9dSMark Adams             PetscScalar vala = av[j];
14158f7e8f9dSMark Adams             if (colb == rowb) {
14168f7e8f9dSMark Adams               *(ba_d + bdiag_d[rowb]) = vala;
14178f7e8f9dSMark Adams             } else {
14188f7e8f9dSMark Adams               const PetscInt *pbj = bj_d + ((colb > rowb) ? bdiag_d[rowb + 1] + 1 : bi_d[rowb]);
14198f7e8f9dSMark Adams               PetscScalar    *pba = ba_d + ((colb > rowb) ? bdiag_d[rowb + 1] + 1 : bi_d[rowb]);
14208f7e8f9dSMark Adams               PetscInt        nz = (colb > rowb) ? bdiag_d[rowb] - (bdiag_d[rowb + 1] + 1) : bi_d[rowb + 1] - bi_d[rowb], set = 0;
14218f7e8f9dSMark Adams               for (int j = 0; j < nz; j++) {
14228f7e8f9dSMark Adams                 if (pbj[j] == colb) {
14238f7e8f9dSMark Adams                   pba[j] = vala;
14248f7e8f9dSMark Adams                   set++;
14258f7e8f9dSMark Adams                   break;
14268f7e8f9dSMark Adams                 }
14278f7e8f9dSMark Adams               }
14288f1da0b2SJunchao Zhang #if !defined(PETSC_HAVE_SYCL)
14298f7e8f9dSMark Adams               if (set != 1) printf("\t\t\t ERROR DID NOT SET ?????\n");
14308f1da0b2SJunchao Zhang #endif
14318f7e8f9dSMark Adams             }
14328f7e8f9dSMark Adams           }
14338f7e8f9dSMark Adams         });
14348f7e8f9dSMark Adams       });
14358f7e8f9dSMark Adams     Kokkos::fence();
1436930e68a5SMark Adams 
14379371c9d4SSatish Balay     Kokkos::parallel_for(
14389371c9d4SSatish Balay       Kokkos::TeamPolicy<>(Nf * Ni, team_size, nVec).set_scratch_size(KOKKOS_SHARED_LEVEL, Kokkos::PerThread(sizet_scr_t::shmem_size() + scalar_scr_t::shmem_size()), Kokkos::PerTeam(sizet_scr_t::shmem_size())), KOKKOS_LAMBDA(const team_member team) {
14398f7e8f9dSMark Adams         sizet_scr_t    colkIdx(team.thread_scratch(KOKKOS_SHARED_LEVEL));
14408f7e8f9dSMark Adams         scalar_scr_t   L_ki(team.thread_scratch(KOKKOS_SHARED_LEVEL));
14418f7e8f9dSMark Adams         sizet_scr_t    flops(team.team_scratch(KOKKOS_SHARED_LEVEL));
1442042217e8SBarry Smith         const PetscInt field = team.league_rank() / Ni, field_block_idx = team.league_rank() % Ni; // use grid.x/y in CUDA
14438f7e8f9dSMark Adams         const PetscInt start = field * nloc, end = start + nloc;
14448f7e8f9dSMark Adams         Kokkos::single(Kokkos::PerTeam(team), [=]() { flops() = 0; });
14458f7e8f9dSMark Adams         // A22 panel update for each row A(1,:) and col A(:,1)
14468f7e8f9dSMark Adams         for (int ii = start; ii < end - 1; ii++) {
14478f7e8f9dSMark Adams           const PetscInt    *bjUi = bj_d + bdiag_d[ii + 1] + 1, nzUi = bdiag_d[ii] - (bdiag_d[ii + 1] + 1); // vector, and vector size, of column indices of U(i,(i+1):end)
14488f7e8f9dSMark Adams           const PetscScalar *baUi    = ba_d + bdiag_d[ii + 1] + 1;                                          // vector of data  U(i,i+1:end)
14498f7e8f9dSMark Adams           const PetscInt     nUi_its = nzUi / Ni + !!(nzUi % Ni);
14508f7e8f9dSMark Adams           const PetscScalar  Bii     = *(ba_d + bdiag_d[ii]); // diagonal in its special place
14518f7e8f9dSMark Adams           Kokkos::parallel_for(Kokkos::TeamThreadRange(team, nUi_its), [=](const int j) {
14528f7e8f9dSMark Adams             PetscInt kIdx = j * Ni + field_block_idx;
14539371c9d4SSatish Balay             if (kIdx >= nzUi) /* void */
14549371c9d4SSatish Balay               ;
14558f7e8f9dSMark Adams             else {
14568f7e8f9dSMark Adams               const PetscInt  myk = bjUi[kIdx];                // assume symmetric structure, need a transposed meta-data here in general
14578f7e8f9dSMark Adams               const PetscInt *pjL = bj_d + bi_d[myk];          // look for L(myk,ii) in start of row
14588f7e8f9dSMark Adams               const PetscInt  nzL = bi_d[myk + 1] - bi_d[myk]; // size of L_k(:)
14598f7e8f9dSMark Adams               size_t          st_idx;
14608f7e8f9dSMark Adams               // find and do L(k,i) = A(:k,i) / A(i,i)
14618f7e8f9dSMark Adams               Kokkos::single(Kokkos::PerThread(team), [&]() { colkIdx() = PETSC_MAX_INT; });
14628f7e8f9dSMark Adams               // get column, there has got to be a better way
14639371c9d4SSatish Balay               Kokkos::parallel_reduce(
14649371c9d4SSatish Balay                 Kokkos::ThreadVectorRange(team, nzL),
14659371c9d4SSatish Balay                 [&](const int &j, size_t &idx) {
14668f7e8f9dSMark Adams                   if (pjL[j] == ii) {
14678f7e8f9dSMark Adams                     PetscScalar *pLki = ba_d + bi_d[myk] + j;
14688f7e8f9dSMark Adams                     idx               = j;           // output
14698f7e8f9dSMark Adams                     *pLki             = *pLki / Bii; // column scaling:  L(k,i) = A(:k,i) / A(i,i)
14708f7e8f9dSMark Adams                   }
14719371c9d4SSatish Balay                 },
14729371c9d4SSatish Balay                 st_idx);
14739371c9d4SSatish Balay               Kokkos::single(Kokkos::PerThread(team), [=]() {
14749371c9d4SSatish Balay                 colkIdx() = st_idx;
14759371c9d4SSatish Balay                 L_ki()    = *(ba_d + bi_d[myk] + st_idx);
14769371c9d4SSatish Balay               });
14778f1da0b2SJunchao Zhang #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
147899551766SMark Adams               if (colkIdx() == PETSC_MAX_INT) printf("\t\t\t\t\t\t\tERROR: failed to find L_ki(%d,%d)\n", (int)myk, ii); // uses a register
147999551766SMark Adams #endif
148099551766SMark Adams               // active row k, do  A_kj -= Lki * U_ij; j \in U(i,:) j != i
14818f7e8f9dSMark Adams               // U(i+1,:end)
14828f7e8f9dSMark Adams               Kokkos::parallel_for(Kokkos::ThreadVectorRange(team, nzUi), [=](const int &uiIdx) { // index into i (U)
14838f7e8f9dSMark Adams                 PetscScalar Uij = baUi[uiIdx];
14848f7e8f9dSMark Adams                 PetscInt    col = bjUi[uiIdx];
14858f7e8f9dSMark Adams                 if (col == myk) {
14868f7e8f9dSMark Adams                   // A_kk = A_kk - L_ki * U_ij(k)
14878f7e8f9dSMark Adams                   PetscScalar *Akkv = (ba_d + bdiag_d[myk]); // diagonal in its special place
14888f7e8f9dSMark Adams                   *Akkv             = *Akkv - L_ki() * Uij;  // UiK
14898f7e8f9dSMark Adams                 } else {
14908f7e8f9dSMark Adams                   PetscScalar    *start, *end, *pAkjv = NULL;
14918f7e8f9dSMark Adams                   PetscInt        high, low;
14928f7e8f9dSMark Adams                   const PetscInt *startj;
14938f7e8f9dSMark Adams                   if (col < myk) { // L
14948f7e8f9dSMark Adams                     PetscScalar *pLki = ba_d + bi_d[myk] + colkIdx();
14958f7e8f9dSMark Adams                     PetscInt     idx  = (pLki + 1) - (ba_d + bi_d[myk]); // index into row
14968f7e8f9dSMark Adams                     start             = pLki + 1;                        // start at pLki+1, A22(myk,1)
14978f7e8f9dSMark Adams                     startj            = bj_d + bi_d[myk] + idx;
14988f7e8f9dSMark Adams                     end               = ba_d + bi_d[myk + 1];
14998f7e8f9dSMark Adams                   } else {
15008f7e8f9dSMark Adams                     PetscInt idx = bdiag_d[myk + 1] + 1;
15018f7e8f9dSMark Adams                     start        = ba_d + idx;
15028f7e8f9dSMark Adams                     startj       = bj_d + idx;
15038f7e8f9dSMark Adams                     end          = ba_d + bdiag_d[myk];
15048f7e8f9dSMark Adams                   }
15058f7e8f9dSMark Adams                   // search for 'col', use bisection search - TODO
15068f7e8f9dSMark Adams                   low  = 0;
15078f7e8f9dSMark Adams                   high = (PetscInt)(end - start);
15088f7e8f9dSMark Adams                   while (high - low > 5) {
15098f7e8f9dSMark Adams                     int t = (low + high) / 2;
15108f7e8f9dSMark Adams                     if (startj[t] > col) high = t;
15118f7e8f9dSMark Adams                     else low = t;
15128f7e8f9dSMark Adams                   }
15138f7e8f9dSMark Adams                   for (pAkjv = start + low; pAkjv < start + high; pAkjv++) {
15148f7e8f9dSMark Adams                     if (startj[pAkjv - start] == col) break;
15158f7e8f9dSMark Adams                   }
15168f1da0b2SJunchao Zhang #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
151799551766SMark Adams                   if (pAkjv == start + high) printf("\t\t\t\t\t\t\t\t\t\t\tERROR: *** failed to find Akj(%d,%d)\n", (int)myk, (int)col); // uses a register
151899551766SMark Adams #endif
15198f7e8f9dSMark Adams                   *pAkjv = *pAkjv - L_ki() * Uij; // A_kj = A_kj - L_ki * U_ij
15208f7e8f9dSMark Adams                 }
15218f7e8f9dSMark Adams               });
15228f7e8f9dSMark Adams             }
15238f7e8f9dSMark Adams           });
15248f7e8f9dSMark Adams           team.team_barrier(); // this needs to be a league barrier to use more that one SM per block
15258f7e8f9dSMark Adams           if (field_block_idx == 0) Kokkos::single(Kokkos::PerTeam(team), [&]() { Kokkos::atomic_add(flops.data(), (size_t)(2 * (nzUi * nzUi) + 2)); });
15268f7e8f9dSMark Adams         } /* endof for (i=0; i<n; i++) { */
15279371c9d4SSatish Balay         Kokkos::single(Kokkos::PerTeam(team), [=]() {
15289371c9d4SSatish Balay           Kokkos::atomic_add(&d_flops_k(), flops());
15299371c9d4SSatish Balay           flops() = 0;
15309371c9d4SSatish Balay         });
15318f7e8f9dSMark Adams       });
15328f7e8f9dSMark Adams     Kokkos::fence();
15338f7e8f9dSMark Adams     Kokkos::deep_copy(h_flops_k, d_flops_k);
15349566063dSJacob Faibussowitsch     PetscCall(PetscLogGpuFlops((PetscLogDouble)h_flops_k()));
15359371c9d4SSatish Balay     Kokkos::parallel_for(
15369371c9d4SSatish Balay       Kokkos::TeamPolicy<>(Nf * Ni, 1, 256), KOKKOS_LAMBDA(const team_member team) {
15378f7e8f9dSMark Adams         const PetscInt lg_rank = team.league_rank(), field = lg_rank / Ni;                            //, field_offset = lg_rank%Ni;
15388f7e8f9dSMark Adams         const PetscInt start = field * nloc, end = start + nloc, n_its = (nloc / Ni + !!(nloc % Ni)); // 1/Ni iters
15398f7e8f9dSMark Adams         /* Invert diagonal for simpler triangular solves */
15408f7e8f9dSMark Adams         Kokkos::parallel_for(Kokkos::TeamVectorRange(team, n_its), [=](int outer_index) {
15418f7e8f9dSMark Adams           int i = start + outer_index * Ni + lg_rank % Ni;
15428f7e8f9dSMark Adams           if (i < end) {
15438f7e8f9dSMark Adams             PetscScalar *pv = ba_d + bdiag_d[i];
15448f7e8f9dSMark Adams             *pv             = 1.0 / (*pv);
15458f7e8f9dSMark Adams           }
15468f7e8f9dSMark Adams         });
15478f7e8f9dSMark Adams       });
15488f7e8f9dSMark Adams   }
15499566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
15509566063dSJacob Faibussowitsch   PetscCall(ISRestoreIndices(isicol, &ic_h));
15519566063dSJacob Faibussowitsch   PetscCall(ISRestoreIndices(isrow, &r_h));
15528f7e8f9dSMark Adams 
15539566063dSJacob Faibussowitsch   PetscCall(ISIdentity(isrow, &row_identity));
15549566063dSJacob Faibussowitsch   PetscCall(ISIdentity(isicol, &col_identity));
15558f7e8f9dSMark Adams   if (b->inode.size) {
15568f7e8f9dSMark Adams     B->ops->solve = MatSolve_SeqAIJ_Inode;
15578f7e8f9dSMark Adams   } else if (row_identity && col_identity) {
15588f7e8f9dSMark Adams     B->ops->solve = MatSolve_SeqAIJ_NaturalOrdering;
15598f7e8f9dSMark Adams   } else {
15608f7e8f9dSMark Adams     B->ops->solve = MatSolve_SeqAIJ; // at least this needs to be in Kokkos
15618f7e8f9dSMark Adams   }
15628f7e8f9dSMark Adams   B->offloadmask = PETSC_OFFLOAD_GPU;
15639566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncHost(B));          // solve on CPU
15648f7e8f9dSMark Adams   B->ops->solveadd          = MatSolveAdd_SeqAIJ; // and this
15658f7e8f9dSMark Adams   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJ;
15668f7e8f9dSMark Adams   B->ops->solvetransposeadd = MatSolveTransposeAdd_SeqAIJ;
15678f7e8f9dSMark Adams   B->ops->matsolve          = MatMatSolve_SeqAIJ;
15688f7e8f9dSMark Adams   B->assembled              = PETSC_TRUE;
15698f7e8f9dSMark Adams   B->preallocated           = PETSC_TRUE;
15708f7e8f9dSMark Adams 
1571930e68a5SMark Adams   PetscFunctionReturn(0);
1572930e68a5SMark Adams }
1573930e68a5SMark Adams 
1574*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1575*d71ae5a4SJacob Faibussowitsch {
1576930e68a5SMark Adams   PetscFunctionBegin;
15779566063dSJacob Faibussowitsch   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
157886a27549SJunchao Zhang   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
157986a27549SJunchao Zhang   PetscFunctionReturn(0);
158086a27549SJunchao Zhang }
158186a27549SJunchao Zhang 
1582*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSymbolicSolveCheck(Mat A)
1583*d71ae5a4SJacob Faibussowitsch {
158486a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
158586a27549SJunchao Zhang 
158686a27549SJunchao Zhang   PetscFunctionBegin;
158786a27549SJunchao Zhang   if (!factors->sptrsv_symbolic_completed) {
158886a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d);
158986a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d);
159086a27549SJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_TRUE;
159186a27549SJunchao Zhang   }
159286a27549SJunchao Zhang   PetscFunctionReturn(0);
159386a27549SJunchao Zhang }
159486a27549SJunchao Zhang 
159586a27549SJunchao Zhang /* Check if we need to update factors etc for transpose solve */
1596*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosTransposeSolveCheck(Mat A)
1597*d71ae5a4SJacob Faibussowitsch {
159886a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1599076ba34aSJunchao Zhang   MatColIdxType               n       = A->rmap->n;
160086a27549SJunchao Zhang 
160186a27549SJunchao Zhang   PetscFunctionBegin;
160286a27549SJunchao Zhang   if (!factors->transpose_updated) { /* TODO: KK needs to provide functions to do numeric transpose only */
160386a27549SJunchao Zhang     /* Update L^T and do sptrsv symbolic */
1604076ba34aSJunchao Zhang     factors->iLt_d = MatRowMapKokkosView("factors->iLt_d", n + 1);
160586a27549SJunchao Zhang     Kokkos::deep_copy(factors->iLt_d, 0); /* KK requires 0 */
1606076ba34aSJunchao Zhang     factors->jLt_d = MatColIdxKokkosView("factors->jLt_d", factors->jL_d.extent(0));
1607076ba34aSJunchao Zhang     factors->aLt_d = MatScalarKokkosView("factors->aLt_d", factors->aL_d.extent(0));
160886a27549SJunchao Zhang 
16099371c9d4SSatish Balay     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iL_d, factors->jL_d, factors->aL_d,
161086a27549SJunchao Zhang                                                                                                                                                                                                               factors->iLt_d, factors->jLt_d, factors->aLt_d);
161186a27549SJunchao Zhang 
161286a27549SJunchao Zhang     /* TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices.
161386a27549SJunchao Zhang       We have to sort the indices, until KK provides finer control options.
161486a27549SJunchao Zhang     */
16159371c9d4SSatish Balay     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iLt_d, factors->jLt_d, factors->aLt_d);
161686a27549SJunchao Zhang 
161786a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d);
161886a27549SJunchao Zhang 
161986a27549SJunchao Zhang     /* Update U^T and do sptrsv symbolic */
1620076ba34aSJunchao Zhang     factors->iUt_d = MatRowMapKokkosView("factors->iUt_d", n + 1);
162186a27549SJunchao Zhang     Kokkos::deep_copy(factors->iUt_d, 0); /* KK requires 0 */
1622076ba34aSJunchao Zhang     factors->jUt_d = MatColIdxKokkosView("factors->jUt_d", factors->jU_d.extent(0));
1623076ba34aSJunchao Zhang     factors->aUt_d = MatScalarKokkosView("factors->aUt_d", factors->aU_d.extent(0));
162486a27549SJunchao Zhang 
16259371c9d4SSatish Balay     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iU_d, factors->jU_d, factors->aU_d,
162686a27549SJunchao Zhang                                                                                                                                                                                                               factors->iUt_d, factors->jUt_d, factors->aUt_d);
162786a27549SJunchao Zhang 
162886a27549SJunchao Zhang     /* Sort indices. See comments above */
16299371c9d4SSatish Balay     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iUt_d, factors->jUt_d, factors->aUt_d);
163086a27549SJunchao Zhang 
163186a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d);
163286a27549SJunchao Zhang     factors->transpose_updated = PETSC_TRUE;
163386a27549SJunchao Zhang   }
163486a27549SJunchao Zhang   PetscFunctionReturn(0);
163586a27549SJunchao Zhang }
163686a27549SJunchao Zhang 
163786a27549SJunchao Zhang /* Solve Ax = b, with A = LU */
1638*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSolve_SeqAIJKokkos(Mat A, Vec b, Vec x)
1639*d71ae5a4SJacob Faibussowitsch {
164086a27549SJunchao Zhang   ConstPetscScalarKokkosView  bv;
164186a27549SJunchao Zhang   PetscScalarKokkosView       xv;
164286a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
164386a27549SJunchao Zhang 
164486a27549SJunchao Zhang   PetscFunctionBegin;
16459566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
16469566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSymbolicSolveCheck(A));
16479566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(b, &bv));
16489566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(x, &xv));
164986a27549SJunchao Zhang   /* Solve L tmpv = b */
16509566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d, bv, factors->workVector));
165186a27549SJunchao Zhang   /* Solve Ux = tmpv */
16529566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, factors->workVector, xv));
16539566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(b, &bv));
16549566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
16559566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
165686a27549SJunchao Zhang   PetscFunctionReturn(0);
165786a27549SJunchao Zhang }
165886a27549SJunchao Zhang 
1659076ba34aSJunchao Zhang /* Solve A^T x = b, where A^T = U^T L^T */
1660*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSolveTranspose_SeqAIJKokkos(Mat A, Vec b, Vec x)
1661*d71ae5a4SJacob Faibussowitsch {
166286a27549SJunchao Zhang   ConstPetscScalarKokkosView  bv;
166386a27549SJunchao Zhang   PetscScalarKokkosView       xv;
166486a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
166586a27549SJunchao Zhang 
166686a27549SJunchao Zhang   PetscFunctionBegin;
16679566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
16689566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A));
16699566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(b, &bv));
16709566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(x, &xv));
167186a27549SJunchao Zhang   /* Solve U^T tmpv = b */
167286a27549SJunchao Zhang   KokkosSparse::Experimental::sptrsv_solve(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, bv, factors->workVector);
167386a27549SJunchao Zhang 
167486a27549SJunchao Zhang   /* Solve L^T x = tmpv */
167586a27549SJunchao Zhang   KokkosSparse::Experimental::sptrsv_solve(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d, factors->workVector, xv);
16769566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(b, &bv));
16779566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
16789566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
167986a27549SJunchao Zhang   PetscFunctionReturn(0);
168086a27549SJunchao Zhang }
168186a27549SJunchao Zhang 
1682*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatILUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1683*d71ae5a4SJacob Faibussowitsch {
168486a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
168586a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
168686a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
168786a27549SJunchao Zhang 
168886a27549SJunchao Zhang   PetscFunctionBegin;
16899566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
16909566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1691076ba34aSJunchao Zhang 
1692076ba34aSJunchao Zhang   auto a_d = aijkok->a_dual.view_device();
1693076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
1694076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
1695076ba34aSJunchao Zhang 
1696076ba34aSJunchao Zhang   KokkosSparse::Experimental::spiluk_numeric(&factors->kh, fill_lev, i_d, j_d, a_d, factors->iL_d, factors->jL_d, factors->aL_d, factors->iU_d, factors->jU_d, factors->aU_d);
169786a27549SJunchao Zhang 
169886a27549SJunchao Zhang   B->assembled              = PETSC_TRUE;
169986a27549SJunchao Zhang   B->preallocated           = PETSC_TRUE;
170086a27549SJunchao Zhang   B->ops->solve             = MatSolve_SeqAIJKokkos;
170186a27549SJunchao Zhang   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJKokkos;
170286a27549SJunchao Zhang   B->ops->matsolve          = NULL;
170386a27549SJunchao Zhang   B->ops->matsolvetranspose = NULL;
170486a27549SJunchao Zhang   B->offloadmask            = PETSC_OFFLOAD_GPU;
170586a27549SJunchao Zhang 
170686a27549SJunchao Zhang   /* Once the factors' value changed, we need to update their transpose and sptrsv handle */
170786a27549SJunchao Zhang   factors->transpose_updated         = PETSC_FALSE;
170886a27549SJunchao Zhang   factors->sptrsv_symbolic_completed = PETSC_FALSE;
1709eeadb341SJunchao Zhang   /* TODO: log flops, but how to know that? */
17109566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
171186a27549SJunchao Zhang   PetscFunctionReturn(0);
171286a27549SJunchao Zhang }
171386a27549SJunchao Zhang 
1714*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1715*d71ae5a4SJacob Faibussowitsch {
171686a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok;
171786a27549SJunchao Zhang   Mat_SeqAIJ                 *b;
171886a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
171986a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
172086a27549SJunchao Zhang   PetscInt                    nnzA     = ((Mat_SeqAIJ *)A->data)->nz, nnzL, nnzU;
172186a27549SJunchao Zhang   PetscInt                    n        = A->rmap->n;
172286a27549SJunchao Zhang 
172386a27549SJunchao Zhang   PetscFunctionBegin;
17249566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
172586a27549SJunchao Zhang   /* Rebuild factors */
17269371c9d4SSatish Balay   if (factors) {
17279371c9d4SSatish Balay     factors->Destroy();
17289371c9d4SSatish Balay   } /* Destroy the old if it exists */
17299371c9d4SSatish Balay   else {
17309371c9d4SSatish Balay     B->spptr = factors = new Mat_SeqAIJKokkosTriFactors(n);
17319371c9d4SSatish Balay   }
173286a27549SJunchao Zhang 
173386a27549SJunchao Zhang   /* Create a spiluk handle and then do symbolic factorization */
173486a27549SJunchao Zhang   nnzL = nnzU = PetscRealIntMultTruncate(info->fill, nnzA);
173586a27549SJunchao Zhang   factors->kh.create_spiluk_handle(KokkosSparse::Experimental::SPILUKAlgorithm::SEQLVLSCHD_TP1, n, nnzL, nnzU);
173686a27549SJunchao Zhang 
173786a27549SJunchao Zhang   auto spiluk_handle = factors->kh.get_spiluk_handle();
173886a27549SJunchao Zhang 
173986a27549SJunchao Zhang   Kokkos::realloc(factors->iL_d, n + 1); /* Free old arrays and realloc */
174086a27549SJunchao Zhang   Kokkos::realloc(factors->jL_d, spiluk_handle->get_nnzL());
174186a27549SJunchao Zhang   Kokkos::realloc(factors->iU_d, n + 1);
174286a27549SJunchao Zhang   Kokkos::realloc(factors->jU_d, spiluk_handle->get_nnzU());
174386a27549SJunchao Zhang 
174486a27549SJunchao Zhang   aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
1745076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
1746076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
1747076ba34aSJunchao Zhang   KokkosSparse::Experimental::spiluk_symbolic(&factors->kh, fill_lev, i_d, j_d, factors->iL_d, factors->jL_d, factors->iU_d, factors->jU_d);
174886a27549SJunchao Zhang   /* TODO: if spiluk_symbolic is asynchronous, do we need to sync before calling get_nnzL()? */
174986a27549SJunchao Zhang 
175086a27549SJunchao Zhang   Kokkos::resize(factors->jL_d, spiluk_handle->get_nnzL()); /* Shrink or expand, and retain old value */
175186a27549SJunchao Zhang   Kokkos::resize(factors->jU_d, spiluk_handle->get_nnzU());
175286a27549SJunchao Zhang   Kokkos::realloc(factors->aL_d, spiluk_handle->get_nnzL()); /* No need to retain old value */
175386a27549SJunchao Zhang   Kokkos::realloc(factors->aU_d, spiluk_handle->get_nnzU());
175486a27549SJunchao Zhang 
175586a27549SJunchao Zhang   /* TODO: add options to select sptrsv algorithms */
175686a27549SJunchao Zhang   /* Create sptrsv handles for L, U and their transpose */
175786a27549SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
175886a27549SJunchao Zhang   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SPTRSV_CUSPARSE;
175986a27549SJunchao Zhang #else
176086a27549SJunchao Zhang   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SEQLVLSCHD_TP1;
176186a27549SJunchao Zhang #endif
176286a27549SJunchao Zhang 
176386a27549SJunchao Zhang   factors->khL.create_sptrsv_handle(sptrsv_alg, n, true /* L is lower tri */);
176486a27549SJunchao Zhang   factors->khU.create_sptrsv_handle(sptrsv_alg, n, false /* U is not lower tri */);
176586a27549SJunchao Zhang   factors->khLt.create_sptrsv_handle(sptrsv_alg, n, false /* L^T is not lower tri */);
176686a27549SJunchao Zhang   factors->khUt.create_sptrsv_handle(sptrsv_alg, n, true /* U^T is lower tri */);
176786a27549SJunchao Zhang 
176886a27549SJunchao Zhang   /* Fill fields of the factor matrix B */
17699566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(B, MAT_SKIP_ALLOCATION, NULL));
177086a27549SJunchao Zhang   b     = (Mat_SeqAIJ *)B->data;
177186a27549SJunchao Zhang   b->nz = b->maxnz          = spiluk_handle->get_nnzL() + spiluk_handle->get_nnzU();
177286a27549SJunchao Zhang   B->info.fill_ratio_given  = info->fill;
177386a27549SJunchao Zhang   B->info.fill_ratio_needed = ((PetscReal)b->nz) / ((PetscReal)nnzA);
177486a27549SJunchao Zhang 
177586a27549SJunchao Zhang   B->offloadmask          = PETSC_OFFLOAD_GPU;
177686a27549SJunchao Zhang   B->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJKokkos;
1777930e68a5SMark Adams   PetscFunctionReturn(0);
1778930e68a5SMark Adams }
1779930e68a5SMark Adams 
1780*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorSymbolic_SeqAIJKOKKOSDEVICE(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1781*d71ae5a4SJacob Faibussowitsch {
17828f7e8f9dSMark Adams   Mat_SeqAIJ    *b     = (Mat_SeqAIJ *)B->data;
17838f7e8f9dSMark Adams   const PetscInt nrows = A->rmap->n;
1784930e68a5SMark Adams 
17858f7e8f9dSMark Adams   PetscFunctionBegin;
17869566063dSJacob Faibussowitsch   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
17878f7e8f9dSMark Adams   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKOKKOSDEVICE;
17888f7e8f9dSMark Adams   // move B data into Kokkos
17899566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B)); // create aijkok
17909566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A)); // create aijkok
17918f7e8f9dSMark Adams   {
17928f7e8f9dSMark Adams     Mat_SeqAIJKokkos *baijkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
1793300d22a6SJunchao Zhang     if (!baijkok->diag_d.extent(0)) {
17948f7e8f9dSMark Adams       const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_diag(b->diag, nrows + 1);
1795300d22a6SJunchao Zhang       baijkok->diag_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_diag));
1796300d22a6SJunchao Zhang       Kokkos::deep_copy(baijkok->diag_d, h_diag);
17978f7e8f9dSMark Adams     }
17988f7e8f9dSMark Adams   }
17998f7e8f9dSMark Adams   PetscFunctionReturn(0);
18008f7e8f9dSMark Adams }
18018f7e8f9dSMark Adams 
1802*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatFactorGetSolverType_SeqAIJKokkos(Mat A, MatSolverType *type)
1803*d71ae5a4SJacob Faibussowitsch {
1804930e68a5SMark Adams   PetscFunctionBegin;
1805930e68a5SMark Adams   *type = MATSOLVERKOKKOS;
1806930e68a5SMark Adams   PetscFunctionReturn(0);
1807930e68a5SMark Adams }
1808930e68a5SMark Adams 
1809*d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatFactorGetSolverType_seqaij_kokkos_device(Mat A, MatSolverType *type)
1810*d71ae5a4SJacob Faibussowitsch {
18118f7e8f9dSMark Adams   PetscFunctionBegin;
18128f7e8f9dSMark Adams   *type = MATSOLVERKOKKOSDEVICE;
18138f7e8f9dSMark Adams   PetscFunctionReturn(0);
18148f7e8f9dSMark Adams }
18158f7e8f9dSMark Adams 
1816930e68a5SMark Adams /*MC
181786a27549SJunchao Zhang   MATSOLVERKOKKOS = "Kokkos" - A matrix solver type providing triangular solvers for sequential matrices
181811a5261eSBarry Smith   on a single GPU of type, `MATSEQAIJKOKKOS`, `MATAIJKOKKOS`.
1819930e68a5SMark Adams 
1820930e68a5SMark Adams   Level: beginner
1821930e68a5SMark Adams 
1822db781477SPatrick Sanan .seealso: `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJKokkos()`, `MATAIJKOKKOS`, `MatKokkosSetFormat()`, `MatKokkosStorageFormat`, `MatKokkosFormatOperation`
1823930e68a5SMark Adams M*/
182486a27549SJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_SeqAIJKokkos_Kokkos(Mat A, MatFactorType ftype, Mat *B) /* MatGetFactor_<MatType>_<MatSolverType> */
1825930e68a5SMark Adams {
1826930e68a5SMark Adams   PetscInt n = A->rmap->n;
1827930e68a5SMark Adams 
1828930e68a5SMark Adams   PetscFunctionBegin;
18299566063dSJacob Faibussowitsch   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B));
18309566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*B, n, n, n, n));
1831930e68a5SMark Adams   (*B)->factortype = ftype;
18329566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
18339566063dSJacob Faibussowitsch   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
1834930e68a5SMark Adams 
18358f7e8f9dSMark Adams   if (ftype == MAT_FACTOR_LU) {
18369566063dSJacob Faibussowitsch     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
183786a27549SJunchao Zhang     (*B)->canuseordering        = PETSC_TRUE;
183886a27549SJunchao Zhang     (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJKokkos;
183986a27549SJunchao Zhang   } else if (ftype == MAT_FACTOR_ILU) {
18409566063dSJacob Faibussowitsch     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
184186a27549SJunchao Zhang     (*B)->canuseordering         = PETSC_FALSE;
184286a27549SJunchao Zhang     (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJKokkos;
184398921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "MatFactorType %s is not supported by MatType SeqAIJKokkos", MatFactorTypes[ftype]);
1844930e68a5SMark Adams 
18459566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
18469566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)(*B), "MatFactorGetSolverType_C", MatFactorGetSolverType_SeqAIJKokkos));
1847930e68a5SMark Adams   PetscFunctionReturn(0);
1848930e68a5SMark Adams }
18498f7e8f9dSMark Adams 
1850*d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatGetFactor_seqaijkokkos_kokkos_device(Mat A, MatFactorType ftype, Mat *B)
1851*d71ae5a4SJacob Faibussowitsch {
18528f7e8f9dSMark Adams   PetscInt n = A->rmap->n;
18538f7e8f9dSMark Adams 
18548f7e8f9dSMark Adams   PetscFunctionBegin;
18559566063dSJacob Faibussowitsch   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B));
18569566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*B, n, n, n, n));
18578f7e8f9dSMark Adams   (*B)->factortype     = ftype;
1858f73b0415SBarry Smith   (*B)->canuseordering = PETSC_TRUE;
18599566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
18609566063dSJacob Faibussowitsch   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
18618f7e8f9dSMark Adams 
18628f7e8f9dSMark Adams   if (ftype == MAT_FACTOR_LU) {
18639566063dSJacob Faibussowitsch     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
18648f7e8f9dSMark Adams     (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJKOKKOSDEVICE;
18658f7e8f9dSMark Adams   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Factor type not supported for KOKKOS Matrix Types");
18668f7e8f9dSMark Adams 
18679566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
18689566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)(*B), "MatFactorGetSolverType_C", MatFactorGetSolverType_seqaij_kokkos_device));
18698f7e8f9dSMark Adams   PetscFunctionReturn(0);
18708f7e8f9dSMark Adams }
187186a27549SJunchao Zhang 
1872*d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatSolverTypeRegister_KOKKOS(void)
1873*d71ae5a4SJacob Faibussowitsch {
187486a27549SJunchao Zhang   PetscFunctionBegin;
18759566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_SeqAIJKokkos_Kokkos));
18769566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ILU, MatGetFactor_SeqAIJKokkos_Kokkos));
18779566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOSDEVICE, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_seqaijkokkos_kokkos_device));
187886a27549SJunchao Zhang   PetscFunctionReturn(0);
187986a27549SJunchao Zhang }
188086a27549SJunchao Zhang 
1881076ba34aSJunchao Zhang /* Utility to print out a KokkosCsrMatrix for debugging */
1882*d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PrintCsrMatrix(const KokkosCsrMatrix &csrmat)
1883*d71ae5a4SJacob Faibussowitsch {
1884076ba34aSJunchao Zhang   const auto        &iv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.row_map);
1885076ba34aSJunchao Zhang   const auto        &jv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.entries);
1886076ba34aSJunchao Zhang   const auto        &av = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.values);
1887076ba34aSJunchao Zhang   const PetscInt    *i  = iv.data();
1888076ba34aSJunchao Zhang   const PetscInt    *j  = jv.data();
1889076ba34aSJunchao Zhang   const PetscScalar *a  = av.data();
1890076ba34aSJunchao Zhang   PetscInt           m = csrmat.numRows(), n = csrmat.numCols(), nnz = csrmat.nnz();
1891076ba34aSJunchao Zhang 
1892076ba34aSJunchao Zhang   PetscFunctionBegin;
18939566063dSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT " x %" PetscInt_FMT " SeqAIJKokkos, with %" PetscInt_FMT " nonzeros\n", m, n, nnz));
1894076ba34aSJunchao Zhang   for (PetscInt k = 0; k < m; k++) {
18959566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT ": ", k));
189648a46eb9SPierre Jolivet     for (PetscInt p = i[k]; p < i[k + 1]; p++) PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT "(%.1f), ", j[p], (double)PetscRealPart(a[p])));
18979566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "\n"));
1898076ba34aSJunchao Zhang   }
1899076ba34aSJunchao Zhang   PetscFunctionReturn(0);
1900076ba34aSJunchao Zhang }
1901