xref: /petsc/src/mat/impls/aij/seq/kokkos/aijkok.kokkos.cxx (revision 3ba1676111f5c958fe6c2729b46ca4d523958bb3)
111d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
2076ba34aSJunchao Zhang #include <petscpkg_version.h>
3152b3e56SJunchao Zhang #include <petsc/private/petscimpl.h>
442550becSJunchao Zhang #include <petsc/private/sfimpl.h>
58c3ff71bSJunchao Zhang #include <petscsystypes.h>
68c3ff71bSJunchao Zhang #include <petscerror.h>
78c3ff71bSJunchao Zhang 
88c3ff71bSJunchao Zhang #include <Kokkos_Core.hpp>
9f0cf5187SStefano Zampini #include <KokkosBlas.hpp>
108c3ff71bSJunchao Zhang #include <KokkosSparse_CrsMatrix.hpp>
118c3ff71bSJunchao Zhang #include <KokkosSparse_spmv.hpp>
1286a27549SJunchao Zhang #include <KokkosSparse_spiluk.hpp>
1386a27549SJunchao Zhang #include <KokkosSparse_sptrsv.hpp>
14076ba34aSJunchao Zhang #include <KokkosSparse_spgemm.hpp>
15076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
1686a27549SJunchao Zhang 
1742550becSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
188c3ff71bSJunchao Zhang 
19f98996d3SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 6, 99)
20f98996d3SJunchao Zhang   #include <KokkosSparse_Utils.hpp>
21f98996d3SJunchao Zhang using KokkosSparse::sort_crs_matrix;
229371c9d4SSatish Balay using KokkosSparse::Impl::transpose_matrix;
23f98996d3SJunchao Zhang #else
24f98996d3SJunchao Zhang   #include <KokkosKernels_Sorting.hpp>
25f98996d3SJunchao Zhang using KokkosKernels::sort_crs_matrix;
269371c9d4SSatish Balay using KokkosKernels::Impl::transpose_matrix;
27f98996d3SJunchao Zhang #endif
28f98996d3SJunchao Zhang 
298c3ff71bSJunchao Zhang static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat); /* Forward declaration */
308c3ff71bSJunchao Zhang 
31076ba34aSJunchao Zhang /* MatAssemblyEnd_SeqAIJKokkos() happens when we finalized nonzeros of the matrix, either after
32076ba34aSJunchao Zhang    we assembled the matrix on host, or after we directly produced the matrix data on device (ex., through MatMatMult).
33076ba34aSJunchao Zhang    In the latter case, it is important to set a_dual's sync state correctly.
34076ba34aSJunchao Zhang  */
35d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_SeqAIJKokkos(Mat A, MatAssemblyType mode)
36d71ae5a4SJacob Faibussowitsch {
37076ba34aSJunchao Zhang   Mat_SeqAIJ       *aijseq;
38076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
398c3ff71bSJunchao Zhang 
408c3ff71bSJunchao Zhang   PetscFunctionBegin;
41*3ba16761SJacob Faibussowitsch   if (mode == MAT_FLUSH_ASSEMBLY) PetscFunctionReturn(PETSC_SUCCESS);
429566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
43076ba34aSJunchao Zhang 
44076ba34aSJunchao Zhang   aijseq = static_cast<Mat_SeqAIJ *>(A->data);
45076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
46076ba34aSJunchao Zhang 
47076ba34aSJunchao Zhang   /* If aijkok does not exist, we just copy i, j to device.
48076ba34aSJunchao Zhang      If aijkok already exists, but the device's nonzero pattern does not match with the host's, we assume the latest data is on host.
49076ba34aSJunchao Zhang      In both cases, we build a new aijkok structure.
50076ba34aSJunchao Zhang   */
51076ba34aSJunchao Zhang   if (!aijkok || aijkok->nonzerostate != A->nonzerostate) { /* aijkok might not exist yet or nonzero pattern has changed */
52076ba34aSJunchao Zhang     delete aijkok;
53076ba34aSJunchao Zhang     aijkok   = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aijseq->nz, aijseq->i, aijseq->j, aijseq->a, A->nonzerostate, PETSC_FALSE /*don't copy mat values to device*/);
54076ba34aSJunchao Zhang     A->spptr = aijkok;
55076ba34aSJunchao Zhang   }
56076ba34aSJunchao Zhang 
575519a089SJose E. Roman   if (aijkok->device_mat_d.data()) {
58a587d139SMark     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
59a587d139SMark   }
60*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
618c3ff71bSJunchao Zhang }
628c3ff71bSJunchao Zhang 
6386a27549SJunchao Zhang /* Sync CSR data to device if not yet */
64d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosSyncDevice(Mat A)
65d71ae5a4SJacob Faibussowitsch {
668c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
678c3ff71bSJunchao Zhang 
688c3ff71bSJunchao Zhang   PetscFunctionBegin;
695f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Cann't sync factorized matrix from host to device");
705f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
71076ba34aSJunchao Zhang   if (aijkok->a_dual.need_sync_device()) {
72076ba34aSJunchao Zhang     aijkok->a_dual.sync_device();
73580c7c76SPierre Jolivet     aijkok->transpose_updated = PETSC_FALSE; /* values of the transpose is out-of-date */
7486a27549SJunchao Zhang     aijkok->hermitian_updated = PETSC_FALSE;
758c3ff71bSJunchao Zhang   }
76*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
778c3ff71bSJunchao Zhang }
788c3ff71bSJunchao Zhang 
79076ba34aSJunchao Zhang /* Mark the CSR data on device as modified */
80d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosModifyDevice(Mat A)
81d71ae5a4SJacob Faibussowitsch {
8286a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
8386a27549SJunchao Zhang 
8486a27549SJunchao Zhang   PetscFunctionBegin;
855f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Not supported for factorized matries");
8686a27549SJunchao Zhang   aijkok->a_dual.clear_sync_state();
8786a27549SJunchao Zhang   aijkok->a_dual.modify_device();
8886a27549SJunchao Zhang   aijkok->transpose_updated = PETSC_FALSE;
8986a27549SJunchao Zhang   aijkok->hermitian_updated = PETSC_FALSE;
909566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJInvalidateDiagonal(A));
919566063dSJacob Faibussowitsch   PetscCall(PetscObjectStateIncrease((PetscObject)A));
92*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
9386a27549SJunchao Zhang }
9486a27549SJunchao Zhang 
95d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSyncHost(Mat A)
96d71ae5a4SJacob Faibussowitsch {
97f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
98f0cf5187SStefano Zampini 
99f0cf5187SStefano Zampini   PetscFunctionBegin;
100f0cf5187SStefano Zampini   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10186a27549SJunchao Zhang   /* We do not expect one needs factors on host  */
1025f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Cann't sync factorized matrix from device to host");
1035f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing AIJKOK");
104076ba34aSJunchao Zhang   aijkok->a_dual.sync_host();
105*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
106f0cf5187SStefano Zampini }
107f0cf5187SStefano Zampini 
108d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
109d71ae5a4SJacob Faibussowitsch {
110076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
111f0cf5187SStefano Zampini 
112f0cf5187SStefano Zampini   PetscFunctionBegin;
1135519a089SJose E. Roman   /* aijkok contains valid pointers only if the host's nonzerostate matches with the device's.
1145519a089SJose E. Roman     Calling MatSeqAIJSetPreallocation() or MatSetValues() on host, where aijseq->{i,j,a} might be
1155519a089SJose E. Roman     reallocated, will lead to stale {i,j,a}_dual in aijkok. In both operations, the hosts's nonzerostate
1165519a089SJose E. Roman     must have been updated. The stale aijkok will be rebuilt during MatAssemblyEnd.
1175519a089SJose E. Roman   */
1185519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
119076ba34aSJunchao Zhang     aijkok->a_dual.sync_host();
120076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
121076ba34aSJunchao Zhang   } else { /* Happens when calling MatSetValues on a newly created matrix */
122076ba34aSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
123076ba34aSJunchao Zhang   }
124*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
125076ba34aSJunchao Zhang }
126076ba34aSJunchao Zhang 
127d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
128d71ae5a4SJacob Faibussowitsch {
129076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
130076ba34aSJunchao Zhang 
131076ba34aSJunchao Zhang   PetscFunctionBegin;
1325519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) aijkok->a_dual.modify_host();
133*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
134076ba34aSJunchao Zhang }
135076ba34aSJunchao Zhang 
136d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
137d71ae5a4SJacob Faibussowitsch {
138076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
139076ba34aSJunchao Zhang 
140076ba34aSJunchao Zhang   PetscFunctionBegin;
1415519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
142076ba34aSJunchao Zhang     aijkok->a_dual.sync_host();
143076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1442328674fSJunchao Zhang   } else {
1452328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1462328674fSJunchao Zhang   }
147*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
148076ba34aSJunchao Zhang }
149076ba34aSJunchao Zhang 
150d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
151d71ae5a4SJacob Faibussowitsch {
152076ba34aSJunchao Zhang   PetscFunctionBegin;
153076ba34aSJunchao Zhang   *array = NULL;
154*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
155076ba34aSJunchao Zhang }
156076ba34aSJunchao Zhang 
157d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
158d71ae5a4SJacob Faibussowitsch {
159076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
160076ba34aSJunchao Zhang 
161076ba34aSJunchao Zhang   PetscFunctionBegin;
1625519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
163076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1642328674fSJunchao Zhang   } else { /* Ex. happens with MatZeroEntries on a preallocated but not assembled matrix */
1652328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1662328674fSJunchao Zhang   }
167*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
168076ba34aSJunchao Zhang }
169076ba34aSJunchao Zhang 
170d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
171d71ae5a4SJacob Faibussowitsch {
172076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
173076ba34aSJunchao Zhang 
174076ba34aSJunchao Zhang   PetscFunctionBegin;
1755519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
176076ba34aSJunchao Zhang     aijkok->a_dual.clear_sync_state();
177076ba34aSJunchao Zhang     aijkok->a_dual.modify_host();
1782328674fSJunchao Zhang   }
179*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
180f0cf5187SStefano Zampini }
181f0cf5187SStefano Zampini 
182d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJKokkos(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
183d71ae5a4SJacob Faibussowitsch {
1847ee59b9bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1857ee59b9bSJunchao Zhang 
1867ee59b9bSJunchao Zhang   PetscFunctionBegin;
1877ee59b9bSJunchao Zhang   PetscCheck(aijkok != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "aijkok is NULL");
1887ee59b9bSJunchao Zhang 
1897ee59b9bSJunchao Zhang   if (i) *i = aijkok->i_device_data();
1907ee59b9bSJunchao Zhang   if (j) *j = aijkok->j_device_data();
1917ee59b9bSJunchao Zhang   if (a) {
1927ee59b9bSJunchao Zhang     aijkok->a_dual.sync_device();
1937ee59b9bSJunchao Zhang     *a = aijkok->a_device_data();
1947ee59b9bSJunchao Zhang   }
1957ee59b9bSJunchao Zhang   if (mtype) *mtype = PETSC_MEMTYPE_KOKKOS;
196*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1977ee59b9bSJunchao Zhang }
1987ee59b9bSJunchao Zhang 
199a587d139SMark // MatSeqAIJKokkosSetDeviceMat takes a PetscSplitCSRDataStructure with device data and copies it to the device. Note, "deep_copy" here is really a shallow copy
200d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosSetDeviceMat(Mat A, PetscSplitCSRDataStructure h_mat)
201d71ae5a4SJacob Faibussowitsch {
202042217e8SBarry Smith   Mat_SeqAIJKokkos                            *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
203042217e8SBarry Smith   Kokkos::View<SplitCSRMat, Kokkos::HostSpace> h_mat_k(h_mat);
204a587d139SMark 
205a587d139SMark   PetscFunctionBegin;
2065f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
207152b3e56SJunchao Zhang   aijkok->device_mat_d = create_mirror(DefaultMemorySpace(), h_mat_k);
208a587d139SMark   Kokkos::deep_copy(aijkok->device_mat_d, h_mat_k);
209*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
210a587d139SMark }
211a587d139SMark 
212a587d139SMark // MatSeqAIJKokkosGetDeviceMat gets the device if it is here, otherwise it creates a place for it and returns NULL
213d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosGetDeviceMat(Mat A, PetscSplitCSRDataStructure *d_mat)
214d71ae5a4SJacob Faibussowitsch {
215042217e8SBarry Smith   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
216a587d139SMark 
217a587d139SMark   PetscFunctionBegin;
218a587d139SMark   if (aijkok && aijkok->device_mat_d.data()) {
219a587d139SMark     *d_mat = aijkok->device_mat_d.data();
220a587d139SMark   } else {
2219566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosSyncDevice(A)); // create aijkok (we are making d_mat now so make a place for it)
222a587d139SMark     *d_mat = NULL;
223a587d139SMark   }
224*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
225a587d139SMark }
226076ba34aSJunchao Zhang 
227076ba34aSJunchao Zhang /* Generate the transpose on device and cache it internally */
228d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosGenerateTranspose_Private(Mat A, KokkosCsrMatrix **csrmatT)
229d71ae5a4SJacob Faibussowitsch {
230152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
231152b3e56SJunchao Zhang 
232152b3e56SJunchao Zhang   PetscFunctionBegin;
2335f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
234076ba34aSJunchao Zhang   if (!aijkok->csrmatT.nnz() || !aijkok->transpose_updated) { /* Generate At for the first time OR just update its values */
235076ba34aSJunchao Zhang     /* FIXME: KK does not separate symbolic/numeric transpose. We could have a permutation array to help value-only update */
2369566063dSJacob Faibussowitsch     PetscCallCXX(aijkok->a_dual.sync_device());
237f98996d3SJunchao Zhang     PetscCallCXX(aijkok->csrmatT = transpose_matrix(aijkok->csrmat));
238f98996d3SJunchao Zhang     PetscCallCXX(sort_crs_matrix(aijkok->csrmatT));
23986a27549SJunchao Zhang     aijkok->transpose_updated = PETSC_TRUE;
240076ba34aSJunchao Zhang   }
241076ba34aSJunchao Zhang   *csrmatT = &aijkok->csrmatT;
242*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
243152b3e56SJunchao Zhang }
244152b3e56SJunchao Zhang 
245076ba34aSJunchao Zhang /* Generate the Hermitian on device and cache it internally */
246d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosGenerateHermitian_Private(Mat A, KokkosCsrMatrix **csrmatH)
247d71ae5a4SJacob Faibussowitsch {
248152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
249152b3e56SJunchao Zhang 
250152b3e56SJunchao Zhang   PetscFunctionBegin;
2519566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
2525f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
253076ba34aSJunchao Zhang   if (!aijkok->csrmatH.nnz() || !aijkok->hermitian_updated) { /* Generate Ah for the first time OR just update its values */
2549566063dSJacob Faibussowitsch     PetscCallCXX(aijkok->a_dual.sync_device());
255f98996d3SJunchao Zhang     PetscCallCXX(aijkok->csrmatH = transpose_matrix(aijkok->csrmat));
256f98996d3SJunchao Zhang     PetscCallCXX(sort_crs_matrix(aijkok->csrmatH));
257076ba34aSJunchao Zhang #if defined(PETSC_USE_COMPLEX)
258076ba34aSJunchao Zhang     const auto &a = aijkok->csrmatH.values;
2599371c9d4SSatish Balay     Kokkos::parallel_for(
2609371c9d4SSatish Balay       a.extent(0), KOKKOS_LAMBDA(MatRowMapType i) { a(i) = PetscConj(a(i)); });
261076ba34aSJunchao Zhang #endif
26286a27549SJunchao Zhang     aijkok->hermitian_updated = PETSC_TRUE;
263076ba34aSJunchao Zhang   }
264076ba34aSJunchao Zhang   *csrmatH = &aijkok->csrmatH;
2659566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
266*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
267152b3e56SJunchao Zhang }
268a587d139SMark 
2698c3ff71bSJunchao Zhang /* y = A x */
270d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMult_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
271d71ae5a4SJacob Faibussowitsch {
2728c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
273152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
274152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
2758c3ff71bSJunchao Zhang 
2768c3ff71bSJunchao Zhang   PetscFunctionBegin;
2779566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
2789566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
2799566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
2809566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
2818c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
2829d52486cSJunchao Zhang   PetscCallCXX(KokkosSparse::spmv("N", 1.0 /*alpha*/, aijkok->csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A x + beta y */
2839566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
2849566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
285076ba34aSJunchao Zhang   /* 2.0*nnz - numRows seems more accurate here but assumes there are no zero-rows. So a little sloppy here. */
2869566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
2879566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
288*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2898c3ff71bSJunchao Zhang }
2908c3ff71bSJunchao Zhang 
2918c3ff71bSJunchao Zhang /* y = A^T x */
292d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
293d71ae5a4SJacob Faibussowitsch {
2948c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
295152b3e56SJunchao Zhang   const char                *mode;
296152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
297152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
298076ba34aSJunchao Zhang   KokkosCsrMatrix           *csrmat;
2998c3ff71bSJunchao Zhang 
3008c3ff71bSJunchao Zhang   PetscFunctionBegin;
3019566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3029566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3039566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3049566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
305152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
3069566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
307152b3e56SJunchao Zhang     mode = "N";
308152b3e56SJunchao Zhang   } else {
309076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
310076ba34aSJunchao Zhang     csrmat = &aijkok->csrmat;
311152b3e56SJunchao Zhang     mode   = "T";
312152b3e56SJunchao Zhang   }
3139d52486cSJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, *csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^T x + beta y */
3149566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3159566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
3169566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * csrmat->nnz()));
3179566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
318*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
3198c3ff71bSJunchao Zhang }
3208c3ff71bSJunchao Zhang 
3218c3ff71bSJunchao Zhang /* y = A^H x */
322d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
323d71ae5a4SJacob Faibussowitsch {
3248c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
325152b3e56SJunchao Zhang   const char                *mode;
326152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
327152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
328076ba34aSJunchao Zhang   KokkosCsrMatrix           *csrmat;
3298c3ff71bSJunchao Zhang 
3308c3ff71bSJunchao Zhang   PetscFunctionBegin;
3319566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3329566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3339566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3349566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
335152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
3369566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
337152b3e56SJunchao Zhang     mode = "N";
338152b3e56SJunchao Zhang   } else {
339076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
340076ba34aSJunchao Zhang     csrmat = &aijkok->csrmat;
341152b3e56SJunchao Zhang     mode   = "C";
342152b3e56SJunchao Zhang   }
3439d52486cSJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, *csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^H x + beta y */
3449566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3459566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
3469566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * csrmat->nnz()));
3479566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
348*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
3498c3ff71bSJunchao Zhang }
3508c3ff71bSJunchao Zhang 
3518c3ff71bSJunchao Zhang /* z = A x + y */
352d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
353d71ae5a4SJacob Faibussowitsch {
3548c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
355152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
356152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
3578c3ff71bSJunchao Zhang 
3588c3ff71bSJunchao Zhang   PetscFunctionBegin;
3599566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3609566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3619566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3629566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
3639566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
3648c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
3658c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3669d52486cSJunchao Zhang   PetscCallCXX(KokkosSparse::spmv("N", 1.0 /*alpha*/, aijkok->csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A x + beta z */
3679566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3689566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
3699566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
3709566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
3719566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
372*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
3738c3ff71bSJunchao Zhang }
3748c3ff71bSJunchao Zhang 
3758c3ff71bSJunchao Zhang /* z = A^T x + y */
376d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
377d71ae5a4SJacob Faibussowitsch {
3788c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
379152b3e56SJunchao Zhang   const char                *mode;
380152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
381152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
382076ba34aSJunchao Zhang   KokkosCsrMatrix           *csrmat;
3838c3ff71bSJunchao Zhang 
3848c3ff71bSJunchao Zhang   PetscFunctionBegin;
3859566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3869566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3879566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3889566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
3899566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
3908c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
391152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
3929566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
393152b3e56SJunchao Zhang     mode = "N";
394152b3e56SJunchao Zhang   } else {
395076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
396076ba34aSJunchao Zhang     csrmat = &aijkok->csrmat;
397152b3e56SJunchao Zhang     mode   = "T";
398152b3e56SJunchao Zhang   }
3999d52486cSJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, *csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^T x + beta z */
4009566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4019566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
4029566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
4039566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * csrmat->nnz()));
4049566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
405*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4068c3ff71bSJunchao Zhang }
4078c3ff71bSJunchao Zhang 
4088c3ff71bSJunchao Zhang /* z = A^H x + y */
409d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
410d71ae5a4SJacob Faibussowitsch {
4118c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
412152b3e56SJunchao Zhang   const char                *mode;
413152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv, yv;
414152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
415076ba34aSJunchao Zhang   KokkosCsrMatrix           *csrmat;
4168c3ff71bSJunchao Zhang 
4178c3ff71bSJunchao Zhang   PetscFunctionBegin;
4189566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4199566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4209566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4219566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(yy, &yv));
4229566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(zz, &zv));
4238c3ff71bSJunchao Zhang   if (zz != yy) Kokkos::deep_copy(zv, yv);
424152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4259566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
426152b3e56SJunchao Zhang     mode = "N";
427152b3e56SJunchao Zhang   } else {
428076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
429076ba34aSJunchao Zhang     csrmat = &aijkok->csrmat;
430152b3e56SJunchao Zhang     mode   = "C";
431152b3e56SJunchao Zhang   }
4329d52486cSJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(mode, 1.0 /*alpha*/, *csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^H x + beta z */
4339566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4349566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(yy, &yv));
4359566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(zz, &zv));
4369566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * csrmat->nnz()));
4379566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
438*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
439152b3e56SJunchao Zhang }
440152b3e56SJunchao Zhang 
441d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSetOption_SeqAIJKokkos(Mat A, MatOption op, PetscBool flg)
442d71ae5a4SJacob Faibussowitsch {
443152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
444152b3e56SJunchao Zhang 
445152b3e56SJunchao Zhang   PetscFunctionBegin;
446152b3e56SJunchao Zhang   switch (op) {
447152b3e56SJunchao Zhang   case MAT_FORM_EXPLICIT_TRANSPOSE:
448152b3e56SJunchao Zhang     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
4499566063dSJacob Faibussowitsch     if (A->form_explicit_transpose && !flg && aijkok) PetscCall(aijkok->DestroyMatTranspose());
450152b3e56SJunchao Zhang     A->form_explicit_transpose = flg;
451152b3e56SJunchao Zhang     break;
452d71ae5a4SJacob Faibussowitsch   default:
453d71ae5a4SJacob Faibussowitsch     PetscCall(MatSetOption_SeqAIJ(A, op, flg));
454d71ae5a4SJacob Faibussowitsch     break;
455152b3e56SJunchao Zhang   }
456*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4578c3ff71bSJunchao Zhang }
4588c3ff71bSJunchao Zhang 
459076ba34aSJunchao Zhang /* Depending on reuse, either build a new mat, or use the existing mat */
460d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
461d71ae5a4SJacob Faibussowitsch {
462076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
4638c3ff71bSJunchao Zhang 
4648c3ff71bSJunchao Zhang   PetscFunctionBegin;
4659566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
466076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {                      /* Build a brand new mat */
4679566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));  /* the returned newmat is a SeqAIJKokkos */
4688c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {                 /* Reuse the mat created before */
4699566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); /* newmat is already a SeqAIJKokkos */
470076ba34aSJunchao Zhang   } else if (reuse == MAT_INPLACE_MATRIX) {               /* newmat is A */
4715f80ce2aSJacob Faibussowitsch     PetscCheck(A == *newmat, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "A != *newmat with MAT_INPLACE_MATRIX");
4729566063dSJacob Faibussowitsch     PetscCall(PetscFree(A->defaultvectype));
4739566063dSJacob Faibussowitsch     PetscCall(PetscStrallocpy(VECKOKKOS, &A->defaultvectype)); /* Allocate and copy the string */
4749566063dSJacob Faibussowitsch     PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSEQAIJKOKKOS));
4759566063dSJacob Faibussowitsch     PetscCall(MatSetOps_SeqAIJKokkos(A));
476076ba34aSJunchao Zhang     aseq = static_cast<Mat_SeqAIJ *>(A->data);
477394ed5ebSJunchao Zhang     if (A->assembled) { /* Copy i, j (but not values) to device for an assembled matrix if not yet */
4785f80ce2aSJacob Faibussowitsch       PetscCheck(!A->spptr, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Expect NULL (Mat_SeqAIJKokkos*)A->spptr");
479076ba34aSJunchao Zhang       A->spptr = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aseq->nz, aseq->i, aseq->j, aseq->a, A->nonzerostate, PETSC_FALSE);
4808c3ff71bSJunchao Zhang     }
481076ba34aSJunchao Zhang   }
482*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4838c3ff71bSJunchao Zhang }
4848c3ff71bSJunchao Zhang 
485076ba34aSJunchao Zhang /* MatDuplicate always creates a new matrix. MatDuplicate can be called either on an assembled matrix or
486076ba34aSJunchao Zhang    an unassembled matrix, even though MAT_COPY_VALUES is not allowed for unassembled matrix.
487076ba34aSJunchao Zhang  */
488d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDuplicate_SeqAIJKokkos(Mat A, MatDuplicateOption dupOption, Mat *B)
489d71ae5a4SJacob Faibussowitsch {
490076ba34aSJunchao Zhang   Mat_SeqAIJ       *bseq;
491076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok;
492076ba34aSJunchao Zhang   Mat               mat;
4938c3ff71bSJunchao Zhang 
4948c3ff71bSJunchao Zhang   PetscFunctionBegin;
495076ba34aSJunchao Zhang   /* Do not copy values on host as A's latest values might be on device. We don't want to do sync blindly */
4969566063dSJacob Faibussowitsch   PetscCall(MatDuplicate_SeqAIJ(A, MAT_DO_NOT_COPY_VALUES, B));
497076ba34aSJunchao Zhang   mat = *B;
498076ba34aSJunchao Zhang   if (A->assembled) {
499076ba34aSJunchao Zhang     bseq = static_cast<Mat_SeqAIJ *>(mat->data);
500076ba34aSJunchao Zhang     bkok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, bseq->nz, bseq->i, bseq->j, bseq->a, mat->nonzerostate, PETSC_FALSE);
501076ba34aSJunchao Zhang     bkok->a_dual.clear_sync_state(); /* Clear B's sync state as it will be decided below */
502076ba34aSJunchao Zhang     /* Now copy values to B if needed */
503076ba34aSJunchao Zhang     if (dupOption == MAT_COPY_VALUES) {
504076ba34aSJunchao Zhang       if (akok->a_dual.need_sync_device()) {
505076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_host(), akok->a_dual.view_host());
506076ba34aSJunchao Zhang         bkok->a_dual.modify_host();
507076ba34aSJunchao Zhang       } else { /* If device has the latest data, we only copy data on device */
508076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_device(), akok->a_dual.view_device());
509076ba34aSJunchao Zhang         bkok->a_dual.modify_device();
510076ba34aSJunchao Zhang       }
511076ba34aSJunchao Zhang     } else { /* MAT_DO_NOT_COPY_VALUES or MAT_SHARE_NONZERO_PATTERN. B's values should be zeroed */
512076ba34aSJunchao Zhang       /* B's values on host should be already zeroed by MatDuplicate_SeqAIJ() */
513076ba34aSJunchao Zhang       bkok->a_dual.modify_host();
514076ba34aSJunchao Zhang     }
515076ba34aSJunchao Zhang     mat->spptr = bkok;
516076ba34aSJunchao Zhang   }
517076ba34aSJunchao Zhang 
5189566063dSJacob Faibussowitsch   PetscCall(PetscFree(mat->defaultvectype));
5199566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &mat->defaultvectype)); /* Allocate and copy the string */
5209566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)mat, MATSEQAIJKOKKOS));
5219566063dSJacob Faibussowitsch   PetscCall(MatSetOps_SeqAIJKokkos(mat));
522*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5238c3ff71bSJunchao Zhang }
5248c3ff71bSJunchao Zhang 
525d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatTranspose_SeqAIJKokkos(Mat A, MatReuse reuse, Mat *B)
526d71ae5a4SJacob Faibussowitsch {
5270ecb592aSJunchao Zhang   Mat               At;
528ff751488SJunchao Zhang   KokkosCsrMatrix  *internT;
5290ecb592aSJunchao Zhang   Mat_SeqAIJKokkos *atkok, *bkok;
5300ecb592aSJunchao Zhang 
5310ecb592aSJunchao Zhang   PetscFunctionBegin;
5327fb60732SBarry Smith   if (reuse == MAT_REUSE_MATRIX) PetscCall(MatTransposeCheckNonzeroState_Private(A, *B));
5339566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &internT)); /* Generate a transpose internally */
5340ecb592aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX || reuse == MAT_INPLACE_MATRIX) {
535ff751488SJunchao Zhang     /* Deep copy internT, as we want to isolate the internal transpose */
5369566063dSJacob Faibussowitsch     PetscCallCXX(atkok = new Mat_SeqAIJKokkos(KokkosCsrMatrix("csrmat", *internT)));
5379566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PetscObjectComm((PetscObject)A), atkok, &At));
5380ecb592aSJunchao Zhang     if (reuse == MAT_INITIAL_MATRIX) *B = At;
5399566063dSJacob Faibussowitsch     else PetscCall(MatHeaderReplace(A, &At)); /* Replace A with At inplace */
5400ecb592aSJunchao Zhang   } else {                                    /* MAT_REUSE_MATRIX, just need to copy values to B on device */
5410ecb592aSJunchao Zhang     if ((*B)->assembled) {
5420ecb592aSJunchao Zhang       bkok = static_cast<Mat_SeqAIJKokkos *>((*B)->spptr);
5439566063dSJacob Faibussowitsch       PetscCallCXX(Kokkos::deep_copy(bkok->a_dual.view_device(), internT->values));
5449566063dSJacob Faibussowitsch       PetscCall(MatSeqAIJKokkosModifyDevice(*B));
5450ecb592aSJunchao Zhang     } else if ((*B)->preallocated) { /* It is ok for B to be only preallocated, as needed in MatTranspose_MPIAIJ */
5460ecb592aSJunchao Zhang       Mat_SeqAIJ             *bseq = static_cast<Mat_SeqAIJ *>((*B)->data);
5470ecb592aSJunchao Zhang       MatScalarKokkosViewHost a_h(bseq->a, internT->nnz()); /* bseq->nz = 0 if unassembled */
5480ecb592aSJunchao Zhang       MatColIdxKokkosViewHost j_h(bseq->j, internT->nnz());
5499566063dSJacob Faibussowitsch       PetscCallCXX(Kokkos::deep_copy(a_h, internT->values));
5509566063dSJacob Faibussowitsch       PetscCallCXX(Kokkos::deep_copy(j_h, internT->graph.entries));
5510ecb592aSJunchao Zhang     } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "B must be assembled or preallocated");
5520ecb592aSJunchao Zhang   }
553*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5540ecb592aSJunchao Zhang }
5550ecb592aSJunchao Zhang 
556d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDestroy_SeqAIJKokkos(Mat A)
557d71ae5a4SJacob Faibussowitsch {
55886a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
5598c3ff71bSJunchao Zhang 
5608c3ff71bSJunchao Zhang   PetscFunctionBegin;
56186a27549SJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
56286a27549SJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
5638c3ff71bSJunchao Zhang     delete aijkok;
56486a27549SJunchao Zhang   } else {
56586a27549SJunchao Zhang     delete static_cast<Mat_SeqAIJKokkosTriFactors *>(A->spptr);
56686a27549SJunchao Zhang   }
567cbc6b225SStefano Zampini   A->spptr = NULL;
5689566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
5699566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
5709566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
5719566063dSJacob Faibussowitsch   PetscCall(MatDestroy_SeqAIJ(A));
572*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5738c3ff71bSJunchao Zhang }
5748c3ff71bSJunchao Zhang 
5753f3ba80aSJunchao Zhang /*MC
5763f3ba80aSJunchao Zhang    MATSEQAIJKOKKOS - MATAIJKOKKOS = "(seq)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
5773f3ba80aSJunchao Zhang 
5783f3ba80aSJunchao Zhang    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types
5793f3ba80aSJunchao Zhang 
5803f3ba80aSJunchao Zhang    Options Database Keys:
58111a5261eSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATSEQAIJKOKKOS` during a call to `MatSetFromOptions()`
5823f3ba80aSJunchao Zhang 
5833f3ba80aSJunchao Zhang   Level: beginner
5843f3ba80aSJunchao Zhang 
585db781477SPatrick Sanan .seealso: `MatCreateSeqAIJKokkos()`, `MATMPIAIJKOKKOS`
5863f3ba80aSJunchao Zhang M*/
587d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJKokkos(Mat A)
588d71ae5a4SJacob Faibussowitsch {
58986a27549SJunchao Zhang   PetscFunctionBegin;
5909566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
5919566063dSJacob Faibussowitsch   PetscCall(MatCreate_SeqAIJ(A));
5929566063dSJacob Faibussowitsch   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
593*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
59486a27549SJunchao Zhang }
59586a27549SJunchao Zhang 
596076ba34aSJunchao Zhang /* Merge A, B into a matrix C. A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n) */
597d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
598d71ae5a4SJacob Faibussowitsch {
599076ba34aSJunchao Zhang   Mat_SeqAIJ         *a, *b;
600076ba34aSJunchao Zhang   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
601076ba34aSJunchao Zhang   MatScalarKokkosView aa, ba, ca;
602076ba34aSJunchao Zhang   MatRowMapKokkosView ai, bi, ci;
603076ba34aSJunchao Zhang   MatColIdxKokkosView aj, bj, cj;
604076ba34aSJunchao Zhang   PetscInt            m, n, nnz, aN;
605a3f881fbSStefano Zampini 
606a3f881fbSStefano Zampini   PetscFunctionBegin;
607076ba34aSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
608076ba34aSJunchao Zhang   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
609076ba34aSJunchao Zhang   PetscValidPointer(C, 4);
610076ba34aSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
611076ba34aSJunchao Zhang   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
6125f80ce2aSJacob Faibussowitsch   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
6135f80ce2aSJacob Faibussowitsch   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
614076ba34aSJunchao Zhang 
6159566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
6169566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
617076ba34aSJunchao Zhang   a    = static_cast<Mat_SeqAIJ *>(A->data);
618076ba34aSJunchao Zhang   b    = static_cast<Mat_SeqAIJ *>(B->data);
619076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
620076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
621076ba34aSJunchao Zhang   aa   = akok->a_dual.view_device();
622076ba34aSJunchao Zhang   ai   = akok->i_dual.view_device();
623076ba34aSJunchao Zhang   ba   = bkok->a_dual.view_device();
624076ba34aSJunchao Zhang   bi   = bkok->i_dual.view_device();
625076ba34aSJunchao Zhang   m    = A->rmap->n; /* M, N and nnz of C */
626076ba34aSJunchao Zhang   n    = A->cmap->n + B->cmap->n;
627076ba34aSJunchao Zhang   nnz  = a->nz + b->nz;
628076ba34aSJunchao Zhang   aN   = A->cmap->n; /* N of A */
629076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
630076ba34aSJunchao Zhang     aj           = akok->j_dual.view_device();
631076ba34aSJunchao Zhang     bj           = bkok->j_dual.view_device();
632076ba34aSJunchao Zhang     auto ca_dual = MatScalarKokkosDualView("a", aa.extent(0) + ba.extent(0));
633076ba34aSJunchao Zhang     auto ci_dual = MatRowMapKokkosDualView("i", ai.extent(0));
634076ba34aSJunchao Zhang     auto cj_dual = MatColIdxKokkosDualView("j", aj.extent(0) + bj.extent(0));
635076ba34aSJunchao Zhang     ca           = ca_dual.view_device();
636076ba34aSJunchao Zhang     ci           = ci_dual.view_device();
637076ba34aSJunchao Zhang     cj           = cj_dual.view_device();
638076ba34aSJunchao Zhang 
639076ba34aSJunchao Zhang     /* Concatenate A and B in parallel using Kokkos hierarchical parallelism */
6409371c9d4SSatish Balay     Kokkos::parallel_for(
6419371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
642076ba34aSJunchao Zhang         PetscInt i       = t.league_rank(); /* row i */
643076ba34aSJunchao Zhang         PetscInt coffset = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
644076ba34aSJunchao Zhang 
645076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* this side effect only happens once per whole team */
646076ba34aSJunchao Zhang                                                    ci(i) = coffset;
647076ba34aSJunchao Zhang                                                    if (i == m - 1) ci(m) = ai(m) + bi(m);
648076ba34aSJunchao Zhang         });
649076ba34aSJunchao Zhang 
650076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
651076ba34aSJunchao Zhang           if (k < alen) {
652076ba34aSJunchao Zhang             ca(coffset + k) = aa(ai(i) + k);
653076ba34aSJunchao Zhang             cj(coffset + k) = aj(ai(i) + k);
654076ba34aSJunchao Zhang           } else {
655076ba34aSJunchao Zhang             ca(coffset + k) = ba(bi(i) + k - alen);
656076ba34aSJunchao Zhang             cj(coffset + k) = bj(bi(i) + k - alen) + aN; /* Entries in B get new column indices in C */
657076ba34aSJunchao Zhang           }
658076ba34aSJunchao Zhang         });
659076ba34aSJunchao Zhang       });
660076ba34aSJunchao Zhang     ca_dual.modify_device();
661076ba34aSJunchao Zhang     ci_dual.modify_device();
662076ba34aSJunchao Zhang     cj_dual.modify_device();
6639566063dSJacob Faibussowitsch     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, n, nnz, ci_dual, cj_dual, ca_dual));
6649566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
665076ba34aSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
666076ba34aSJunchao Zhang     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
667076ba34aSJunchao Zhang     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
668076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
669076ba34aSJunchao Zhang     ca   = ckok->a_dual.view_device();
670076ba34aSJunchao Zhang     ci   = ckok->i_dual.view_device();
671076ba34aSJunchao Zhang 
6729371c9d4SSatish Balay     Kokkos::parallel_for(
6739371c9d4SSatish Balay       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
674076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* row i */
675076ba34aSJunchao Zhang         PetscInt alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
676076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
677076ba34aSJunchao Zhang           if (k < alen) ca(ci(i) + k) = aa(ai(i) + k);
678076ba34aSJunchao Zhang           else ca(ci(i) + k) = ba(bi(i) + k - alen);
679076ba34aSJunchao Zhang         });
680076ba34aSJunchao Zhang       });
6819566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
682076ba34aSJunchao Zhang   }
683*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
684076ba34aSJunchao Zhang }
685076ba34aSJunchao Zhang 
686d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_SeqAIJKokkos(void *pdata)
687d71ae5a4SJacob Faibussowitsch {
688076ba34aSJunchao Zhang   PetscFunctionBegin;
689076ba34aSJunchao Zhang   delete static_cast<MatProductData_SeqAIJKokkos *>(pdata);
690*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
691a3f881fbSStefano Zampini }
692a3f881fbSStefano Zampini 
693d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos(Mat C)
694d71ae5a4SJacob Faibussowitsch {
695a3f881fbSStefano Zampini   Mat_Product                 *product = C->product;
696a3f881fbSStefano Zampini   Mat                          A, B;
697076ba34aSJunchao Zhang   bool                         transA, transB; /* use bool, since KK needs this type */
698a3f881fbSStefano Zampini   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
699a3f881fbSStefano Zampini   Mat_SeqAIJ                  *c;
700076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
701076ba34aSJunchao Zhang   KokkosCsrMatrix             *csrmatA, *csrmatB;
702a3f881fbSStefano Zampini 
703a3f881fbSStefano Zampini   PetscFunctionBegin;
704a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
7055f80ce2aSJacob Faibussowitsch   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data empty");
706076ba34aSJunchao Zhang   pdata = static_cast<MatProductData_SeqAIJKokkos *>(C->product->data);
707076ba34aSJunchao Zhang 
708076ba34aSJunchao Zhang   if (pdata->reusesym) {           /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
709076ba34aSJunchao Zhang     pdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
710*3ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
711076ba34aSJunchao Zhang   }
712076ba34aSJunchao Zhang 
713076ba34aSJunchao Zhang   switch (product->type) {
7149371c9d4SSatish Balay   case MATPRODUCT_AB:
7159371c9d4SSatish Balay     transA = false;
7169371c9d4SSatish Balay     transB = false;
7179371c9d4SSatish Balay     break;
7189371c9d4SSatish Balay   case MATPRODUCT_AtB:
7199371c9d4SSatish Balay     transA = true;
7209371c9d4SSatish Balay     transB = false;
7219371c9d4SSatish Balay     break;
7229371c9d4SSatish Balay   case MATPRODUCT_ABt:
7239371c9d4SSatish Balay     transA = false;
7249371c9d4SSatish Balay     transB = true;
7259371c9d4SSatish Balay     break;
726d71ae5a4SJacob Faibussowitsch   default:
727d71ae5a4SJacob Faibussowitsch     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
728076ba34aSJunchao Zhang   }
729076ba34aSJunchao Zhang 
730a3f881fbSStefano Zampini   A = product->A;
731a3f881fbSStefano Zampini   B = product->B;
7329566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
7339566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
734a3f881fbSStefano Zampini   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
735a3f881fbSStefano Zampini   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
736a3f881fbSStefano Zampini   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
737076ba34aSJunchao Zhang 
7385f80ce2aSJacob Faibussowitsch   PetscCheck(ckok, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Device data structure spptr is empty");
739076ba34aSJunchao Zhang 
740076ba34aSJunchao Zhang   csrmatA = &akok->csrmat;
741076ba34aSJunchao Zhang   csrmatB = &bkok->csrmat;
742076ba34aSJunchao Zhang 
743076ba34aSJunchao Zhang   /* TODO: Once KK spgemm implements transpose, we can get rid of the explicit transpose here */
744076ba34aSJunchao Zhang   if (transA) {
7459566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
746076ba34aSJunchao Zhang     transA = false;
747a3f881fbSStefano Zampini   }
748a3f881fbSStefano Zampini 
749076ba34aSJunchao Zhang   if (transB) {
7509566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
751076ba34aSJunchao Zhang     transB = false;
752076ba34aSJunchao Zhang   }
7539566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
7549566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, *csrmatA, transA, *csrmatB, transB, ckok->csrmat));
755e944a159SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(3, 7, 99)
756866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
757866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(ckok->csrmat)); /* without sort, mat_tests-ex62_14_seqaijkokkos fails */
758e944a159SJunchao Zhang #endif
759866eb059SJunchao Zhang 
7609566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
7619566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(C));
762a3f881fbSStefano Zampini   /* shorter version of MatAssemblyEnd_SeqAIJ */
763a3f881fbSStefano Zampini   c = (Mat_SeqAIJ *)C->data;
7649566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded,%" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
7659566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
7669566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
767a3f881fbSStefano Zampini   c->reallocs         = 0;
768076ba34aSJunchao Zhang   C->info.mallocs     = 0;
769a3f881fbSStefano Zampini   C->info.nz_unneeded = 0;
770a3f881fbSStefano Zampini   C->assembled = C->was_assembled = PETSC_TRUE;
771a3f881fbSStefano Zampini   C->num_ass++;
772*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
773a3f881fbSStefano Zampini }
774a3f881fbSStefano Zampini 
775d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos(Mat C)
776d71ae5a4SJacob Faibussowitsch {
777076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
778076ba34aSJunchao Zhang   MatProductType               ptype;
779076ba34aSJunchao Zhang   Mat                          A, B;
780076ba34aSJunchao Zhang   bool                         transA, transB;
781076ba34aSJunchao Zhang   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
782076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
783076ba34aSJunchao Zhang   MPI_Comm                     comm;
784076ba34aSJunchao Zhang   KokkosCsrMatrix             *csrmatA, *csrmatB, csrmatC;
785a3f881fbSStefano Zampini 
786a3f881fbSStefano Zampini   PetscFunctionBegin;
787a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
7889566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
7895f80ce2aSJacob Faibussowitsch   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
790a3f881fbSStefano Zampini   A = product->A;
791a3f881fbSStefano Zampini   B = product->B;
7929566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
7939566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
794a3f881fbSStefano Zampini   akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
795a3f881fbSStefano Zampini   bkok    = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
796076ba34aSJunchao Zhang   csrmatA = &akok->csrmat;
797076ba34aSJunchao Zhang   csrmatB = &bkok->csrmat;
798076ba34aSJunchao Zhang 
799a3f881fbSStefano Zampini   ptype = product->type;
800a3f881fbSStefano Zampini   switch (ptype) {
8019371c9d4SSatish Balay   case MATPRODUCT_AB:
8029371c9d4SSatish Balay     transA = false;
8039371c9d4SSatish Balay     transB = false;
8049371c9d4SSatish Balay     break;
8059371c9d4SSatish Balay   case MATPRODUCT_AtB:
8069371c9d4SSatish Balay     transA = true;
8079371c9d4SSatish Balay     transB = false;
8089371c9d4SSatish Balay     break;
8099371c9d4SSatish Balay   case MATPRODUCT_ABt:
8109371c9d4SSatish Balay     transA = false;
8119371c9d4SSatish Balay     transB = true;
8129371c9d4SSatish Balay     break;
813d71ae5a4SJacob Faibussowitsch   default:
814d71ae5a4SJacob Faibussowitsch     SETERRQ(comm, PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
815a3f881fbSStefano Zampini   }
816a3f881fbSStefano Zampini 
817076ba34aSJunchao Zhang   product->data = pdata = new MatProductData_SeqAIJKokkos();
818076ba34aSJunchao Zhang   pdata->kh.set_team_work_size(16);
819076ba34aSJunchao Zhang   pdata->kh.set_dynamic_scheduling(true);
820076ba34aSJunchao Zhang   pdata->reusesym = product->api_user;
821a3f881fbSStefano Zampini 
822076ba34aSJunchao Zhang   /* TODO: add command line options to select spgemm algorithms */
823866eb059SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; /* default alg is TPL if enabled, otherwise KK */
824866eb059SJunchao Zhang 
825866eb059SJunchao Zhang   /* CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 */
826866eb059SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
827866eb059SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
828866eb059SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
829866eb059SJunchao Zhang   #endif
830866eb059SJunchao Zhang #endif
831866eb059SJunchao Zhang 
832076ba34aSJunchao Zhang   pdata->kh.create_spgemm_handle(spgemm_alg);
833076ba34aSJunchao Zhang 
8349566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
835076ba34aSJunchao Zhang   /* TODO: Get rid of the explicit transpose once KK-spgemm implements the transpose option */
836076ba34aSJunchao Zhang   if (transA) {
8379566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
838076ba34aSJunchao Zhang     transA = false;
839076ba34aSJunchao Zhang   }
840076ba34aSJunchao Zhang 
841076ba34aSJunchao Zhang   if (transB) {
8429566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
843076ba34aSJunchao Zhang     transB = false;
844076ba34aSJunchao Zhang   }
845076ba34aSJunchao Zhang 
8469566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::spgemm_symbolic(pdata->kh, *csrmatA, transA, *csrmatB, transB, csrmatC));
847866eb059SJunchao Zhang 
848076ba34aSJunchao Zhang   /* spgemm_symbolic() only populates C's rowmap, but not C's column indices.
849076ba34aSJunchao Zhang     So we have to do a fake spgemm_numeric() here to get csrmatC.j_d setup, before
850076ba34aSJunchao Zhang     calling new Mat_SeqAIJKokkos().
851076ba34aSJunchao Zhang     TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
852076ba34aSJunchao Zhang   */
8539566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, *csrmatA, transA, *csrmatB, transB, csrmatC));
854e944a159SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(3, 7, 99)
855866eb059SJunchao Zhang   /* Query if KK outputs a sorted matrix. If not, we need to sort it */
856866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
857866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(csrmatC)); /* sort_option defaults to -1 in KK!*/
858e944a159SJunchao Zhang #endif
8599566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
860076ba34aSJunchao Zhang 
8619566063dSJacob Faibussowitsch   PetscCallCXX(ckok = new Mat_SeqAIJKokkos(csrmatC));
8629566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(C, ckok));
863076ba34aSJunchao Zhang   C->product->destroy = MatProductDataDestroy_SeqAIJKokkos;
864*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
865a3f881fbSStefano Zampini }
866a3f881fbSStefano Zampini 
867a3f881fbSStefano Zampini /* handles sparse matrix matrix ops */
868d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSetFromOptions_SeqAIJKokkos(Mat mat)
869d71ae5a4SJacob Faibussowitsch {
870076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
871a3f881fbSStefano Zampini   PetscBool    Biskok = PETSC_FALSE, Ciskok = PETSC_TRUE;
872a3f881fbSStefano Zampini 
873a3f881fbSStefano Zampini   PetscFunctionBegin;
874a3f881fbSStefano Zampini   MatCheckProduct(mat, 1);
8759566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJKOKKOS, &Biskok));
87648a46eb9SPierre Jolivet   if (product->type == MATPRODUCT_ABC) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJKOKKOS, &Ciskok));
877a3f881fbSStefano Zampini   if (Biskok && Ciskok) {
878a3f881fbSStefano Zampini     switch (product->type) {
879a3f881fbSStefano Zampini     case MATPRODUCT_AB:
880a3f881fbSStefano Zampini     case MATPRODUCT_AtB:
881d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABt:
882d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos;
883d71ae5a4SJacob Faibussowitsch       break;
884a3f881fbSStefano Zampini     case MATPRODUCT_PtAP:
885a3f881fbSStefano Zampini     case MATPRODUCT_RARt:
886d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABC:
887d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
888d71ae5a4SJacob Faibussowitsch       break;
889d71ae5a4SJacob Faibussowitsch     default:
890d71ae5a4SJacob Faibussowitsch       break;
891a3f881fbSStefano Zampini     }
892a3f881fbSStefano Zampini   } else { /* fallback for AIJ */
8939566063dSJacob Faibussowitsch     PetscCall(MatProductSetFromOptions_SeqAIJ(mat));
894a3f881fbSStefano Zampini   }
895*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
896a3f881fbSStefano Zampini }
897a587d139SMark 
898d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatScale_SeqAIJKokkos(Mat A, PetscScalar a)
899d71ae5a4SJacob Faibussowitsch {
900f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok;
901f0cf5187SStefano Zampini 
902f0cf5187SStefano Zampini   PetscFunctionBegin;
9039566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
9049566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
905f0cf5187SStefano Zampini   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
906076ba34aSJunchao Zhang   KokkosBlas::scal(aijkok->a_dual.view_device(), a, aijkok->a_dual.view_device());
9079566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
9089566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(aijkok->a_dual.extent(0)));
9099566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
910*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
911f0cf5187SStefano Zampini }
912f0cf5187SStefano Zampini 
913d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatZeroEntries_SeqAIJKokkos(Mat A)
914d71ae5a4SJacob Faibussowitsch {
915076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
916a587d139SMark 
917a587d139SMark   PetscFunctionBegin;
918076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
9192328674fSJunchao Zhang   if (aijkok) { /* Only zero the device if data is already there */
920076ba34aSJunchao Zhang     KokkosBlas::fill(aijkok->a_dual.view_device(), 0.0);
9219566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A));
9222328674fSJunchao Zhang   } else { /* Might be preallocated but not assembled */
9239566063dSJacob Faibussowitsch     PetscCall(MatZeroEntries_SeqAIJ(A));
9242328674fSJunchao Zhang   }
925*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
926a587d139SMark }
927a587d139SMark 
928d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatGetDiagonal_SeqAIJKokkos(Mat A, Vec x)
929d71ae5a4SJacob Faibussowitsch {
930f78ce678SMark Adams   Mat_SeqAIJ           *aijseq;
931f78ce678SMark Adams   Mat_SeqAIJKokkos     *aijkok;
932f78ce678SMark Adams   PetscInt              n;
933f78ce678SMark Adams   PetscScalarKokkosView xv;
934f78ce678SMark Adams 
935f78ce678SMark Adams   PetscFunctionBegin;
936f78ce678SMark Adams   PetscCall(VecGetLocalSize(x, &n));
937f78ce678SMark Adams   PetscCheck(n == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Nonconforming matrix and vector");
938f78ce678SMark Adams   PetscCheck(A->factortype == MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_SUP, "MatGetDiagonal_SeqAIJKokkos not supported on factored matrices");
939f78ce678SMark Adams 
940f78ce678SMark Adams   PetscCall(MatSeqAIJKokkosSyncDevice(A));
941f78ce678SMark Adams   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
942f78ce678SMark Adams 
943f78ce678SMark Adams   if (A->rmap->n && aijkok->diag_dual.extent(0) == 0) { /* Set the diagonal pointer if not already */
944f78ce678SMark Adams     PetscCall(MatMarkDiagonal_SeqAIJ(A));
945f78ce678SMark Adams     aijseq = static_cast<Mat_SeqAIJ *>(A->data);
946f78ce678SMark Adams     aijkok->SetDiagonal(aijseq->diag);
947f78ce678SMark Adams   }
948f78ce678SMark Adams 
949f78ce678SMark Adams   const auto &Aa    = aijkok->a_dual.view_device();
950f78ce678SMark Adams   const auto &Ai    = aijkok->i_dual.view_device();
951f78ce678SMark Adams   const auto &Adiag = aijkok->diag_dual.view_device();
952f78ce678SMark Adams 
953f78ce678SMark Adams   PetscCall(VecGetKokkosViewWrite(x, &xv));
9549371c9d4SSatish Balay   Kokkos::parallel_for(
9559371c9d4SSatish Balay     n, KOKKOS_LAMBDA(const PetscInt i) {
956f78ce678SMark Adams       if (Adiag(i) < Ai(i + 1)) xv(i) = Aa(Adiag(i));
957f78ce678SMark Adams       else xv(i) = 0;
958f78ce678SMark Adams     });
959f78ce678SMark Adams   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
960*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
961f78ce678SMark Adams }
962f78ce678SMark Adams 
963db78de30SJunchao Zhang /* Get a Kokkos View from a mat of type MatSeqAIJKokkos */
964d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, ConstMatScalarKokkosView *kv)
965d71ae5a4SJacob Faibussowitsch {
966db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
967db78de30SJunchao Zhang 
968db78de30SJunchao Zhang   PetscFunctionBegin;
969db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
970db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
971db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
9729566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
973db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
974076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
975*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
976db78de30SJunchao Zhang }
977db78de30SJunchao Zhang 
978d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, ConstMatScalarKokkosView *kv)
979d71ae5a4SJacob Faibussowitsch {
980db78de30SJunchao Zhang   PetscFunctionBegin;
981db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
982db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
983db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
984*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
985db78de30SJunchao Zhang }
986db78de30SJunchao Zhang 
987d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, MatScalarKokkosView *kv)
988d71ae5a4SJacob Faibussowitsch {
989db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
990db78de30SJunchao Zhang 
991db78de30SJunchao Zhang   PetscFunctionBegin;
992db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
993db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
994db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
9959566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
996db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
997076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
998*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
999db78de30SJunchao Zhang }
1000db78de30SJunchao Zhang 
1001d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, MatScalarKokkosView *kv)
1002d71ae5a4SJacob Faibussowitsch {
1003db78de30SJunchao Zhang   PetscFunctionBegin;
1004db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1005db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1006db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10079566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
1008*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1009db78de30SJunchao Zhang }
1010db78de30SJunchao Zhang 
1011d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1012d71ae5a4SJacob Faibussowitsch {
1013db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1014db78de30SJunchao Zhang 
1015db78de30SJunchao Zhang   PetscFunctionBegin;
1016db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1017db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1018db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1019db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1020076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
1021*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1022db78de30SJunchao Zhang }
1023db78de30SJunchao Zhang 
1024d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1025d71ae5a4SJacob Faibussowitsch {
1026db78de30SJunchao Zhang   PetscFunctionBegin;
1027db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
1028db78de30SJunchao Zhang   PetscValidPointer(kv, 2);
1029db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
10309566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
1031*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1032db78de30SJunchao Zhang }
1033db78de30SJunchao Zhang 
1034c17cf699SJunchao Zhang /* Computes Y += alpha X */
1035d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAXPY_SeqAIJKokkos(Mat Y, PetscScalar alpha, Mat X, MatStructure pattern)
1036d71ae5a4SJacob Faibussowitsch {
1037a587d139SMark   Mat_SeqAIJ              *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
1038c17cf699SJunchao Zhang   Mat_SeqAIJKokkos        *xkok, *ykok, *zkok;
1039c17cf699SJunchao Zhang   ConstMatScalarKokkosView Xa;
1040c17cf699SJunchao Zhang   MatScalarKokkosView      Ya;
1041a587d139SMark 
1042a587d139SMark   PetscFunctionBegin;
1043c17cf699SJunchao Zhang   PetscCheckTypeName(Y, MATSEQAIJKOKKOS);
1044c17cf699SJunchao Zhang   PetscCheckTypeName(X, MATSEQAIJKOKKOS);
10459566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(Y));
10469566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(X));
10479566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1048db78de30SJunchao Zhang 
1049c17cf699SJunchao Zhang   if (pattern != SAME_NONZERO_PATTERN && x->nz == y->nz) {
1050c17cf699SJunchao Zhang     /* We could compare on device, but have to get the comparison result on host. So compare on host instead. */
1051a587d139SMark     PetscBool e;
10529566063dSJacob Faibussowitsch     PetscCall(PetscArraycmp(x->i, y->i, Y->rmap->n + 1, &e));
1053a587d139SMark     if (e) {
10549566063dSJacob Faibussowitsch       PetscCall(PetscArraycmp(x->j, y->j, y->nz, &e));
1055c17cf699SJunchao Zhang       if (e) pattern = SAME_NONZERO_PATTERN;
1056a587d139SMark     }
1057a587d139SMark   }
1058db78de30SJunchao Zhang 
1059c17cf699SJunchao Zhang   /* cusparseDcsrgeam2() computes C = alpha A + beta B. If one knew sparsity pattern of C, one can skip
1060c17cf699SJunchao Zhang     cusparseScsrgeam2_bufferSizeExt() / cusparseXcsrgeam2Nnz(), and directly call cusparseScsrgeam2().
1061c17cf699SJunchao Zhang     If X is SUBSET_NONZERO_PATTERN of Y, we could take advantage of this cusparse feature. However,
1062c17cf699SJunchao Zhang     KokkosSparse::spadd(alpha,A,beta,B,C) has symbolic and numeric phases, MatAXPY does not.
1063c17cf699SJunchao Zhang   */
1064c17cf699SJunchao Zhang   ykok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1065c17cf699SJunchao Zhang   xkok = static_cast<Mat_SeqAIJKokkos *>(X->spptr);
1066c17cf699SJunchao Zhang   Xa   = xkok->a_dual.view_device();
1067c17cf699SJunchao Zhang   Ya   = ykok->a_dual.view_device();
1068c17cf699SJunchao Zhang 
1069c17cf699SJunchao Zhang   if (pattern == SAME_NONZERO_PATTERN) {
1070c17cf699SJunchao Zhang     KokkosBlas::axpy(alpha, Xa, Ya);
10719566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1072c17cf699SJunchao Zhang   } else if (pattern == SUBSET_NONZERO_PATTERN) {
1073c17cf699SJunchao Zhang     MatRowMapKokkosView Xi = xkok->i_dual.view_device(), Yi = ykok->i_dual.view_device();
1074c17cf699SJunchao Zhang     MatColIdxKokkosView Xj = xkok->j_dual.view_device(), Yj = ykok->j_dual.view_device();
1075c17cf699SJunchao Zhang 
10769371c9d4SSatish Balay     Kokkos::parallel_for(
10779371c9d4SSatish Balay       Kokkos::TeamPolicy<>(Y->rmap->n, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1078c17cf699SJunchao Zhang         PetscInt i = t.league_rank();              /* row i */
1079c17cf699SJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* Only one thread works in a team */
1080c17cf699SJunchao Zhang                                                    PetscInt p, q = Yi(i);
1081c17cf699SJunchao Zhang                                                    for (p = Xi(i); p < Xi(i + 1); p++) {          /* For each nonzero on row i of X */
1082c17cf699SJunchao Zhang                                                      while (Xj(p) != Yj(q) && q < Yi(i + 1)) q++; /* find the matching nonzero on row i of Y */
1083c17cf699SJunchao Zhang                                                      if (Xj(p) == Yj(q)) {                        /* Found it */
1084c17cf699SJunchao Zhang                                                        Ya(q) += alpha * Xa(p);
1085c17cf699SJunchao Zhang                                                        q++;
1086a587d139SMark                                                      } else {
1087c17cf699SJunchao Zhang                                                        /* If not found, it indicates the input is wrong (X is not a SUBSET_NONZERO_PATTERN of Y).
1088c17cf699SJunchao Zhang                Just insert a NaN at the beginning of row i if it is not empty, to make the result wrong.
1089c17cf699SJunchao Zhang             */
10909371c9d4SSatish Balay                                                        if (Yi(i) != Yi(i + 1))
10919371c9d4SSatish Balay                                                          Ya(Yi(i)) =
10928b8b16f9SJunchao Zhang #if PETSC_PKG_KOKKOS_VERSION_GE(3, 6, 99)
10938b8b16f9SJunchao Zhang                                                            Kokkos::nan("1"); /* auto promote the double NaN if needed */
10948b8b16f9SJunchao Zhang #else
10958b8b16f9SJunchao Zhang               Kokkos::Experimental::nan("1");
10968b8b16f9SJunchao Zhang #endif
1097a587d139SMark                                                      }
1098c17cf699SJunchao Zhang                                                    }
1099c17cf699SJunchao Zhang         });
1100c17cf699SJunchao Zhang       });
11019566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1102c17cf699SJunchao Zhang   } else { /* different nonzero patterns */
1103c17cf699SJunchao Zhang     Mat             Z;
1104c17cf699SJunchao Zhang     KokkosCsrMatrix zcsr;
1105c17cf699SJunchao Zhang     KernelHandle    kh;
1106c17cf699SJunchao Zhang     kh.create_spadd_handle(false);
1107c17cf699SJunchao Zhang     KokkosSparse::spadd_symbolic(&kh, xkok->csrmat, ykok->csrmat, zcsr);
1108c17cf699SJunchao Zhang     KokkosSparse::spadd_numeric(&kh, alpha, xkok->csrmat, (PetscScalar)1.0, ykok->csrmat, zcsr);
1109c17cf699SJunchao Zhang     zkok = new Mat_SeqAIJKokkos(zcsr);
11109566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, zkok, &Z));
11119566063dSJacob Faibussowitsch     PetscCall(MatHeaderReplace(Y, &Z));
1112c17cf699SJunchao Zhang     kh.destroy_spadd_handle();
1113c17cf699SJunchao Zhang   }
11149566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
11159566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(xkok->a_dual.extent(0) * 2)); /* Because we scaled X and then added it to Y */
1116*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1117a587d139SMark }
1118a587d139SMark 
1119d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_SeqAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1120d71ae5a4SJacob Faibussowitsch {
112142550becSJunchao Zhang   Mat_SeqAIJKokkos *akok;
112242550becSJunchao Zhang   Mat_SeqAIJ       *aseq;
112342550becSJunchao Zhang 
112442550becSJunchao Zhang   PetscFunctionBegin;
11259566063dSJacob Faibussowitsch   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, coo_i, coo_j));
1126394ed5ebSJunchao Zhang   aseq = static_cast<Mat_SeqAIJ *>(mat->data);
112742550becSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(mat->spptr);
1128cbc6b225SStefano Zampini   delete akok;
1129cbc6b225SStefano Zampini   mat->spptr = akok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, aseq->nz, aseq->i, aseq->j, aseq->a, mat->nonzerostate + 1, PETSC_FALSE);
11309566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries_SeqAIJKokkos(mat));
1131394ed5ebSJunchao Zhang   akok->SetUpCOO(aseq);
1132*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
113342550becSJunchao Zhang }
113442550becSJunchao Zhang 
1135d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_SeqAIJKokkos(Mat A, const PetscScalar v[], InsertMode imode)
1136d71ae5a4SJacob Faibussowitsch {
113742550becSJunchao Zhang   Mat_SeqAIJ                 *aseq = static_cast<Mat_SeqAIJ *>(A->data);
113842550becSJunchao Zhang   Mat_SeqAIJKokkos           *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1139394ed5ebSJunchao Zhang   PetscCount                  Annz = aseq->nz;
1140394ed5ebSJunchao Zhang   const PetscCountKokkosView &jmap = akok->jmap_d;
1141394ed5ebSJunchao Zhang   const PetscCountKokkosView &perm = akok->perm_d;
114242550becSJunchao Zhang   MatScalarKokkosView         Aa;
114342550becSJunchao Zhang   ConstMatScalarKokkosView    kv;
114442550becSJunchao Zhang   PetscMemType                memtype;
114542550becSJunchao Zhang 
114642550becSJunchao Zhang   PetscFunctionBegin;
11479566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype));
114842550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
1149394ed5ebSJunchao Zhang     kv = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatScalarKokkosViewHost(v, aseq->coo_n));
115042550becSJunchao Zhang   } else {
1151394ed5ebSJunchao Zhang     kv = ConstMatScalarKokkosView(v, aseq->coo_n); /* Directly use v[]'s memory */
115242550becSJunchao Zhang   }
115342550becSJunchao Zhang 
1154c7b718f4SJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1155c7b718f4SJunchao Zhang   else PetscCall(MatSeqAIJGetKokkosView(A, &Aa));                             /* read & write matrix values */
115642550becSJunchao Zhang 
11579371c9d4SSatish Balay   Kokkos::parallel_for(
11589371c9d4SSatish Balay     Annz, KOKKOS_LAMBDA(const PetscCount i) {
1159c7b718f4SJunchao Zhang       PetscScalar sum = 0.0;
1160c7b718f4SJunchao Zhang       for (PetscCount k = jmap(i); k < jmap(i + 1); k++) sum += kv(perm(k));
1161c7b718f4SJunchao Zhang       Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1162c7b718f4SJunchao Zhang     });
1163394ed5ebSJunchao Zhang 
11649566063dSJacob Faibussowitsch   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
11659566063dSJacob Faibussowitsch   else PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1166*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
116742550becSJunchao Zhang }
116842550becSJunchao Zhang 
1169d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJMoveDiagonalValuesFront_SeqAIJKokkos(Mat A, const PetscInt *diag)
1170d71ae5a4SJacob Faibussowitsch {
11715fbaff96SJunchao Zhang   Mat_SeqAIJKokkos          *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
11725fbaff96SJunchao Zhang   MatScalarKokkosView        Aa;
11735fbaff96SJunchao Zhang   const MatRowMapKokkosView &Ai = akok->i_dual.view_device();
11745fbaff96SJunchao Zhang   PetscInt                   m  = A->rmap->n;
11755fbaff96SJunchao Zhang   ConstMatRowMapKokkosView   Adiag(diag, m); /* diag is a device pointer */
11765fbaff96SJunchao Zhang 
11775fbaff96SJunchao Zhang   PetscFunctionBegin;
11785fbaff96SJunchao Zhang   PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa));
11799371c9d4SSatish Balay   Kokkos::parallel_for(
11809371c9d4SSatish Balay     m, KOKKOS_LAMBDA(const PetscInt i) {
11815fbaff96SJunchao Zhang       PetscScalar tmp;
11825fbaff96SJunchao Zhang       if (Adiag(i) >= Ai(i) && Adiag(i) < Ai(i + 1)) { /* The diagonal element exists */
11835fbaff96SJunchao Zhang         tmp          = Aa(Ai(i));
11845fbaff96SJunchao Zhang         Aa(Ai(i))    = Aa(Adiag(i));
11855fbaff96SJunchao Zhang         Aa(Adiag(i)) = tmp;
11865fbaff96SJunchao Zhang       }
11875fbaff96SJunchao Zhang     });
11885fbaff96SJunchao Zhang   PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
1189*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
11905fbaff96SJunchao Zhang }
11915fbaff96SJunchao Zhang 
1192d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1193d71ae5a4SJacob Faibussowitsch {
11948f7e8f9dSMark Adams   PetscFunctionBegin;
11959566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncHost(A));
11969566063dSJacob Faibussowitsch   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
11978f7e8f9dSMark Adams   B->offloadmask = PETSC_OFFLOAD_CPU;
1198*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
11998f7e8f9dSMark Adams }
12008f7e8f9dSMark Adams 
1201d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat A)
1202d71ae5a4SJacob Faibussowitsch {
1203076ba34aSJunchao Zhang   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
1204076ba34aSJunchao Zhang 
12058c3ff71bSJunchao Zhang   PetscFunctionBegin;
1206076ba34aSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_KOKKOS; /* We do not really use this flag */
12076f3d89d0SStefano Zampini   A->boundtocpu  = PETSC_FALSE;
12086f3d89d0SStefano Zampini 
12098c3ff71bSJunchao Zhang   A->ops->assemblyend               = MatAssemblyEnd_SeqAIJKokkos;
12108c3ff71bSJunchao Zhang   A->ops->destroy                   = MatDestroy_SeqAIJKokkos;
12118c3ff71bSJunchao Zhang   A->ops->duplicate                 = MatDuplicate_SeqAIJKokkos;
1212a587d139SMark   A->ops->axpy                      = MatAXPY_SeqAIJKokkos;
1213f0cf5187SStefano Zampini   A->ops->scale                     = MatScale_SeqAIJKokkos;
1214a587d139SMark   A->ops->zeroentries               = MatZeroEntries_SeqAIJKokkos;
1215076ba34aSJunchao Zhang   A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJKokkos;
12168c3ff71bSJunchao Zhang   A->ops->mult                      = MatMult_SeqAIJKokkos;
12178c3ff71bSJunchao Zhang   A->ops->multadd                   = MatMultAdd_SeqAIJKokkos;
12188c3ff71bSJunchao Zhang   A->ops->multtranspose             = MatMultTranspose_SeqAIJKokkos;
12198c3ff71bSJunchao Zhang   A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJKokkos;
12208c3ff71bSJunchao Zhang   A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJKokkos;
12218c3ff71bSJunchao Zhang   A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJKokkos;
1222076ba34aSJunchao Zhang   A->ops->productnumeric            = MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos;
12230ecb592aSJunchao Zhang   A->ops->transpose                 = MatTranspose_SeqAIJKokkos;
1224152b3e56SJunchao Zhang   A->ops->setoption                 = MatSetOption_SeqAIJKokkos;
1225f78ce678SMark Adams   A->ops->getdiagonal               = MatGetDiagonal_SeqAIJKokkos;
1226076ba34aSJunchao Zhang   a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJKokkos;
1227076ba34aSJunchao Zhang   a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJKokkos;
1228076ba34aSJunchao Zhang   a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJKokkos;
1229076ba34aSJunchao Zhang   a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJKokkos;
1230076ba34aSJunchao Zhang   a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJKokkos;
1231076ba34aSJunchao Zhang   a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJKokkos;
12327ee59b9bSJunchao Zhang   a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJKokkos;
123342550becSJunchao Zhang 
12349566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJKokkos));
12359566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJKokkos));
1236*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1237076ba34aSJunchao Zhang }
1238076ba34aSJunchao Zhang 
1239d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSetSeqAIJKokkosWithCSRMatrix(Mat A, Mat_SeqAIJKokkos *akok)
1240d71ae5a4SJacob Faibussowitsch {
1241076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
1242076ba34aSJunchao Zhang   PetscInt    i, m, n;
1243076ba34aSJunchao Zhang 
1244076ba34aSJunchao Zhang   PetscFunctionBegin;
12455f80ce2aSJacob Faibussowitsch   PetscCheck(!A->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A->spptr is supposed to be empty");
1246076ba34aSJunchao Zhang 
1247076ba34aSJunchao Zhang   m = akok->nrows();
1248076ba34aSJunchao Zhang   n = akok->ncols();
12499566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(A, m, n, m, n));
12509566063dSJacob Faibussowitsch   PetscCall(MatSetType(A, MATSEQAIJKOKKOS));
1251076ba34aSJunchao Zhang 
1252076ba34aSJunchao Zhang   /* Set up data structures of A as a MATSEQAIJ */
12539566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(A, MAT_SKIP_ALLOCATION, NULL));
1254076ba34aSJunchao Zhang   aseq = (Mat_SeqAIJ *)(A)->data;
1255076ba34aSJunchao Zhang 
1256076ba34aSJunchao Zhang   akok->i_dual.sync_host(); /* We always need sync'ed i, j on host */
1257076ba34aSJunchao Zhang   akok->j_dual.sync_host();
1258076ba34aSJunchao Zhang 
1259076ba34aSJunchao Zhang   aseq->i            = akok->i_host_data();
1260076ba34aSJunchao Zhang   aseq->j            = akok->j_host_data();
1261076ba34aSJunchao Zhang   aseq->a            = akok->a_host_data();
1262076ba34aSJunchao Zhang   aseq->nonew        = -1; /*this indicates that inserting a new value in the matrix that generates a new nonzero is an error*/
1263076ba34aSJunchao Zhang   aseq->singlemalloc = PETSC_FALSE;
1264076ba34aSJunchao Zhang   aseq->free_a       = PETSC_FALSE;
1265076ba34aSJunchao Zhang   aseq->free_ij      = PETSC_FALSE;
1266076ba34aSJunchao Zhang   aseq->nz           = akok->nnz();
1267076ba34aSJunchao Zhang   aseq->maxnz        = aseq->nz;
1268076ba34aSJunchao Zhang 
12699566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->imax));
12709566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->ilen));
1271ad540459SPierre Jolivet   for (i = 0; i < m; i++) aseq->ilen[i] = aseq->imax[i] = aseq->i[i + 1] - aseq->i[i];
1272076ba34aSJunchao Zhang 
1273076ba34aSJunchao Zhang   /* It is critical to set the nonzerostate, as we use it to check if sparsity pattern (hence data) has changed on host in MatAssemblyEnd */
1274076ba34aSJunchao Zhang   akok->nonzerostate = A->nonzerostate;
1275ff751488SJunchao Zhang   A->spptr           = akok; /* Set A->spptr before MatAssembly so that A->spptr won't be allocated again there */
12769566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
12779566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
1278*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1279076ba34aSJunchao Zhang }
1280076ba34aSJunchao Zhang 
1281076ba34aSJunchao Zhang /* Crete a SEQAIJKOKKOS matrix with a Mat_SeqAIJKokkos data structure
1282076ba34aSJunchao Zhang 
1283076ba34aSJunchao Zhang    Note we have names like MatSeqAIJSetPreallocationCSR, so I use capitalized CSR
1284076ba34aSJunchao Zhang  */
1285d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithCSRMatrix(MPI_Comm comm, Mat_SeqAIJKokkos *akok, Mat *A)
1286d71ae5a4SJacob Faibussowitsch {
1287076ba34aSJunchao Zhang   PetscFunctionBegin;
12889566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
12899566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
1290*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
12918c3ff71bSJunchao Zhang }
12928c3ff71bSJunchao Zhang 
12938c3ff71bSJunchao Zhang /* --------------------------------------------------------------------------------*/
1294152b3e56SJunchao Zhang /*@C
129511a5261eSBarry Smith    MatCreateSeqAIJKokkos - Creates a sparse matrix in `MATSEQAIJKOKKOS` (compressed row) format
12968c3ff71bSJunchao Zhang    (the default parallel PETSc format). This matrix will ultimately be handled by
12978c3ff71bSJunchao Zhang    Kokkos for calculations. For good matrix
12988c3ff71bSJunchao Zhang    assembly performance the user should preallocate the matrix storage by setting
12998c3ff71bSJunchao Zhang    the parameter nz (or the array nnz).  By setting these parameters accurately,
13008c3ff71bSJunchao Zhang    performance during matrix assembly can be increased by more than a factor of 50.
13018c3ff71bSJunchao Zhang 
13028c3ff71bSJunchao Zhang    Collective
13038c3ff71bSJunchao Zhang 
13048c3ff71bSJunchao Zhang    Input Parameters:
130511a5261eSBarry Smith +  comm - MPI communicator, set to `PETSC_COMM_SELF`
13068c3ff71bSJunchao Zhang .  m - number of rows
13078c3ff71bSJunchao Zhang .  n - number of columns
13088c3ff71bSJunchao Zhang .  nz - number of nonzeros per row (same for all rows)
13098c3ff71bSJunchao Zhang -  nnz - array containing the number of nonzeros in the various rows
13108c3ff71bSJunchao Zhang          (possibly different for each row) or NULL
13118c3ff71bSJunchao Zhang 
13128c3ff71bSJunchao Zhang    Output Parameter:
13138c3ff71bSJunchao Zhang .  A - the matrix
13148c3ff71bSJunchao Zhang 
131511a5261eSBarry Smith    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
13168c3ff71bSJunchao Zhang    MatXXXXSetPreallocation() paradgm instead of this routine directly.
131711a5261eSBarry Smith    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
13188c3ff71bSJunchao Zhang 
13198c3ff71bSJunchao Zhang    Notes:
13208c3ff71bSJunchao Zhang    If nnz is given then nz is ignored
13218c3ff71bSJunchao Zhang 
132211a5261eSBarry Smith    The AIJ format, also called
132311a5261eSBarry Smith    compressed row storage, is fully compatible with standard Fortran 77
13248c3ff71bSJunchao Zhang    storage.  That is, the stored row and column indices can begin at
13258c3ff71bSJunchao Zhang    either one (as in Fortran) or zero.  See the users' manual for details.
13268c3ff71bSJunchao Zhang 
13278c3ff71bSJunchao Zhang    Specify the preallocated storage with either nz or nnz (not both).
132811a5261eSBarry Smith    Set nz = `PETSC_DEFAULT` and nnz = NULL for PETSc to control dynamic memory
13298c3ff71bSJunchao Zhang    allocation.  For large problems you MUST preallocate memory or you
13308c3ff71bSJunchao Zhang    will get TERRIBLE performance, see the users' manual chapter on matrices.
13318c3ff71bSJunchao Zhang 
13328c3ff71bSJunchao Zhang    By default, this format uses inodes (identical nodes) when possible, to
13338c3ff71bSJunchao Zhang    improve numerical efficiency of matrix-vector products and solves. We
13348c3ff71bSJunchao Zhang    search for consecutive rows with the same nonzero structure, thereby
13358c3ff71bSJunchao Zhang    reusing matrix information to achieve increased efficiency.
13368c3ff71bSJunchao Zhang 
13378c3ff71bSJunchao Zhang    Level: intermediate
13388c3ff71bSJunchao Zhang 
1339db781477SPatrick Sanan .seealso: `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`
13408c3ff71bSJunchao Zhang @*/
1341d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateSeqAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
1342d71ae5a4SJacob Faibussowitsch {
13438c3ff71bSJunchao Zhang   PetscFunctionBegin;
13449566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
13459566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
13469566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, m, n));
13479566063dSJacob Faibussowitsch   PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
13489566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
1349*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
13508c3ff71bSJunchao Zhang }
1351930e68a5SMark Adams 
13528f7e8f9dSMark Adams typedef Kokkos::TeamPolicy<>::member_type team_member;
13538f7e8f9dSMark Adams //
135446804e07SMark Adams // This factorization exploits block diagonal matrices with "Nf" (not used).
13558f7e8f9dSMark Adams // Use -pc_factor_mat_ordering_type rcm to order decouple blocks of size N/Nf for this optimization
13568f7e8f9dSMark Adams //
1357d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorNumeric_SeqAIJKOKKOSDEVICE(Mat B, Mat A, const MatFactorInfo *info)
1358d71ae5a4SJacob Faibussowitsch {
13598f7e8f9dSMark Adams   Mat_SeqAIJ       *b      = (Mat_SeqAIJ *)B->data;
13608f7e8f9dSMark Adams   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *baijkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
13618f7e8f9dSMark Adams   IS                isrow = b->row, isicol = b->icol;
13628f7e8f9dSMark Adams   const PetscInt   *r_h, *ic_h;
1363300d22a6SJunchao Zhang   const PetscInt n = A->rmap->n, *ai_d = aijkok->i_dual.view_device().data(), *aj_d = aijkok->j_dual.view_device().data(), *bi_d = baijkok->i_dual.view_device().data(), *bj_d = baijkok->j_dual.view_device().data(), *bdiag_d = baijkok->diag_d.data();
1364076ba34aSJunchao Zhang   const PetscScalar *aa_d = aijkok->a_dual.view_device().data();
1365076ba34aSJunchao Zhang   PetscScalar       *ba_d = baijkok->a_dual.view_device().data();
13668f7e8f9dSMark Adams   PetscBool          row_identity, col_identity;
136746804e07SMark Adams   PetscInt           nc, Nf = 1, nVec = 32; // should be a parameter, Nf is batch size - not used
1368930e68a5SMark Adams 
1369930e68a5SMark Adams   PetscFunctionBegin;
13702c71b3e2SJacob Faibussowitsch   PetscCheck(A->rmap->n == n, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "square matrices only supported %" PetscInt_FMT " %" PetscInt_FMT, A->rmap->n, n);
1371b94d7dedSBarry Smith   PetscCall(MatIsStructurallySymmetric(A, &row_identity));
13722c71b3e2SJacob Faibussowitsch   PetscCheck(row_identity, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "structurally symmetric matrices only supported");
13739566063dSJacob Faibussowitsch   PetscCall(ISGetIndices(isrow, &r_h));
13749566063dSJacob Faibussowitsch   PetscCall(ISGetIndices(isicol, &ic_h));
13759566063dSJacob Faibussowitsch   PetscCall(ISGetSize(isicol, &nc));
13769566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
13779566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
13788f7e8f9dSMark Adams   {
13798f7e8f9dSMark Adams #define KOKKOS_SHARED_LEVEL 1
13808f7e8f9dSMark Adams     using scr_mem_t    = Kokkos::DefaultExecutionSpace::scratch_memory_space;
13818f7e8f9dSMark Adams     using sizet_scr_t  = Kokkos::View<size_t, scr_mem_t>;
13828f7e8f9dSMark Adams     using scalar_scr_t = Kokkos::View<PetscScalar, scr_mem_t>;
13838f7e8f9dSMark Adams     const Kokkos::View<const PetscInt *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_r_k(r_h, n);
13848f7e8f9dSMark Adams     Kokkos::View<PetscInt *, Kokkos::LayoutLeft>                                                                         d_r_k("r", n);
13858f7e8f9dSMark Adams     const Kokkos::View<const PetscInt *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_ic_k(ic_h, nc);
13868f7e8f9dSMark Adams     Kokkos::View<PetscInt *, Kokkos::LayoutLeft>                                                                         d_ic_k("ic", nc);
13878f7e8f9dSMark Adams     size_t                                                                                                               flops_h = 0.0;
13888f7e8f9dSMark Adams     Kokkos::View<size_t, Kokkos::HostSpace>                                                                              h_flops_k(&flops_h);
13898f7e8f9dSMark Adams     Kokkos::View<size_t>                                                                                                 d_flops_k("flops");
13908f7e8f9dSMark Adams     const int                                                                                                            conc = Kokkos::DefaultExecutionSpace().concurrency(), team_size = conc > 1 ? 16 : 1; // 8*32 = 256
13918f7e8f9dSMark Adams     const int                                                                                                            nloc = n / Nf, Ni = (conc > 8) ? 1 /* some intelegent number of SMs -- but need league_barrier */ : 1;
13928f7e8f9dSMark Adams     Kokkos::deep_copy(d_flops_k, h_flops_k);
13938f7e8f9dSMark Adams     Kokkos::deep_copy(d_r_k, h_r_k);
13948f7e8f9dSMark Adams     Kokkos::deep_copy(d_ic_k, h_ic_k);
13958f7e8f9dSMark Adams     // Fill A --> fact
13969371c9d4SSatish Balay     Kokkos::parallel_for(
13979371c9d4SSatish Balay       Kokkos::TeamPolicy<>(Nf * Ni, team_size, nVec), KOKKOS_LAMBDA(const team_member team) {
1398042217e8SBarry Smith         const PetscInt  field = team.league_rank() / Ni, field_block = team.league_rank() % Ni; // use grid.x/y in CUDA
13998f7e8f9dSMark Adams         const PetscInt  nloc_i = (nloc / Ni + !!(nloc % Ni)), start_i = field * nloc + field_block * nloc_i, end_i = (start_i + nloc_i) > (field + 1) * nloc ? (field + 1) * nloc : (start_i + nloc_i);
14008f7e8f9dSMark Adams         const PetscInt *ic = d_ic_k.data(), *r = d_r_k.data();
14018f7e8f9dSMark Adams         // zero rows of B
14028f7e8f9dSMark Adams         Kokkos::parallel_for(Kokkos::TeamVectorRange(team, start_i, end_i), [=](const int &rowb) {
14038f7e8f9dSMark Adams           PetscInt     nzbL = bi_d[rowb + 1] - bi_d[rowb], nzbU = bdiag_d[rowb] - bdiag_d[rowb + 1]; // with diag
14048f7e8f9dSMark Adams           PetscScalar *baL = ba_d + bi_d[rowb];
14058f7e8f9dSMark Adams           PetscScalar *baU = ba_d + bdiag_d[rowb + 1] + 1;
14068f7e8f9dSMark Adams           /* zero (unfactored row) */
14078f7e8f9dSMark Adams           for (int j = 0; j < nzbL; j++) baL[j] = 0;
14088f7e8f9dSMark Adams           for (int j = 0; j < nzbU; j++) baU[j] = 0;
14098f7e8f9dSMark Adams         });
14108f7e8f9dSMark Adams         // copy A into B
14118f7e8f9dSMark Adams         Kokkos::parallel_for(Kokkos::TeamVectorRange(team, start_i, end_i), [=](const int &rowb) {
14128f7e8f9dSMark Adams           PetscInt           rowa = r[rowb], nza = ai_d[rowa + 1] - ai_d[rowa];
14138f7e8f9dSMark Adams           const PetscScalar *av    = aa_d + ai_d[rowa];
14148f7e8f9dSMark Adams           const PetscInt    *ajtmp = aj_d + ai_d[rowa];
14158f7e8f9dSMark Adams           /* load in initial (unfactored row) */
14168f7e8f9dSMark Adams           for (int j = 0; j < nza; j++) {
14178f7e8f9dSMark Adams             PetscInt    colb = ic[ajtmp[j]];
14188f7e8f9dSMark Adams             PetscScalar vala = av[j];
14198f7e8f9dSMark Adams             if (colb == rowb) {
14208f7e8f9dSMark Adams               *(ba_d + bdiag_d[rowb]) = vala;
14218f7e8f9dSMark Adams             } else {
14228f7e8f9dSMark Adams               const PetscInt *pbj = bj_d + ((colb > rowb) ? bdiag_d[rowb + 1] + 1 : bi_d[rowb]);
14238f7e8f9dSMark Adams               PetscScalar    *pba = ba_d + ((colb > rowb) ? bdiag_d[rowb + 1] + 1 : bi_d[rowb]);
14248f7e8f9dSMark Adams               PetscInt        nz = (colb > rowb) ? bdiag_d[rowb] - (bdiag_d[rowb + 1] + 1) : bi_d[rowb + 1] - bi_d[rowb], set = 0;
14258f7e8f9dSMark Adams               for (int j = 0; j < nz; j++) {
14268f7e8f9dSMark Adams                 if (pbj[j] == colb) {
14278f7e8f9dSMark Adams                   pba[j] = vala;
14288f7e8f9dSMark Adams                   set++;
14298f7e8f9dSMark Adams                   break;
14308f7e8f9dSMark Adams                 }
14318f7e8f9dSMark Adams               }
14328f1da0b2SJunchao Zhang #if !defined(PETSC_HAVE_SYCL)
14338f7e8f9dSMark Adams               if (set != 1) printf("\t\t\t ERROR DID NOT SET ?????\n");
14348f1da0b2SJunchao Zhang #endif
14358f7e8f9dSMark Adams             }
14368f7e8f9dSMark Adams           }
14378f7e8f9dSMark Adams         });
14388f7e8f9dSMark Adams       });
14398f7e8f9dSMark Adams     Kokkos::fence();
1440930e68a5SMark Adams 
14419371c9d4SSatish Balay     Kokkos::parallel_for(
14429371c9d4SSatish Balay       Kokkos::TeamPolicy<>(Nf * Ni, team_size, nVec).set_scratch_size(KOKKOS_SHARED_LEVEL, Kokkos::PerThread(sizet_scr_t::shmem_size() + scalar_scr_t::shmem_size()), Kokkos::PerTeam(sizet_scr_t::shmem_size())), KOKKOS_LAMBDA(const team_member team) {
14438f7e8f9dSMark Adams         sizet_scr_t    colkIdx(team.thread_scratch(KOKKOS_SHARED_LEVEL));
14448f7e8f9dSMark Adams         scalar_scr_t   L_ki(team.thread_scratch(KOKKOS_SHARED_LEVEL));
14458f7e8f9dSMark Adams         sizet_scr_t    flops(team.team_scratch(KOKKOS_SHARED_LEVEL));
1446042217e8SBarry Smith         const PetscInt field = team.league_rank() / Ni, field_block_idx = team.league_rank() % Ni; // use grid.x/y in CUDA
14478f7e8f9dSMark Adams         const PetscInt start = field * nloc, end = start + nloc;
14488f7e8f9dSMark Adams         Kokkos::single(Kokkos::PerTeam(team), [=]() { flops() = 0; });
14498f7e8f9dSMark Adams         // A22 panel update for each row A(1,:) and col A(:,1)
14508f7e8f9dSMark Adams         for (int ii = start; ii < end - 1; ii++) {
14518f7e8f9dSMark Adams           const PetscInt    *bjUi = bj_d + bdiag_d[ii + 1] + 1, nzUi = bdiag_d[ii] - (bdiag_d[ii + 1] + 1); // vector, and vector size, of column indices of U(i,(i+1):end)
14528f7e8f9dSMark Adams           const PetscScalar *baUi    = ba_d + bdiag_d[ii + 1] + 1;                                          // vector of data  U(i,i+1:end)
14538f7e8f9dSMark Adams           const PetscInt     nUi_its = nzUi / Ni + !!(nzUi % Ni);
14548f7e8f9dSMark Adams           const PetscScalar  Bii     = *(ba_d + bdiag_d[ii]); // diagonal in its special place
14558f7e8f9dSMark Adams           Kokkos::parallel_for(Kokkos::TeamThreadRange(team, nUi_its), [=](const int j) {
14568f7e8f9dSMark Adams             PetscInt kIdx = j * Ni + field_block_idx;
14579371c9d4SSatish Balay             if (kIdx >= nzUi) /* void */
14589371c9d4SSatish Balay               ;
14598f7e8f9dSMark Adams             else {
14608f7e8f9dSMark Adams               const PetscInt  myk = bjUi[kIdx];                // assume symmetric structure, need a transposed meta-data here in general
14618f7e8f9dSMark Adams               const PetscInt *pjL = bj_d + bi_d[myk];          // look for L(myk,ii) in start of row
14628f7e8f9dSMark Adams               const PetscInt  nzL = bi_d[myk + 1] - bi_d[myk]; // size of L_k(:)
14638f7e8f9dSMark Adams               size_t          st_idx;
14648f7e8f9dSMark Adams               // find and do L(k,i) = A(:k,i) / A(i,i)
14658f7e8f9dSMark Adams               Kokkos::single(Kokkos::PerThread(team), [&]() { colkIdx() = PETSC_MAX_INT; });
14668f7e8f9dSMark Adams               // get column, there has got to be a better way
14679371c9d4SSatish Balay               Kokkos::parallel_reduce(
14689371c9d4SSatish Balay                 Kokkos::ThreadVectorRange(team, nzL),
14699371c9d4SSatish Balay                 [&](const int &j, size_t &idx) {
14708f7e8f9dSMark Adams                   if (pjL[j] == ii) {
14718f7e8f9dSMark Adams                     PetscScalar *pLki = ba_d + bi_d[myk] + j;
14728f7e8f9dSMark Adams                     idx               = j;           // output
14738f7e8f9dSMark Adams                     *pLki             = *pLki / Bii; // column scaling:  L(k,i) = A(:k,i) / A(i,i)
14748f7e8f9dSMark Adams                   }
14759371c9d4SSatish Balay                 },
14769371c9d4SSatish Balay                 st_idx);
14779371c9d4SSatish Balay               Kokkos::single(Kokkos::PerThread(team), [=]() {
14789371c9d4SSatish Balay                 colkIdx() = st_idx;
14799371c9d4SSatish Balay                 L_ki()    = *(ba_d + bi_d[myk] + st_idx);
14809371c9d4SSatish Balay               });
14818f1da0b2SJunchao Zhang #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
148299551766SMark Adams               if (colkIdx() == PETSC_MAX_INT) printf("\t\t\t\t\t\t\tERROR: failed to find L_ki(%d,%d)\n", (int)myk, ii); // uses a register
148399551766SMark Adams #endif
148499551766SMark Adams               // active row k, do  A_kj -= Lki * U_ij; j \in U(i,:) j != i
14858f7e8f9dSMark Adams               // U(i+1,:end)
14868f7e8f9dSMark Adams               Kokkos::parallel_for(Kokkos::ThreadVectorRange(team, nzUi), [=](const int &uiIdx) { // index into i (U)
14878f7e8f9dSMark Adams                 PetscScalar Uij = baUi[uiIdx];
14888f7e8f9dSMark Adams                 PetscInt    col = bjUi[uiIdx];
14898f7e8f9dSMark Adams                 if (col == myk) {
14908f7e8f9dSMark Adams                   // A_kk = A_kk - L_ki * U_ij(k)
14918f7e8f9dSMark Adams                   PetscScalar *Akkv = (ba_d + bdiag_d[myk]); // diagonal in its special place
14928f7e8f9dSMark Adams                   *Akkv             = *Akkv - L_ki() * Uij;  // UiK
14938f7e8f9dSMark Adams                 } else {
14948f7e8f9dSMark Adams                   PetscScalar    *start, *end, *pAkjv = NULL;
14958f7e8f9dSMark Adams                   PetscInt        high, low;
14968f7e8f9dSMark Adams                   const PetscInt *startj;
14978f7e8f9dSMark Adams                   if (col < myk) { // L
14988f7e8f9dSMark Adams                     PetscScalar *pLki = ba_d + bi_d[myk] + colkIdx();
14998f7e8f9dSMark Adams                     PetscInt     idx  = (pLki + 1) - (ba_d + bi_d[myk]); // index into row
15008f7e8f9dSMark Adams                     start             = pLki + 1;                        // start at pLki+1, A22(myk,1)
15018f7e8f9dSMark Adams                     startj            = bj_d + bi_d[myk] + idx;
15028f7e8f9dSMark Adams                     end               = ba_d + bi_d[myk + 1];
15038f7e8f9dSMark Adams                   } else {
15048f7e8f9dSMark Adams                     PetscInt idx = bdiag_d[myk + 1] + 1;
15058f7e8f9dSMark Adams                     start        = ba_d + idx;
15068f7e8f9dSMark Adams                     startj       = bj_d + idx;
15078f7e8f9dSMark Adams                     end          = ba_d + bdiag_d[myk];
15088f7e8f9dSMark Adams                   }
15098f7e8f9dSMark Adams                   // search for 'col', use bisection search - TODO
15108f7e8f9dSMark Adams                   low  = 0;
15118f7e8f9dSMark Adams                   high = (PetscInt)(end - start);
15128f7e8f9dSMark Adams                   while (high - low > 5) {
15138f7e8f9dSMark Adams                     int t = (low + high) / 2;
15148f7e8f9dSMark Adams                     if (startj[t] > col) high = t;
15158f7e8f9dSMark Adams                     else low = t;
15168f7e8f9dSMark Adams                   }
15178f7e8f9dSMark Adams                   for (pAkjv = start + low; pAkjv < start + high; pAkjv++) {
15188f7e8f9dSMark Adams                     if (startj[pAkjv - start] == col) break;
15198f7e8f9dSMark Adams                   }
15208f1da0b2SJunchao Zhang #if defined(PETSC_USE_DEBUG) && !defined(PETSC_HAVE_SYCL)
152199551766SMark Adams                   if (pAkjv == start + high) printf("\t\t\t\t\t\t\t\t\t\t\tERROR: *** failed to find Akj(%d,%d)\n", (int)myk, (int)col); // uses a register
152299551766SMark Adams #endif
15238f7e8f9dSMark Adams                   *pAkjv = *pAkjv - L_ki() * Uij; // A_kj = A_kj - L_ki * U_ij
15248f7e8f9dSMark Adams                 }
15258f7e8f9dSMark Adams               });
15268f7e8f9dSMark Adams             }
15278f7e8f9dSMark Adams           });
15288f7e8f9dSMark Adams           team.team_barrier(); // this needs to be a league barrier to use more that one SM per block
15298f7e8f9dSMark Adams           if (field_block_idx == 0) Kokkos::single(Kokkos::PerTeam(team), [&]() { Kokkos::atomic_add(flops.data(), (size_t)(2 * (nzUi * nzUi) + 2)); });
15308f7e8f9dSMark Adams         } /* endof for (i=0; i<n; i++) { */
15319371c9d4SSatish Balay         Kokkos::single(Kokkos::PerTeam(team), [=]() {
15329371c9d4SSatish Balay           Kokkos::atomic_add(&d_flops_k(), flops());
15339371c9d4SSatish Balay           flops() = 0;
15349371c9d4SSatish Balay         });
15358f7e8f9dSMark Adams       });
15368f7e8f9dSMark Adams     Kokkos::fence();
15378f7e8f9dSMark Adams     Kokkos::deep_copy(h_flops_k, d_flops_k);
15389566063dSJacob Faibussowitsch     PetscCall(PetscLogGpuFlops((PetscLogDouble)h_flops_k()));
15399371c9d4SSatish Balay     Kokkos::parallel_for(
15409371c9d4SSatish Balay       Kokkos::TeamPolicy<>(Nf * Ni, 1, 256), KOKKOS_LAMBDA(const team_member team) {
15418f7e8f9dSMark Adams         const PetscInt lg_rank = team.league_rank(), field = lg_rank / Ni;                            //, field_offset = lg_rank%Ni;
15428f7e8f9dSMark Adams         const PetscInt start = field * nloc, end = start + nloc, n_its = (nloc / Ni + !!(nloc % Ni)); // 1/Ni iters
15438f7e8f9dSMark Adams         /* Invert diagonal for simpler triangular solves */
15448f7e8f9dSMark Adams         Kokkos::parallel_for(Kokkos::TeamVectorRange(team, n_its), [=](int outer_index) {
15458f7e8f9dSMark Adams           int i = start + outer_index * Ni + lg_rank % Ni;
15468f7e8f9dSMark Adams           if (i < end) {
15478f7e8f9dSMark Adams             PetscScalar *pv = ba_d + bdiag_d[i];
15488f7e8f9dSMark Adams             *pv             = 1.0 / (*pv);
15498f7e8f9dSMark Adams           }
15508f7e8f9dSMark Adams         });
15518f7e8f9dSMark Adams       });
15528f7e8f9dSMark Adams   }
15539566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
15549566063dSJacob Faibussowitsch   PetscCall(ISRestoreIndices(isicol, &ic_h));
15559566063dSJacob Faibussowitsch   PetscCall(ISRestoreIndices(isrow, &r_h));
15568f7e8f9dSMark Adams 
15579566063dSJacob Faibussowitsch   PetscCall(ISIdentity(isrow, &row_identity));
15589566063dSJacob Faibussowitsch   PetscCall(ISIdentity(isicol, &col_identity));
15598f7e8f9dSMark Adams   if (b->inode.size) {
15608f7e8f9dSMark Adams     B->ops->solve = MatSolve_SeqAIJ_Inode;
15618f7e8f9dSMark Adams   } else if (row_identity && col_identity) {
15628f7e8f9dSMark Adams     B->ops->solve = MatSolve_SeqAIJ_NaturalOrdering;
15638f7e8f9dSMark Adams   } else {
15648f7e8f9dSMark Adams     B->ops->solve = MatSolve_SeqAIJ; // at least this needs to be in Kokkos
15658f7e8f9dSMark Adams   }
15668f7e8f9dSMark Adams   B->offloadmask = PETSC_OFFLOAD_GPU;
15679566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncHost(B));          // solve on CPU
15688f7e8f9dSMark Adams   B->ops->solveadd          = MatSolveAdd_SeqAIJ; // and this
15698f7e8f9dSMark Adams   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJ;
15708f7e8f9dSMark Adams   B->ops->solvetransposeadd = MatSolveTransposeAdd_SeqAIJ;
15718f7e8f9dSMark Adams   B->ops->matsolve          = MatMatSolve_SeqAIJ;
15728f7e8f9dSMark Adams   B->assembled              = PETSC_TRUE;
15738f7e8f9dSMark Adams   B->preallocated           = PETSC_TRUE;
15748f7e8f9dSMark Adams 
1575*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1576930e68a5SMark Adams }
1577930e68a5SMark Adams 
1578d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1579d71ae5a4SJacob Faibussowitsch {
1580930e68a5SMark Adams   PetscFunctionBegin;
15819566063dSJacob Faibussowitsch   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
158286a27549SJunchao Zhang   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
1583*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
158486a27549SJunchao Zhang }
158586a27549SJunchao Zhang 
1586d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSymbolicSolveCheck(Mat A)
1587d71ae5a4SJacob Faibussowitsch {
158886a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
158986a27549SJunchao Zhang 
159086a27549SJunchao Zhang   PetscFunctionBegin;
159186a27549SJunchao Zhang   if (!factors->sptrsv_symbolic_completed) {
159286a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d);
159386a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d);
159486a27549SJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_TRUE;
159586a27549SJunchao Zhang   }
1596*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
159786a27549SJunchao Zhang }
159886a27549SJunchao Zhang 
159986a27549SJunchao Zhang /* Check if we need to update factors etc for transpose solve */
1600d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosTransposeSolveCheck(Mat A)
1601d71ae5a4SJacob Faibussowitsch {
160286a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1603076ba34aSJunchao Zhang   MatColIdxType               n       = A->rmap->n;
160486a27549SJunchao Zhang 
160586a27549SJunchao Zhang   PetscFunctionBegin;
160686a27549SJunchao Zhang   if (!factors->transpose_updated) { /* TODO: KK needs to provide functions to do numeric transpose only */
160786a27549SJunchao Zhang     /* Update L^T and do sptrsv symbolic */
1608076ba34aSJunchao Zhang     factors->iLt_d = MatRowMapKokkosView("factors->iLt_d", n + 1);
160986a27549SJunchao Zhang     Kokkos::deep_copy(factors->iLt_d, 0); /* KK requires 0 */
1610076ba34aSJunchao Zhang     factors->jLt_d = MatColIdxKokkosView("factors->jLt_d", factors->jL_d.extent(0));
1611076ba34aSJunchao Zhang     factors->aLt_d = MatScalarKokkosView("factors->aLt_d", factors->aL_d.extent(0));
161286a27549SJunchao Zhang 
16139371c9d4SSatish Balay     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iL_d, factors->jL_d, factors->aL_d,
161486a27549SJunchao Zhang                                                                                                                                                                                                               factors->iLt_d, factors->jLt_d, factors->aLt_d);
161586a27549SJunchao Zhang 
161686a27549SJunchao Zhang     /* TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices.
161786a27549SJunchao Zhang       We have to sort the indices, until KK provides finer control options.
161886a27549SJunchao Zhang     */
16199371c9d4SSatish Balay     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iLt_d, factors->jLt_d, factors->aLt_d);
162086a27549SJunchao Zhang 
162186a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d);
162286a27549SJunchao Zhang 
162386a27549SJunchao Zhang     /* Update U^T and do sptrsv symbolic */
1624076ba34aSJunchao Zhang     factors->iUt_d = MatRowMapKokkosView("factors->iUt_d", n + 1);
162586a27549SJunchao Zhang     Kokkos::deep_copy(factors->iUt_d, 0); /* KK requires 0 */
1626076ba34aSJunchao Zhang     factors->jUt_d = MatColIdxKokkosView("factors->jUt_d", factors->jU_d.extent(0));
1627076ba34aSJunchao Zhang     factors->aUt_d = MatScalarKokkosView("factors->aUt_d", factors->aU_d.extent(0));
162886a27549SJunchao Zhang 
16299371c9d4SSatish Balay     transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iU_d, factors->jU_d, factors->aU_d,
163086a27549SJunchao Zhang                                                                                                                                                                                                               factors->iUt_d, factors->jUt_d, factors->aUt_d);
163186a27549SJunchao Zhang 
163286a27549SJunchao Zhang     /* Sort indices. See comments above */
16339371c9d4SSatish Balay     sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iUt_d, factors->jUt_d, factors->aUt_d);
163486a27549SJunchao Zhang 
163586a27549SJunchao Zhang     KokkosSparse::Experimental::sptrsv_symbolic(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d);
163686a27549SJunchao Zhang     factors->transpose_updated = PETSC_TRUE;
163786a27549SJunchao Zhang   }
1638*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
163986a27549SJunchao Zhang }
164086a27549SJunchao Zhang 
164186a27549SJunchao Zhang /* Solve Ax = b, with A = LU */
1642d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSolve_SeqAIJKokkos(Mat A, Vec b, Vec x)
1643d71ae5a4SJacob Faibussowitsch {
164486a27549SJunchao Zhang   ConstPetscScalarKokkosView  bv;
164586a27549SJunchao Zhang   PetscScalarKokkosView       xv;
164686a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
164786a27549SJunchao Zhang 
164886a27549SJunchao Zhang   PetscFunctionBegin;
16499566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
16509566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSymbolicSolveCheck(A));
16519566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(b, &bv));
16529566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(x, &xv));
165386a27549SJunchao Zhang   /* Solve L tmpv = b */
16549566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d, bv, factors->workVector));
165586a27549SJunchao Zhang   /* Solve Ux = tmpv */
16569566063dSJacob Faibussowitsch   PetscCallCXX(KokkosSparse::Experimental::sptrsv_solve(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, factors->workVector, xv));
16579566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(b, &bv));
16589566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
16599566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
1660*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
166186a27549SJunchao Zhang }
166286a27549SJunchao Zhang 
1663076ba34aSJunchao Zhang /* Solve A^T x = b, where A^T = U^T L^T */
1664d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSolveTranspose_SeqAIJKokkos(Mat A, Vec b, Vec x)
1665d71ae5a4SJacob Faibussowitsch {
166686a27549SJunchao Zhang   ConstPetscScalarKokkosView  bv;
166786a27549SJunchao Zhang   PetscScalarKokkosView       xv;
166886a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
166986a27549SJunchao Zhang 
167086a27549SJunchao Zhang   PetscFunctionBegin;
16719566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
16729566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A));
16739566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(b, &bv));
16749566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(x, &xv));
167586a27549SJunchao Zhang   /* Solve U^T tmpv = b */
167686a27549SJunchao Zhang   KokkosSparse::Experimental::sptrsv_solve(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, bv, factors->workVector);
167786a27549SJunchao Zhang 
167886a27549SJunchao Zhang   /* Solve L^T x = tmpv */
167986a27549SJunchao Zhang   KokkosSparse::Experimental::sptrsv_solve(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d, factors->workVector, xv);
16809566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(b, &bv));
16819566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
16829566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
1683*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
168486a27549SJunchao Zhang }
168586a27549SJunchao Zhang 
1686d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatILUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1687d71ae5a4SJacob Faibussowitsch {
168886a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
168986a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
169086a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
169186a27549SJunchao Zhang 
169286a27549SJunchao Zhang   PetscFunctionBegin;
16939566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
16949566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1695076ba34aSJunchao Zhang 
1696076ba34aSJunchao Zhang   auto a_d = aijkok->a_dual.view_device();
1697076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
1698076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
1699076ba34aSJunchao Zhang 
1700076ba34aSJunchao Zhang   KokkosSparse::Experimental::spiluk_numeric(&factors->kh, fill_lev, i_d, j_d, a_d, factors->iL_d, factors->jL_d, factors->aL_d, factors->iU_d, factors->jU_d, factors->aU_d);
170186a27549SJunchao Zhang 
170286a27549SJunchao Zhang   B->assembled              = PETSC_TRUE;
170386a27549SJunchao Zhang   B->preallocated           = PETSC_TRUE;
170486a27549SJunchao Zhang   B->ops->solve             = MatSolve_SeqAIJKokkos;
170586a27549SJunchao Zhang   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJKokkos;
170686a27549SJunchao Zhang   B->ops->matsolve          = NULL;
170786a27549SJunchao Zhang   B->ops->matsolvetranspose = NULL;
170886a27549SJunchao Zhang   B->offloadmask            = PETSC_OFFLOAD_GPU;
170986a27549SJunchao Zhang 
171086a27549SJunchao Zhang   /* Once the factors' value changed, we need to update their transpose and sptrsv handle */
171186a27549SJunchao Zhang   factors->transpose_updated         = PETSC_FALSE;
171286a27549SJunchao Zhang   factors->sptrsv_symbolic_completed = PETSC_FALSE;
1713eeadb341SJunchao Zhang   /* TODO: log flops, but how to know that? */
17149566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
1715*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
171686a27549SJunchao Zhang }
171786a27549SJunchao Zhang 
1718d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1719d71ae5a4SJacob Faibussowitsch {
172086a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok;
172186a27549SJunchao Zhang   Mat_SeqAIJ                 *b;
172286a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
172386a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
172486a27549SJunchao Zhang   PetscInt                    nnzA     = ((Mat_SeqAIJ *)A->data)->nz, nnzL, nnzU;
172586a27549SJunchao Zhang   PetscInt                    n        = A->rmap->n;
172686a27549SJunchao Zhang 
172786a27549SJunchao Zhang   PetscFunctionBegin;
17289566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
172986a27549SJunchao Zhang   /* Rebuild factors */
17309371c9d4SSatish Balay   if (factors) {
17319371c9d4SSatish Balay     factors->Destroy();
17329371c9d4SSatish Balay   } /* Destroy the old if it exists */
17339371c9d4SSatish Balay   else {
17349371c9d4SSatish Balay     B->spptr = factors = new Mat_SeqAIJKokkosTriFactors(n);
17359371c9d4SSatish Balay   }
173686a27549SJunchao Zhang 
173786a27549SJunchao Zhang   /* Create a spiluk handle and then do symbolic factorization */
173886a27549SJunchao Zhang   nnzL = nnzU = PetscRealIntMultTruncate(info->fill, nnzA);
173986a27549SJunchao Zhang   factors->kh.create_spiluk_handle(KokkosSparse::Experimental::SPILUKAlgorithm::SEQLVLSCHD_TP1, n, nnzL, nnzU);
174086a27549SJunchao Zhang 
174186a27549SJunchao Zhang   auto spiluk_handle = factors->kh.get_spiluk_handle();
174286a27549SJunchao Zhang 
174386a27549SJunchao Zhang   Kokkos::realloc(factors->iL_d, n + 1); /* Free old arrays and realloc */
174486a27549SJunchao Zhang   Kokkos::realloc(factors->jL_d, spiluk_handle->get_nnzL());
174586a27549SJunchao Zhang   Kokkos::realloc(factors->iU_d, n + 1);
174686a27549SJunchao Zhang   Kokkos::realloc(factors->jU_d, spiluk_handle->get_nnzU());
174786a27549SJunchao Zhang 
174886a27549SJunchao Zhang   aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
1749076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
1750076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
1751076ba34aSJunchao Zhang   KokkosSparse::Experimental::spiluk_symbolic(&factors->kh, fill_lev, i_d, j_d, factors->iL_d, factors->jL_d, factors->iU_d, factors->jU_d);
175286a27549SJunchao Zhang   /* TODO: if spiluk_symbolic is asynchronous, do we need to sync before calling get_nnzL()? */
175386a27549SJunchao Zhang 
175486a27549SJunchao Zhang   Kokkos::resize(factors->jL_d, spiluk_handle->get_nnzL()); /* Shrink or expand, and retain old value */
175586a27549SJunchao Zhang   Kokkos::resize(factors->jU_d, spiluk_handle->get_nnzU());
175686a27549SJunchao Zhang   Kokkos::realloc(factors->aL_d, spiluk_handle->get_nnzL()); /* No need to retain old value */
175786a27549SJunchao Zhang   Kokkos::realloc(factors->aU_d, spiluk_handle->get_nnzU());
175886a27549SJunchao Zhang 
175986a27549SJunchao Zhang   /* TODO: add options to select sptrsv algorithms */
176086a27549SJunchao Zhang   /* Create sptrsv handles for L, U and their transpose */
176186a27549SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
176286a27549SJunchao Zhang   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SPTRSV_CUSPARSE;
176386a27549SJunchao Zhang #else
176486a27549SJunchao Zhang   auto sptrsv_alg = KokkosSparse::Experimental::SPTRSVAlgorithm::SEQLVLSCHD_TP1;
176586a27549SJunchao Zhang #endif
176686a27549SJunchao Zhang 
176786a27549SJunchao Zhang   factors->khL.create_sptrsv_handle(sptrsv_alg, n, true /* L is lower tri */);
176886a27549SJunchao Zhang   factors->khU.create_sptrsv_handle(sptrsv_alg, n, false /* U is not lower tri */);
176986a27549SJunchao Zhang   factors->khLt.create_sptrsv_handle(sptrsv_alg, n, false /* L^T is not lower tri */);
177086a27549SJunchao Zhang   factors->khUt.create_sptrsv_handle(sptrsv_alg, n, true /* U^T is lower tri */);
177186a27549SJunchao Zhang 
177286a27549SJunchao Zhang   /* Fill fields of the factor matrix B */
17739566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(B, MAT_SKIP_ALLOCATION, NULL));
177486a27549SJunchao Zhang   b     = (Mat_SeqAIJ *)B->data;
177586a27549SJunchao Zhang   b->nz = b->maxnz          = spiluk_handle->get_nnzL() + spiluk_handle->get_nnzU();
177686a27549SJunchao Zhang   B->info.fill_ratio_given  = info->fill;
177786a27549SJunchao Zhang   B->info.fill_ratio_needed = ((PetscReal)b->nz) / ((PetscReal)nnzA);
177886a27549SJunchao Zhang 
177986a27549SJunchao Zhang   B->offloadmask          = PETSC_OFFLOAD_GPU;
178086a27549SJunchao Zhang   B->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJKokkos;
1781*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1782930e68a5SMark Adams }
1783930e68a5SMark Adams 
1784d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatLUFactorSymbolic_SeqAIJKOKKOSDEVICE(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1785d71ae5a4SJacob Faibussowitsch {
17868f7e8f9dSMark Adams   Mat_SeqAIJ    *b     = (Mat_SeqAIJ *)B->data;
17878f7e8f9dSMark Adams   const PetscInt nrows = A->rmap->n;
1788930e68a5SMark Adams 
17898f7e8f9dSMark Adams   PetscFunctionBegin;
17909566063dSJacob Faibussowitsch   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
17918f7e8f9dSMark Adams   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKOKKOSDEVICE;
17928f7e8f9dSMark Adams   // move B data into Kokkos
17939566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B)); // create aijkok
17949566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A)); // create aijkok
17958f7e8f9dSMark Adams   {
17968f7e8f9dSMark Adams     Mat_SeqAIJKokkos *baijkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
1797300d22a6SJunchao Zhang     if (!baijkok->diag_d.extent(0)) {
17988f7e8f9dSMark Adams       const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_diag(b->diag, nrows + 1);
1799300d22a6SJunchao Zhang       baijkok->diag_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_diag));
1800300d22a6SJunchao Zhang       Kokkos::deep_copy(baijkok->diag_d, h_diag);
18018f7e8f9dSMark Adams     }
18028f7e8f9dSMark Adams   }
1803*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
18048f7e8f9dSMark Adams }
18058f7e8f9dSMark Adams 
1806d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatFactorGetSolverType_SeqAIJKokkos(Mat A, MatSolverType *type)
1807d71ae5a4SJacob Faibussowitsch {
1808930e68a5SMark Adams   PetscFunctionBegin;
1809930e68a5SMark Adams   *type = MATSOLVERKOKKOS;
1810*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1811930e68a5SMark Adams }
1812930e68a5SMark Adams 
1813d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatFactorGetSolverType_seqaij_kokkos_device(Mat A, MatSolverType *type)
1814d71ae5a4SJacob Faibussowitsch {
18158f7e8f9dSMark Adams   PetscFunctionBegin;
18168f7e8f9dSMark Adams   *type = MATSOLVERKOKKOSDEVICE;
1817*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
18188f7e8f9dSMark Adams }
18198f7e8f9dSMark Adams 
1820930e68a5SMark Adams /*MC
182186a27549SJunchao Zhang   MATSOLVERKOKKOS = "Kokkos" - A matrix solver type providing triangular solvers for sequential matrices
182211a5261eSBarry Smith   on a single GPU of type, `MATSEQAIJKOKKOS`, `MATAIJKOKKOS`.
1823930e68a5SMark Adams 
1824930e68a5SMark Adams   Level: beginner
1825930e68a5SMark Adams 
1826db781477SPatrick Sanan .seealso: `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJKokkos()`, `MATAIJKOKKOS`, `MatKokkosSetFormat()`, `MatKokkosStorageFormat`, `MatKokkosFormatOperation`
1827930e68a5SMark Adams M*/
182886a27549SJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_SeqAIJKokkos_Kokkos(Mat A, MatFactorType ftype, Mat *B) /* MatGetFactor_<MatType>_<MatSolverType> */
1829930e68a5SMark Adams {
1830930e68a5SMark Adams   PetscInt n = A->rmap->n;
1831930e68a5SMark Adams 
1832930e68a5SMark Adams   PetscFunctionBegin;
18339566063dSJacob Faibussowitsch   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B));
18349566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*B, n, n, n, n));
1835930e68a5SMark Adams   (*B)->factortype = ftype;
18369566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
18379566063dSJacob Faibussowitsch   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
1838930e68a5SMark Adams 
18398f7e8f9dSMark Adams   if (ftype == MAT_FACTOR_LU) {
18409566063dSJacob Faibussowitsch     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
184186a27549SJunchao Zhang     (*B)->canuseordering        = PETSC_TRUE;
184286a27549SJunchao Zhang     (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJKokkos;
184386a27549SJunchao Zhang   } else if (ftype == MAT_FACTOR_ILU) {
18449566063dSJacob Faibussowitsch     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
184586a27549SJunchao Zhang     (*B)->canuseordering         = PETSC_FALSE;
184686a27549SJunchao Zhang     (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJKokkos;
184798921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "MatFactorType %s is not supported by MatType SeqAIJKokkos", MatFactorTypes[ftype]);
1848930e68a5SMark Adams 
18499566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
18509566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)(*B), "MatFactorGetSolverType_C", MatFactorGetSolverType_SeqAIJKokkos));
1851*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1852930e68a5SMark Adams }
18538f7e8f9dSMark Adams 
1854d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatGetFactor_seqaijkokkos_kokkos_device(Mat A, MatFactorType ftype, Mat *B)
1855d71ae5a4SJacob Faibussowitsch {
18568f7e8f9dSMark Adams   PetscInt n = A->rmap->n;
18578f7e8f9dSMark Adams 
18588f7e8f9dSMark Adams   PetscFunctionBegin;
18599566063dSJacob Faibussowitsch   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B));
18609566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*B, n, n, n, n));
18618f7e8f9dSMark Adams   (*B)->factortype     = ftype;
1862f73b0415SBarry Smith   (*B)->canuseordering = PETSC_TRUE;
18639566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
18649566063dSJacob Faibussowitsch   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
18658f7e8f9dSMark Adams 
18668f7e8f9dSMark Adams   if (ftype == MAT_FACTOR_LU) {
18679566063dSJacob Faibussowitsch     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
18688f7e8f9dSMark Adams     (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJKOKKOSDEVICE;
18698f7e8f9dSMark Adams   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Factor type not supported for KOKKOS Matrix Types");
18708f7e8f9dSMark Adams 
18719566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
18729566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)(*B), "MatFactorGetSolverType_C", MatFactorGetSolverType_seqaij_kokkos_device));
1873*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
18748f7e8f9dSMark Adams }
187586a27549SJunchao Zhang 
1876d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatSolverTypeRegister_KOKKOS(void)
1877d71ae5a4SJacob Faibussowitsch {
187886a27549SJunchao Zhang   PetscFunctionBegin;
18799566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_SeqAIJKokkos_Kokkos));
18809566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ILU, MatGetFactor_SeqAIJKokkos_Kokkos));
18819566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOSDEVICE, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_seqaijkokkos_kokkos_device));
1882*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
188386a27549SJunchao Zhang }
188486a27549SJunchao Zhang 
1885076ba34aSJunchao Zhang /* Utility to print out a KokkosCsrMatrix for debugging */
1886d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PrintCsrMatrix(const KokkosCsrMatrix &csrmat)
1887d71ae5a4SJacob Faibussowitsch {
1888076ba34aSJunchao Zhang   const auto        &iv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.row_map);
1889076ba34aSJunchao Zhang   const auto        &jv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.graph.entries);
1890076ba34aSJunchao Zhang   const auto        &av = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), csrmat.values);
1891076ba34aSJunchao Zhang   const PetscInt    *i  = iv.data();
1892076ba34aSJunchao Zhang   const PetscInt    *j  = jv.data();
1893076ba34aSJunchao Zhang   const PetscScalar *a  = av.data();
1894076ba34aSJunchao Zhang   PetscInt           m = csrmat.numRows(), n = csrmat.numCols(), nnz = csrmat.nnz();
1895076ba34aSJunchao Zhang 
1896076ba34aSJunchao Zhang   PetscFunctionBegin;
18979566063dSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT " x %" PetscInt_FMT " SeqAIJKokkos, with %" PetscInt_FMT " nonzeros\n", m, n, nnz));
1898076ba34aSJunchao Zhang   for (PetscInt k = 0; k < m; k++) {
18999566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT ": ", k));
190048a46eb9SPierre Jolivet     for (PetscInt p = i[k]; p < i[k + 1]; p++) PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT "(%.1f), ", j[p], (double)PetscRealPart(a[p])));
19019566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "\n"));
1902076ba34aSJunchao Zhang   }
1903*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1904076ba34aSJunchao Zhang }
1905