xref: /petsc/src/mat/impls/aij/seq/kokkos/aijkok.kokkos.cxx (revision 7233ce556edaaa29ce4b396318a4ca76d3b0291b)
1e36ced11SJunchao Zhang #include <petsc_kokkos.hpp>
211d22bbfSJunchao Zhang #include <petscvec_kokkos.hpp>
3c0c276a7Ssdargavi #include <petscmat_kokkos.hpp>
4076ba34aSJunchao Zhang #include <petscpkg_version.h>
5152b3e56SJunchao Zhang #include <petsc/private/petscimpl.h>
642550becSJunchao Zhang #include <petsc/private/sfimpl.h>
7aac854edSJunchao Zhang #include <petsc/private/kokkosimpl.hpp>
8*7233ce55SJed Brown #include <petscsys.h>
98c3ff71bSJunchao Zhang 
108c3ff71bSJunchao Zhang #include <Kokkos_Core.hpp>
11f0cf5187SStefano Zampini #include <KokkosBlas.hpp>
128c3ff71bSJunchao Zhang #include <KokkosSparse_CrsMatrix.hpp>
13cc6e31f1SJunchao Zhang 
14cc6e31f1SJunchao Zhang // To suppress compiler warnings:
15cc6e31f1SJunchao Zhang // /path/include/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp:434:63:
16cc6e31f1SJunchao Zhang // warning: 'cusparseStatus_t cusparseDbsrmm(cusparseHandle_t, cusparseDirection_t, cusparseOperation_t,
17cc6e31f1SJunchao Zhang // cusparseOperation_t, int, int, int, int, const double*, cusparseMatDescr_t, const double*, const int*, const int*,
18cc6e31f1SJunchao Zhang // int, const double*, int, const double*, double*, int)' is deprecated: please use cusparseSpMM instead [-Wdeprecated-declarations]
19cc6e31f1SJunchao Zhang PETSC_PRAGMA_DIAGNOSTIC_IGNORED_BEGIN("-Wdeprecated-declarations")
208c3ff71bSJunchao Zhang #include <KokkosSparse_spmv.hpp>
21cc6e31f1SJunchao Zhang PETSC_PRAGMA_DIAGNOSTIC_IGNORED_END()
22cc6e31f1SJunchao Zhang 
2386a27549SJunchao Zhang #include <KokkosSparse_spiluk.hpp>
2486a27549SJunchao Zhang #include <KokkosSparse_sptrsv.hpp>
25076ba34aSJunchao Zhang #include <KokkosSparse_spgemm.hpp>
26076ba34aSJunchao Zhang #include <KokkosSparse_spadd.hpp>
279d13fa56SJunchao Zhang #include <KokkosBatched_LU_Decl.hpp>
289d13fa56SJunchao Zhang #include <KokkosBatched_InverseLU_Decl.hpp>
2986a27549SJunchao Zhang 
3042550becSJunchao Zhang #include <../src/mat/impls/aij/seq/kokkos/aijkok.hpp>
318c3ff71bSJunchao Zhang 
320e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 7, 0)
33f98996d3SJunchao Zhang   #include <KokkosSparse_Utils.hpp>
34f98996d3SJunchao Zhang using KokkosSparse::sort_crs_matrix;
359371c9d4SSatish Balay using KokkosSparse::Impl::transpose_matrix;
36f98996d3SJunchao Zhang #else
37f98996d3SJunchao Zhang   #include <KokkosKernels_Sorting.hpp>
38f98996d3SJunchao Zhang using KokkosKernels::sort_crs_matrix;
399371c9d4SSatish Balay using KokkosKernels::Impl::transpose_matrix;
40f98996d3SJunchao Zhang #endif
41f98996d3SJunchao Zhang 
42aac854edSJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(4, 6, 0)
43aac854edSJunchao Zhang using KokkosSparse::spiluk_symbolic;
44aac854edSJunchao Zhang using KokkosSparse::spiluk_numeric;
45aac854edSJunchao Zhang using KokkosSparse::sptrsv_symbolic;
46aac854edSJunchao Zhang using KokkosSparse::sptrsv_solve;
47aac854edSJunchao Zhang using KokkosSparse::Experimental::SPTRSVAlgorithm;
48aac854edSJunchao Zhang using KokkosSparse::Experimental::SPILUKAlgorithm;
49aac854edSJunchao Zhang #else
50aac854edSJunchao Zhang using KokkosSparse::Experimental::spiluk_symbolic;
51aac854edSJunchao Zhang using KokkosSparse::Experimental::spiluk_numeric;
52aac854edSJunchao Zhang using KokkosSparse::Experimental::sptrsv_symbolic;
53aac854edSJunchao Zhang using KokkosSparse::Experimental::sptrsv_solve;
54aac854edSJunchao Zhang using KokkosSparse::Experimental::SPTRSVAlgorithm;
55aac854edSJunchao Zhang using KokkosSparse::Experimental::SPILUKAlgorithm;
56aac854edSJunchao Zhang #endif
57aac854edSJunchao Zhang 
588c3ff71bSJunchao Zhang static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat); /* Forward declaration */
598c3ff71bSJunchao Zhang 
60076ba34aSJunchao Zhang /* MatAssemblyEnd_SeqAIJKokkos() happens when we finalized nonzeros of the matrix, either after
61076ba34aSJunchao Zhang    we assembled the matrix on host, or after we directly produced the matrix data on device (ex., through MatMatMult).
62076ba34aSJunchao Zhang    In the latter case, it is important to set a_dual's sync state correctly.
63076ba34aSJunchao Zhang  */
64d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAssemblyEnd_SeqAIJKokkos(Mat A, MatAssemblyType mode)
65d71ae5a4SJacob Faibussowitsch {
66076ba34aSJunchao Zhang   Mat_SeqAIJ       *aijseq;
67076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
688c3ff71bSJunchao Zhang 
698c3ff71bSJunchao Zhang   PetscFunctionBegin;
703ba16761SJacob Faibussowitsch   if (mode == MAT_FLUSH_ASSEMBLY) PetscFunctionReturn(PETSC_SUCCESS);
719566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
72076ba34aSJunchao Zhang 
73076ba34aSJunchao Zhang   aijseq = static_cast<Mat_SeqAIJ *>(A->data);
74076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
75076ba34aSJunchao Zhang 
76076ba34aSJunchao Zhang   /* If aijkok does not exist, we just copy i, j to device.
77076ba34aSJunchao Zhang      If aijkok already exists, but the device's nonzero pattern does not match with the host's, we assume the latest data is on host.
78076ba34aSJunchao Zhang      In both cases, we build a new aijkok structure.
79076ba34aSJunchao Zhang   */
80076ba34aSJunchao Zhang   if (!aijkok || aijkok->nonzerostate != A->nonzerostate) { /* aijkok might not exist yet or nonzero pattern has changed */
81076ba34aSJunchao Zhang     delete aijkok;
82f4747e26SJunchao Zhang     aijkok   = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aijseq, A->nonzerostate, PETSC_FALSE /*don't copy mat values to device*/);
83076ba34aSJunchao Zhang     A->spptr = aijkok;
84f4747e26SJunchao Zhang   } else if (A->rmap->n && aijkok->diag_dual.extent(0) == 0) { // MatProduct might directly produce AIJ on device, but not the diag.
85f4747e26SJunchao Zhang     MatRowMapKokkosViewHost diag_h(aijseq->diag, A->rmap->n);
86f4747e26SJunchao Zhang     auto                    diag_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), diag_h);
87f4747e26SJunchao Zhang     aijkok->diag_dual              = MatRowMapKokkosDualView(diag_d, diag_h);
88076ba34aSJunchao Zhang   }
893ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
908c3ff71bSJunchao Zhang }
918c3ff71bSJunchao Zhang 
9286a27549SJunchao Zhang /* Sync CSR data to device if not yet */
93d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosSyncDevice(Mat A)
94d71ae5a4SJacob Faibussowitsch {
958c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
968c3ff71bSJunchao Zhang 
978c3ff71bSJunchao Zhang   PetscFunctionBegin;
98aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from host to device");
995f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
100076ba34aSJunchao Zhang   if (aijkok->a_dual.need_sync_device()) {
101076ba34aSJunchao Zhang     aijkok->a_dual.sync_device();
102580c7c76SPierre Jolivet     aijkok->transpose_updated = PETSC_FALSE; /* values of the transpose is out-of-date */
10386a27549SJunchao Zhang     aijkok->hermitian_updated = PETSC_FALSE;
1048c3ff71bSJunchao Zhang   }
1053ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1068c3ff71bSJunchao Zhang }
1078c3ff71bSJunchao Zhang 
108076ba34aSJunchao Zhang /* Mark the CSR data on device as modified */
109d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSeqAIJKokkosModifyDevice(Mat A)
110d71ae5a4SJacob Faibussowitsch {
11186a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
11286a27549SJunchao Zhang 
11386a27549SJunchao Zhang   PetscFunctionBegin;
1145f80ce2aSJacob Faibussowitsch   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Not supported for factorized matries");
11586a27549SJunchao Zhang   aijkok->a_dual.clear_sync_state();
11686a27549SJunchao Zhang   aijkok->a_dual.modify_device();
11786a27549SJunchao Zhang   aijkok->transpose_updated = PETSC_FALSE;
11886a27549SJunchao Zhang   aijkok->hermitian_updated = PETSC_FALSE;
1199566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJInvalidateDiagonal(A));
1209566063dSJacob Faibussowitsch   PetscCall(PetscObjectStateIncrease((PetscObject)A));
1213ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
12286a27549SJunchao Zhang }
12386a27549SJunchao Zhang 
124d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosSyncHost(Mat A)
125d71ae5a4SJacob Faibussowitsch {
126f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1274df4a32cSJunchao Zhang   auto              exec   = PetscGetKokkosExecutionSpace();
128f0cf5187SStefano Zampini 
129f0cf5187SStefano Zampini   PetscFunctionBegin;
130f0cf5187SStefano Zampini   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
13186a27549SJunchao Zhang   /* We do not expect one needs factors on host  */
132aaa8cc7dSPierre Jolivet   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Can't sync factorized matrix from device to host");
1335f80ce2aSJacob Faibussowitsch   PetscCheck(aijkok, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "Missing AIJKOK");
134aac854edSJunchao Zhang   PetscCall(KokkosDualViewSync<HostMirrorMemorySpace>(aijkok->a_dual, exec));
1353ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
136f0cf5187SStefano Zampini }
137f0cf5187SStefano Zampini 
138d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
139d71ae5a4SJacob Faibussowitsch {
140076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
141f0cf5187SStefano Zampini 
142f0cf5187SStefano Zampini   PetscFunctionBegin;
1435519a089SJose E. Roman   /* aijkok contains valid pointers only if the host's nonzerostate matches with the device's.
1445519a089SJose E. Roman     Calling MatSeqAIJSetPreallocation() or MatSetValues() on host, where aijseq->{i,j,a} might be
1455519a089SJose E. Roman     reallocated, will lead to stale {i,j,a}_dual in aijkok. In both operations, the hosts's nonzerostate
1465519a089SJose E. Roman     must have been updated. The stale aijkok will be rebuilt during MatAssemblyEnd.
1475519a089SJose E. Roman   */
1485519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
1494df4a32cSJunchao Zhang     auto exec = PetscGetKokkosExecutionSpace();
150e36ced11SJunchao Zhang     PetscCallCXX(aijkok->a_dual.sync_host(exec));
151e36ced11SJunchao Zhang     PetscCallCXX(exec.fence());
152076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
153076ba34aSJunchao Zhang   } else { /* Happens when calling MatSetValues on a newly created matrix */
154076ba34aSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
155076ba34aSJunchao Zhang   }
1563ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
157076ba34aSJunchao Zhang }
158076ba34aSJunchao Zhang 
159d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJKokkos(Mat A, PetscScalar *array[])
160d71ae5a4SJacob Faibussowitsch {
161076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
162076ba34aSJunchao Zhang 
163076ba34aSJunchao Zhang   PetscFunctionBegin;
1645519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) aijkok->a_dual.modify_host();
1653ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
166076ba34aSJunchao Zhang }
167076ba34aSJunchao Zhang 
168d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
169d71ae5a4SJacob Faibussowitsch {
170076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
171076ba34aSJunchao Zhang 
172076ba34aSJunchao Zhang   PetscFunctionBegin;
1735519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
1744df4a32cSJunchao Zhang     auto exec = PetscGetKokkosExecutionSpace();
175e36ced11SJunchao Zhang     PetscCallCXX(aijkok->a_dual.sync_host(exec));
176e36ced11SJunchao Zhang     PetscCallCXX(exec.fence());
177076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1782328674fSJunchao Zhang   } else {
1792328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
1802328674fSJunchao Zhang   }
1813ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
182076ba34aSJunchao Zhang }
183076ba34aSJunchao Zhang 
184d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJKokkos(Mat A, const PetscScalar *array[])
185d71ae5a4SJacob Faibussowitsch {
186076ba34aSJunchao Zhang   PetscFunctionBegin;
187076ba34aSJunchao Zhang   *array = NULL;
1883ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
189076ba34aSJunchao Zhang }
190076ba34aSJunchao Zhang 
191d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
192d71ae5a4SJacob Faibussowitsch {
193076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
194076ba34aSJunchao Zhang 
195076ba34aSJunchao Zhang   PetscFunctionBegin;
1965519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
197076ba34aSJunchao Zhang     *array = aijkok->a_dual.view_host().data();
1982328674fSJunchao Zhang   } else { /* Ex. happens with MatZeroEntries on a preallocated but not assembled matrix */
1992328674fSJunchao Zhang     *array = static_cast<Mat_SeqAIJ *>(A->data)->a;
2002328674fSJunchao Zhang   }
2013ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
202076ba34aSJunchao Zhang }
203076ba34aSJunchao Zhang 
204d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJKokkos(Mat A, PetscScalar *array[])
205d71ae5a4SJacob Faibussowitsch {
206076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
207076ba34aSJunchao Zhang 
208076ba34aSJunchao Zhang   PetscFunctionBegin;
2095519a089SJose E. Roman   if (aijkok && A->nonzerostate == aijkok->nonzerostate) {
210076ba34aSJunchao Zhang     aijkok->a_dual.clear_sync_state();
211076ba34aSJunchao Zhang     aijkok->a_dual.modify_host();
2122328674fSJunchao Zhang   }
2133ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
214f0cf5187SStefano Zampini }
215f0cf5187SStefano Zampini 
216d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJKokkos(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
217d71ae5a4SJacob Faibussowitsch {
2187ee59b9bSJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
2197ee59b9bSJunchao Zhang 
2207ee59b9bSJunchao Zhang   PetscFunctionBegin;
2217ee59b9bSJunchao Zhang   PetscCheck(aijkok != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "aijkok is NULL");
2227ee59b9bSJunchao Zhang 
2237ee59b9bSJunchao Zhang   if (i) *i = aijkok->i_device_data();
2247ee59b9bSJunchao Zhang   if (j) *j = aijkok->j_device_data();
2257ee59b9bSJunchao Zhang   if (a) {
2267ee59b9bSJunchao Zhang     aijkok->a_dual.sync_device();
2277ee59b9bSJunchao Zhang     *a = aijkok->a_device_data();
2287ee59b9bSJunchao Zhang   }
2297ee59b9bSJunchao Zhang   if (mtype) *mtype = PETSC_MEMTYPE_KOKKOS;
2303ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2317ee59b9bSJunchao Zhang }
2327ee59b9bSJunchao Zhang 
2330e3ece09SJunchao Zhang /*
2340e3ece09SJunchao Zhang   Generate the sparsity pattern of a MatSeqAIJKokkos matrix's transpose on device.
2350e3ece09SJunchao Zhang 
2360e3ece09SJunchao Zhang   Input Parameter:
2370e3ece09SJunchao Zhang .  A       - the MATSEQAIJKOKKOS matrix
2380e3ece09SJunchao Zhang 
2390e3ece09SJunchao Zhang   Output Parameters:
2400e3ece09SJunchao Zhang +  perm_d - the permutation array on device, which connects Ta(i) = Aa(perm(i))
241aaa8cc7dSPierre Jolivet -  T_d    - the transpose on device, whose value array is allocated but not initialized
2420e3ece09SJunchao Zhang */
2430e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateTransposeStructure(Mat A, MatRowMapKokkosView &perm_d, KokkosCsrMatrix &T_d)
244d71ae5a4SJacob Faibussowitsch {
2450e3ece09SJunchao Zhang   Mat_SeqAIJ             *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2460e3ece09SJunchao Zhang   PetscInt                nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
2470e3ece09SJunchao Zhang   const PetscInt         *Ai = aseq->i, *Aj = aseq->j;
2487b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost Ti_h(NoInit("Ti"), n + 1);
2490e3ece09SJunchao Zhang   MatRowMapType          *Ti = Ti_h.data();
2507b8d4ba6SJunchao Zhang   MatColIdxKokkosViewHost Tj_h(NoInit("Tj"), nz);
2517b8d4ba6SJunchao Zhang   MatRowMapKokkosViewHost perm_h(NoInit("permutation"), nz);
2520e3ece09SJunchao Zhang   PetscInt               *Tj   = Tj_h.data();
2530e3ece09SJunchao Zhang   PetscInt               *perm = perm_h.data();
2540e3ece09SJunchao Zhang   PetscInt               *offset;
255152b3e56SJunchao Zhang 
256152b3e56SJunchao Zhang   PetscFunctionBegin;
2570e3ece09SJunchao Zhang   // Populate Ti
2580e3ece09SJunchao Zhang   PetscCallCXX(Kokkos::deep_copy(Ti_h, 0));
2590e3ece09SJunchao Zhang   Ti++;
2600e3ece09SJunchao Zhang   for (PetscInt i = 0; i < nz; i++) Ti[Aj[i]]++;
2610e3ece09SJunchao Zhang   Ti--;
2620e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) Ti[i + 1] += Ti[i];
2630e3ece09SJunchao Zhang 
2640e3ece09SJunchao Zhang   // Populate Tj and the permutation array
2650e3ece09SJunchao Zhang   PetscCall(PetscCalloc1(n, &offset)); // offset in each T row to fill in its column indices
2660e3ece09SJunchao Zhang   for (PetscInt i = 0; i < m; i++) {
2670e3ece09SJunchao Zhang     for (PetscInt j = Ai[i]; j < Ai[i + 1]; j++) { // A's (i,j) is T's (j,i)
2680e3ece09SJunchao Zhang       PetscInt r    = Aj[j];                       // row r of T
2690e3ece09SJunchao Zhang       PetscInt disp = Ti[r] + offset[r];
2700e3ece09SJunchao Zhang 
2710e3ece09SJunchao Zhang       Tj[disp]   = i; // col i of T
2720e3ece09SJunchao Zhang       perm[disp] = j;
2730e3ece09SJunchao Zhang       offset[r]++;
274076ba34aSJunchao Zhang     }
2750e3ece09SJunchao Zhang   }
2760e3ece09SJunchao Zhang   PetscCall(PetscFree(offset));
2770e3ece09SJunchao Zhang 
2780e3ece09SJunchao Zhang   // Sort each row of T, along with the permutation array
2790e3ece09SJunchao Zhang   for (PetscInt i = 0; i < n; i++) PetscCall(PetscSortIntWithArray(Ti[i + 1] - Ti[i], Tj + Ti[i], perm + Ti[i]));
2800e3ece09SJunchao Zhang 
2810e3ece09SJunchao Zhang   // Output perm and T on device
2820e3ece09SJunchao Zhang   auto Ti_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ti_h);
2830e3ece09SJunchao Zhang   auto Tj_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Tj_h);
2840e3ece09SJunchao Zhang   PetscCallCXX(T_d = KokkosCsrMatrix("csrmatT", n, m, nz, MatScalarKokkosView("Ta", nz), Ti_d, Tj_d));
2850e3ece09SJunchao Zhang   PetscCallCXX(perm_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), perm_h));
2863ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
287152b3e56SJunchao Zhang }
288152b3e56SJunchao Zhang 
2890e3ece09SJunchao Zhang // Generate the transpose on device and cache it internally
2900e3ece09SJunchao Zhang // Note: KK transpose_matrix() does not have support symbolic/numeric transpose, so we do it on our own
2910e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGenerateTranspose_Private(Mat A, KokkosCsrMatrix *csrmatT)
292d71ae5a4SJacob Faibussowitsch {
2930e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
2940e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
2950e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
2960e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatT;
297152b3e56SJunchao Zhang 
298152b3e56SJunchao Zhang   PetscFunctionBegin;
2990e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
300145b44c9SPierre Jolivet   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's values since we are going to access them on device
3010e3ece09SJunchao Zhang 
3020e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
3030e3ece09SJunchao Zhang 
3040e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE) {
3050e3ece09SJunchao Zhang     *csrmatT = akok->csrmat;
3060e3ece09SJunchao Zhang   } else {
3070e3ece09SJunchao Zhang     // See if we already have a cached transpose and its value is up to date
3080e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
3090e3ece09SJunchao Zhang       if (!akok->transpose_updated) {            // if the value is out of date, update the cached version
3100e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
3110e3ece09SJunchao Zhang         auto       &Ta   = T.values;
3120e3ece09SJunchao Zhang 
313d326c3f1SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = Aa(perm(i)); }));
314076ba34aSJunchao Zhang       }
3150e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
3160e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
3170e3ece09SJunchao Zhang 
3180e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
3190e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
320d326c3f1SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = Aa(perm(i)); }));
3210e3ece09SJunchao Zhang     }
3220e3ece09SJunchao Zhang     akok->transpose_updated = PETSC_TRUE;
3230e3ece09SJunchao Zhang     *csrmatT                = akok->csrmatT;
3240e3ece09SJunchao Zhang   }
3250e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3260e3ece09SJunchao Zhang }
3270e3ece09SJunchao Zhang 
3280e3ece09SJunchao Zhang // Generate the Hermitian on device and cache it internally
3290e3ece09SJunchao Zhang static PetscErrorCode MatSeqAIJKokkosGenerateHermitian_Private(Mat A, KokkosCsrMatrix *csrmatH)
3300e3ece09SJunchao Zhang {
3310e3ece09SJunchao Zhang   Mat_SeqAIJ       *aseq = static_cast<Mat_SeqAIJ *>(A->data);
3320e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
3330e3ece09SJunchao Zhang   PetscInt          nz = aseq->nz, m = A->rmap->N, n = A->cmap->n;
3340e3ece09SJunchao Zhang   KokkosCsrMatrix  &T = akok->csrmatH;
3350e3ece09SJunchao Zhang 
3360e3ece09SJunchao Zhang   PetscFunctionBegin;
3370e3ece09SJunchao Zhang   PetscCheck(akok, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Unexpected NULL (Mat_SeqAIJKokkos*)A->spptr");
3380e3ece09SJunchao Zhang   PetscCallCXX(akok->a_dual.sync_device()); // Sync A's values since we are going to access them on device
3390e3ece09SJunchao Zhang 
3400e3ece09SJunchao Zhang   const auto &Aa = akok->a_dual.view_device();
3410e3ece09SJunchao Zhang 
3420e3ece09SJunchao Zhang   if (A->hermitian == PETSC_BOOL3_TRUE) {
3430e3ece09SJunchao Zhang     *csrmatH = akok->csrmat;
3440e3ece09SJunchao Zhang   } else {
3450e3ece09SJunchao Zhang     // See if we already have a cached hermitian and its value is up to date
3460e3ece09SJunchao Zhang     if (T.numRows() == n && T.numCols() == m) {  // this indicates csrmatT had been generated before, otherwise T has 0 rows/cols after construction
3470e3ece09SJunchao Zhang       if (!akok->hermitian_updated) {            // if the value is out of date, update the cached version
3480e3ece09SJunchao Zhang         const auto &perm = akok->transpose_perm; // get the permutation array
3490e3ece09SJunchao Zhang         auto       &Ta   = T.values;
3500e3ece09SJunchao Zhang 
351d326c3f1SJunchao Zhang         PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { Ta(i) = PetscConj(Aa(perm(i))); }));
3520e3ece09SJunchao Zhang       }
3530e3ece09SJunchao Zhang     } else { // Generate T of size n x m for the first time
3540e3ece09SJunchao Zhang       MatRowMapKokkosView perm;
3550e3ece09SJunchao Zhang 
3560e3ece09SJunchao Zhang       PetscCall(MatSeqAIJKokkosGenerateTransposeStructure(A, perm, T));
3570e3ece09SJunchao Zhang       akok->transpose_perm = perm; // cache the perm in this matrix for reuse
358d326c3f1SJunchao Zhang       PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt i) { T.values(i) = PetscConj(Aa(perm(i))); }));
3590e3ece09SJunchao Zhang     }
3600e3ece09SJunchao Zhang     akok->hermitian_updated = PETSC_TRUE;
3610e3ece09SJunchao Zhang     *csrmatH                = akok->csrmatH;
3620e3ece09SJunchao Zhang   }
3633ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
364152b3e56SJunchao Zhang }
365a587d139SMark 
3668c3ff71bSJunchao Zhang /* y = A x */
367d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMult_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
368d71ae5a4SJacob Faibussowitsch {
3698c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
370152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
371152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3728c3ff71bSJunchao Zhang 
3738c3ff71bSJunchao Zhang   PetscFunctionBegin;
3749566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3759566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
3769566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
3779566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
3788c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
379d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), "N", 1.0 /*alpha*/, aijkok->csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A x + beta y */
3809566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
3819566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
382076ba34aSJunchao Zhang   /* 2.0*nnz - numRows seems more accurate here but assumes there are no zero-rows. So a little sloppy here. */
3839566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
3849566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
3853ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
3868c3ff71bSJunchao Zhang }
3878c3ff71bSJunchao Zhang 
3888c3ff71bSJunchao Zhang /* y = A^T x */
389d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
390d71ae5a4SJacob Faibussowitsch {
3918c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
392152b3e56SJunchao Zhang   const char                *mode;
393152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
394152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
3950e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
3968c3ff71bSJunchao Zhang 
3978c3ff71bSJunchao Zhang   PetscFunctionBegin;
3989566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
3999566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4009566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4019566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
402152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4039566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
404152b3e56SJunchao Zhang     mode = "N";
405152b3e56SJunchao Zhang   } else {
406076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4070e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
408152b3e56SJunchao Zhang     mode   = "T";
409152b3e56SJunchao Zhang   }
410d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^T x + beta y */
4119566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4129566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
4130e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4149566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4153ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4168c3ff71bSJunchao Zhang }
4178c3ff71bSJunchao Zhang 
4188c3ff71bSJunchao Zhang /* y = A^H x */
419d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTranspose_SeqAIJKokkos(Mat A, Vec xx, Vec yy)
420d71ae5a4SJacob Faibussowitsch {
4218c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
422152b3e56SJunchao Zhang   const char                *mode;
423152b3e56SJunchao Zhang   ConstPetscScalarKokkosView xv;
424152b3e56SJunchao Zhang   PetscScalarKokkosView      yv;
4250e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4268c3ff71bSJunchao Zhang 
4278c3ff71bSJunchao Zhang   PetscFunctionBegin;
4289566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4299566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
4309566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
4319566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosViewWrite(yy, &yv));
432152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4339566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
434152b3e56SJunchao Zhang     mode = "N";
435152b3e56SJunchao Zhang   } else {
436076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4370e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
438152b3e56SJunchao Zhang     mode   = "C";
439152b3e56SJunchao Zhang   }
440d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 0.0 /*beta*/, yv)); /* y = alpha A^H x + beta y */
4419566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
4429566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosViewWrite(yy, &yv));
4430e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4449566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4453ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4468c3ff71bSJunchao Zhang }
4478c3ff71bSJunchao Zhang 
4488c3ff71bSJunchao Zhang /* z = A x + y */
449d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
450d71ae5a4SJacob Faibussowitsch {
4518c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
45292896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
453152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4548c3ff71bSJunchao Zhang 
4558c3ff71bSJunchao Zhang   PetscFunctionBegin;
4569566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4579566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
45892896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz)); // depending on yy's sync flags, zz might get its latest data on host
4599566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
46092896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv)); // do after VecCopy(yy, zz) to get the latest data on device
4618c3ff71bSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
462d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), "N", 1.0 /*alpha*/, aijkok->csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A x + beta z */
4639566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
46492896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
4659566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(2.0 * aijkok->csrmat.nnz()));
4669566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4673ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4688c3ff71bSJunchao Zhang }
4698c3ff71bSJunchao Zhang 
4708c3ff71bSJunchao Zhang /* z = A^T x + y */
471d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
472d71ae5a4SJacob Faibussowitsch {
4738c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
474152b3e56SJunchao Zhang   const char                *mode;
47592896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
476152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
4770e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
4788c3ff71bSJunchao Zhang 
4798c3ff71bSJunchao Zhang   PetscFunctionBegin;
4809566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
4819566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
48292896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz));
4839566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
48492896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv));
485152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
4869566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmat));
487152b3e56SJunchao Zhang     mode = "N";
488152b3e56SJunchao Zhang   } else {
489076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
4900e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
491152b3e56SJunchao Zhang     mode   = "T";
492152b3e56SJunchao Zhang   }
493d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^T x + beta z */
4949566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
49592896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
4960e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
4979566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
4983ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
4998c3ff71bSJunchao Zhang }
5008c3ff71bSJunchao Zhang 
5018c3ff71bSJunchao Zhang /* z = A^H x + y */
502d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJKokkos(Mat A, Vec xx, Vec yy, Vec zz)
503d71ae5a4SJacob Faibussowitsch {
5048c3ff71bSJunchao Zhang   Mat_SeqAIJKokkos          *aijkok;
505152b3e56SJunchao Zhang   const char                *mode;
50692896123SJunchao Zhang   ConstPetscScalarKokkosView xv;
507152b3e56SJunchao Zhang   PetscScalarKokkosView      zv;
5080e3ece09SJunchao Zhang   KokkosCsrMatrix            csrmat;
5098c3ff71bSJunchao Zhang 
5108c3ff71bSJunchao Zhang   PetscFunctionBegin;
5119566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
5129566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
51392896123SJunchao Zhang   if (zz != yy) PetscCall(VecCopy(yy, zz));
5149566063dSJacob Faibussowitsch   PetscCall(VecGetKokkosView(xx, &xv));
51592896123SJunchao Zhang   PetscCall(VecGetKokkosView(zz, &zv));
516152b3e56SJunchao Zhang   if (A->form_explicit_transpose) {
5179566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateHermitian_Private(A, &csrmat));
518152b3e56SJunchao Zhang     mode = "N";
519152b3e56SJunchao Zhang   } else {
520076ba34aSJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
5210e3ece09SJunchao Zhang     csrmat = aijkok->csrmat;
522152b3e56SJunchao Zhang     mode   = "C";
523152b3e56SJunchao Zhang   }
524d326c3f1SJunchao Zhang   PetscCallCXX(KokkosSparse::spmv(PetscGetKokkosExecutionSpace(), mode, 1.0 /*alpha*/, csrmat, xv, 1.0 /*beta*/, zv)); /* z = alpha A^H x + beta z */
5259566063dSJacob Faibussowitsch   PetscCall(VecRestoreKokkosView(xx, &xv));
52692896123SJunchao Zhang   PetscCall(VecRestoreKokkosView(zz, &zv));
5270e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * csrmat.nnz()));
5289566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
5293ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
530152b3e56SJunchao Zhang }
531152b3e56SJunchao Zhang 
53266976f2fSJacob Faibussowitsch static PetscErrorCode MatSetOption_SeqAIJKokkos(Mat A, MatOption op, PetscBool flg)
533d71ae5a4SJacob Faibussowitsch {
534152b3e56SJunchao Zhang   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
535152b3e56SJunchao Zhang 
536152b3e56SJunchao Zhang   PetscFunctionBegin;
537152b3e56SJunchao Zhang   switch (op) {
538152b3e56SJunchao Zhang   case MAT_FORM_EXPLICIT_TRANSPOSE:
539152b3e56SJunchao Zhang     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
5409566063dSJacob Faibussowitsch     if (A->form_explicit_transpose && !flg && aijkok) PetscCall(aijkok->DestroyMatTranspose());
541152b3e56SJunchao Zhang     A->form_explicit_transpose = flg;
542152b3e56SJunchao Zhang     break;
543d71ae5a4SJacob Faibussowitsch   default:
544d71ae5a4SJacob Faibussowitsch     PetscCall(MatSetOption_SeqAIJ(A, op, flg));
545d71ae5a4SJacob Faibussowitsch     break;
546152b3e56SJunchao Zhang   }
5473ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5488c3ff71bSJunchao Zhang }
5498c3ff71bSJunchao Zhang 
550076ba34aSJunchao Zhang /* Depending on reuse, either build a new mat, or use the existing mat */
551d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
552d71ae5a4SJacob Faibussowitsch {
553076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
5548c3ff71bSJunchao Zhang 
5558c3ff71bSJunchao Zhang   PetscFunctionBegin;
5569566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
557076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {                      /* Build a brand new mat */
5589566063dSJacob Faibussowitsch     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));  /* the returned newmat is a SeqAIJKokkos */
5598c3ff71bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {                 /* Reuse the mat created before */
5609566063dSJacob Faibussowitsch     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); /* newmat is already a SeqAIJKokkos */
561076ba34aSJunchao Zhang   } else if (reuse == MAT_INPLACE_MATRIX) {               /* newmat is A */
5625f80ce2aSJacob Faibussowitsch     PetscCheck(A == *newmat, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "A != *newmat with MAT_INPLACE_MATRIX");
5639566063dSJacob Faibussowitsch     PetscCall(PetscFree(A->defaultvectype));
5649566063dSJacob Faibussowitsch     PetscCall(PetscStrallocpy(VECKOKKOS, &A->defaultvectype)); /* Allocate and copy the string */
5659566063dSJacob Faibussowitsch     PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATSEQAIJKOKKOS));
5669566063dSJacob Faibussowitsch     PetscCall(MatSetOps_SeqAIJKokkos(A));
567076ba34aSJunchao Zhang     aseq = static_cast<Mat_SeqAIJ *>(A->data);
568394ed5ebSJunchao Zhang     if (A->assembled) { /* Copy i, j (but not values) to device for an assembled matrix if not yet */
5695f80ce2aSJacob Faibussowitsch       PetscCheck(!A->spptr, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Expect NULL (Mat_SeqAIJKokkos*)A->spptr");
570f4747e26SJunchao Zhang       A->spptr = new Mat_SeqAIJKokkos(A->rmap->n, A->cmap->n, aseq, A->nonzerostate, PETSC_FALSE);
5718c3ff71bSJunchao Zhang     }
572076ba34aSJunchao Zhang   }
5733ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
5748c3ff71bSJunchao Zhang }
5758c3ff71bSJunchao Zhang 
576076ba34aSJunchao Zhang /* MatDuplicate always creates a new matrix. MatDuplicate can be called either on an assembled matrix or
577076ba34aSJunchao Zhang    an unassembled matrix, even though MAT_COPY_VALUES is not allowed for unassembled matrix.
578076ba34aSJunchao Zhang  */
579d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDuplicate_SeqAIJKokkos(Mat A, MatDuplicateOption dupOption, Mat *B)
580d71ae5a4SJacob Faibussowitsch {
581076ba34aSJunchao Zhang   Mat_SeqAIJ       *bseq;
582076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr), *bkok;
583076ba34aSJunchao Zhang   Mat               mat;
5848c3ff71bSJunchao Zhang 
5858c3ff71bSJunchao Zhang   PetscFunctionBegin;
586076ba34aSJunchao Zhang   /* Do not copy values on host as A's latest values might be on device. We don't want to do sync blindly */
5879566063dSJacob Faibussowitsch   PetscCall(MatDuplicate_SeqAIJ(A, MAT_DO_NOT_COPY_VALUES, B));
588076ba34aSJunchao Zhang   mat = *B;
589f4747e26SJunchao Zhang   if (A->assembled) {
590076ba34aSJunchao Zhang     bseq = static_cast<Mat_SeqAIJ *>(mat->data);
591f4747e26SJunchao Zhang     bkok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, bseq, mat->nonzerostate, PETSC_FALSE);
592076ba34aSJunchao Zhang     bkok->a_dual.clear_sync_state(); /* Clear B's sync state as it will be decided below */
593076ba34aSJunchao Zhang     /* Now copy values to B if needed */
594076ba34aSJunchao Zhang     if (dupOption == MAT_COPY_VALUES) {
595076ba34aSJunchao Zhang       if (akok->a_dual.need_sync_device()) {
596076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_host(), akok->a_dual.view_host());
597076ba34aSJunchao Zhang         bkok->a_dual.modify_host();
598076ba34aSJunchao Zhang       } else { /* If device has the latest data, we only copy data on device */
599076ba34aSJunchao Zhang         Kokkos::deep_copy(bkok->a_dual.view_device(), akok->a_dual.view_device());
600076ba34aSJunchao Zhang         bkok->a_dual.modify_device();
601076ba34aSJunchao Zhang       }
602076ba34aSJunchao Zhang     } else { /* MAT_DO_NOT_COPY_VALUES or MAT_SHARE_NONZERO_PATTERN. B's values should be zeroed */
603076ba34aSJunchao Zhang       /* B's values on host should be already zeroed by MatDuplicate_SeqAIJ() */
604076ba34aSJunchao Zhang       bkok->a_dual.modify_host();
605076ba34aSJunchao Zhang     }
606076ba34aSJunchao Zhang     mat->spptr = bkok;
607076ba34aSJunchao Zhang   }
608076ba34aSJunchao Zhang 
6099566063dSJacob Faibussowitsch   PetscCall(PetscFree(mat->defaultvectype));
6109566063dSJacob Faibussowitsch   PetscCall(PetscStrallocpy(VECKOKKOS, &mat->defaultvectype)); /* Allocate and copy the string */
6119566063dSJacob Faibussowitsch   PetscCall(PetscObjectChangeTypeName((PetscObject)mat, MATSEQAIJKOKKOS));
6129566063dSJacob Faibussowitsch   PetscCall(MatSetOps_SeqAIJKokkos(mat));
6133ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6148c3ff71bSJunchao Zhang }
6158c3ff71bSJunchao Zhang 
616d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatTranspose_SeqAIJKokkos(Mat A, MatReuse reuse, Mat *B)
617d71ae5a4SJacob Faibussowitsch {
6180ecb592aSJunchao Zhang   Mat               At;
6190e3ece09SJunchao Zhang   KokkosCsrMatrix   internT;
6200ecb592aSJunchao Zhang   Mat_SeqAIJKokkos *atkok, *bkok;
6210ecb592aSJunchao Zhang 
6220ecb592aSJunchao Zhang   PetscFunctionBegin;
6237fb60732SBarry Smith   if (reuse == MAT_REUSE_MATRIX) PetscCall(MatTransposeCheckNonzeroState_Private(A, *B));
6249566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &internT)); /* Generate a transpose internally */
6250ecb592aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX || reuse == MAT_INPLACE_MATRIX) {
626ff751488SJunchao Zhang     /* Deep copy internT, as we want to isolate the internal transpose */
6270e3ece09SJunchao Zhang     PetscCallCXX(atkok = new Mat_SeqAIJKokkos(KokkosCsrMatrix("csrmat", internT)));
6289566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PetscObjectComm((PetscObject)A), atkok, &At));
6290ecb592aSJunchao Zhang     if (reuse == MAT_INITIAL_MATRIX) *B = At;
6309566063dSJacob Faibussowitsch     else PetscCall(MatHeaderReplace(A, &At)); /* Replace A with At inplace */
6310ecb592aSJunchao Zhang   } else {                                    /* MAT_REUSE_MATRIX, just need to copy values to B on device */
6320ecb592aSJunchao Zhang     if ((*B)->assembled) {
6330ecb592aSJunchao Zhang       bkok = static_cast<Mat_SeqAIJKokkos *>((*B)->spptr);
6340e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(bkok->a_dual.view_device(), internT.values));
6359566063dSJacob Faibussowitsch       PetscCall(MatSeqAIJKokkosModifyDevice(*B));
6360ecb592aSJunchao Zhang     } else if ((*B)->preallocated) { /* It is ok for B to be only preallocated, as needed in MatTranspose_MPIAIJ */
6370ecb592aSJunchao Zhang       Mat_SeqAIJ             *bseq = static_cast<Mat_SeqAIJ *>((*B)->data);
6380e3ece09SJunchao Zhang       MatScalarKokkosViewHost a_h(bseq->a, internT.nnz()); /* bseq->nz = 0 if unassembled */
6390e3ece09SJunchao Zhang       MatColIdxKokkosViewHost j_h(bseq->j, internT.nnz());
6400e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(a_h, internT.values));
6410e3ece09SJunchao Zhang       PetscCallCXX(Kokkos::deep_copy(j_h, internT.graph.entries));
6420ecb592aSJunchao Zhang     } else SETERRQ(PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "B must be assembled or preallocated");
6430ecb592aSJunchao Zhang   }
6443ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6450ecb592aSJunchao Zhang }
6460ecb592aSJunchao Zhang 
647d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatDestroy_SeqAIJKokkos(Mat A)
648d71ae5a4SJacob Faibussowitsch {
64986a27549SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
6508c3ff71bSJunchao Zhang 
6518c3ff71bSJunchao Zhang   PetscFunctionBegin;
65286a27549SJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
65386a27549SJunchao Zhang     aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
6548c3ff71bSJunchao Zhang     delete aijkok;
65586a27549SJunchao Zhang   } else {
65686a27549SJunchao Zhang     delete static_cast<Mat_SeqAIJKokkosTriFactors *>(A->spptr);
65786a27549SJunchao Zhang   }
658cbc6b225SStefano Zampini   A->spptr = NULL;
6599566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
6609566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
6619566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
66257761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
66357761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijkokkos_hypre_C", NULL));
66457761e9aSJunchao Zhang #endif
6659566063dSJacob Faibussowitsch   PetscCall(MatDestroy_SeqAIJ(A));
6663ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
6678c3ff71bSJunchao Zhang }
6688c3ff71bSJunchao Zhang 
6693f3ba80aSJunchao Zhang /*MC
6703f3ba80aSJunchao Zhang    MATSEQAIJKOKKOS - MATAIJKOKKOS = "(seq)aijkokkos" - A matrix type to be used for sparse matrices with Kokkos
6713f3ba80aSJunchao Zhang 
67215229ffcSPierre Jolivet    A matrix type using Kokkos-Kernels CrsMatrix type for portability across different device types
6733f3ba80aSJunchao Zhang 
6742ef1f0ffSBarry Smith    Options Database Key:
67511a5261eSBarry Smith .  -mat_type aijkokkos - sets the matrix type to `MATSEQAIJKOKKOS` during a call to `MatSetFromOptions()`
6763f3ba80aSJunchao Zhang 
6773f3ba80aSJunchao Zhang   Level: beginner
6783f3ba80aSJunchao Zhang 
6791cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `MatCreateSeqAIJKokkos()`, `MATMPIAIJKOKKOS`
6803f3ba80aSJunchao Zhang M*/
681d71ae5a4SJacob Faibussowitsch PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJKokkos(Mat A)
682d71ae5a4SJacob Faibussowitsch {
68386a27549SJunchao Zhang   PetscFunctionBegin;
6849566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
6859566063dSJacob Faibussowitsch   PetscCall(MatCreate_SeqAIJ(A));
6869566063dSJacob Faibussowitsch   PetscCall(MatConvert_SeqAIJ_SeqAIJKokkos(A, MATSEQAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
6873ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
68886a27549SJunchao Zhang }
68986a27549SJunchao Zhang 
690076ba34aSJunchao Zhang /* Merge A, B into a matrix C. A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n) */
691d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJKokkosMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
692d71ae5a4SJacob Faibussowitsch {
693076ba34aSJunchao Zhang   Mat_SeqAIJ         *a, *b;
694076ba34aSJunchao Zhang   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
695076ba34aSJunchao Zhang   MatScalarKokkosView aa, ba, ca;
696076ba34aSJunchao Zhang   MatRowMapKokkosView ai, bi, ci;
697076ba34aSJunchao Zhang   MatColIdxKokkosView aj, bj, cj;
698076ba34aSJunchao Zhang   PetscInt            m, n, nnz, aN;
699a3f881fbSStefano Zampini 
700a3f881fbSStefano Zampini   PetscFunctionBegin;
701076ba34aSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
702076ba34aSJunchao Zhang   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
7034f572ea9SToby Isaac   PetscAssertPointer(C, 4);
704076ba34aSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
705076ba34aSJunchao Zhang   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
7065f80ce2aSJacob Faibussowitsch   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
7075f80ce2aSJacob Faibussowitsch   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
708076ba34aSJunchao Zhang 
7099566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
7109566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
711076ba34aSJunchao Zhang   a    = static_cast<Mat_SeqAIJ *>(A->data);
712076ba34aSJunchao Zhang   b    = static_cast<Mat_SeqAIJ *>(B->data);
713076ba34aSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
714076ba34aSJunchao Zhang   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
715076ba34aSJunchao Zhang   aa   = akok->a_dual.view_device();
716076ba34aSJunchao Zhang   ai   = akok->i_dual.view_device();
717076ba34aSJunchao Zhang   ba   = bkok->a_dual.view_device();
718076ba34aSJunchao Zhang   bi   = bkok->i_dual.view_device();
719076ba34aSJunchao Zhang   m    = A->rmap->n; /* M, N and nnz of C */
720076ba34aSJunchao Zhang   n    = A->cmap->n + B->cmap->n;
721076ba34aSJunchao Zhang   nnz  = a->nz + b->nz;
722076ba34aSJunchao Zhang   aN   = A->cmap->n; /* N of A */
723076ba34aSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
724076ba34aSJunchao Zhang     aj           = akok->j_dual.view_device();
725076ba34aSJunchao Zhang     bj           = bkok->j_dual.view_device();
726076ba34aSJunchao Zhang     auto ca_dual = MatScalarKokkosDualView("a", aa.extent(0) + ba.extent(0));
727076ba34aSJunchao Zhang     auto ci_dual = MatRowMapKokkosDualView("i", ai.extent(0));
728076ba34aSJunchao Zhang     auto cj_dual = MatColIdxKokkosDualView("j", aj.extent(0) + bj.extent(0));
729076ba34aSJunchao Zhang     ca           = ca_dual.view_device();
730076ba34aSJunchao Zhang     ci           = ci_dual.view_device();
731076ba34aSJunchao Zhang     cj           = cj_dual.view_device();
732076ba34aSJunchao Zhang 
733076ba34aSJunchao Zhang     /* Concatenate A and B in parallel using Kokkos hierarchical parallelism */
7349371c9d4SSatish Balay     Kokkos::parallel_for(
735d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
736076ba34aSJunchao Zhang         PetscInt i       = t.league_rank(); /* row i */
737076ba34aSJunchao Zhang         PetscInt coffset = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
738076ba34aSJunchao Zhang 
739076ba34aSJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() { /* this side effect only happens once per whole team */
740076ba34aSJunchao Zhang                                                    ci(i) = coffset;
741076ba34aSJunchao Zhang                                                    if (i == m - 1) ci(m) = ai(m) + bi(m);
742076ba34aSJunchao Zhang         });
743076ba34aSJunchao Zhang 
744076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
745076ba34aSJunchao Zhang           if (k < alen) {
746076ba34aSJunchao Zhang             ca(coffset + k) = aa(ai(i) + k);
747076ba34aSJunchao Zhang             cj(coffset + k) = aj(ai(i) + k);
748076ba34aSJunchao Zhang           } else {
749076ba34aSJunchao Zhang             ca(coffset + k) = ba(bi(i) + k - alen);
750076ba34aSJunchao Zhang             cj(coffset + k) = bj(bi(i) + k - alen) + aN; /* Entries in B get new column indices in C */
751076ba34aSJunchao Zhang           }
752076ba34aSJunchao Zhang         });
753076ba34aSJunchao Zhang       });
754076ba34aSJunchao Zhang     ca_dual.modify_device();
755076ba34aSJunchao Zhang     ci_dual.modify_device();
756076ba34aSJunchao Zhang     cj_dual.modify_device();
7579566063dSJacob Faibussowitsch     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, n, nnz, ci_dual, cj_dual, ca_dual));
7589566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
759076ba34aSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
760076ba34aSJunchao Zhang     PetscValidHeaderSpecific(*C, MAT_CLASSID, 4);
761076ba34aSJunchao Zhang     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
762076ba34aSJunchao Zhang     ckok = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
763076ba34aSJunchao Zhang     ca   = ckok->a_dual.view_device();
764076ba34aSJunchao Zhang     ci   = ckok->i_dual.view_device();
765076ba34aSJunchao Zhang 
7669371c9d4SSatish Balay     Kokkos::parallel_for(
767d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
768076ba34aSJunchao Zhang         PetscInt i    = t.league_rank(); /* row i */
769076ba34aSJunchao Zhang         PetscInt alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
770076ba34aSJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
771076ba34aSJunchao Zhang           if (k < alen) ca(ci(i) + k) = aa(ai(i) + k);
772076ba34aSJunchao Zhang           else ca(ci(i) + k) = ba(bi(i) + k - alen);
773076ba34aSJunchao Zhang         });
774076ba34aSJunchao Zhang       });
7759566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
776076ba34aSJunchao Zhang   }
7773ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
778076ba34aSJunchao Zhang }
779076ba34aSJunchao Zhang 
780d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductDataDestroy_SeqAIJKokkos(void *pdata)
781d71ae5a4SJacob Faibussowitsch {
782076ba34aSJunchao Zhang   PetscFunctionBegin;
783076ba34aSJunchao Zhang   delete static_cast<MatProductData_SeqAIJKokkos *>(pdata);
7843ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
785a3f881fbSStefano Zampini }
786a3f881fbSStefano Zampini 
787d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos(Mat C)
788d71ae5a4SJacob Faibussowitsch {
789a3f881fbSStefano Zampini   Mat_Product                 *product = C->product;
790a3f881fbSStefano Zampini   Mat                          A, B;
791076ba34aSJunchao Zhang   bool                         transA, transB; /* use bool, since KK needs this type */
792a3f881fbSStefano Zampini   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
793a3f881fbSStefano Zampini   Mat_SeqAIJ                  *c;
794076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
7950e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB;
796a3f881fbSStefano Zampini 
797a3f881fbSStefano Zampini   PetscFunctionBegin;
798a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
7995f80ce2aSJacob Faibussowitsch   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data empty");
800076ba34aSJunchao Zhang   pdata = static_cast<MatProductData_SeqAIJKokkos *>(C->product->data);
801076ba34aSJunchao Zhang 
8020e3ece09SJunchao Zhang   // See if numeric has already been done in symbolic (e.g., user calls MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C)).
8030e3ece09SJunchao Zhang   // If yes, skip the numeric, but reset the flag so that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C),
8040e3ece09SJunchao Zhang   // we still do numeric.
8050e3ece09SJunchao Zhang   if (pdata->reusesym) { // numeric reuses results from symbolic
8060e3ece09SJunchao Zhang     pdata->reusesym = PETSC_FALSE;
8073ba16761SJacob Faibussowitsch     PetscFunctionReturn(PETSC_SUCCESS);
808076ba34aSJunchao Zhang   }
809076ba34aSJunchao Zhang 
810076ba34aSJunchao Zhang   switch (product->type) {
8119371c9d4SSatish Balay   case MATPRODUCT_AB:
8129371c9d4SSatish Balay     transA = false;
8139371c9d4SSatish Balay     transB = false;
8149371c9d4SSatish Balay     break;
8159371c9d4SSatish Balay   case MATPRODUCT_AtB:
8169371c9d4SSatish Balay     transA = true;
8179371c9d4SSatish Balay     transB = false;
8189371c9d4SSatish Balay     break;
8199371c9d4SSatish Balay   case MATPRODUCT_ABt:
8209371c9d4SSatish Balay     transA = false;
8219371c9d4SSatish Balay     transB = true;
8229371c9d4SSatish Balay     break;
823d71ae5a4SJacob Faibussowitsch   default:
824d71ae5a4SJacob Faibussowitsch     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
825076ba34aSJunchao Zhang   }
826076ba34aSJunchao Zhang 
827a3f881fbSStefano Zampini   A = product->A;
828a3f881fbSStefano Zampini   B = product->B;
8299566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8309566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
831a3f881fbSStefano Zampini   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
832a3f881fbSStefano Zampini   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
833a3f881fbSStefano Zampini   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
834076ba34aSJunchao Zhang 
8355f80ce2aSJacob Faibussowitsch   PetscCheck(ckok, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Device data structure spptr is empty");
836076ba34aSJunchao Zhang 
8370e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
8380e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
839076ba34aSJunchao Zhang 
840076ba34aSJunchao Zhang   /* TODO: Once KK spgemm implements transpose, we can get rid of the explicit transpose here */
841076ba34aSJunchao Zhang   if (transA) {
8429566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
843076ba34aSJunchao Zhang     transA = false;
844a3f881fbSStefano Zampini   }
845a3f881fbSStefano Zampini 
846076ba34aSJunchao Zhang   if (transB) {
8479566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
848076ba34aSJunchao Zhang     transB = false;
849076ba34aSJunchao Zhang   }
8509566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
8510e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, ckok->csrmat));
8520e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
853866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
854866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(ckok->csrmat)); /* without sort, mat_tests-ex62_14_seqaijkokkos fails */
855e944a159SJunchao Zhang #endif
856866eb059SJunchao Zhang 
8579566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
8589566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(C));
859a3f881fbSStefano Zampini   /* shorter version of MatAssemblyEnd_SeqAIJ */
860a3f881fbSStefano Zampini   c = (Mat_SeqAIJ *)C->data;
8619566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded,%" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
8629566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
8639566063dSJacob Faibussowitsch   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
864a3f881fbSStefano Zampini   c->reallocs         = 0;
865076ba34aSJunchao Zhang   C->info.mallocs     = 0;
866a3f881fbSStefano Zampini   C->info.nz_unneeded = 0;
867a3f881fbSStefano Zampini   C->assembled = C->was_assembled = PETSC_TRUE;
868a3f881fbSStefano Zampini   C->num_ass++;
8693ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
870a3f881fbSStefano Zampini }
871a3f881fbSStefano Zampini 
872d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos(Mat C)
873d71ae5a4SJacob Faibussowitsch {
874076ba34aSJunchao Zhang   Mat_Product                 *product = C->product;
875076ba34aSJunchao Zhang   MatProductType               ptype;
876076ba34aSJunchao Zhang   Mat                          A, B;
877076ba34aSJunchao Zhang   bool                         transA, transB;
878076ba34aSJunchao Zhang   Mat_SeqAIJKokkos            *akok, *bkok, *ckok;
879076ba34aSJunchao Zhang   MatProductData_SeqAIJKokkos *pdata;
880076ba34aSJunchao Zhang   MPI_Comm                     comm;
8810e3ece09SJunchao Zhang   KokkosCsrMatrix              csrmatA, csrmatB, csrmatC;
882a3f881fbSStefano Zampini 
883a3f881fbSStefano Zampini   PetscFunctionBegin;
884a3f881fbSStefano Zampini   MatCheckProduct(C, 1);
8859566063dSJacob Faibussowitsch   PetscCall(PetscObjectGetComm((PetscObject)C, &comm));
8865f80ce2aSJacob Faibussowitsch   PetscCheck(!product->data, comm, PETSC_ERR_PLIB, "Product data not empty");
887a3f881fbSStefano Zampini   A = product->A;
888a3f881fbSStefano Zampini   B = product->B;
8899566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
8909566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(B));
891a3f881fbSStefano Zampini   akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
892a3f881fbSStefano Zampini   bkok    = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
8930e3ece09SJunchao Zhang   csrmatA = akok->csrmat;
8940e3ece09SJunchao Zhang   csrmatB = bkok->csrmat;
895076ba34aSJunchao Zhang 
896a3f881fbSStefano Zampini   ptype = product->type;
8970e3ece09SJunchao Zhang   // Take advantage of the symmetry if true
8980e3ece09SJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
8990e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
9000e3ece09SJunchao Zhang     product->symbolic_used_the_fact_A_is_symmetric = PETSC_TRUE;
9010e3ece09SJunchao Zhang   }
9020e3ece09SJunchao Zhang   if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
9030e3ece09SJunchao Zhang     ptype                                          = MATPRODUCT_AB;
9040e3ece09SJunchao Zhang     product->symbolic_used_the_fact_B_is_symmetric = PETSC_TRUE;
9050e3ece09SJunchao Zhang   }
9060e3ece09SJunchao Zhang 
907a3f881fbSStefano Zampini   switch (ptype) {
9089371c9d4SSatish Balay   case MATPRODUCT_AB:
9099371c9d4SSatish Balay     transA = false;
9109371c9d4SSatish Balay     transB = false;
9110e6a1e94SMark Adams     PetscCall(MatSetBlockSizesFromMats(C, A, B));
9129371c9d4SSatish Balay     break;
9139371c9d4SSatish Balay   case MATPRODUCT_AtB:
9149371c9d4SSatish Balay     transA = true;
9159371c9d4SSatish Balay     transB = false;
9160e6a1e94SMark Adams     if (A->cmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->rmap, A->cmap->bs));
9170e6a1e94SMark Adams     if (B->cmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->cmap, B->cmap->bs));
9189371c9d4SSatish Balay     break;
9199371c9d4SSatish Balay   case MATPRODUCT_ABt:
9209371c9d4SSatish Balay     transA = false;
9219371c9d4SSatish Balay     transB = true;
9220e6a1e94SMark Adams     if (A->rmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->rmap, A->rmap->bs));
9230e6a1e94SMark Adams     if (B->rmap->bs > 0) PetscCall(PetscLayoutSetBlockSize(C->cmap, B->rmap->bs));
9249371c9d4SSatish Balay     break;
925d71ae5a4SJacob Faibussowitsch   default:
926d71ae5a4SJacob Faibussowitsch     SETERRQ(comm, PETSC_ERR_PLIB, "Unsupported product type %s", MatProductTypes[product->type]);
927a3f881fbSStefano Zampini   }
9280e3ece09SJunchao Zhang   PetscCallCXX(product->data = pdata = new MatProductData_SeqAIJKokkos());
929076ba34aSJunchao Zhang   pdata->reusesym = product->api_user;
930a3f881fbSStefano Zampini 
931076ba34aSJunchao Zhang   /* TODO: add command line options to select spgemm algorithms */
932866eb059SJunchao Zhang   auto spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_DEFAULT; /* default alg is TPL if enabled, otherwise KK */
933866eb059SJunchao Zhang 
934866eb059SJunchao Zhang   /* CUDA-10.2's spgemm has bugs. We prefer the SpGEMMreuse APIs introduced in cuda-11.4 */
935866eb059SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
936866eb059SJunchao Zhang   #if PETSC_PKG_CUDA_VERSION_LT(11, 4, 0)
937866eb059SJunchao Zhang   spgemm_alg = KokkosSparse::SPGEMMAlgorithm::SPGEMM_KK;
938866eb059SJunchao Zhang   #endif
939866eb059SJunchao Zhang #endif
9400e3ece09SJunchao Zhang   PetscCallCXX(pdata->kh.create_spgemm_handle(spgemm_alg));
941076ba34aSJunchao Zhang 
9429566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
943076ba34aSJunchao Zhang   /* TODO: Get rid of the explicit transpose once KK-spgemm implements the transpose option */
944076ba34aSJunchao Zhang   if (transA) {
9459566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(A, &csrmatA));
946076ba34aSJunchao Zhang     transA = false;
947076ba34aSJunchao Zhang   }
948076ba34aSJunchao Zhang 
949076ba34aSJunchao Zhang   if (transB) {
9509566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosGenerateTranspose_Private(B, &csrmatB));
951076ba34aSJunchao Zhang     transB = false;
952076ba34aSJunchao Zhang   }
953076ba34aSJunchao Zhang 
9540e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_symbolic(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
955076ba34aSJunchao Zhang   /* spgemm_symbolic() only populates C's rowmap, but not C's column indices.
956076ba34aSJunchao Zhang     So we have to do a fake spgemm_numeric() here to get csrmatC.j_d setup, before
957076ba34aSJunchao Zhang     calling new Mat_SeqAIJKokkos().
958076ba34aSJunchao Zhang     TODO: Remove the fake spgemm_numeric() after KK fixed this problem.
959076ba34aSJunchao Zhang   */
9600e3ece09SJunchao Zhang   PetscCallCXX(KokkosSparse::spgemm_numeric(pdata->kh, csrmatA, transA, csrmatB, transB, csrmatC));
9610e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(4, 0, 0)
962866eb059SJunchao Zhang   /* Query if KK outputs a sorted matrix. If not, we need to sort it */
963866eb059SJunchao Zhang   auto spgemmHandle = pdata->kh.get_spgemm_handle();
964866eb059SJunchao Zhang   if (spgemmHandle->get_sort_option() != 1) PetscCallCXX(sort_crs_matrix(csrmatC)); /* sort_option defaults to -1 in KK!*/
965e944a159SJunchao Zhang #endif
9669566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
967076ba34aSJunchao Zhang 
9689566063dSJacob Faibussowitsch   PetscCallCXX(ckok = new Mat_SeqAIJKokkos(csrmatC));
9699566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(C, ckok));
970076ba34aSJunchao Zhang   C->product->destroy = MatProductDataDestroy_SeqAIJKokkos;
9713ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
972a3f881fbSStefano Zampini }
973a3f881fbSStefano Zampini 
974a3f881fbSStefano Zampini /* handles sparse matrix matrix ops */
975d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatProductSetFromOptions_SeqAIJKokkos(Mat mat)
976d71ae5a4SJacob Faibussowitsch {
977076ba34aSJunchao Zhang   Mat_Product *product = mat->product;
978a3f881fbSStefano Zampini   PetscBool    Biskok = PETSC_FALSE, Ciskok = PETSC_TRUE;
979a3f881fbSStefano Zampini 
980a3f881fbSStefano Zampini   PetscFunctionBegin;
981a3f881fbSStefano Zampini   MatCheckProduct(mat, 1);
9829566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJKOKKOS, &Biskok));
98348a46eb9SPierre Jolivet   if (product->type == MATPRODUCT_ABC) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJKOKKOS, &Ciskok));
984a3f881fbSStefano Zampini   if (Biskok && Ciskok) {
985a3f881fbSStefano Zampini     switch (product->type) {
986a3f881fbSStefano Zampini     case MATPRODUCT_AB:
987a3f881fbSStefano Zampini     case MATPRODUCT_AtB:
988d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABt:
989d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_SeqAIJKokkos_SeqAIJKokkos;
990d71ae5a4SJacob Faibussowitsch       break;
991a3f881fbSStefano Zampini     case MATPRODUCT_PtAP:
992a3f881fbSStefano Zampini     case MATPRODUCT_RARt:
993d71ae5a4SJacob Faibussowitsch     case MATPRODUCT_ABC:
994d71ae5a4SJacob Faibussowitsch       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
995d71ae5a4SJacob Faibussowitsch       break;
996d71ae5a4SJacob Faibussowitsch     default:
997d71ae5a4SJacob Faibussowitsch       break;
998a3f881fbSStefano Zampini     }
999a3f881fbSStefano Zampini   } else { /* fallback for AIJ */
10009566063dSJacob Faibussowitsch     PetscCall(MatProductSetFromOptions_SeqAIJ(mat));
1001a3f881fbSStefano Zampini   }
10023ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1003a3f881fbSStefano Zampini }
1004a587d139SMark 
1005d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatScale_SeqAIJKokkos(Mat A, PetscScalar a)
1006d71ae5a4SJacob Faibussowitsch {
1007f0cf5187SStefano Zampini   Mat_SeqAIJKokkos *aijkok;
1008f0cf5187SStefano Zampini 
1009f0cf5187SStefano Zampini   PetscFunctionBegin;
10109566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
10119566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1012f0cf5187SStefano Zampini   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1013d326c3f1SJunchao Zhang   KokkosBlas::scal(PetscGetKokkosExecutionSpace(), aijkok->a_dual.view_device(), a, aijkok->a_dual.view_device());
10149566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
10159566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuFlops(aijkok->a_dual.extent(0)));
10169566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
10173ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1018f0cf5187SStefano Zampini }
1019f0cf5187SStefano Zampini 
1020f4747e26SJunchao Zhang // add a to A's diagonal (if A is square) or main diagonal (if A is rectangular)
1021f4747e26SJunchao Zhang static PetscErrorCode MatShift_SeqAIJKokkos(Mat A, PetscScalar a)
1022f4747e26SJunchao Zhang {
1023f4747e26SJunchao Zhang   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1024f4747e26SJunchao Zhang 
1025f4747e26SJunchao Zhang   PetscFunctionBegin;
1026f4747e26SJunchao Zhang   if (A->assembled && aijseq->diagonaldense) { // no missing diagonals
1027f4747e26SJunchao Zhang     PetscInt n = PetscMin(A->rmap->n, A->cmap->n);
1028f4747e26SJunchao Zhang 
1029f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
1030f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosSyncDevice(A));
1031f4747e26SJunchao Zhang     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1032f4747e26SJunchao Zhang     const auto &Aa     = aijkok->a_dual.view_device();
1033f4747e26SJunchao Zhang     const auto &Adiag  = aijkok->diag_dual.view_device();
1034d326c3f1SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) { Aa(Adiag(i)) += a; }));
1035f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosModifyDevice(A));
1036f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(n));
1037f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
1038f4747e26SJunchao Zhang   } else { // need reassembly, very slow!
1039f4747e26SJunchao Zhang     PetscCall(MatShift_Basic(A, a));
1040f4747e26SJunchao Zhang   }
1041f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1042f4747e26SJunchao Zhang }
1043f4747e26SJunchao Zhang 
1044f4747e26SJunchao Zhang static PetscErrorCode MatDiagonalSet_SeqAIJKokkos(Mat Y, Vec D, InsertMode is)
1045f4747e26SJunchao Zhang {
1046f4747e26SJunchao Zhang   Mat_SeqAIJ *aijseq = static_cast<Mat_SeqAIJ *>(Y->data);
1047f4747e26SJunchao Zhang 
1048f4747e26SJunchao Zhang   PetscFunctionBegin;
1049f4747e26SJunchao Zhang   if (Y->assembled && aijseq->diagonaldense) { // no missing diagonals
1050f4747e26SJunchao Zhang     ConstPetscScalarKokkosView dv;
1051f4747e26SJunchao Zhang     PetscInt                   n, nv;
1052f4747e26SJunchao Zhang 
1053f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
1054f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosSyncDevice(Y));
1055f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(D, &dv));
1056f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(D, &nv));
1057f4747e26SJunchao Zhang     n = PetscMin(Y->rmap->n, Y->cmap->n);
1058f4747e26SJunchao Zhang     PetscCheck(n == nv, PetscObjectComm((PetscObject)Y), PETSC_ERR_ARG_SIZ, "Matrix size and vector size do not match");
1059f4747e26SJunchao Zhang 
1060f4747e26SJunchao Zhang     const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1061f4747e26SJunchao Zhang     const auto &Aa     = aijkok->a_dual.view_device();
1062f4747e26SJunchao Zhang     const auto &Adiag  = aijkok->diag_dual.view_device();
1063f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(
1064d326c3f1SJunchao Zhang       Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) {
1065f4747e26SJunchao Zhang         if (is == INSERT_VALUES) Aa(Adiag(i)) = dv(i);
1066f4747e26SJunchao Zhang         else Aa(Adiag(i)) += dv(i);
1067f4747e26SJunchao Zhang       }));
1068f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(D, &dv));
1069f4747e26SJunchao Zhang     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1070f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(n));
1071f4747e26SJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
1072f4747e26SJunchao Zhang   } else { // need reassembly, very slow!
1073f4747e26SJunchao Zhang     PetscCall(MatDiagonalSet_Default(Y, D, is));
1074f4747e26SJunchao Zhang   }
1075f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1076f4747e26SJunchao Zhang }
1077f4747e26SJunchao Zhang 
1078f4747e26SJunchao Zhang static PetscErrorCode MatDiagonalScale_SeqAIJKokkos(Mat A, Vec ll, Vec rr)
1079f4747e26SJunchao Zhang {
1080f4747e26SJunchao Zhang   Mat_SeqAIJ                *aijseq = static_cast<Mat_SeqAIJ *>(A->data);
1081f4747e26SJunchao Zhang   PetscInt                   m = A->rmap->n, n = A->cmap->n, nz = aijseq->nz;
1082f4747e26SJunchao Zhang   ConstPetscScalarKokkosView lv, rv;
1083f4747e26SJunchao Zhang 
1084f4747e26SJunchao Zhang   PetscFunctionBegin;
1085f4747e26SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1086f4747e26SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1087f4747e26SJunchao Zhang   const auto  aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1088f4747e26SJunchao Zhang   const auto &Aa     = aijkok->a_dual.view_device();
1089f4747e26SJunchao Zhang   const auto &Ai     = aijkok->i_dual.view_device();
1090f4747e26SJunchao Zhang   const auto &Aj     = aijkok->j_dual.view_device();
1091f4747e26SJunchao Zhang   if (ll) {
1092f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(ll, &m));
1093f4747e26SJunchao Zhang     PetscCheck(m == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Left scaling vector wrong length");
1094f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(ll, &lv));
1095f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for( // for each row
1096d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
1097f4747e26SJunchao Zhang         PetscInt i   = t.league_rank(); // row i
1098f4747e26SJunchao Zhang         PetscInt len = Ai(i + 1) - Ai(i);
1099f4747e26SJunchao Zhang         // scale entries on the row
1100f4747e26SJunchao Zhang         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt j) { Aa(Ai(i) + j) *= lv(i); });
1101f4747e26SJunchao Zhang       }));
1102f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(ll, &lv));
1103f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(nz));
1104f4747e26SJunchao Zhang   }
1105f4747e26SJunchao Zhang   if (rr) {
1106f4747e26SJunchao Zhang     PetscCall(VecGetLocalSize(rr, &n));
1107f4747e26SJunchao Zhang     PetscCheck(n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Right scaling vector wrong length");
1108f4747e26SJunchao Zhang     PetscCall(VecGetKokkosView(rr, &rv));
1109f4747e26SJunchao Zhang     PetscCallCXX(Kokkos::parallel_for( // for each nonzero
1110d326c3f1SJunchao Zhang       Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, nz), KOKKOS_LAMBDA(const PetscInt k) { Aa(k) *= rv(Aj(k)); }));
1111f4747e26SJunchao Zhang     PetscCall(VecRestoreKokkosView(rr, &lv));
1112f4747e26SJunchao Zhang     PetscCall(PetscLogGpuFlops(nz));
1113f4747e26SJunchao Zhang   }
1114f4747e26SJunchao Zhang   PetscCall(MatSeqAIJKokkosModifyDevice(A));
1115f4747e26SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1116f4747e26SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1117f4747e26SJunchao Zhang }
1118f4747e26SJunchao Zhang 
1119d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatZeroEntries_SeqAIJKokkos(Mat A)
1120d71ae5a4SJacob Faibussowitsch {
1121076ba34aSJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1122a587d139SMark 
1123a587d139SMark   PetscFunctionBegin;
1124076ba34aSJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
11252328674fSJunchao Zhang   if (aijkok) { /* Only zero the device if data is already there */
1126d326c3f1SJunchao Zhang     KokkosBlas::fill(PetscGetKokkosExecutionSpace(), aijkok->a_dual.view_device(), 0.0);
11279566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(A));
11282328674fSJunchao Zhang   } else { /* Might be preallocated but not assembled */
11299566063dSJacob Faibussowitsch     PetscCall(MatZeroEntries_SeqAIJ(A));
11302328674fSJunchao Zhang   }
11313ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1132a587d139SMark }
1133a587d139SMark 
1134d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatGetDiagonal_SeqAIJKokkos(Mat A, Vec x)
1135d71ae5a4SJacob Faibussowitsch {
1136f78ce678SMark Adams   Mat_SeqAIJKokkos     *aijkok;
1137f78ce678SMark Adams   PetscInt              n;
1138f78ce678SMark Adams   PetscScalarKokkosView xv;
1139f78ce678SMark Adams 
1140f78ce678SMark Adams   PetscFunctionBegin;
1141f78ce678SMark Adams   PetscCall(VecGetLocalSize(x, &n));
1142f78ce678SMark Adams   PetscCheck(n == A->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Nonconforming matrix and vector");
1143f78ce678SMark Adams   PetscCheck(A->factortype == MAT_FACTOR_NONE, PETSC_COMM_SELF, PETSC_ERR_SUP, "MatGetDiagonal_SeqAIJKokkos not supported on factored matrices");
1144f78ce678SMark Adams 
1145f78ce678SMark Adams   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1146f78ce678SMark Adams   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1147f78ce678SMark Adams 
1148f78ce678SMark Adams   const auto &Aa    = aijkok->a_dual.view_device();
1149f78ce678SMark Adams   const auto &Ai    = aijkok->i_dual.view_device();
1150f78ce678SMark Adams   const auto &Adiag = aijkok->diag_dual.view_device();
1151f78ce678SMark Adams 
1152f78ce678SMark Adams   PetscCall(VecGetKokkosViewWrite(x, &xv));
11539371c9d4SSatish Balay   Kokkos::parallel_for(
1154d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, n), KOKKOS_LAMBDA(const PetscInt i) {
1155f78ce678SMark Adams       if (Adiag(i) < Ai(i + 1)) xv(i) = Aa(Adiag(i));
1156f78ce678SMark Adams       else xv(i) = 0;
1157f78ce678SMark Adams     });
1158f78ce678SMark Adams   PetscCall(VecRestoreKokkosViewWrite(x, &xv));
11593ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1160f78ce678SMark Adams }
1161f78ce678SMark Adams 
1162db78de30SJunchao Zhang /* Get a Kokkos View from a mat of type MatSeqAIJKokkos */
1163d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1164d71ae5a4SJacob Faibussowitsch {
1165db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1166db78de30SJunchao Zhang 
1167db78de30SJunchao Zhang   PetscFunctionBegin;
1168db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11694f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1170db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11719566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1172db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1173076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
11743ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1175db78de30SJunchao Zhang }
1176db78de30SJunchao Zhang 
1177d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, ConstMatScalarKokkosView *kv)
1178d71ae5a4SJacob Faibussowitsch {
1179db78de30SJunchao Zhang   PetscFunctionBegin;
1180db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11814f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1182db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11833ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1184db78de30SJunchao Zhang }
1185db78de30SJunchao Zhang 
1186d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosView(Mat A, MatScalarKokkosView *kv)
1187d71ae5a4SJacob Faibussowitsch {
1188db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1189db78de30SJunchao Zhang 
1190db78de30SJunchao Zhang   PetscFunctionBegin;
1191db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
11924f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1193db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
11949566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
1195db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1196076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
11973ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1198db78de30SJunchao Zhang }
1199db78de30SJunchao Zhang 
1200d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosView(Mat A, MatScalarKokkosView *kv)
1201d71ae5a4SJacob Faibussowitsch {
1202db78de30SJunchao Zhang   PetscFunctionBegin;
1203db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12044f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1205db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12069566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
12073ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1208db78de30SJunchao Zhang }
1209db78de30SJunchao Zhang 
1210d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJGetKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1211d71ae5a4SJacob Faibussowitsch {
1212db78de30SJunchao Zhang   Mat_SeqAIJKokkos *aijkok;
1213db78de30SJunchao Zhang 
1214db78de30SJunchao Zhang   PetscFunctionBegin;
1215db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12164f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1217db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
1218db78de30SJunchao Zhang   aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1219076ba34aSJunchao Zhang   *kv    = aijkok->a_dual.view_device();
12203ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1221db78de30SJunchao Zhang }
1222db78de30SJunchao Zhang 
1223d71ae5a4SJacob Faibussowitsch PetscErrorCode MatSeqAIJRestoreKokkosViewWrite(Mat A, MatScalarKokkosView *kv)
1224d71ae5a4SJacob Faibussowitsch {
1225db78de30SJunchao Zhang   PetscFunctionBegin;
1226db78de30SJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
12274f572ea9SToby Isaac   PetscAssertPointer(kv, 2);
1228db78de30SJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
12299566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosModifyDevice(A));
12303ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1231db78de30SJunchao Zhang }
1232db78de30SJunchao Zhang 
1233c0c276a7Ssdargavi PetscErrorCode MatCreateSeqAIJKokkosWithKokkosViews(MPI_Comm comm, PetscInt m, PetscInt n, Kokkos::View<PetscInt *> &i_d, Kokkos::View<PetscInt *> &j_d, Kokkos::View<PetscScalar *> &a_d, Mat *A)
1234c0c276a7Ssdargavi {
1235c0c276a7Ssdargavi   Mat_SeqAIJKokkos *akok;
1236c0c276a7Ssdargavi 
1237c0c276a7Ssdargavi   PetscFunctionBegin;
1238c0c276a7Ssdargavi   auto exec = PetscGetKokkosExecutionSpace();
1239c0c276a7Ssdargavi   // Create host copies of the input aij
1240c0c276a7Ssdargavi   auto i_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), i_d);
1241c0c276a7Ssdargavi   auto j_h = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), j_d);
1242c0c276a7Ssdargavi   // Don't copy the vals to the host now
1243c0c276a7Ssdargavi   auto a_h = Kokkos::create_mirror_view(HostMirrorMemorySpace(), a_d);
1244c0c276a7Ssdargavi 
1245c0c276a7Ssdargavi   MatScalarKokkosDualView a_dual = MatScalarKokkosDualView(a_d, a_h);
1246c0c276a7Ssdargavi   // Note we have modified device data so it will copy lazily
1247c0c276a7Ssdargavi   a_dual.modify_device();
1248c0c276a7Ssdargavi   MatRowMapKokkosDualView i_dual = MatRowMapKokkosDualView(i_d, i_h);
1249c0c276a7Ssdargavi   MatColIdxKokkosDualView j_dual = MatColIdxKokkosDualView(j_d, j_h);
1250c0c276a7Ssdargavi 
1251c0c276a7Ssdargavi   PetscCallCXX(akok = new Mat_SeqAIJKokkos(m, n, j_dual.extent(0), i_dual, j_dual, a_dual));
1252c0c276a7Ssdargavi   PetscCall(MatCreate(comm, A));
1253c0c276a7Ssdargavi   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
1254c0c276a7Ssdargavi   PetscFunctionReturn(PETSC_SUCCESS);
1255c0c276a7Ssdargavi }
1256c0c276a7Ssdargavi 
1257c17cf699SJunchao Zhang /* Computes Y += alpha X */
1258d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatAXPY_SeqAIJKokkos(Mat Y, PetscScalar alpha, Mat X, MatStructure pattern)
1259d71ae5a4SJacob Faibussowitsch {
1260a587d139SMark   Mat_SeqAIJ              *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
1261c17cf699SJunchao Zhang   Mat_SeqAIJKokkos        *xkok, *ykok, *zkok;
1262c17cf699SJunchao Zhang   ConstMatScalarKokkosView Xa;
1263c17cf699SJunchao Zhang   MatScalarKokkosView      Ya;
12644df4a32cSJunchao Zhang   auto                     exec = PetscGetKokkosExecutionSpace();
1265a587d139SMark 
1266a587d139SMark   PetscFunctionBegin;
1267c17cf699SJunchao Zhang   PetscCheckTypeName(Y, MATSEQAIJKOKKOS);
1268c17cf699SJunchao Zhang   PetscCheckTypeName(X, MATSEQAIJKOKKOS);
12699566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(Y));
12709566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(X));
12719566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1272db78de30SJunchao Zhang 
1273c17cf699SJunchao Zhang   if (pattern != SAME_NONZERO_PATTERN && x->nz == y->nz) {
1274a587d139SMark     PetscBool e;
12759566063dSJacob Faibussowitsch     PetscCall(PetscArraycmp(x->i, y->i, Y->rmap->n + 1, &e));
1276a587d139SMark     if (e) {
12779566063dSJacob Faibussowitsch       PetscCall(PetscArraycmp(x->j, y->j, y->nz, &e));
1278c17cf699SJunchao Zhang       if (e) pattern = SAME_NONZERO_PATTERN;
1279a587d139SMark     }
1280a587d139SMark   }
1281db78de30SJunchao Zhang 
1282c17cf699SJunchao Zhang   /* cusparseDcsrgeam2() computes C = alpha A + beta B. If one knew sparsity pattern of C, one can skip
1283c17cf699SJunchao Zhang     cusparseScsrgeam2_bufferSizeExt() / cusparseXcsrgeam2Nnz(), and directly call cusparseScsrgeam2().
1284c17cf699SJunchao Zhang     If X is SUBSET_NONZERO_PATTERN of Y, we could take advantage of this cusparse feature. However,
1285c17cf699SJunchao Zhang     KokkosSparse::spadd(alpha,A,beta,B,C) has symbolic and numeric phases, MatAXPY does not.
1286c17cf699SJunchao Zhang   */
1287c17cf699SJunchao Zhang   ykok = static_cast<Mat_SeqAIJKokkos *>(Y->spptr);
1288c17cf699SJunchao Zhang   xkok = static_cast<Mat_SeqAIJKokkos *>(X->spptr);
1289c17cf699SJunchao Zhang   Xa   = xkok->a_dual.view_device();
1290c17cf699SJunchao Zhang   Ya   = ykok->a_dual.view_device();
1291c17cf699SJunchao Zhang 
1292c17cf699SJunchao Zhang   if (pattern == SAME_NONZERO_PATTERN) {
1293d326c3f1SJunchao Zhang     KokkosBlas::axpy(exec, alpha, Xa, Ya);
12949566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
1295c17cf699SJunchao Zhang   } else if (pattern == SUBSET_NONZERO_PATTERN) {
1296c17cf699SJunchao Zhang     MatRowMapKokkosView Xi = xkok->i_dual.view_device(), Yi = ykok->i_dual.view_device();
1297c17cf699SJunchao Zhang     MatColIdxKokkosView Xj = xkok->j_dual.view_device(), Yj = ykok->j_dual.view_device();
1298c17cf699SJunchao Zhang 
12999371c9d4SSatish Balay     Kokkos::parallel_for(
1300d326c3f1SJunchao Zhang       Kokkos::TeamPolicy<>(exec, Y->rmap->n, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
13010e3ece09SJunchao Zhang         PetscInt i = t.league_rank(); // row i
13020e3ece09SJunchao Zhang         Kokkos::single(Kokkos::PerTeam(t), [=]() {
13030e3ece09SJunchao Zhang           // Only one thread works in a team
1304c17cf699SJunchao Zhang           PetscInt p, q = Yi(i);
13050e3ece09SJunchao Zhang           for (p = Xi(i); p < Xi(i + 1); p++) {          // For each nonzero on row i of X,
13060e3ece09SJunchao Zhang             while (Xj(p) != Yj(q) && q < Yi(i + 1)) q++; // find the matching nonzero on row i of Y.
13070e3ece09SJunchao Zhang             if (Xj(p) == Yj(q)) {                        // Found it
1308c17cf699SJunchao Zhang               Ya(q) += alpha * Xa(p);
1309c17cf699SJunchao Zhang               q++;
1310a587d139SMark             } else {
13110e3ece09SJunchao Zhang             // If not found, it indicates the input is wrong (X is not a SUBSET_NONZERO_PATTERN of Y).
13120e3ece09SJunchao Zhang             // Just insert a NaN at the beginning of row i if it is not empty, to make the result wrong.
13130e3ece09SJunchao Zhang #if PETSC_PKG_KOKKOS_VERSION_GE(3, 7, 0)
13140e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::ArithTraits<PetscScalar>::nan();
13158b8b16f9SJunchao Zhang #else
13160e3ece09SJunchao Zhang               if (Yi(i) != Yi(i + 1)) Ya(Yi(i)) = Kokkos::Experimental::nan("1");
13178b8b16f9SJunchao Zhang #endif
1318a587d139SMark             }
1319c17cf699SJunchao Zhang           }
1320c17cf699SJunchao Zhang         });
1321c17cf699SJunchao Zhang       });
13229566063dSJacob Faibussowitsch     PetscCall(MatSeqAIJKokkosModifyDevice(Y));
13230e3ece09SJunchao Zhang   } else { // different nonzero patterns
1324c17cf699SJunchao Zhang     Mat             Z;
1325c17cf699SJunchao Zhang     KokkosCsrMatrix zcsr;
1326c17cf699SJunchao Zhang     KernelHandle    kh;
13270e3ece09SJunchao Zhang     kh.create_spadd_handle(true); // X, Y are sorted
1328c17cf699SJunchao Zhang     KokkosSparse::spadd_symbolic(&kh, xkok->csrmat, ykok->csrmat, zcsr);
1329c17cf699SJunchao Zhang     KokkosSparse::spadd_numeric(&kh, alpha, xkok->csrmat, (PetscScalar)1.0, ykok->csrmat, zcsr);
1330c17cf699SJunchao Zhang     zkok = new Mat_SeqAIJKokkos(zcsr);
13319566063dSJacob Faibussowitsch     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, zkok, &Z));
13329566063dSJacob Faibussowitsch     PetscCall(MatHeaderReplace(Y, &Z));
1333c17cf699SJunchao Zhang     kh.destroy_spadd_handle();
1334c17cf699SJunchao Zhang   }
13359566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
13360e3ece09SJunchao Zhang   PetscCall(PetscLogGpuFlops(xkok->a_dual.extent(0) * 2)); // Because we scaled X and then added it to Y
13373ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1338a587d139SMark }
1339a587d139SMark 
13402c4ab24aSJunchao Zhang struct MatCOOStruct_SeqAIJKokkos {
13412c4ab24aSJunchao Zhang   PetscCount           n;
13422c4ab24aSJunchao Zhang   PetscCount           Atot;
13432c4ab24aSJunchao Zhang   PetscInt             nz;
13442c4ab24aSJunchao Zhang   PetscCountKokkosView jmap;
13452c4ab24aSJunchao Zhang   PetscCountKokkosView perm;
13462c4ab24aSJunchao Zhang 
13472c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos(const MatCOOStruct_SeqAIJ *coo_h)
13482c4ab24aSJunchao Zhang   {
13492c4ab24aSJunchao Zhang     nz   = coo_h->nz;
13502c4ab24aSJunchao Zhang     n    = coo_h->n;
13512c4ab24aSJunchao Zhang     Atot = coo_h->Atot;
13522c4ab24aSJunchao Zhang     jmap = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->jmap, nz + 1));
13532c4ab24aSJunchao Zhang     perm = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), PetscCountKokkosViewHost(coo_h->perm, Atot));
13542c4ab24aSJunchao Zhang   }
13552c4ab24aSJunchao Zhang };
13562c4ab24aSJunchao Zhang 
135749abdd8aSBarry Smith static PetscErrorCode MatCOOStructDestroy_SeqAIJKokkos(void **data)
13582c4ab24aSJunchao Zhang {
13592c4ab24aSJunchao Zhang   PetscFunctionBegin;
136049abdd8aSBarry Smith   PetscCallCXX(delete static_cast<MatCOOStruct_SeqAIJKokkos *>(*data));
13612c4ab24aSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
13622c4ab24aSJunchao Zhang }
13632c4ab24aSJunchao Zhang 
1364d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetPreallocationCOO_SeqAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1365d71ae5a4SJacob Faibussowitsch {
136642550becSJunchao Zhang   Mat_SeqAIJKokkos          *akok;
136742550becSJunchao Zhang   Mat_SeqAIJ                *aseq;
136803e76207SPierre Jolivet   PetscContainer             container_h;
13692c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJ       *coo_h;
13702c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo_d;
137142550becSJunchao Zhang 
137242550becSJunchao Zhang   PetscFunctionBegin;
13739566063dSJacob Faibussowitsch   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, coo_i, coo_j));
1374394ed5ebSJunchao Zhang   aseq = static_cast<Mat_SeqAIJ *>(mat->data);
137542550becSJunchao Zhang   akok = static_cast<Mat_SeqAIJKokkos *>(mat->spptr);
1376cbc6b225SStefano Zampini   delete akok;
1377f4747e26SJunchao Zhang   mat->spptr = akok = new Mat_SeqAIJKokkos(mat->rmap->n, mat->cmap->n, aseq, mat->nonzerostate + 1, PETSC_FALSE);
13789566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries_SeqAIJKokkos(mat));
13792c4ab24aSJunchao Zhang 
13802c4ab24aSJunchao Zhang   // Copy the COO struct to device
13812c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
13822c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
13832c4ab24aSJunchao Zhang   PetscCallCXX(coo_d = new MatCOOStruct_SeqAIJKokkos(coo_h));
13842c4ab24aSJunchao Zhang 
13852c4ab24aSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
138603e76207SPierre Jolivet   PetscCall(PetscObjectContainerCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", coo_d, MatCOOStructDestroy_SeqAIJKokkos));
13873ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
138842550becSJunchao Zhang }
138942550becSJunchao Zhang 
1390d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetValuesCOO_SeqAIJKokkos(Mat A, const PetscScalar v[], InsertMode imode)
1391d71ae5a4SJacob Faibussowitsch {
139242550becSJunchao Zhang   MatScalarKokkosView        Aa;
139342550becSJunchao Zhang   ConstMatScalarKokkosView   kv;
139442550becSJunchao Zhang   PetscMemType               memtype;
13952c4ab24aSJunchao Zhang   PetscContainer             container;
13962c4ab24aSJunchao Zhang   MatCOOStruct_SeqAIJKokkos *coo;
139742550becSJunchao Zhang 
139842550becSJunchao Zhang   PetscFunctionBegin;
13992c4ab24aSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)A, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
14002c4ab24aSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, (void **)&coo));
14012c4ab24aSJunchao Zhang 
14022c4ab24aSJunchao Zhang   const auto &n    = coo->n;
14032c4ab24aSJunchao Zhang   const auto &Annz = coo->nz;
14042c4ab24aSJunchao Zhang   const auto &jmap = coo->jmap;
14052c4ab24aSJunchao Zhang   const auto &perm = coo->perm;
14062c4ab24aSJunchao Zhang 
14079566063dSJacob Faibussowitsch   PetscCall(PetscGetMemType(v, &memtype));
140842550becSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
14092c4ab24aSJunchao Zhang     kv = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatScalarKokkosViewHost(v, n));
141042550becSJunchao Zhang   } else {
14112c4ab24aSJunchao Zhang     kv = ConstMatScalarKokkosView(v, n); /* Directly use v[]'s memory */
141242550becSJunchao Zhang   }
141342550becSJunchao Zhang 
1414c7b718f4SJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1415c7b718f4SJunchao Zhang   else PetscCall(MatSeqAIJGetKokkosView(A, &Aa));                             /* read & write matrix values */
141642550becSJunchao Zhang 
141708bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
14189371c9d4SSatish Balay   Kokkos::parallel_for(
1419d326c3f1SJunchao Zhang     Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, Annz), KOKKOS_LAMBDA(const PetscCount i) {
1420c7b718f4SJunchao Zhang       PetscScalar sum = 0.0;
1421c7b718f4SJunchao Zhang       for (PetscCount k = jmap(i); k < jmap(i + 1); k++) sum += kv(perm(k));
1422c7b718f4SJunchao Zhang       Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1423c7b718f4SJunchao Zhang     });
142408bb9926SJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1425394ed5ebSJunchao Zhang 
14269566063dSJacob Faibussowitsch   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa));
14279566063dSJacob Faibussowitsch   else PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
14283ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
142942550becSJunchao Zhang }
143042550becSJunchao Zhang 
1431d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSetOps_SeqAIJKokkos(Mat A)
1432d71ae5a4SJacob Faibussowitsch {
1433076ba34aSJunchao Zhang   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
1434076ba34aSJunchao Zhang 
14358c3ff71bSJunchao Zhang   PetscFunctionBegin;
1436076ba34aSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_KOKKOS; /* We do not really use this flag */
14376f3d89d0SStefano Zampini   A->boundtocpu  = PETSC_FALSE;
14386f3d89d0SStefano Zampini 
14398c3ff71bSJunchao Zhang   A->ops->assemblyend               = MatAssemblyEnd_SeqAIJKokkos;
14408c3ff71bSJunchao Zhang   A->ops->destroy                   = MatDestroy_SeqAIJKokkos;
14418c3ff71bSJunchao Zhang   A->ops->duplicate                 = MatDuplicate_SeqAIJKokkos;
1442a587d139SMark   A->ops->axpy                      = MatAXPY_SeqAIJKokkos;
1443f0cf5187SStefano Zampini   A->ops->scale                     = MatScale_SeqAIJKokkos;
1444a587d139SMark   A->ops->zeroentries               = MatZeroEntries_SeqAIJKokkos;
1445076ba34aSJunchao Zhang   A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJKokkos;
14468c3ff71bSJunchao Zhang   A->ops->mult                      = MatMult_SeqAIJKokkos;
14478c3ff71bSJunchao Zhang   A->ops->multadd                   = MatMultAdd_SeqAIJKokkos;
14488c3ff71bSJunchao Zhang   A->ops->multtranspose             = MatMultTranspose_SeqAIJKokkos;
14498c3ff71bSJunchao Zhang   A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJKokkos;
14508c3ff71bSJunchao Zhang   A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJKokkos;
14518c3ff71bSJunchao Zhang   A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJKokkos;
1452076ba34aSJunchao Zhang   A->ops->productnumeric            = MatProductNumeric_SeqAIJKokkos_SeqAIJKokkos;
14530ecb592aSJunchao Zhang   A->ops->transpose                 = MatTranspose_SeqAIJKokkos;
1454152b3e56SJunchao Zhang   A->ops->setoption                 = MatSetOption_SeqAIJKokkos;
1455f78ce678SMark Adams   A->ops->getdiagonal               = MatGetDiagonal_SeqAIJKokkos;
1456f4747e26SJunchao Zhang   A->ops->shift                     = MatShift_SeqAIJKokkos;
1457f4747e26SJunchao Zhang   A->ops->diagonalset               = MatDiagonalSet_SeqAIJKokkos;
1458f4747e26SJunchao Zhang   A->ops->diagonalscale             = MatDiagonalScale_SeqAIJKokkos;
1459076ba34aSJunchao Zhang   a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJKokkos;
1460076ba34aSJunchao Zhang   a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJKokkos;
1461076ba34aSJunchao Zhang   a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJKokkos;
1462076ba34aSJunchao Zhang   a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJKokkos;
1463076ba34aSJunchao Zhang   a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJKokkos;
1464076ba34aSJunchao Zhang   a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJKokkos;
14657ee59b9bSJunchao Zhang   a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJKokkos;
146642550becSJunchao Zhang 
14679566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJKokkos));
14689566063dSJacob Faibussowitsch   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJKokkos));
146957761e9aSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
147057761e9aSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijkokkos_hypre_C", MatConvert_AIJ_HYPRE));
147157761e9aSJunchao Zhang #endif
14723ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1473076ba34aSJunchao Zhang }
1474076ba34aSJunchao Zhang 
14759d13fa56SJunchao Zhang /*
14769d13fa56SJunchao Zhang    Extract the (prescribled) diagonal blocks of the matrix and then invert them
14779d13fa56SJunchao Zhang 
14789d13fa56SJunchao Zhang   Input Parameters:
14799d13fa56SJunchao Zhang +  A       - the MATSEQAIJKOKKOS matrix
14809d13fa56SJunchao Zhang .  bs      - block sizes in 'csr' format, i.e., the i-th block has size bs(i+1) - bs(i)
14819d13fa56SJunchao Zhang .  bs2     - square of block sizes in 'csr' format, i.e., the i-th block should be stored at offset bs2(i) in diagVal[]
14829d13fa56SJunchao Zhang .  blkMap  - map row ids to block ids, i.e., row i belongs to the block blkMap(i)
14839d13fa56SJunchao Zhang -  work    - a pre-allocated work buffer (as big as diagVal) for use by this routine
14849d13fa56SJunchao Zhang 
14859d13fa56SJunchao Zhang   Output Parameter:
14869d13fa56SJunchao Zhang .  diagVal - the (pre-allocated) buffer to store the inverted blocks (each block is stored in column-major order)
14879d13fa56SJunchao Zhang */
14889d13fa56SJunchao Zhang PETSC_INTERN PetscErrorCode MatInvertVariableBlockDiagonal_SeqAIJKokkos(Mat A, const PetscIntKokkosView &bs, const PetscIntKokkosView &bs2, const PetscIntKokkosView &blkMap, PetscScalarKokkosView &work, PetscScalarKokkosView &diagVal)
14899d13fa56SJunchao Zhang {
14909d13fa56SJunchao Zhang   Mat_SeqAIJKokkos *akok    = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
14919d13fa56SJunchao Zhang   PetscInt          nblocks = bs.extent(0) - 1;
14929d13fa56SJunchao Zhang 
14939d13fa56SJunchao Zhang   PetscFunctionBegin;
14949d13fa56SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A)); // Since we'll access A's value on device
14959d13fa56SJunchao Zhang 
14969d13fa56SJunchao Zhang   // Pull out the diagonal blocks of the matrix and then invert the blocks
14979d13fa56SJunchao Zhang   auto Aa    = akok->a_dual.view_device();
14989d13fa56SJunchao Zhang   auto Ai    = akok->i_dual.view_device();
14999d13fa56SJunchao Zhang   auto Aj    = akok->j_dual.view_device();
15009d13fa56SJunchao Zhang   auto Adiag = akok->diag_dual.view_device();
15019d13fa56SJunchao Zhang   // TODO: how to tune the team size?
150245402d8aSJunchao Zhang #if defined(KOKKOS_ENABLE_UNIFIED_MEMORY)
15039d13fa56SJunchao Zhang   auto ts = Kokkos::AUTO();
15049d13fa56SJunchao Zhang #else
15059d13fa56SJunchao Zhang   auto ts = 16; // improved performance 30% over Kokkos::AUTO() with CUDA, but failed with "Kokkos::abort: Requested Team Size is too large!" on CPUs
15069d13fa56SJunchao Zhang #endif
15079d13fa56SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(
1508d326c3f1SJunchao Zhang     Kokkos::TeamPolicy<>(PetscGetKokkosExecutionSpace(), nblocks, ts), KOKKOS_LAMBDA(const KokkosTeamMemberType &teamMember) {
15099d13fa56SJunchao Zhang       const PetscInt bid    = teamMember.league_rank();                                                   // block id
15109d13fa56SJunchao Zhang       const PetscInt rstart = bs(bid);                                                                    // this block starts from this row
15119d13fa56SJunchao Zhang       const PetscInt m      = bs(bid + 1) - bs(bid);                                                      // size of this block
15129d13fa56SJunchao Zhang       const auto    &B      = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft>(&diagVal(bs2(bid)), m, m); // column-major order
15139d13fa56SJunchao Zhang       const auto    &W      = PetscScalarKokkosView(&work(bs2(bid)), m * m);
15149d13fa56SJunchao Zhang 
15159d13fa56SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(teamMember, m), [=](const PetscInt &r) { // r-th row in B
15169d13fa56SJunchao Zhang         PetscInt i = rstart + r;                                                            // i-th row in A
15179d13fa56SJunchao Zhang 
15189d13fa56SJunchao Zhang         if (Ai(i) <= Adiag(i) && Adiag(i) < Ai(i + 1)) { // if the diagonal exists (common case)
15199d13fa56SJunchao Zhang           PetscInt first = Adiag(i) - r;                 // we start to check nonzeros from here along this row
15209d13fa56SJunchao Zhang 
15219d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) {                   // walk n steps to see what column indices we will meet
15229d13fa56SJunchao Zhang             if (first + c < Ai(i) || first + c >= Ai(i + 1)) { // this entry (first+c) is out of range of this row, in other words, its value is zero
15239d13fa56SJunchao Zhang               B(r, c) = 0.0;
15249d13fa56SJunchao Zhang             } else if (Aj(first + c) == rstart + c) { // this entry is right on the (rstart+c) column
15259d13fa56SJunchao Zhang               B(r, c) = Aa(first + c);
15269d13fa56SJunchao Zhang             } else { // this entry does not show up in the CSR
15279d13fa56SJunchao Zhang               B(r, c) = 0.0;
15289d13fa56SJunchao Zhang             }
15299d13fa56SJunchao Zhang           }
15309d13fa56SJunchao Zhang         } else { // rare case that the diagonal does not exist
15319d13fa56SJunchao Zhang           const PetscInt begin = Ai(i);
15329d13fa56SJunchao Zhang           const PetscInt end   = Ai(i + 1);
15339d13fa56SJunchao Zhang           for (PetscInt c = 0; c < m; c++) B(r, c) = 0.0;
15349d13fa56SJunchao Zhang           for (PetscInt j = begin; j < end; j++) { // scan the whole row; could use binary search but this is a rare case so we did not.
15359d13fa56SJunchao Zhang             if (rstart <= Aj(j) && Aj(j) < rstart + m) B(r, Aj(j) - rstart) = Aa(j);
15369d13fa56SJunchao Zhang             else if (Aj(j) >= rstart + m) break;
15379d13fa56SJunchao Zhang           }
15389d13fa56SJunchao Zhang         }
15399d13fa56SJunchao Zhang       });
15409d13fa56SJunchao Zhang 
15419d13fa56SJunchao Zhang       // LU-decompose B (w/o pivoting) and then invert B
15429d13fa56SJunchao Zhang       KokkosBatched::TeamLU<KokkosTeamMemberType, KokkosBatched::Algo::LU::Unblocked>::invoke(teamMember, B, 0.0);
15439d13fa56SJunchao Zhang       KokkosBatched::TeamInverseLU<KokkosTeamMemberType, KokkosBatched::Algo::InverseLU::Unblocked>::invoke(teamMember, B, W);
15449d13fa56SJunchao Zhang     }));
15459d13fa56SJunchao Zhang   // PetscLogGpuFlops() is done in the caller PCSetUp_VPBJacobi_Kokkos as we don't want to compute the flops in kernels
15469d13fa56SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15479d13fa56SJunchao Zhang }
15489d13fa56SJunchao Zhang 
1549d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatSetSeqAIJKokkosWithCSRMatrix(Mat A, Mat_SeqAIJKokkos *akok)
1550d71ae5a4SJacob Faibussowitsch {
1551076ba34aSJunchao Zhang   Mat_SeqAIJ *aseq;
1552076ba34aSJunchao Zhang   PetscInt    i, m, n;
15534df4a32cSJunchao Zhang   auto        exec = PetscGetKokkosExecutionSpace();
1554076ba34aSJunchao Zhang 
1555076ba34aSJunchao Zhang   PetscFunctionBegin;
15565f80ce2aSJacob Faibussowitsch   PetscCheck(!A->spptr, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A->spptr is supposed to be empty");
1557076ba34aSJunchao Zhang 
1558076ba34aSJunchao Zhang   m = akok->nrows();
1559076ba34aSJunchao Zhang   n = akok->ncols();
15609566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(A, m, n, m, n));
15619566063dSJacob Faibussowitsch   PetscCall(MatSetType(A, MATSEQAIJKOKKOS));
1562076ba34aSJunchao Zhang 
1563076ba34aSJunchao Zhang   /* Set up data structures of A as a MATSEQAIJ */
15649566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(A, MAT_SKIP_ALLOCATION, NULL));
156557508eceSPierre Jolivet   aseq = (Mat_SeqAIJ *)A->data;
1566076ba34aSJunchao Zhang 
1567e36ced11SJunchao Zhang   PetscCallCXX(akok->i_dual.sync_host(exec)); /* We always need sync'ed i, j on host */
1568e36ced11SJunchao Zhang   PetscCallCXX(akok->j_dual.sync_host(exec));
1569e36ced11SJunchao Zhang   PetscCallCXX(exec.fence());
1570076ba34aSJunchao Zhang 
1571076ba34aSJunchao Zhang   aseq->i       = akok->i_host_data();
1572076ba34aSJunchao Zhang   aseq->j       = akok->j_host_data();
1573076ba34aSJunchao Zhang   aseq->a       = akok->a_host_data();
1574076ba34aSJunchao Zhang   aseq->nonew   = -1; /*this indicates that inserting a new value in the matrix that generates a new nonzero is an error*/
1575076ba34aSJunchao Zhang   aseq->free_a  = PETSC_FALSE;
1576076ba34aSJunchao Zhang   aseq->free_ij = PETSC_FALSE;
1577076ba34aSJunchao Zhang   aseq->nz      = akok->nnz();
1578076ba34aSJunchao Zhang   aseq->maxnz   = aseq->nz;
1579076ba34aSJunchao Zhang 
15809566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->imax));
15819566063dSJacob Faibussowitsch   PetscCall(PetscMalloc1(m, &aseq->ilen));
1582ad540459SPierre Jolivet   for (i = 0; i < m; i++) aseq->ilen[i] = aseq->imax[i] = aseq->i[i + 1] - aseq->i[i];
1583076ba34aSJunchao Zhang 
1584076ba34aSJunchao Zhang   /* It is critical to set the nonzerostate, as we use it to check if sparsity pattern (hence data) has changed on host in MatAssemblyEnd */
1585076ba34aSJunchao Zhang   akok->nonzerostate = A->nonzerostate;
1586ff751488SJunchao Zhang   A->spptr           = akok; /* Set A->spptr before MatAssembly so that A->spptr won't be allocated again there */
15879566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
15889566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
15893ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
1590076ba34aSJunchao Zhang }
1591076ba34aSJunchao Zhang 
15920e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetKokkosCsrMatrix(Mat A, KokkosCsrMatrix *csr)
15930e3ece09SJunchao Zhang {
15940e3ece09SJunchao Zhang   PetscFunctionBegin;
15950e3ece09SJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncDevice(A));
15960e3ece09SJunchao Zhang   *csr = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
15970e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
15980e3ece09SJunchao Zhang }
15990e3ece09SJunchao Zhang 
16000e3ece09SJunchao Zhang PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithKokkosCsrMatrix(MPI_Comm comm, KokkosCsrMatrix csr, Mat *A)
16010e3ece09SJunchao Zhang {
16020e3ece09SJunchao Zhang   Mat_SeqAIJKokkos *akok;
16034d86920dSPierre Jolivet 
16040e3ece09SJunchao Zhang   PetscFunctionBegin;
16050e3ece09SJunchao Zhang   PetscCallCXX(akok = new Mat_SeqAIJKokkos(csr));
16060e3ece09SJunchao Zhang   PetscCall(MatCreate(comm, A));
16070e3ece09SJunchao Zhang   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
16080e3ece09SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
16090e3ece09SJunchao Zhang }
16100e3ece09SJunchao Zhang 
1611076ba34aSJunchao Zhang /* Crete a SEQAIJKOKKOS matrix with a Mat_SeqAIJKokkos data structure
1612076ba34aSJunchao Zhang 
1613076ba34aSJunchao Zhang    Note we have names like MatSeqAIJSetPreallocationCSR, so I use capitalized CSR
1614076ba34aSJunchao Zhang  */
1615d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode MatCreateSeqAIJKokkosWithCSRMatrix(MPI_Comm comm, Mat_SeqAIJKokkos *akok, Mat *A)
1616d71ae5a4SJacob Faibussowitsch {
1617076ba34aSJunchao Zhang   PetscFunctionBegin;
16189566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
16199566063dSJacob Faibussowitsch   PetscCall(MatSetSeqAIJKokkosWithCSRMatrix(*A, akok));
16203ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16218c3ff71bSJunchao Zhang }
16228c3ff71bSJunchao Zhang 
1623152b3e56SJunchao Zhang /*@C
162411a5261eSBarry Smith   MatCreateSeqAIJKokkos - Creates a sparse matrix in `MATSEQAIJKOKKOS` (compressed row) format
16258c3ff71bSJunchao Zhang   (the default parallel PETSc format). This matrix will ultimately be handled by
162620f4b53cSBarry Smith   Kokkos for calculations.
16278c3ff71bSJunchao Zhang 
16288c3ff71bSJunchao Zhang   Collective
16298c3ff71bSJunchao Zhang 
16308c3ff71bSJunchao Zhang   Input Parameters:
163111a5261eSBarry Smith + comm - MPI communicator, set to `PETSC_COMM_SELF`
16328c3ff71bSJunchao Zhang . m    - number of rows
16338c3ff71bSJunchao Zhang . n    - number of columns
163420f4b53cSBarry Smith . nz   - number of nonzeros per row (same for all rows), ignored if `nnz` is provided
163520f4b53cSBarry Smith - nnz  - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL`
16368c3ff71bSJunchao Zhang 
16378c3ff71bSJunchao Zhang   Output Parameter:
16388c3ff71bSJunchao Zhang . A - the matrix
16398c3ff71bSJunchao Zhang 
16402ef1f0ffSBarry Smith   Level: intermediate
16412ef1f0ffSBarry Smith 
16422ef1f0ffSBarry Smith   Notes:
164311a5261eSBarry Smith   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
16448c3ff71bSJunchao Zhang   MatXXXXSetPreallocation() paradgm instead of this routine directly.
164511a5261eSBarry Smith   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]
16468c3ff71bSJunchao Zhang 
164711a5261eSBarry Smith   The AIJ format, also called
16482ef1f0ffSBarry Smith   compressed row storage, is fully compatible with standard Fortran
16498c3ff71bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
165020f4b53cSBarry Smith   either one (as in Fortran) or zero.
16518c3ff71bSJunchao Zhang 
16522ef1f0ffSBarry Smith   Specify the preallocated storage with either `nz` or `nnz` (not both).
16532ef1f0ffSBarry Smith   Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
16542ef1f0ffSBarry Smith   allocation.
16558c3ff71bSJunchao Zhang 
1656fe59aa6dSJacob Faibussowitsch .seealso: [](ch_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`
16578c3ff71bSJunchao Zhang @*/
1658d71ae5a4SJacob Faibussowitsch PetscErrorCode MatCreateSeqAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
1659d71ae5a4SJacob Faibussowitsch {
16608c3ff71bSJunchao Zhang   PetscFunctionBegin;
16619566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
16629566063dSJacob Faibussowitsch   PetscCall(MatCreate(comm, A));
16639566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*A, m, n, m, n));
16649566063dSJacob Faibussowitsch   PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
16659566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
16663ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
16678c3ff71bSJunchao Zhang }
1668930e68a5SMark Adams 
1669aac854edSJunchao Zhang // After matrix numeric factorization, there are still steps to do before triangular solve can be called.
1670aac854edSJunchao Zhang // For example, for transpose solve, we might need to compute the transpose matrices if the solver does not support it (such as KK, while cusparse does).
1671aac854edSJunchao Zhang // In cusparse, one has to call cusparseSpSV_analysis() with updated triangular matrix values before calling cusparseSpSV_solve().
1672aac854edSJunchao Zhang // Simiarily, in KK sptrsv_symbolic() has to be called before sptrsv_solve(). We put these steps in MatSeqAIJKokkos{Transpose}SolveCheck.
1673aac854edSJunchao Zhang static PetscErrorCode MatSeqAIJKokkosSolveCheck(Mat A)
1674d71ae5a4SJacob Faibussowitsch {
167586a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors   = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1676aac854edSJunchao Zhang   const PetscBool             has_lower = factors->iL_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // false with Choleksy
1677aac854edSJunchao Zhang   const PetscBool             has_upper = factors->iU_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // true with LU and Choleksy
167886a27549SJunchao Zhang 
167986a27549SJunchao Zhang   PetscFunctionBegin;
1680aac854edSJunchao Zhang   if (!factors->sptrsv_symbolic_completed) { // If sptrsv_symbolic was not called yet
1681aac854edSJunchao Zhang     if (has_upper) PetscCallCXX(sptrsv_symbolic(&factors->khU, factors->iU_d, factors->jU_d, factors->aU_d));
1682aac854edSJunchao Zhang     if (has_lower) PetscCallCXX(sptrsv_symbolic(&factors->khL, factors->iL_d, factors->jL_d, factors->aL_d));
168386a27549SJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_TRUE;
168486a27549SJunchao Zhang   }
16853ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
168686a27549SJunchao Zhang }
168786a27549SJunchao Zhang 
1688d71ae5a4SJacob Faibussowitsch static PetscErrorCode MatSeqAIJKokkosTransposeSolveCheck(Mat A)
1689d71ae5a4SJacob Faibussowitsch {
1690aac854edSJunchao Zhang   const PetscInt              n         = A->rmap->n;
169186a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors   = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1692aac854edSJunchao Zhang   const PetscBool             has_lower = factors->iL_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // false with Choleksy
1693aac854edSJunchao Zhang   const PetscBool             has_upper = factors->iU_d.extent(0) ? PETSC_TRUE : PETSC_FALSE; // true with LU or Choleksy
169486a27549SJunchao Zhang 
169586a27549SJunchao Zhang   PetscFunctionBegin;
1696aac854edSJunchao Zhang   if (!factors->transpose_updated) {
1697aac854edSJunchao Zhang     if (has_upper) {
1698aac854edSJunchao Zhang       if (!factors->iUt_d.extent(0)) {                                 // Allocate Ut on device if not yet
1699aac854edSJunchao Zhang         factors->iUt_d = MatRowMapKokkosView("factors->iUt_d", n + 1); // KK requires this view to be initialized to 0 to call transpose_matrix
17007b8d4ba6SJunchao Zhang         factors->jUt_d = MatColIdxKokkosView(NoInit("factors->jUt_d"), factors->jU_d.extent(0));
17017b8d4ba6SJunchao Zhang         factors->aUt_d = MatScalarKokkosView(NoInit("factors->aUt_d"), factors->aU_d.extent(0));
1702aac854edSJunchao Zhang       }
170386a27549SJunchao Zhang 
1704aac854edSJunchao Zhang       if (factors->iU_h.extent(0)) { // If U is on host (factorization was done on host), we also compute the transpose on host
1705aac854edSJunchao Zhang         if (!factors->U) {
1706aac854edSJunchao Zhang           Mat_SeqAIJ *seq;
170786a27549SJunchao Zhang 
1708aac854edSJunchao Zhang           PetscCall(MatCreateSeqAIJWithArrays(PETSC_COMM_SELF, n, n, factors->iU_h.data(), factors->jU_h.data(), factors->aU_h.data(), &factors->U));
1709aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->U, MAT_INITIAL_MATRIX, &factors->Ut));
171086a27549SJunchao Zhang 
1711aac854edSJunchao Zhang           seq            = static_cast<Mat_SeqAIJ *>(factors->Ut->data);
1712aac854edSJunchao Zhang           factors->iUt_h = MatRowMapKokkosViewHost(seq->i, n + 1);
1713aac854edSJunchao Zhang           factors->jUt_h = MatColIdxKokkosViewHost(seq->j, seq->nz);
1714aac854edSJunchao Zhang           factors->aUt_h = MatScalarKokkosViewHost(seq->a, seq->nz);
1715aac854edSJunchao Zhang         } else {
1716aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->U, MAT_REUSE_MATRIX, &factors->Ut)); // Matrix Ut' data is aliased with {i, j, a}Ut_h
1717aac854edSJunchao Zhang         }
1718aac854edSJunchao Zhang         // Copy Ut from host to device
1719aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->iUt_d, factors->iUt_h));
1720aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->jUt_d, factors->jUt_h));
1721aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->aUt_d, factors->aUt_h));
1722aac854edSJunchao Zhang       } else { // If U was computed on device, we also compute the transpose there
1723aac854edSJunchao Zhang         // TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices. We have to sort the indices, until KK provides finer control options.
1724aac854edSJunchao Zhang         PetscCallCXX(transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iU_d,
1725aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->jU_d, factors->aU_d,
1726aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->iUt_d, factors->jUt_d,
1727aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->aUt_d));
1728aac854edSJunchao Zhang         PetscCallCXX(sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iUt_d, factors->jUt_d, factors->aUt_d));
1729aac854edSJunchao Zhang       }
1730aac854edSJunchao Zhang       PetscCallCXX(sptrsv_symbolic(&factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d));
1731aac854edSJunchao Zhang     }
1732aac854edSJunchao Zhang 
1733aac854edSJunchao Zhang     // do the same for L with LU
1734aac854edSJunchao Zhang     if (has_lower) {
1735aac854edSJunchao Zhang       if (!factors->iLt_d.extent(0)) {                                 // Allocate Lt on device if not yet
1736aac854edSJunchao Zhang         factors->iLt_d = MatRowMapKokkosView("factors->iLt_d", n + 1); // KK requires this view to be initialized to 0 to call transpose_matrix
1737aac854edSJunchao Zhang         factors->jLt_d = MatColIdxKokkosView(NoInit("factors->jLt_d"), factors->jL_d.extent(0));
1738aac854edSJunchao Zhang         factors->aLt_d = MatScalarKokkosView(NoInit("factors->aLt_d"), factors->aL_d.extent(0));
1739aac854edSJunchao Zhang       }
1740aac854edSJunchao Zhang 
1741aac854edSJunchao Zhang       if (factors->iL_h.extent(0)) { // If L is on host, we also compute the transpose on host
1742aac854edSJunchao Zhang         if (!factors->L) {
1743aac854edSJunchao Zhang           Mat_SeqAIJ *seq;
1744aac854edSJunchao Zhang 
1745aac854edSJunchao Zhang           PetscCall(MatCreateSeqAIJWithArrays(PETSC_COMM_SELF, n, n, factors->iL_h.data(), factors->jL_h.data(), factors->aL_h.data(), &factors->L));
1746aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->L, MAT_INITIAL_MATRIX, &factors->Lt));
1747aac854edSJunchao Zhang 
1748aac854edSJunchao Zhang           seq            = static_cast<Mat_SeqAIJ *>(factors->Lt->data);
1749aac854edSJunchao Zhang           factors->iLt_h = MatRowMapKokkosViewHost(seq->i, n + 1);
1750aac854edSJunchao Zhang           factors->jLt_h = MatColIdxKokkosViewHost(seq->j, seq->nz);
1751aac854edSJunchao Zhang           factors->aLt_h = MatScalarKokkosViewHost(seq->a, seq->nz);
1752aac854edSJunchao Zhang         } else {
1753aac854edSJunchao Zhang           PetscCall(MatTranspose(factors->L, MAT_REUSE_MATRIX, &factors->Lt)); // Matrix Lt' data is aliased with {i, j, a}Lt_h
1754aac854edSJunchao Zhang         }
1755aac854edSJunchao Zhang         // Copy Lt from host to device
1756aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->iLt_d, factors->iLt_h));
1757aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->jLt_d, factors->jLt_h));
1758aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->aLt_d, factors->aLt_h));
1759aac854edSJunchao Zhang       } else { // If L was computed on device, we also compute the transpose there
1760aac854edSJunchao Zhang         // TODO: KK transpose_matrix() does not sort column indices, however cusparse requires sorted indices. We have to sort the indices, until KK provides finer control options.
1761aac854edSJunchao Zhang         PetscCallCXX(transpose_matrix<ConstMatRowMapKokkosView, ConstMatColIdxKokkosView, ConstMatScalarKokkosView, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView, MatRowMapKokkosView, DefaultExecutionSpace>(n, n, factors->iL_d,
1762aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->jL_d, factors->aL_d,
1763aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->iLt_d, factors->jLt_d,
1764aac854edSJunchao Zhang                                                                                                                                                                                                                                factors->aLt_d));
1765aac854edSJunchao Zhang         PetscCallCXX(sort_crs_matrix<DefaultExecutionSpace, MatRowMapKokkosView, MatColIdxKokkosView, MatScalarKokkosView>(factors->iLt_d, factors->jLt_d, factors->aLt_d));
1766aac854edSJunchao Zhang       }
1767aac854edSJunchao Zhang       PetscCallCXX(sptrsv_symbolic(&factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d));
1768aac854edSJunchao Zhang     }
1769aac854edSJunchao Zhang 
177086a27549SJunchao Zhang     factors->transpose_updated = PETSC_TRUE;
177186a27549SJunchao Zhang   }
17723ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
177386a27549SJunchao Zhang }
177486a27549SJunchao Zhang 
1775aac854edSJunchao Zhang // Solve Ax = b, with RAR = U^T D U, where R is the row (and col) permutation matrix on A.
1776aac854edSJunchao Zhang // R is represented by rowperm in factors. If R is identity (i.e, no reordering), then rowperm is empty.
1777aac854edSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJKokkos_Cholesky(Mat A, Vec bb, Vec xx)
1778d71ae5a4SJacob Faibussowitsch {
1779aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
178086a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1781aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1782aac854edSJunchao Zhang   PetscScalarKokkosView       D       = factors->D_d;
1783aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1784aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1785aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1786aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm  = factors->rowperm;
1787aac854edSJunchao Zhang   PetscBool                   identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
178886a27549SJunchao Zhang 
178986a27549SJunchao Zhang   PetscFunctionBegin;
17909566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1791aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSolveCheck(A));          // for UX = T
1792aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A)); // for U^T Y = B
1793aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1794aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
1795aac854edSJunchao Zhang 
1796aac854edSJunchao Zhang   // Solve U^T Y = B
1797aac854edSJunchao Zhang   if (identity) { // Reorder b with the row permutation
1798aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1799aac854edSJunchao Zhang     Y = factors->workVector;
1800aac854edSJunchao Zhang   } else {
1801aac854edSJunchao Zhang     B = factors->workVector;
1802aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(rowperm(i)); }));
1803aac854edSJunchao Zhang     Y = x;
1804aac854edSJunchao Zhang   }
1805aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, B, Y));
1806aac854edSJunchao Zhang 
1807aac854edSJunchao Zhang   // Solve diag(D) Y' = Y.
1808aac854edSJunchao Zhang   // Actually just do Y' = Y*D since D is already inverted in MatCholeskyFactorNumeric_SeqAIJ(). It is basically a vector element-wise multiplication.
1809aac854edSJunchao Zhang   PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { Y(i) = Y(i) * D(i); }));
1810aac854edSJunchao Zhang 
1811aac854edSJunchao Zhang   // Solve UX = Y
1812aac854edSJunchao Zhang   if (identity) {
1813aac854edSJunchao Zhang     X = x;
1814aac854edSJunchao Zhang   } else {
1815aac854edSJunchao Zhang     X = factors->workVector; // B is not needed anymore
1816aac854edSJunchao Zhang   }
1817aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, Y, X));
1818aac854edSJunchao Zhang 
1819aac854edSJunchao Zhang   // Reorder X with the inverse column (row) permutation
1820aac854edSJunchao Zhang   if (!identity) {
1821aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(rowperm(i)) = X(i); }));
1822aac854edSJunchao Zhang   }
1823aac854edSJunchao Zhang 
1824aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1825aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
18269566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
18273ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
182886a27549SJunchao Zhang }
182986a27549SJunchao Zhang 
1830aac854edSJunchao Zhang // Solve Ax = b, with RAC = LU, where R and C are row and col permutation matrices on A respectively.
1831aac854edSJunchao Zhang // R and C are represented by rowperm and colperm in factors.
1832aac854edSJunchao Zhang // If R or C is identity (i.e, no reordering), then rowperm or colperm is empty.
1833aac854edSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJKokkos_LU(Mat A, Vec bb, Vec xx)
1834d71ae5a4SJacob Faibussowitsch {
1835aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
183686a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1837aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1838aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1839aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1840aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1841aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm      = factors->rowperm;
1842aac854edSJunchao Zhang   PetscIntKokkosView         &colperm      = factors->colperm;
1843aac854edSJunchao Zhang   PetscBool                   row_identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1844aac854edSJunchao Zhang   PetscBool                   col_identity = colperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
184586a27549SJunchao Zhang 
184686a27549SJunchao Zhang   PetscFunctionBegin;
18479566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
1848aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSolveCheck(A));
1849aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1850aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
185186a27549SJunchao Zhang 
1852aac854edSJunchao Zhang   // Solve L Y = B (i.e., L (U C^- x) = R b).  R b indicates applying the row permutation on b.
1853aac854edSJunchao Zhang   if (row_identity) {
1854aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1855aac854edSJunchao Zhang     Y = factors->workVector;
1856aac854edSJunchao Zhang   } else {
1857aac854edSJunchao Zhang     B = factors->workVector;
1858aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(rowperm(i)); }));
1859aac854edSJunchao Zhang     Y = x;
1860aac854edSJunchao Zhang   }
1861aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khL, factors->iL_d, factors->jL_d, factors->aL_d, B, Y));
1862aac854edSJunchao Zhang 
1863aac854edSJunchao Zhang   // Solve U C^- x = Y
1864aac854edSJunchao Zhang   if (col_identity) {
1865aac854edSJunchao Zhang     X = x;
1866aac854edSJunchao Zhang   } else {
1867aac854edSJunchao Zhang     X = factors->workVector;
1868aac854edSJunchao Zhang   }
1869aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khU, factors->iU_d, factors->jU_d, factors->aU_d, Y, X));
1870aac854edSJunchao Zhang 
1871aac854edSJunchao Zhang   // x = C X; Reorder X with the inverse col permutation
1872aac854edSJunchao Zhang   if (!col_identity) {
1873aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(colperm(i)) = X(i); }));
1874aac854edSJunchao Zhang   }
1875aac854edSJunchao Zhang 
1876aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1877aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
18789566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
18793ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
188086a27549SJunchao Zhang }
188186a27549SJunchao Zhang 
1882aac854edSJunchao Zhang // Solve A^T x = b, with RAC = LU, where R and C are row and col permutation matrices on A respectively.
1883aac854edSJunchao Zhang // R and C are represented by rowperm and colperm in factors.
1884aac854edSJunchao Zhang // If R or C is identity (i.e, no reordering), then rowperm or colperm is empty.
1885aac854edSJunchao Zhang // A = R^-1 L U C^-1, so A^T = C^-T U^T L^T R^-T. But since C^- = C^T, R^- = R^T, we have A^T = C U^T L^T R.
1886aac854edSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJKokkos_LU(Mat A, Vec bb, Vec xx)
1887aac854edSJunchao Zhang {
1888aac854edSJunchao Zhang   auto                        exec    = PetscGetKokkosExecutionSpace();
1889aac854edSJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)A->spptr;
1890aac854edSJunchao Zhang   PetscInt                    m       = A->rmap->n;
1891aac854edSJunchao Zhang   PetscScalarKokkosView       X, Y, B; // alias
1892aac854edSJunchao Zhang   ConstPetscScalarKokkosView  b;
1893aac854edSJunchao Zhang   PetscScalarKokkosView       x;
1894aac854edSJunchao Zhang   PetscIntKokkosView         &rowperm      = factors->rowperm;
1895aac854edSJunchao Zhang   PetscIntKokkosView         &colperm      = factors->colperm;
1896aac854edSJunchao Zhang   PetscBool                   row_identity = rowperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1897aac854edSJunchao Zhang   PetscBool                   col_identity = colperm.extent(0) ? PETSC_FALSE : PETSC_TRUE;
1898aac854edSJunchao Zhang 
1899aac854edSJunchao Zhang   PetscFunctionBegin;
1900aac854edSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1901aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosTransposeSolveCheck(A)); // Update L^T, U^T if needed, and do sptrsv symbolic for L^T, U^T
1902aac854edSJunchao Zhang   PetscCall(VecGetKokkosView(bb, &b));
1903aac854edSJunchao Zhang   PetscCall(VecGetKokkosViewWrite(xx, &x));
1904aac854edSJunchao Zhang 
1905aac854edSJunchao Zhang   // Solve U^T Y = B (i.e., U^T (L^T R x) = C^- b).  Note C^- b = C^T b, which means applying the column permutation on b.
1906aac854edSJunchao Zhang   if (col_identity) { // Reorder b with the col permutation
1907aac854edSJunchao Zhang     B = PetscScalarKokkosView(const_cast<PetscScalar *>(b.data()), b.extent(0));
1908aac854edSJunchao Zhang     Y = factors->workVector;
1909aac854edSJunchao Zhang   } else {
1910aac854edSJunchao Zhang     B = factors->workVector;
1911aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { B(i) = b(colperm(i)); }));
1912aac854edSJunchao Zhang     Y = x;
1913aac854edSJunchao Zhang   }
1914aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khUt, factors->iUt_d, factors->jUt_d, factors->aUt_d, B, Y));
1915aac854edSJunchao Zhang 
1916aac854edSJunchao Zhang   // Solve L^T X = Y
1917aac854edSJunchao Zhang   if (row_identity) {
1918aac854edSJunchao Zhang     X = x;
1919aac854edSJunchao Zhang   } else {
1920aac854edSJunchao Zhang     X = factors->workVector;
1921aac854edSJunchao Zhang   }
1922aac854edSJunchao Zhang   PetscCallCXX(sptrsv_solve(exec, &factors->khLt, factors->iLt_d, factors->jLt_d, factors->aLt_d, Y, X));
1923aac854edSJunchao Zhang 
1924aac854edSJunchao Zhang   // x = R^- X = R^T X; Reorder X with the inverse row permutation
1925aac854edSJunchao Zhang   if (!row_identity) {
1926aac854edSJunchao Zhang     PetscCallCXX(Kokkos::parallel_for(Kokkos::RangePolicy<>(exec, 0, m), KOKKOS_LAMBDA(const PetscInt i) { x(rowperm(i)) = X(i); }));
1927aac854edSJunchao Zhang   }
1928aac854edSJunchao Zhang 
1929aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosView(bb, &b));
1930aac854edSJunchao Zhang   PetscCall(VecRestoreKokkosViewWrite(xx, &x));
1931aac854edSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1932aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1933aac854edSJunchao Zhang }
1934aac854edSJunchao Zhang 
1935aac854edSJunchao Zhang static PetscErrorCode MatLUFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
1936aac854edSJunchao Zhang {
1937aac854edSJunchao Zhang   PetscFunctionBegin;
1938aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncHost(A));
1939aac854edSJunchao Zhang   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
1940aac854edSJunchao Zhang 
1941aac854edSJunchao Zhang   if (!info->solveonhost) { // if solve on host, then we don't need to copy L, U to device
1942aac854edSJunchao Zhang     Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
1943aac854edSJunchao Zhang     Mat_SeqAIJ                 *b       = static_cast<Mat_SeqAIJ *>(B->data);
1944aac854edSJunchao Zhang     const PetscInt             *Bi = b->i, *Bj = b->j, *Bdiag = b->diag;
1945aac854edSJunchao Zhang     const MatScalar            *Ba = b->a;
1946aac854edSJunchao Zhang     PetscInt                    m = B->rmap->n, n = B->cmap->n;
1947aac854edSJunchao Zhang 
1948aac854edSJunchao Zhang     if (factors->iL_h.extent(0) == 0) { // Allocate memory and copy the L, U structure for the first time
1949aac854edSJunchao Zhang       // Allocate memory and copy the structure
1950aac854edSJunchao Zhang       factors->iL_h = MatRowMapKokkosViewHost(NoInit("iL_h"), m + 1);
1951aac854edSJunchao Zhang       factors->jL_h = MatColIdxKokkosViewHost(NoInit("jL_h"), (Bi[m] - Bi[0]) + m); // + the diagonal entries
1952aac854edSJunchao Zhang       factors->aL_h = MatScalarKokkosViewHost(NoInit("aL_h"), (Bi[m] - Bi[0]) + m);
1953aac854edSJunchao Zhang       factors->iU_h = MatRowMapKokkosViewHost(NoInit("iU_h"), m + 1);
1954aac854edSJunchao Zhang       factors->jU_h = MatColIdxKokkosViewHost(NoInit("jU_h"), (Bdiag[0] - Bdiag[m]));
1955aac854edSJunchao Zhang       factors->aU_h = MatScalarKokkosViewHost(NoInit("aU_h"), (Bdiag[0] - Bdiag[m]));
1956aac854edSJunchao Zhang 
1957aac854edSJunchao Zhang       PetscInt *Li = factors->iL_h.data();
1958aac854edSJunchao Zhang       PetscInt *Lj = factors->jL_h.data();
1959aac854edSJunchao Zhang       PetscInt *Ui = factors->iU_h.data();
1960aac854edSJunchao Zhang       PetscInt *Uj = factors->jU_h.data();
1961aac854edSJunchao Zhang 
1962aac854edSJunchao Zhang       Li[0] = Ui[0] = 0;
1963aac854edSJunchao Zhang       for (PetscInt i = 0; i < m; i++) {
1964aac854edSJunchao Zhang         PetscInt llen = Bi[i + 1] - Bi[i];       // exclusive of the diagonal entry
1965aac854edSJunchao Zhang         PetscInt ulen = Bdiag[i] - Bdiag[i + 1]; // inclusive of the diagonal entry
1966aac854edSJunchao Zhang 
1967aac854edSJunchao Zhang         PetscArraycpy(Lj + Li[i], Bj + Bi[i], llen); // entries of L on the left of the diagonal
1968aac854edSJunchao Zhang         Lj[Li[i] + llen] = i;                        // diagonal entry of L
1969aac854edSJunchao Zhang 
1970aac854edSJunchao Zhang         Uj[Ui[i]] = i;                                                  // diagonal entry of U
1971aac854edSJunchao Zhang         PetscArraycpy(Uj + Ui[i] + 1, Bj + Bdiag[i + 1] + 1, ulen - 1); // entries of U on  the right of the diagonal
1972aac854edSJunchao Zhang 
1973aac854edSJunchao Zhang         Li[i + 1] = Li[i] + llen + 1;
1974aac854edSJunchao Zhang         Ui[i + 1] = Ui[i] + ulen;
1975aac854edSJunchao Zhang       }
1976aac854edSJunchao Zhang 
1977aac854edSJunchao Zhang       factors->iL_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iL_h);
1978aac854edSJunchao Zhang       factors->jL_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jL_h);
1979aac854edSJunchao Zhang       factors->iU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iU_h);
1980aac854edSJunchao Zhang       factors->jU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jU_h);
1981aac854edSJunchao Zhang       factors->aL_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aL_h);
1982aac854edSJunchao Zhang       factors->aU_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aU_h);
1983aac854edSJunchao Zhang 
1984aac854edSJunchao Zhang       // Copy row/col permutation to device
1985aac854edSJunchao Zhang       IS        rowperm = ((Mat_SeqAIJ *)B->data)->row;
1986aac854edSJunchao Zhang       PetscBool row_identity;
1987aac854edSJunchao Zhang       PetscCall(ISIdentity(rowperm, &row_identity));
1988aac854edSJunchao Zhang       if (!row_identity) {
1989aac854edSJunchao Zhang         const PetscInt *ip;
1990aac854edSJunchao Zhang 
1991aac854edSJunchao Zhang         PetscCall(ISGetIndices(rowperm, &ip));
1992aac854edSJunchao Zhang         factors->rowperm = PetscIntKokkosView(NoInit("rowperm"), m);
1993aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->rowperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), m)));
1994aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(rowperm, &ip));
1995aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(m * sizeof(PetscInt)));
1996aac854edSJunchao Zhang       }
1997aac854edSJunchao Zhang 
1998aac854edSJunchao Zhang       IS        colperm = ((Mat_SeqAIJ *)B->data)->col;
1999aac854edSJunchao Zhang       PetscBool col_identity;
2000aac854edSJunchao Zhang       PetscCall(ISIdentity(colperm, &col_identity));
2001aac854edSJunchao Zhang       if (!col_identity) {
2002aac854edSJunchao Zhang         const PetscInt *ip;
2003aac854edSJunchao Zhang 
2004aac854edSJunchao Zhang         PetscCall(ISGetIndices(colperm, &ip));
2005aac854edSJunchao Zhang         factors->colperm = PetscIntKokkosView(NoInit("colperm"), n);
2006aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->colperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), n)));
2007aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(colperm, &ip));
2008aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));
2009aac854edSJunchao Zhang       }
2010aac854edSJunchao Zhang 
2011aac854edSJunchao Zhang       /* Create sptrsv handles for L, U and their transpose */
2012aac854edSJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2013aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
2014aac854edSJunchao Zhang #else
2015aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
2016aac854edSJunchao Zhang #endif
2017aac854edSJunchao Zhang       factors->khL.create_sptrsv_handle(sptrsv_alg, m, true /* L is lower tri */);
2018aac854edSJunchao Zhang       factors->khU.create_sptrsv_handle(sptrsv_alg, m, false /* U is not lower tri */);
2019aac854edSJunchao Zhang       factors->khLt.create_sptrsv_handle(sptrsv_alg, m, false /* L^T is not lower tri */);
2020aac854edSJunchao Zhang       factors->khUt.create_sptrsv_handle(sptrsv_alg, m, true /* U^T is lower tri */);
2021aac854edSJunchao Zhang     }
2022aac854edSJunchao Zhang 
2023aac854edSJunchao Zhang     // Copy the value
2024aac854edSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
2025aac854edSJunchao Zhang       PetscInt        llen = Bi[i + 1] - Bi[i];
2026aac854edSJunchao Zhang       PetscInt        ulen = Bdiag[i] - Bdiag[i + 1];
2027aac854edSJunchao Zhang       const PetscInt *Li   = factors->iL_h.data();
2028aac854edSJunchao Zhang       const PetscInt *Ui   = factors->iU_h.data();
2029aac854edSJunchao Zhang 
2030aac854edSJunchao Zhang       PetscScalar *La = factors->aL_h.data();
2031aac854edSJunchao Zhang       PetscScalar *Ua = factors->aU_h.data();
2032aac854edSJunchao Zhang 
2033aac854edSJunchao Zhang       PetscArraycpy(La + Li[i], Ba + Bi[i], llen); // entries of L
2034aac854edSJunchao Zhang       La[Li[i] + llen] = 1.0;                      // diagonal entry
2035aac854edSJunchao Zhang 
2036aac854edSJunchao Zhang       Ua[Ui[i]] = 1.0 / Ba[Bdiag[i]];                                 // diagonal entry
2037aac854edSJunchao Zhang       PetscArraycpy(Ua + Ui[i] + 1, Ba + Bdiag[i + 1] + 1, ulen - 1); // entries of U
2038aac854edSJunchao Zhang     }
2039aac854edSJunchao Zhang 
2040aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aL_d, factors->aL_h));
2041aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aU_d, factors->aU_h));
2042aac854edSJunchao Zhang     // Once the factors' values have changed, we need to update their transpose and redo sptrsv symbolic
2043aac854edSJunchao Zhang     factors->transpose_updated         = PETSC_FALSE;
2044aac854edSJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_FALSE;
2045aac854edSJunchao Zhang 
2046aac854edSJunchao Zhang     B->ops->solve          = MatSolve_SeqAIJKokkos_LU;
2047aac854edSJunchao Zhang     B->ops->solvetranspose = MatSolveTranspose_SeqAIJKokkos_LU;
2048aac854edSJunchao Zhang   }
2049aac854edSJunchao Zhang 
2050aac854edSJunchao Zhang   B->ops->matsolve          = NULL;
2051aac854edSJunchao Zhang   B->ops->matsolvetranspose = NULL;
2052aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2053aac854edSJunchao Zhang }
2054aac854edSJunchao Zhang 
2055aac854edSJunchao Zhang static PetscErrorCode MatILUFactorNumeric_SeqAIJKokkos_ILU0(Mat B, Mat A, const MatFactorInfo *info)
2056d71ae5a4SJacob Faibussowitsch {
205786a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
205886a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
205986a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
206086a27549SJunchao Zhang 
206186a27549SJunchao Zhang   PetscFunctionBegin;
20629566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeBegin());
2063aac854edSJunchao Zhang   PetscCheck(!info->factoronhost, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "MatFactorInfo.factoronhost should be false");
20649566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
2065076ba34aSJunchao Zhang 
2066076ba34aSJunchao Zhang   auto a_d = aijkok->a_dual.view_device();
2067076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
2068076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
2069076ba34aSJunchao Zhang 
2070aac854edSJunchao Zhang   PetscCallCXX(spiluk_numeric(&factors->kh, fill_lev, i_d, j_d, a_d, factors->iL_d, factors->jL_d, factors->aL_d, factors->iU_d, factors->jU_d, factors->aU_d));
207186a27549SJunchao Zhang 
207286a27549SJunchao Zhang   B->assembled              = PETSC_TRUE;
207386a27549SJunchao Zhang   B->preallocated           = PETSC_TRUE;
2074aac854edSJunchao Zhang   B->ops->solve             = MatSolve_SeqAIJKokkos_LU;
2075aac854edSJunchao Zhang   B->ops->solvetranspose    = MatSolveTranspose_SeqAIJKokkos_LU;
207686a27549SJunchao Zhang   B->ops->matsolve          = NULL;
207786a27549SJunchao Zhang   B->ops->matsolvetranspose = NULL;
207886a27549SJunchao Zhang 
207986a27549SJunchao Zhang   /* Once the factors' value changed, we need to update their transpose and sptrsv handle */
208086a27549SJunchao Zhang   factors->transpose_updated         = PETSC_FALSE;
208186a27549SJunchao Zhang   factors->sptrsv_symbolic_completed = PETSC_FALSE;
2082eeadb341SJunchao Zhang   /* TODO: log flops, but how to know that? */
20839566063dSJacob Faibussowitsch   PetscCall(PetscLogGpuTimeEnd());
20843ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
208586a27549SJunchao Zhang }
208686a27549SJunchao Zhang 
2087aac854edSJunchao Zhang // Use KK's spiluk_symbolic() to do ILU0 symbolic factorization, with no row/col reordering
2088aac854edSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos_ILU0(Mat B, Mat A, IS, IS, const MatFactorInfo *info)
2089d71ae5a4SJacob Faibussowitsch {
209086a27549SJunchao Zhang   Mat_SeqAIJKokkos           *aijkok;
209186a27549SJunchao Zhang   Mat_SeqAIJ                 *b;
209286a27549SJunchao Zhang   Mat_SeqAIJKokkosTriFactors *factors  = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
209386a27549SJunchao Zhang   PetscInt                    fill_lev = info->levels;
209486a27549SJunchao Zhang   PetscInt                    nnzA     = ((Mat_SeqAIJ *)A->data)->nz, nnzL, nnzU;
209586a27549SJunchao Zhang   PetscInt                    n        = A->rmap->n;
209686a27549SJunchao Zhang 
209786a27549SJunchao Zhang   PetscFunctionBegin;
2098aac854edSJunchao Zhang   PetscCheck(!info->factoronhost, PetscObjectComm((PetscObject)A), PETSC_ERR_PLIB, "MatFactorInfo's factoronhost should be false as we are doing it on device right now");
20999566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJKokkosSyncDevice(A));
210086a27549SJunchao Zhang 
210186a27549SJunchao Zhang   /* Create a spiluk handle and then do symbolic factorization */
210286a27549SJunchao Zhang   nnzL = nnzU = PetscRealIntMultTruncate(info->fill, nnzA);
2103aac854edSJunchao Zhang   factors->kh.create_spiluk_handle(SPILUKAlgorithm::SEQLVLSCHD_TP1, n, nnzL, nnzU);
210486a27549SJunchao Zhang 
210586a27549SJunchao Zhang   auto spiluk_handle = factors->kh.get_spiluk_handle();
210686a27549SJunchao Zhang 
210786a27549SJunchao Zhang   Kokkos::realloc(factors->iL_d, n + 1); /* Free old arrays and realloc */
210886a27549SJunchao Zhang   Kokkos::realloc(factors->jL_d, spiluk_handle->get_nnzL());
210986a27549SJunchao Zhang   Kokkos::realloc(factors->iU_d, n + 1);
211086a27549SJunchao Zhang   Kokkos::realloc(factors->jU_d, spiluk_handle->get_nnzU());
211186a27549SJunchao Zhang 
211286a27549SJunchao Zhang   aijkok   = (Mat_SeqAIJKokkos *)A->spptr;
2113076ba34aSJunchao Zhang   auto i_d = aijkok->i_dual.view_device();
2114076ba34aSJunchao Zhang   auto j_d = aijkok->j_dual.view_device();
2115aac854edSJunchao Zhang   PetscCallCXX(spiluk_symbolic(&factors->kh, fill_lev, i_d, j_d, factors->iL_d, factors->jL_d, factors->iU_d, factors->jU_d));
211686a27549SJunchao Zhang   /* TODO: if spiluk_symbolic is asynchronous, do we need to sync before calling get_nnzL()? */
211786a27549SJunchao Zhang 
211886a27549SJunchao Zhang   Kokkos::resize(factors->jL_d, spiluk_handle->get_nnzL()); /* Shrink or expand, and retain old value */
211986a27549SJunchao Zhang   Kokkos::resize(factors->jU_d, spiluk_handle->get_nnzU());
212086a27549SJunchao Zhang   Kokkos::realloc(factors->aL_d, spiluk_handle->get_nnzL()); /* No need to retain old value */
212186a27549SJunchao Zhang   Kokkos::realloc(factors->aU_d, spiluk_handle->get_nnzU());
212286a27549SJunchao Zhang 
212386a27549SJunchao Zhang   /* TODO: add options to select sptrsv algorithms */
212486a27549SJunchao Zhang   /* Create sptrsv handles for L, U and their transpose */
212586a27549SJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2126aac854edSJunchao Zhang   auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
212786a27549SJunchao Zhang #else
2128aac854edSJunchao Zhang   auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
212986a27549SJunchao Zhang #endif
213086a27549SJunchao Zhang 
213186a27549SJunchao Zhang   factors->khL.create_sptrsv_handle(sptrsv_alg, n, true /* L is lower tri */);
213286a27549SJunchao Zhang   factors->khU.create_sptrsv_handle(sptrsv_alg, n, false /* U is not lower tri */);
213386a27549SJunchao Zhang   factors->khLt.create_sptrsv_handle(sptrsv_alg, n, false /* L^T is not lower tri */);
213486a27549SJunchao Zhang   factors->khUt.create_sptrsv_handle(sptrsv_alg, n, true /* U^T is lower tri */);
213586a27549SJunchao Zhang 
213686a27549SJunchao Zhang   /* Fill fields of the factor matrix B */
21379566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(B, MAT_SKIP_ALLOCATION, NULL));
213886a27549SJunchao Zhang   b     = (Mat_SeqAIJ *)B->data;
213986a27549SJunchao Zhang   b->nz = b->maxnz          = spiluk_handle->get_nnzL() + spiluk_handle->get_nnzU();
214086a27549SJunchao Zhang   B->info.fill_ratio_given  = info->fill;
2141a1e4e190SStefano Zampini   B->info.fill_ratio_needed = nnzA > 0 ? ((PetscReal)b->nz) / ((PetscReal)nnzA) : 1.0;
214286a27549SJunchao Zhang 
2143aac854edSJunchao Zhang   B->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJKokkos_ILU0;
21443ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2145930e68a5SMark Adams }
2146930e68a5SMark Adams 
2147aac854edSJunchao Zhang static PetscErrorCode MatLUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
2148aac854edSJunchao Zhang {
2149aac854edSJunchao Zhang   PetscFunctionBegin;
2150aac854edSJunchao Zhang   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
2151aac854edSJunchao Zhang   PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2152aac854edSJunchao Zhang   PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2153aac854edSJunchao Zhang   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
2154aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2155aac854edSJunchao Zhang }
2156aac854edSJunchao Zhang 
2157aac854edSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
2158aac854edSJunchao Zhang {
2159aac854edSJunchao Zhang   PetscBool row_identity = PETSC_FALSE, col_identity = PETSC_FALSE;
2160aac854edSJunchao Zhang 
2161aac854edSJunchao Zhang   PetscFunctionBegin;
2162aac854edSJunchao Zhang   if (!info->factoronhost) {
2163aac854edSJunchao Zhang     PetscCall(ISIdentity(isrow, &row_identity));
2164aac854edSJunchao Zhang     PetscCall(ISIdentity(iscol, &col_identity));
2165aac854edSJunchao Zhang   }
2166aac854edSJunchao Zhang 
2167aac854edSJunchao Zhang   PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2168aac854edSJunchao Zhang   PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2169aac854edSJunchao Zhang 
2170aac854edSJunchao Zhang   if (!info->factoronhost && !info->levels && row_identity && col_identity) { // if level 0 and no reordering
2171aac854edSJunchao Zhang     PetscCall(MatILUFactorSymbolic_SeqAIJKokkos_ILU0(B, A, isrow, iscol, info));
2172aac854edSJunchao Zhang   } else {
2173aac854edSJunchao Zhang     PetscCall(MatILUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info)); // otherwise, use PETSc's ILU on host
2174aac854edSJunchao Zhang     B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJKokkos;
2175aac854edSJunchao Zhang   }
2176aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2177aac854edSJunchao Zhang }
2178aac854edSJunchao Zhang 
2179aac854edSJunchao Zhang static PetscErrorCode MatCholeskyFactorNumeric_SeqAIJKokkos(Mat B, Mat A, const MatFactorInfo *info)
2180aac854edSJunchao Zhang {
2181aac854edSJunchao Zhang   PetscFunctionBegin;
2182aac854edSJunchao Zhang   PetscCall(MatSeqAIJKokkosSyncHost(A));
2183aac854edSJunchao Zhang   PetscCall(MatCholeskyFactorNumeric_SeqAIJ(B, A, info));
2184aac854edSJunchao Zhang 
2185aac854edSJunchao Zhang   if (!info->solveonhost) { // if solve on host, then we don't need to copy L, U to device
2186aac854edSJunchao Zhang     Mat_SeqAIJKokkosTriFactors *factors = (Mat_SeqAIJKokkosTriFactors *)B->spptr;
2187aac854edSJunchao Zhang     Mat_SeqAIJ                 *b       = static_cast<Mat_SeqAIJ *>(B->data);
2188aac854edSJunchao Zhang     const PetscInt             *Bi = b->i, *Bj = b->j, *Bdiag = b->diag;
2189aac854edSJunchao Zhang     const MatScalar            *Ba = b->a;
2190aac854edSJunchao Zhang     PetscInt                    m  = B->rmap->n;
2191aac854edSJunchao Zhang 
2192aac854edSJunchao Zhang     if (factors->iU_h.extent(0) == 0) { // First time of numeric factorization
2193aac854edSJunchao Zhang       // Allocate memory and copy the structure
2194aac854edSJunchao Zhang       factors->iU_h = PetscIntKokkosViewHost(const_cast<PetscInt *>(Bi), m + 1); // wrap Bi as iU_h
2195aac854edSJunchao Zhang       factors->jU_h = MatColIdxKokkosViewHost(NoInit("jU_h"), Bi[m]);
2196aac854edSJunchao Zhang       factors->aU_h = MatScalarKokkosViewHost(NoInit("aU_h"), Bi[m]);
2197aac854edSJunchao Zhang       factors->D_h  = MatScalarKokkosViewHost(NoInit("D_h"), m);
2198aac854edSJunchao Zhang       factors->aU_d = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->aU_h);
2199aac854edSJunchao Zhang       factors->D_d  = Kokkos::create_mirror_view(DefaultMemorySpace(), factors->D_h);
2200aac854edSJunchao Zhang 
2201aac854edSJunchao Zhang       // Build jU_h from the skewed Aj
2202aac854edSJunchao Zhang       PetscInt *Uj = factors->jU_h.data();
2203aac854edSJunchao Zhang       for (PetscInt i = 0; i < m; i++) {
2204aac854edSJunchao Zhang         PetscInt ulen = Bi[i + 1] - Bi[i];
2205aac854edSJunchao Zhang         Uj[Bi[i]]     = i;                                              // diagonal entry
2206aac854edSJunchao Zhang         PetscCall(PetscArraycpy(Uj + Bi[i] + 1, Bj + Bi[i], ulen - 1)); // entries of U on the right of the diagonal
2207aac854edSJunchao Zhang       }
2208aac854edSJunchao Zhang 
2209aac854edSJunchao Zhang       // Copy iU, jU to device
2210aac854edSJunchao Zhang       PetscCallCXX(factors->iU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->iU_h));
2211aac854edSJunchao Zhang       PetscCallCXX(factors->jU_d = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), factors->jU_h));
2212aac854edSJunchao Zhang 
2213aac854edSJunchao Zhang       // Copy row/col permutation to device
2214aac854edSJunchao Zhang       IS        rowperm = ((Mat_SeqAIJ *)B->data)->row;
2215aac854edSJunchao Zhang       PetscBool row_identity;
2216aac854edSJunchao Zhang       PetscCall(ISIdentity(rowperm, &row_identity));
2217aac854edSJunchao Zhang       if (!row_identity) {
2218aac854edSJunchao Zhang         const PetscInt *ip;
2219aac854edSJunchao Zhang 
2220aac854edSJunchao Zhang         PetscCall(ISGetIndices(rowperm, &ip));
2221aac854edSJunchao Zhang         PetscCallCXX(factors->rowperm = PetscIntKokkosView(NoInit("rowperm"), m));
2222aac854edSJunchao Zhang         PetscCallCXX(Kokkos::deep_copy(factors->rowperm, PetscIntKokkosViewHost(const_cast<PetscInt *>(ip), m)));
2223aac854edSJunchao Zhang         PetscCall(ISRestoreIndices(rowperm, &ip));
2224aac854edSJunchao Zhang         PetscCall(PetscLogCpuToGpu(m * sizeof(PetscInt)));
2225aac854edSJunchao Zhang       }
2226aac854edSJunchao Zhang 
2227aac854edSJunchao Zhang       // Create sptrsv handles for U and U^T
2228aac854edSJunchao Zhang #if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
2229aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SPTRSV_CUSPARSE;
2230aac854edSJunchao Zhang #else
2231aac854edSJunchao Zhang       auto sptrsv_alg = SPTRSVAlgorithm::SEQLVLSCHD_TP1;
2232aac854edSJunchao Zhang #endif
2233aac854edSJunchao Zhang       factors->khU.create_sptrsv_handle(sptrsv_alg, m, false /* U is not lower tri */);
2234aac854edSJunchao Zhang       factors->khUt.create_sptrsv_handle(sptrsv_alg, m, true /* U^T is lower tri */);
2235aac854edSJunchao Zhang     }
2236aac854edSJunchao Zhang     // These pointers were set MatCholeskyFactorNumeric_SeqAIJ(), so we always need to update them
2237aac854edSJunchao Zhang     B->ops->solve          = MatSolve_SeqAIJKokkos_Cholesky;
2238aac854edSJunchao Zhang     B->ops->solvetranspose = MatSolve_SeqAIJKokkos_Cholesky;
2239aac854edSJunchao Zhang 
2240aac854edSJunchao Zhang     // Copy the value
2241aac854edSJunchao Zhang     PetscScalar *Ua = factors->aU_h.data();
2242aac854edSJunchao Zhang     PetscScalar *D  = factors->D_h.data();
2243aac854edSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
2244aac854edSJunchao Zhang       D[i]      = Ba[Bdiag[i]];     // actually Aa[Adiag[i]] is the inverse of the diagonal
2245aac854edSJunchao Zhang       Ua[Bi[i]] = (PetscScalar)1.0; // set the unit diagonal for U
2246aac854edSJunchao Zhang       for (PetscInt k = 0; k < Bi[i + 1] - Bi[i] - 1; k++) Ua[Bi[i] + 1 + k] = -Ba[Bi[i] + k];
2247aac854edSJunchao Zhang     }
2248aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->aU_d, factors->aU_h));
2249aac854edSJunchao Zhang     PetscCallCXX(Kokkos::deep_copy(factors->D_d, factors->D_h));
2250aac854edSJunchao Zhang 
2251aac854edSJunchao Zhang     factors->sptrsv_symbolic_completed = PETSC_FALSE; // When numeric value changed, we must do these again
2252aac854edSJunchao Zhang     factors->transpose_updated         = PETSC_FALSE;
2253aac854edSJunchao Zhang   }
2254aac854edSJunchao Zhang 
2255aac854edSJunchao Zhang   B->ops->matsolve          = NULL;
2256aac854edSJunchao Zhang   B->ops->matsolvetranspose = NULL;
2257aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2258aac854edSJunchao Zhang }
2259aac854edSJunchao Zhang 
2260aac854edSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS perm, const MatFactorInfo *info)
2261aac854edSJunchao Zhang {
2262aac854edSJunchao Zhang   PetscFunctionBegin;
2263aac854edSJunchao Zhang   if (info->solveonhost) {
2264aac854edSJunchao Zhang     // If solve on host, we have to change the type, as eventually we need to call MatSolve_SeqSBAIJ_1_NaturalOrdering() etc.
2265aac854edSJunchao Zhang     PetscCall(MatSetType(B, MATSEQSBAIJ));
2266aac854edSJunchao Zhang     PetscCall(MatSeqSBAIJSetPreallocation(B, 1, MAT_SKIP_ALLOCATION, NULL));
2267aac854edSJunchao Zhang   }
2268aac854edSJunchao Zhang 
2269aac854edSJunchao Zhang   PetscCall(MatICCFactorSymbolic_SeqAIJ(B, A, perm, info));
2270aac854edSJunchao Zhang 
2271aac854edSJunchao Zhang   if (!info->solveonhost) {
2272bfe80ac4SPierre Jolivet     // If solve on device, B is still a MATSEQAIJKOKKOS, so we are good to allocate B->spptr
2273aac854edSJunchao Zhang     PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2274aac854edSJunchao Zhang     PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2275aac854edSJunchao Zhang     B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJKokkos;
2276aac854edSJunchao Zhang   }
2277aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2278aac854edSJunchao Zhang }
2279aac854edSJunchao Zhang 
2280aac854edSJunchao Zhang static PetscErrorCode MatCholeskyFactorSymbolic_SeqAIJKokkos(Mat B, Mat A, IS perm, const MatFactorInfo *info)
2281aac854edSJunchao Zhang {
2282aac854edSJunchao Zhang   PetscFunctionBegin;
2283aac854edSJunchao Zhang   if (info->solveonhost) {
2284aac854edSJunchao Zhang     // If solve on host, we have to change the type, as eventually we need to call MatSolve_SeqSBAIJ_1_NaturalOrdering() etc.
2285aac854edSJunchao Zhang     PetscCall(MatSetType(B, MATSEQSBAIJ));
2286aac854edSJunchao Zhang     PetscCall(MatSeqSBAIJSetPreallocation(B, 1, MAT_SKIP_ALLOCATION, NULL));
2287aac854edSJunchao Zhang   }
2288aac854edSJunchao Zhang 
2289aac854edSJunchao Zhang   PetscCall(MatCholeskyFactorSymbolic_SeqAIJ(B, A, perm, info)); // it sets B's two ISes ((Mat_SeqAIJ*)B->data)->{row, col} to perm
2290aac854edSJunchao Zhang 
2291aac854edSJunchao Zhang   if (!info->solveonhost) {
2292bfe80ac4SPierre Jolivet     // If solve on device, B is still a MATSEQAIJKOKKOS, so we are good to allocate B->spptr
2293aac854edSJunchao Zhang     PetscCheck(!B->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2294aac854edSJunchao Zhang     PetscCallCXX(B->spptr = new Mat_SeqAIJKokkosTriFactors(B->rmap->n));
2295aac854edSJunchao Zhang     B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJKokkos;
2296aac854edSJunchao Zhang   }
2297aac854edSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2298aac854edSJunchao Zhang }
2299aac854edSJunchao Zhang 
2300aac854edSJunchao Zhang // The _Kokkos suffix means we will use Kokkos as a solver for the SeqAIJKokkos matrix
2301aac854edSJunchao Zhang static PetscErrorCode MatFactorGetSolverType_SeqAIJKokkos_Kokkos(Mat A, MatSolverType *type)
2302d71ae5a4SJacob Faibussowitsch {
2303930e68a5SMark Adams   PetscFunctionBegin;
2304930e68a5SMark Adams   *type = MATSOLVERKOKKOS;
23053ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2306930e68a5SMark Adams }
2307930e68a5SMark Adams 
2308930e68a5SMark Adams /*MC
230986a27549SJunchao Zhang   MATSOLVERKOKKOS = "Kokkos" - A matrix solver type providing triangular solvers for sequential matrices
231011a5261eSBarry Smith   on a single GPU of type, `MATSEQAIJKOKKOS`, `MATAIJKOKKOS`.
2311930e68a5SMark Adams 
2312930e68a5SMark Adams   Level: beginner
2313930e68a5SMark Adams 
23141cc06b55SBarry Smith .seealso: [](ch_matrices), `Mat`, `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJKokkos()`, `MATAIJKOKKOS`, `MatKokkosSetFormat()`, `MatKokkosStorageFormat`, `MatKokkosFormatOperation`
2315930e68a5SMark Adams M*/
231686a27549SJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_SeqAIJKokkos_Kokkos(Mat A, MatFactorType ftype, Mat *B) /* MatGetFactor_<MatType>_<MatSolverType> */
2317930e68a5SMark Adams {
2318930e68a5SMark Adams   PetscInt n = A->rmap->n;
2319aac854edSJunchao Zhang   MPI_Comm comm;
2320930e68a5SMark Adams 
2321930e68a5SMark Adams   PetscFunctionBegin;
2322aac854edSJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
2323aac854edSJunchao Zhang   PetscCall(MatCreate(comm, B));
23249566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(*B, n, n, n, n));
2325aac854edSJunchao Zhang   PetscCall(MatSetBlockSizesFromMats(*B, A, A));
2326930e68a5SMark Adams   (*B)->factortype = ftype;
23279566063dSJacob Faibussowitsch   PetscCall(MatSetType(*B, MATSEQAIJKOKKOS));
23289566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
2329aac854edSJunchao Zhang   PetscCheck(!(*B)->spptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Expected a NULL spptr");
2330aac854edSJunchao Zhang 
2331aac854edSJunchao Zhang   if (ftype == MAT_FACTOR_LU || ftype == MAT_FACTOR_ILU || ftype == MAT_FACTOR_ILUDT) {
2332aac854edSJunchao Zhang     (*B)->ops->lufactorsymbolic  = MatLUFactorSymbolic_SeqAIJKokkos;
2333aac854edSJunchao Zhang     (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJKokkos;
2334aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
2335aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILU]));
2336aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILUDT]));
2337aac854edSJunchao Zhang   } else if (ftype == MAT_FACTOR_CHOLESKY || ftype == MAT_FACTOR_ICC) {
2338aac854edSJunchao Zhang     (*B)->ops->iccfactorsymbolic      = MatICCFactorSymbolic_SeqAIJKokkos;
2339aac854edSJunchao Zhang     (*B)->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqAIJKokkos;
2340aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_CHOLESKY]));
2341aac854edSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ICC]));
2342aac854edSJunchao Zhang   } else SETERRQ(comm, PETSC_ERR_SUP, "MatFactorType %s is not supported by MatType SeqAIJKokkos", MatFactorTypes[ftype]);
2343aac854edSJunchao Zhang 
2344aac854edSJunchao Zhang   // The factorization can use the ordering provided in MatLUFactorSymbolic(), MatCholeskyFactorSymbolic() etc, though we do it on host
2345aac854edSJunchao Zhang   (*B)->canuseordering = PETSC_TRUE;
2346aac854edSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)*B, "MatFactorGetSolverType_C", MatFactorGetSolverType_SeqAIJKokkos_Kokkos));
23473ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2348930e68a5SMark Adams }
23498f7e8f9dSMark Adams 
2350aac854edSJunchao Zhang PETSC_INTERN PetscErrorCode MatSolverTypeRegister_Kokkos(void)
2351d71ae5a4SJacob Faibussowitsch {
235286a27549SJunchao Zhang   PetscFunctionBegin;
23539566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_LU, MatGetFactor_SeqAIJKokkos_Kokkos));
2354aac854edSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_CHOLESKY, MatGetFactor_SeqAIJKokkos_Kokkos));
23559566063dSJacob Faibussowitsch   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ILU, MatGetFactor_SeqAIJKokkos_Kokkos));
2356aac854edSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERKOKKOS, MATSEQAIJKOKKOS, MAT_FACTOR_ICC, MatGetFactor_SeqAIJKokkos_Kokkos));
23573ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
235886a27549SJunchao Zhang }
235986a27549SJunchao Zhang 
2360076ba34aSJunchao Zhang /* Utility to print out a KokkosCsrMatrix for debugging */
2361d71ae5a4SJacob Faibussowitsch PETSC_INTERN PetscErrorCode PrintCsrMatrix(const KokkosCsrMatrix &csrmat)
2362d71ae5a4SJacob Faibussowitsch {
236345402d8aSJunchao Zhang   const auto        &iv = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.graph.row_map);
236445402d8aSJunchao Zhang   const auto        &jv = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.graph.entries);
236545402d8aSJunchao Zhang   const auto        &av = Kokkos::create_mirror_view_and_copy(HostMirrorMemorySpace(), csrmat.values);
2366076ba34aSJunchao Zhang   const PetscInt    *i  = iv.data();
2367076ba34aSJunchao Zhang   const PetscInt    *j  = jv.data();
2368076ba34aSJunchao Zhang   const PetscScalar *a  = av.data();
2369076ba34aSJunchao Zhang   PetscInt           m = csrmat.numRows(), n = csrmat.numCols(), nnz = csrmat.nnz();
2370076ba34aSJunchao Zhang 
2371076ba34aSJunchao Zhang   PetscFunctionBegin;
23729566063dSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT " x %" PetscInt_FMT " SeqAIJKokkos, with %" PetscInt_FMT " nonzeros\n", m, n, nnz));
2373076ba34aSJunchao Zhang   for (PetscInt k = 0; k < m; k++) {
23749566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT ": ", k));
237548a46eb9SPierre Jolivet     for (PetscInt p = i[k]; p < i[k + 1]; p++) PetscCall(PetscPrintf(PETSC_COMM_SELF, "%" PetscInt_FMT "(%.1f), ", j[p], (double)PetscRealPart(a[p])));
23769566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(PETSC_COMM_SELF, "\n"));
2377076ba34aSJunchao Zhang   }
23783ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
2379076ba34aSJunchao Zhang }
2380