xref: /petsc/src/mat/impls/aij/seq/seqhipsparse/aijhipsparse.hip.cxx (revision d52a580b706c59ca78066c1e38754e45b6b56e2b)
1*d52a580bSJunchao Zhang /*
2*d52a580bSJunchao Zhang   Defines the basic matrix operations for the AIJ (compressed row)
3*d52a580bSJunchao Zhang   matrix storage format using the HIPSPARSE library,
4*d52a580bSJunchao Zhang   Portions of this code are under:
5*d52a580bSJunchao Zhang   Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved.
6*d52a580bSJunchao Zhang */
7*d52a580bSJunchao Zhang #include <petscconf.h>
8*d52a580bSJunchao Zhang #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/
9*d52a580bSJunchao Zhang #include <../src/mat/impls/sbaij/seq/sbaij.h>
10*d52a580bSJunchao Zhang #include <../src/mat/impls/dense/seq/dense.h> // MatMatMultNumeric_SeqDenseHIP_SeqDenseHIP_Internal()
11*d52a580bSJunchao Zhang #include <../src/vec/vec/impls/dvecimpl.h>
12*d52a580bSJunchao Zhang #include <petsc/private/vecimpl.h>
13*d52a580bSJunchao Zhang #undef VecType
14*d52a580bSJunchao Zhang #include <../src/mat/impls/aij/seq/seqhipsparse/hipsparsematimpl.h>
15*d52a580bSJunchao Zhang #include <thrust/adjacent_difference.h>
16*d52a580bSJunchao Zhang #include <thrust/iterator/transform_iterator.h>
17*d52a580bSJunchao Zhang #if PETSC_CPP_VERSION >= 14
18*d52a580bSJunchao Zhang   #define PETSC_HAVE_THRUST_ASYNC 1
19*d52a580bSJunchao Zhang   #include <thrust/async/for_each.h>
20*d52a580bSJunchao Zhang #endif
21*d52a580bSJunchao Zhang #include <thrust/iterator/constant_iterator.h>
22*d52a580bSJunchao Zhang #include <thrust/iterator/discard_iterator.h>
23*d52a580bSJunchao Zhang #include <thrust/binary_search.h>
24*d52a580bSJunchao Zhang #include <thrust/remove.h>
25*d52a580bSJunchao Zhang #include <thrust/sort.h>
26*d52a580bSJunchao Zhang #include <thrust/unique.h>
27*d52a580bSJunchao Zhang 
28*d52a580bSJunchao Zhang const char *const MatHIPSPARSEStorageFormats[] = {"CSR", "ELL", "HYB", "MatHIPSPARSEStorageFormat", "MAT_HIPSPARSE_", 0};
29*d52a580bSJunchao Zhang const char *const MatHIPSPARSESpMVAlgorithms[] = {"MV_ALG_DEFAULT", "COOMV_ALG", "CSRMV_ALG1", "CSRMV_ALG2", "SPMV_ALG_DEFAULT", "SPMV_COO_ALG1", "SPMV_COO_ALG2", "SPMV_CSR_ALG1", "SPMV_CSR_ALG2", "hipsparseSpMVAlg_t", "HIPSPARSE_", 0};
30*d52a580bSJunchao Zhang const char *const MatHIPSPARSESpMMAlgorithms[] = {"ALG_DEFAULT", "COO_ALG1", "COO_ALG2", "COO_ALG3", "CSR_ALG1", "COO_ALG4", "CSR_ALG2", "hipsparseSpMMAlg_t", "HIPSPARSE_SPMM_", 0};
31*d52a580bSJunchao Zhang //const char *const MatHIPSPARSECsr2CscAlgorithms[] = {"INVALID"/*HIPSPARSE does not have enum 0! We created one*/, "ALG1", "ALG2", "hipsparseCsr2CscAlg_t", "HIPSPARSE_CSR2CSC_", 0};
32*d52a580bSJunchao Zhang 
33*d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, const MatFactorInfo *);
34*d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, const MatFactorInfo *);
35*d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorNumeric_SeqAIJHIPSPARSE(Mat, Mat, const MatFactorInfo *);
36*d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, IS, const MatFactorInfo *);
37*d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, IS, const MatFactorInfo *);
38*d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorNumeric_SeqAIJHIPSPARSE(Mat, Mat, const MatFactorInfo *);
39*d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE(Mat, Vec, Vec);
40*d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_NaturalOrdering(Mat, Vec, Vec);
41*d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE(Mat, Vec, Vec);
42*d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering(Mat, Vec, Vec);
43*d52a580bSJunchao Zhang static PetscErrorCode MatSetFromOptions_SeqAIJHIPSPARSE(Mat, PetscOptionItems PetscOptionsObject);
44*d52a580bSJunchao Zhang static PetscErrorCode MatAXPY_SeqAIJHIPSPARSE(Mat, PetscScalar, Mat, MatStructure);
45*d52a580bSJunchao Zhang static PetscErrorCode MatScale_SeqAIJHIPSPARSE(Mat, PetscScalar);
46*d52a580bSJunchao Zhang static PetscErrorCode MatMult_SeqAIJHIPSPARSE(Mat, Vec, Vec);
47*d52a580bSJunchao Zhang static PetscErrorCode MatMultAdd_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec);
48*d52a580bSJunchao Zhang static PetscErrorCode MatMultTranspose_SeqAIJHIPSPARSE(Mat, Vec, Vec);
49*d52a580bSJunchao Zhang static PetscErrorCode MatMultTransposeAdd_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec);
50*d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTranspose_SeqAIJHIPSPARSE(Mat, Vec, Vec);
51*d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec);
52*d52a580bSJunchao Zhang static PetscErrorCode MatMultAddKernel_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec, PetscBool, PetscBool);
53*d52a580bSJunchao Zhang static PetscErrorCode CsrMatrix_Destroy(CsrMatrix **);
54*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSETriFactorStruct **);
55*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSEMultStruct **, MatHIPSPARSEStorageFormat);
56*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSETriFactors_Destroy(Mat_SeqAIJHIPSPARSETriFactors **);
57*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSE_Destroy(Mat);
58*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSECopyFromGPU(Mat);
59*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEILUAnalysisAndCopyToGPU(Mat);
60*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEInvalidateTranspose(Mat, PetscBool);
61*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJCopySubArray_SeqAIJHIPSPARSE(Mat, PetscInt, const PetscInt[], PetscScalar[]);
62*d52a580bSJunchao Zhang static PetscErrorCode MatBindToCPU_SeqAIJHIPSPARSE(Mat, PetscBool);
63*d52a580bSJunchao Zhang static PetscErrorCode MatSetPreallocationCOO_SeqAIJHIPSPARSE(Mat, PetscCount, PetscInt[], PetscInt[]);
64*d52a580bSJunchao Zhang static PetscErrorCode MatSetValuesCOO_SeqAIJHIPSPARSE(Mat, const PetscScalar[], InsertMode);
65*d52a580bSJunchao Zhang 
66*d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatProductSetFromOptions_SeqAIJ_SeqDense(Mat);
67*d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJHIPSPARSE(Mat, MatType, MatReuse, Mat *);
68*d52a580bSJunchao Zhang 
69*d52a580bSJunchao Zhang /*
70*d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetStream(Mat A, const hipStream_t stream)
71*d52a580bSJunchao Zhang {
72*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE*)A->spptr;
73*d52a580bSJunchao Zhang 
74*d52a580bSJunchao Zhang   PetscFunctionBegin;
75*d52a580bSJunchao Zhang   PetscCheck(hipsparsestruct, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing spptr");
76*d52a580bSJunchao Zhang   hipsparsestruct->stream = stream;
77*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetStream(hipsparsestruct->handle, hipsparsestruct->stream));
78*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
79*d52a580bSJunchao Zhang }
80*d52a580bSJunchao Zhang 
81*d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetHandle(Mat A, const hipsparseHandle_t handle)
82*d52a580bSJunchao Zhang {
83*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE*)A->spptr;
84*d52a580bSJunchao Zhang 
85*d52a580bSJunchao Zhang   PetscFunctionBegin;
86*d52a580bSJunchao Zhang   PetscCheck(hipsparsestruct, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing spptr");
87*d52a580bSJunchao Zhang   if (hipsparsestruct->handle != handle) {
88*d52a580bSJunchao Zhang     if (hipsparsestruct->handle) PetscCallHIPSPARSE(hipsparseDestroy(hipsparsestruct->handle));
89*d52a580bSJunchao Zhang     hipsparsestruct->handle = handle;
90*d52a580bSJunchao Zhang   }
91*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetPointerMode(hipsparsestruct->handle, HIPSPARSE_POINTER_MODE_DEVICE));
92*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
93*d52a580bSJunchao Zhang }
94*d52a580bSJunchao Zhang 
95*d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSEClearHandle(Mat A)
96*d52a580bSJunchao Zhang {
97*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE*)A->spptr;
98*d52a580bSJunchao Zhang   PetscBool            flg;
99*d52a580bSJunchao Zhang 
100*d52a580bSJunchao Zhang   PetscFunctionBegin;
101*d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
102*d52a580bSJunchao Zhang   if (!flg || !hipsparsestruct) PetscFunctionReturn(PETSC_SUCCESS);
103*d52a580bSJunchao Zhang   if (hipsparsestruct->handle) hipsparsestruct->handle = 0;
104*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
105*d52a580bSJunchao Zhang }
106*d52a580bSJunchao Zhang */
107*d52a580bSJunchao Zhang 
108*d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatHIPSPARSESetFormat_SeqAIJHIPSPARSE(Mat A, MatHIPSPARSEFormatOperation op, MatHIPSPARSEStorageFormat format)
109*d52a580bSJunchao Zhang {
110*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
111*d52a580bSJunchao Zhang 
112*d52a580bSJunchao Zhang   PetscFunctionBegin;
113*d52a580bSJunchao Zhang   switch (op) {
114*d52a580bSJunchao Zhang   case MAT_HIPSPARSE_MULT:
115*d52a580bSJunchao Zhang     hipsparsestruct->format = format;
116*d52a580bSJunchao Zhang     break;
117*d52a580bSJunchao Zhang   case MAT_HIPSPARSE_ALL:
118*d52a580bSJunchao Zhang     hipsparsestruct->format = format;
119*d52a580bSJunchao Zhang     break;
120*d52a580bSJunchao Zhang   default:
121*d52a580bSJunchao Zhang     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "unsupported operation %d for MatHIPSPARSEFormatOperation. MAT_HIPSPARSE_MULT and MAT_HIPSPARSE_ALL are currently supported.", op);
122*d52a580bSJunchao Zhang   }
123*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
124*d52a580bSJunchao Zhang }
125*d52a580bSJunchao Zhang 
126*d52a580bSJunchao Zhang /*@
127*d52a580bSJunchao Zhang   MatHIPSPARSESetFormat - Sets the storage format of `MATSEQHIPSPARSE` matrices for a particular
128*d52a580bSJunchao Zhang   operation. Only the `MatMult()` operation can use different GPU storage formats
129*d52a580bSJunchao Zhang 
130*d52a580bSJunchao Zhang   Not Collective
131*d52a580bSJunchao Zhang 
132*d52a580bSJunchao Zhang   Input Parameters:
133*d52a580bSJunchao Zhang + A      - Matrix of type `MATSEQAIJHIPSPARSE`
134*d52a580bSJunchao Zhang . op     - `MatHIPSPARSEFormatOperation`. `MATSEQAIJHIPSPARSE` matrices support `MAT_HIPSPARSE_MULT` and `MAT_HIPSPARSE_ALL`.
135*d52a580bSJunchao Zhang          `MATMPIAIJHIPSPARSE` matrices support `MAT_HIPSPARSE_MULT_DIAG`, `MAT_HIPSPARSE_MULT_OFFDIAG`, and `MAT_HIPSPARSE_ALL`.
136*d52a580bSJunchao Zhang - format - `MatHIPSPARSEStorageFormat` (one of `MAT_HIPSPARSE_CSR`, `MAT_HIPSPARSE_ELL`, `MAT_HIPSPARSE_HYB`.)
137*d52a580bSJunchao Zhang 
138*d52a580bSJunchao Zhang   Level: intermediate
139*d52a580bSJunchao Zhang 
140*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MATSEQAIJHIPSPARSE`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation`
141*d52a580bSJunchao Zhang @*/
142*d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetFormat(Mat A, MatHIPSPARSEFormatOperation op, MatHIPSPARSEStorageFormat format)
143*d52a580bSJunchao Zhang {
144*d52a580bSJunchao Zhang   PetscFunctionBegin;
145*d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
146*d52a580bSJunchao Zhang   PetscTryMethod(A, "MatHIPSPARSESetFormat_C", (Mat, MatHIPSPARSEFormatOperation, MatHIPSPARSEStorageFormat), (A, op, format));
147*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
148*d52a580bSJunchao Zhang }
149*d52a580bSJunchao Zhang 
150*d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatHIPSPARSESetUseCPUSolve_SeqAIJHIPSPARSE(Mat A, PetscBool use_cpu)
151*d52a580bSJunchao Zhang {
152*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
153*d52a580bSJunchao Zhang 
154*d52a580bSJunchao Zhang   PetscFunctionBegin;
155*d52a580bSJunchao Zhang   hipsparsestruct->use_cpu_solve = use_cpu;
156*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
157*d52a580bSJunchao Zhang }
158*d52a580bSJunchao Zhang 
159*d52a580bSJunchao Zhang /*@
160*d52a580bSJunchao Zhang   MatHIPSPARSESetUseCPUSolve - Sets use CPU `MatSolve()`.
161*d52a580bSJunchao Zhang 
162*d52a580bSJunchao Zhang   Input Parameters:
163*d52a580bSJunchao Zhang + A       - Matrix of type `MATSEQAIJHIPSPARSE`
164*d52a580bSJunchao Zhang - use_cpu - set flag for using the built-in CPU `MatSolve()`
165*d52a580bSJunchao Zhang 
166*d52a580bSJunchao Zhang   Level: intermediate
167*d52a580bSJunchao Zhang 
168*d52a580bSJunchao Zhang   Notes:
169*d52a580bSJunchao Zhang   The hipSparse LU solver currently computes the factors with the built-in CPU method
170*d52a580bSJunchao Zhang   and moves the factors to the GPU for the solve. We have observed better performance keeping the data on the CPU and computing the solve there.
171*d52a580bSJunchao Zhang   This method to specifies if the solve is done on the CPU or GPU (GPU is the default).
172*d52a580bSJunchao Zhang 
173*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSolve()`, `MATSEQAIJHIPSPARSE`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation`
174*d52a580bSJunchao Zhang @*/
175*d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetUseCPUSolve(Mat A, PetscBool use_cpu)
176*d52a580bSJunchao Zhang {
177*d52a580bSJunchao Zhang   PetscFunctionBegin;
178*d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
179*d52a580bSJunchao Zhang   PetscTryMethod(A, "MatHIPSPARSESetUseCPUSolve_C", (Mat, PetscBool), (A, use_cpu));
180*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
181*d52a580bSJunchao Zhang }
182*d52a580bSJunchao Zhang 
183*d52a580bSJunchao Zhang static PetscErrorCode MatSetOption_SeqAIJHIPSPARSE(Mat A, MatOption op, PetscBool flg)
184*d52a580bSJunchao Zhang {
185*d52a580bSJunchao Zhang   PetscFunctionBegin;
186*d52a580bSJunchao Zhang   switch (op) {
187*d52a580bSJunchao Zhang   case MAT_FORM_EXPLICIT_TRANSPOSE:
188*d52a580bSJunchao Zhang     /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */
189*d52a580bSJunchao Zhang     if (A->form_explicit_transpose && !flg) PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_TRUE));
190*d52a580bSJunchao Zhang     A->form_explicit_transpose = flg;
191*d52a580bSJunchao Zhang     break;
192*d52a580bSJunchao Zhang   default:
193*d52a580bSJunchao Zhang     PetscCall(MatSetOption_SeqAIJ(A, op, flg));
194*d52a580bSJunchao Zhang     break;
195*d52a580bSJunchao Zhang   }
196*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
197*d52a580bSJunchao Zhang }
198*d52a580bSJunchao Zhang 
199*d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorNumeric_SeqAIJHIPSPARSE(Mat B, Mat A, const MatFactorInfo *info)
200*d52a580bSJunchao Zhang {
201*d52a580bSJunchao Zhang   PetscBool            row_identity, col_identity;
202*d52a580bSJunchao Zhang   Mat_SeqAIJ          *b     = (Mat_SeqAIJ *)B->data;
203*d52a580bSJunchao Zhang   IS                   isrow = b->row, iscol = b->col;
204*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)B->spptr;
205*d52a580bSJunchao Zhang 
206*d52a580bSJunchao Zhang   PetscFunctionBegin;
207*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A));
208*d52a580bSJunchao Zhang   PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info));
209*d52a580bSJunchao Zhang   B->offloadmask = PETSC_OFFLOAD_CPU;
210*d52a580bSJunchao Zhang   /* determine which version of MatSolve needs to be used. */
211*d52a580bSJunchao Zhang   PetscCall(ISIdentity(isrow, &row_identity));
212*d52a580bSJunchao Zhang   PetscCall(ISIdentity(iscol, &col_identity));
213*d52a580bSJunchao Zhang   if (!hipsparsestruct->use_cpu_solve) {
214*d52a580bSJunchao Zhang     if (row_identity && col_identity) {
215*d52a580bSJunchao Zhang       B->ops->solve          = MatSolve_SeqAIJHIPSPARSE_NaturalOrdering;
216*d52a580bSJunchao Zhang       B->ops->solvetranspose = MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering;
217*d52a580bSJunchao Zhang     } else {
218*d52a580bSJunchao Zhang       B->ops->solve          = MatSolve_SeqAIJHIPSPARSE;
219*d52a580bSJunchao Zhang       B->ops->solvetranspose = MatSolveTranspose_SeqAIJHIPSPARSE;
220*d52a580bSJunchao Zhang     }
221*d52a580bSJunchao Zhang   }
222*d52a580bSJunchao Zhang   B->ops->matsolve          = NULL;
223*d52a580bSJunchao Zhang   B->ops->matsolvetranspose = NULL;
224*d52a580bSJunchao Zhang 
225*d52a580bSJunchao Zhang   /* get the triangular factors */
226*d52a580bSJunchao Zhang   if (!hipsparsestruct->use_cpu_solve) PetscCall(MatSeqAIJHIPSPARSEILUAnalysisAndCopyToGPU(B));
227*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
228*d52a580bSJunchao Zhang }
229*d52a580bSJunchao Zhang 
230*d52a580bSJunchao Zhang static PetscErrorCode MatSetFromOptions_SeqAIJHIPSPARSE(Mat A, PetscOptionItems PetscOptionsObject)
231*d52a580bSJunchao Zhang {
232*d52a580bSJunchao Zhang   MatHIPSPARSEStorageFormat format;
233*d52a580bSJunchao Zhang   PetscBool                 flg;
234*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE      *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
235*d52a580bSJunchao Zhang 
236*d52a580bSJunchao Zhang   PetscFunctionBegin;
237*d52a580bSJunchao Zhang   PetscOptionsHeadBegin(PetscOptionsObject, "SeqAIJHIPSPARSE options");
238*d52a580bSJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
239*d52a580bSJunchao Zhang     PetscCall(PetscOptionsEnum("-mat_hipsparse_mult_storage_format", "sets storage format of (seq)aijhipsparse gpu matrices for SpMV", "MatHIPSPARSESetFormat", MatHIPSPARSEStorageFormats, (PetscEnum)hipsparsestruct->format, (PetscEnum *)&format, &flg));
240*d52a580bSJunchao Zhang     if (flg) PetscCall(MatHIPSPARSESetFormat(A, MAT_HIPSPARSE_MULT, format));
241*d52a580bSJunchao Zhang     PetscCall(PetscOptionsEnum("-mat_hipsparse_storage_format", "sets storage format of (seq)aijhipsparse gpu matrices for SpMV and TriSolve", "MatHIPSPARSESetFormat", MatHIPSPARSEStorageFormats, (PetscEnum)hipsparsestruct->format, (PetscEnum *)&format, &flg));
242*d52a580bSJunchao Zhang     if (flg) PetscCall(MatHIPSPARSESetFormat(A, MAT_HIPSPARSE_ALL, format));
243*d52a580bSJunchao Zhang     PetscCall(PetscOptionsBool("-mat_hipsparse_use_cpu_solve", "Use CPU (I)LU solve", "MatHIPSPARSESetUseCPUSolve", hipsparsestruct->use_cpu_solve, &hipsparsestruct->use_cpu_solve, &flg));
244*d52a580bSJunchao Zhang     if (flg) PetscCall(MatHIPSPARSESetUseCPUSolve(A, hipsparsestruct->use_cpu_solve));
245*d52a580bSJunchao Zhang     PetscCall(
246*d52a580bSJunchao Zhang       PetscOptionsEnum("-mat_hipsparse_spmv_alg", "sets hipSPARSE algorithm used in sparse-mat dense-vector multiplication (SpMV)", "hipsparseSpMVAlg_t", MatHIPSPARSESpMVAlgorithms, (PetscEnum)hipsparsestruct->spmvAlg, (PetscEnum *)&hipsparsestruct->spmvAlg, &flg));
247*d52a580bSJunchao Zhang     /* If user did use this option, check its consistency with hipSPARSE, since PetscOptionsEnum() sets enum values based on their position in MatHIPSPARSESpMVAlgorithms[] */
248*d52a580bSJunchao Zhang     PetscCheck(!flg || HIPSPARSE_CSRMV_ALG1 == 2, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE enum hipsparseSpMVAlg_t has been changed but PETSc has not been updated accordingly");
249*d52a580bSJunchao Zhang     PetscCall(
250*d52a580bSJunchao Zhang       PetscOptionsEnum("-mat_hipsparse_spmm_alg", "sets hipSPARSE algorithm used in sparse-mat dense-mat multiplication (SpMM)", "hipsparseSpMMAlg_t", MatHIPSPARSESpMMAlgorithms, (PetscEnum)hipsparsestruct->spmmAlg, (PetscEnum *)&hipsparsestruct->spmmAlg, &flg));
251*d52a580bSJunchao Zhang     PetscCheck(!flg || HIPSPARSE_SPMM_CSR_ALG1 == 4, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE enum hipsparseSpMMAlg_t has been changed but PETSc has not been updated accordingly");
252*d52a580bSJunchao Zhang     /*
253*d52a580bSJunchao Zhang     PetscCall(PetscOptionsEnum("-mat_hipsparse_csr2csc_alg", "sets hipSPARSE algorithm used in converting CSR matrices to CSC matrices", "hipsparseCsr2CscAlg_t", MatHIPSPARSECsr2CscAlgorithms, (PetscEnum)hipsparsestruct->csr2cscAlg, (PetscEnum*)&hipsparsestruct->csr2cscAlg, &flg));
254*d52a580bSJunchao Zhang     PetscCheck(!flg || HIPSPARSE_CSR2CSC_ALG1 == 1, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE enum hipsparseCsr2CscAlg_t has been changed but PETSc has not been updated accordingly");
255*d52a580bSJunchao Zhang     */
256*d52a580bSJunchao Zhang   }
257*d52a580bSJunchao Zhang   PetscOptionsHeadEnd();
258*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
259*d52a580bSJunchao Zhang }
260*d52a580bSJunchao Zhang 
261*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEBuildILULowerTriMatrix(Mat A)
262*d52a580bSJunchao Zhang {
263*d52a580bSJunchao Zhang   Mat_SeqAIJ                         *a                   = (Mat_SeqAIJ *)A->data;
264*d52a580bSJunchao Zhang   PetscInt                            n                   = A->rmap->n;
265*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors      *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
266*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr;
267*d52a580bSJunchao Zhang   const PetscInt                     *ai = a->i, *aj = a->j, *vi;
268*d52a580bSJunchao Zhang   const MatScalar                    *aa = a->a, *v;
269*d52a580bSJunchao Zhang   PetscInt                           *AiLo, *AjLo;
270*d52a580bSJunchao Zhang   PetscInt                            i, nz, nzLower, offset, rowOffset;
271*d52a580bSJunchao Zhang 
272*d52a580bSJunchao Zhang   PetscFunctionBegin;
273*d52a580bSJunchao Zhang   if (!n) PetscFunctionReturn(PETSC_SUCCESS);
274*d52a580bSJunchao Zhang   if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) {
275*d52a580bSJunchao Zhang     try {
276*d52a580bSJunchao Zhang       /* first figure out the number of nonzeros in the lower triangular matrix including 1's on the diagonal. */
277*d52a580bSJunchao Zhang       nzLower = n + ai[n] - ai[1];
278*d52a580bSJunchao Zhang       if (!loTriFactor) {
279*d52a580bSJunchao Zhang         PetscScalar *AALo;
280*d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AALo, nzLower * sizeof(PetscScalar)));
281*d52a580bSJunchao Zhang 
282*d52a580bSJunchao Zhang         /* Allocate Space for the lower triangular matrix */
283*d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AiLo, (n + 1) * sizeof(PetscInt)));
284*d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AjLo, nzLower * sizeof(PetscInt)));
285*d52a580bSJunchao Zhang 
286*d52a580bSJunchao Zhang         /* Fill the lower triangular matrix */
287*d52a580bSJunchao Zhang         AiLo[0]   = (PetscInt)0;
288*d52a580bSJunchao Zhang         AiLo[n]   = nzLower;
289*d52a580bSJunchao Zhang         AjLo[0]   = (PetscInt)0;
290*d52a580bSJunchao Zhang         AALo[0]   = (MatScalar)1.0;
291*d52a580bSJunchao Zhang         v         = aa;
292*d52a580bSJunchao Zhang         vi        = aj;
293*d52a580bSJunchao Zhang         offset    = 1;
294*d52a580bSJunchao Zhang         rowOffset = 1;
295*d52a580bSJunchao Zhang         for (i = 1; i < n; i++) {
296*d52a580bSJunchao Zhang           nz = ai[i + 1] - ai[i];
297*d52a580bSJunchao Zhang           /* additional 1 for the term on the diagonal */
298*d52a580bSJunchao Zhang           AiLo[i] = rowOffset;
299*d52a580bSJunchao Zhang           rowOffset += nz + 1;
300*d52a580bSJunchao Zhang 
301*d52a580bSJunchao Zhang           PetscCall(PetscArraycpy(&AjLo[offset], vi, nz));
302*d52a580bSJunchao Zhang           PetscCall(PetscArraycpy(&AALo[offset], v, nz));
303*d52a580bSJunchao Zhang           offset += nz;
304*d52a580bSJunchao Zhang           AjLo[offset] = (PetscInt)i;
305*d52a580bSJunchao Zhang           AALo[offset] = (MatScalar)1.0;
306*d52a580bSJunchao Zhang           offset += 1;
307*d52a580bSJunchao Zhang           v += nz;
308*d52a580bSJunchao Zhang           vi += nz;
309*d52a580bSJunchao Zhang         }
310*d52a580bSJunchao Zhang 
311*d52a580bSJunchao Zhang         /* allocate space for the triangular factor information */
312*d52a580bSJunchao Zhang         PetscCall(PetscNew(&loTriFactor));
313*d52a580bSJunchao Zhang         loTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
314*d52a580bSJunchao Zhang         /* Create the matrix description */
315*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateMatDescr(&loTriFactor->descr));
316*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatIndexBase(loTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO));
317*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatType(loTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
318*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatFillMode(loTriFactor->descr, HIPSPARSE_FILL_MODE_LOWER));
319*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatDiagType(loTriFactor->descr, HIPSPARSE_DIAG_TYPE_UNIT));
320*d52a580bSJunchao Zhang 
321*d52a580bSJunchao Zhang         /* set the operation */
322*d52a580bSJunchao Zhang         loTriFactor->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE;
323*d52a580bSJunchao Zhang 
324*d52a580bSJunchao Zhang         /* set the matrix */
325*d52a580bSJunchao Zhang         loTriFactor->csrMat                 = new CsrMatrix;
326*d52a580bSJunchao Zhang         loTriFactor->csrMat->num_rows       = n;
327*d52a580bSJunchao Zhang         loTriFactor->csrMat->num_cols       = n;
328*d52a580bSJunchao Zhang         loTriFactor->csrMat->num_entries    = nzLower;
329*d52a580bSJunchao Zhang         loTriFactor->csrMat->row_offsets    = new THRUSTINTARRAY32(n + 1);
330*d52a580bSJunchao Zhang         loTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(nzLower);
331*d52a580bSJunchao Zhang         loTriFactor->csrMat->values         = new THRUSTARRAY(nzLower);
332*d52a580bSJunchao Zhang 
333*d52a580bSJunchao Zhang         loTriFactor->csrMat->row_offsets->assign(AiLo, AiLo + n + 1);
334*d52a580bSJunchao Zhang         loTriFactor->csrMat->column_indices->assign(AjLo, AjLo + nzLower);
335*d52a580bSJunchao Zhang         loTriFactor->csrMat->values->assign(AALo, AALo + nzLower);
336*d52a580bSJunchao Zhang 
337*d52a580bSJunchao Zhang         /* Create the solve analysis information */
338*d52a580bSJunchao Zhang         PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
339*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&loTriFactor->solveInfo));
340*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
341*d52a580bSJunchao Zhang                                                     loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, &loTriFactor->solveBufferSize));
342*d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc(&loTriFactor->solveBuffer, loTriFactor->solveBufferSize));
343*d52a580bSJunchao Zhang 
344*d52a580bSJunchao Zhang         /* perform the solve analysis */
345*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
346*d52a580bSJunchao Zhang                                                     loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, loTriFactor->solvePolicy, loTriFactor->solveBuffer));
347*d52a580bSJunchao Zhang 
348*d52a580bSJunchao Zhang         PetscCallHIP(WaitForHIP());
349*d52a580bSJunchao Zhang         PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
350*d52a580bSJunchao Zhang 
351*d52a580bSJunchao Zhang         /* assign the pointer */
352*d52a580bSJunchao Zhang         ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->loTriFactorPtr = loTriFactor;
353*d52a580bSJunchao Zhang         loTriFactor->AA_h                                           = AALo;
354*d52a580bSJunchao Zhang         PetscCallHIP(hipHostFree(AiLo));
355*d52a580bSJunchao Zhang         PetscCallHIP(hipHostFree(AjLo));
356*d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu((n + 1 + nzLower) * sizeof(int) + nzLower * sizeof(PetscScalar)));
357*d52a580bSJunchao Zhang       } else { /* update values only */
358*d52a580bSJunchao Zhang         if (!loTriFactor->AA_h) PetscCallHIP(hipHostMalloc((void **)&loTriFactor->AA_h, nzLower * sizeof(PetscScalar)));
359*d52a580bSJunchao Zhang         /* Fill the lower triangular matrix */
360*d52a580bSJunchao Zhang         loTriFactor->AA_h[0] = 1.0;
361*d52a580bSJunchao Zhang         v                    = aa;
362*d52a580bSJunchao Zhang         vi                   = aj;
363*d52a580bSJunchao Zhang         offset               = 1;
364*d52a580bSJunchao Zhang         for (i = 1; i < n; i++) {
365*d52a580bSJunchao Zhang           nz = ai[i + 1] - ai[i];
366*d52a580bSJunchao Zhang           PetscCall(PetscArraycpy(&loTriFactor->AA_h[offset], v, nz));
367*d52a580bSJunchao Zhang           offset += nz;
368*d52a580bSJunchao Zhang           loTriFactor->AA_h[offset] = 1.0;
369*d52a580bSJunchao Zhang           offset += 1;
370*d52a580bSJunchao Zhang           v += nz;
371*d52a580bSJunchao Zhang         }
372*d52a580bSJunchao Zhang         loTriFactor->csrMat->values->assign(loTriFactor->AA_h, loTriFactor->AA_h + nzLower);
373*d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu(nzLower * sizeof(PetscScalar)));
374*d52a580bSJunchao Zhang       }
375*d52a580bSJunchao Zhang     } catch (char *ex) {
376*d52a580bSJunchao Zhang       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex);
377*d52a580bSJunchao Zhang     }
378*d52a580bSJunchao Zhang   }
379*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
380*d52a580bSJunchao Zhang }
381*d52a580bSJunchao Zhang 
382*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEBuildILUUpperTriMatrix(Mat A)
383*d52a580bSJunchao Zhang {
384*d52a580bSJunchao Zhang   Mat_SeqAIJ                         *a                   = (Mat_SeqAIJ *)A->data;
385*d52a580bSJunchao Zhang   PetscInt                            n                   = A->rmap->n;
386*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors      *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
387*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr;
388*d52a580bSJunchao Zhang   const PetscInt                     *aj                  = a->j, *adiag, *vi;
389*d52a580bSJunchao Zhang   const MatScalar                    *aa                  = a->a, *v;
390*d52a580bSJunchao Zhang   PetscInt                           *AiUp, *AjUp;
391*d52a580bSJunchao Zhang   PetscInt                            i, nz, nzUpper, offset;
392*d52a580bSJunchao Zhang 
393*d52a580bSJunchao Zhang   PetscFunctionBegin;
394*d52a580bSJunchao Zhang   if (!n) PetscFunctionReturn(PETSC_SUCCESS);
395*d52a580bSJunchao Zhang   PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, &adiag, NULL));
396*d52a580bSJunchao Zhang   if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) {
397*d52a580bSJunchao Zhang     try {
398*d52a580bSJunchao Zhang       /* next, figure out the number of nonzeros in the upper triangular matrix. */
399*d52a580bSJunchao Zhang       nzUpper = adiag[0] - adiag[n];
400*d52a580bSJunchao Zhang       if (!upTriFactor) {
401*d52a580bSJunchao Zhang         PetscScalar *AAUp;
402*d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AAUp, nzUpper * sizeof(PetscScalar)));
403*d52a580bSJunchao Zhang 
404*d52a580bSJunchao Zhang         /* Allocate Space for the upper triangular matrix */
405*d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AiUp, (n + 1) * sizeof(PetscInt)));
406*d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AjUp, nzUpper * sizeof(PetscInt)));
407*d52a580bSJunchao Zhang 
408*d52a580bSJunchao Zhang         /* Fill the upper triangular matrix */
409*d52a580bSJunchao Zhang         AiUp[0] = (PetscInt)0;
410*d52a580bSJunchao Zhang         AiUp[n] = nzUpper;
411*d52a580bSJunchao Zhang         offset  = nzUpper;
412*d52a580bSJunchao Zhang         for (i = n - 1; i >= 0; i--) {
413*d52a580bSJunchao Zhang           v  = aa + adiag[i + 1] + 1;
414*d52a580bSJunchao Zhang           vi = aj + adiag[i + 1] + 1;
415*d52a580bSJunchao Zhang           nz = adiag[i] - adiag[i + 1] - 1; /* number of elements NOT on the diagonal */
416*d52a580bSJunchao Zhang           offset -= (nz + 1);               /* decrement the offset */
417*d52a580bSJunchao Zhang 
418*d52a580bSJunchao Zhang           /* first, set the diagonal elements */
419*d52a580bSJunchao Zhang           AjUp[offset] = (PetscInt)i;
420*d52a580bSJunchao Zhang           AAUp[offset] = (MatScalar)1. / v[nz];
421*d52a580bSJunchao Zhang           AiUp[i]      = AiUp[i + 1] - (nz + 1);
422*d52a580bSJunchao Zhang 
423*d52a580bSJunchao Zhang           PetscCall(PetscArraycpy(&AjUp[offset + 1], vi, nz));
424*d52a580bSJunchao Zhang           PetscCall(PetscArraycpy(&AAUp[offset + 1], v, nz));
425*d52a580bSJunchao Zhang         }
426*d52a580bSJunchao Zhang 
427*d52a580bSJunchao Zhang         /* allocate space for the triangular factor information */
428*d52a580bSJunchao Zhang         PetscCall(PetscNew(&upTriFactor));
429*d52a580bSJunchao Zhang         upTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
430*d52a580bSJunchao Zhang 
431*d52a580bSJunchao Zhang         /* Create the matrix description */
432*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateMatDescr(&upTriFactor->descr));
433*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatIndexBase(upTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO));
434*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatType(upTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
435*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatFillMode(upTriFactor->descr, HIPSPARSE_FILL_MODE_UPPER));
436*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatDiagType(upTriFactor->descr, HIPSPARSE_DIAG_TYPE_NON_UNIT));
437*d52a580bSJunchao Zhang 
438*d52a580bSJunchao Zhang         /* set the operation */
439*d52a580bSJunchao Zhang         upTriFactor->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE;
440*d52a580bSJunchao Zhang 
441*d52a580bSJunchao Zhang         /* set the matrix */
442*d52a580bSJunchao Zhang         upTriFactor->csrMat                 = new CsrMatrix;
443*d52a580bSJunchao Zhang         upTriFactor->csrMat->num_rows       = n;
444*d52a580bSJunchao Zhang         upTriFactor->csrMat->num_cols       = n;
445*d52a580bSJunchao Zhang         upTriFactor->csrMat->num_entries    = nzUpper;
446*d52a580bSJunchao Zhang         upTriFactor->csrMat->row_offsets    = new THRUSTINTARRAY32(n + 1);
447*d52a580bSJunchao Zhang         upTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(nzUpper);
448*d52a580bSJunchao Zhang         upTriFactor->csrMat->values         = new THRUSTARRAY(nzUpper);
449*d52a580bSJunchao Zhang         upTriFactor->csrMat->row_offsets->assign(AiUp, AiUp + n + 1);
450*d52a580bSJunchao Zhang         upTriFactor->csrMat->column_indices->assign(AjUp, AjUp + nzUpper);
451*d52a580bSJunchao Zhang         upTriFactor->csrMat->values->assign(AAUp, AAUp + nzUpper);
452*d52a580bSJunchao Zhang 
453*d52a580bSJunchao Zhang         /* Create the solve analysis information */
454*d52a580bSJunchao Zhang         PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
455*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&upTriFactor->solveInfo));
456*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
457*d52a580bSJunchao Zhang                                                     upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, &upTriFactor->solveBufferSize));
458*d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc(&upTriFactor->solveBuffer, upTriFactor->solveBufferSize));
459*d52a580bSJunchao Zhang 
460*d52a580bSJunchao Zhang         /* perform the solve analysis */
461*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
462*d52a580bSJunchao Zhang                                                     upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, upTriFactor->solvePolicy, upTriFactor->solveBuffer));
463*d52a580bSJunchao Zhang 
464*d52a580bSJunchao Zhang         PetscCallHIP(WaitForHIP());
465*d52a580bSJunchao Zhang         PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
466*d52a580bSJunchao Zhang 
467*d52a580bSJunchao Zhang         /* assign the pointer */
468*d52a580bSJunchao Zhang         ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->upTriFactorPtr = upTriFactor;
469*d52a580bSJunchao Zhang         upTriFactor->AA_h                                           = AAUp;
470*d52a580bSJunchao Zhang         PetscCallHIP(hipHostFree(AiUp));
471*d52a580bSJunchao Zhang         PetscCallHIP(hipHostFree(AjUp));
472*d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu((n + 1 + nzUpper) * sizeof(int) + nzUpper * sizeof(PetscScalar)));
473*d52a580bSJunchao Zhang       } else {
474*d52a580bSJunchao Zhang         if (!upTriFactor->AA_h) PetscCallHIP(hipHostMalloc((void **)&upTriFactor->AA_h, nzUpper * sizeof(PetscScalar)));
475*d52a580bSJunchao Zhang         /* Fill the upper triangular matrix */
476*d52a580bSJunchao Zhang         offset = nzUpper;
477*d52a580bSJunchao Zhang         for (i = n - 1; i >= 0; i--) {
478*d52a580bSJunchao Zhang           v  = aa + adiag[i + 1] + 1;
479*d52a580bSJunchao Zhang           nz = adiag[i] - adiag[i + 1] - 1; /* number of elements NOT on the diagonal */
480*d52a580bSJunchao Zhang           offset -= (nz + 1);               /* decrement the offset */
481*d52a580bSJunchao Zhang 
482*d52a580bSJunchao Zhang           /* first, set the diagonal elements */
483*d52a580bSJunchao Zhang           upTriFactor->AA_h[offset] = 1. / v[nz];
484*d52a580bSJunchao Zhang           PetscCall(PetscArraycpy(&upTriFactor->AA_h[offset + 1], v, nz));
485*d52a580bSJunchao Zhang         }
486*d52a580bSJunchao Zhang         upTriFactor->csrMat->values->assign(upTriFactor->AA_h, upTriFactor->AA_h + nzUpper);
487*d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu(nzUpper * sizeof(PetscScalar)));
488*d52a580bSJunchao Zhang       }
489*d52a580bSJunchao Zhang     } catch (char *ex) {
490*d52a580bSJunchao Zhang       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex);
491*d52a580bSJunchao Zhang     }
492*d52a580bSJunchao Zhang   }
493*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
494*d52a580bSJunchao Zhang }
495*d52a580bSJunchao Zhang 
496*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEILUAnalysisAndCopyToGPU(Mat A)
497*d52a580bSJunchao Zhang {
498*d52a580bSJunchao Zhang   PetscBool                      row_identity, col_identity;
499*d52a580bSJunchao Zhang   Mat_SeqAIJ                    *a                   = (Mat_SeqAIJ *)A->data;
500*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
501*d52a580bSJunchao Zhang   IS                             isrow = a->row, iscol = a->icol;
502*d52a580bSJunchao Zhang   PetscInt                       n = A->rmap->n;
503*d52a580bSJunchao Zhang 
504*d52a580bSJunchao Zhang   PetscFunctionBegin;
505*d52a580bSJunchao Zhang   PetscCheck(hipsparseTriFactors, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors");
506*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEBuildILULowerTriMatrix(A));
507*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEBuildILUUpperTriMatrix(A));
508*d52a580bSJunchao Zhang 
509*d52a580bSJunchao Zhang   if (!hipsparseTriFactors->workVector) hipsparseTriFactors->workVector = new THRUSTARRAY(n);
510*d52a580bSJunchao Zhang   hipsparseTriFactors->nnz = a->nz;
511*d52a580bSJunchao Zhang 
512*d52a580bSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_BOTH;
513*d52a580bSJunchao Zhang   /* lower triangular indices */
514*d52a580bSJunchao Zhang   PetscCall(ISIdentity(isrow, &row_identity));
515*d52a580bSJunchao Zhang   if (!row_identity && !hipsparseTriFactors->rpermIndices) {
516*d52a580bSJunchao Zhang     const PetscInt *r;
517*d52a580bSJunchao Zhang 
518*d52a580bSJunchao Zhang     PetscCall(ISGetIndices(isrow, &r));
519*d52a580bSJunchao Zhang     hipsparseTriFactors->rpermIndices = new THRUSTINTARRAY(n);
520*d52a580bSJunchao Zhang     hipsparseTriFactors->rpermIndices->assign(r, r + n);
521*d52a580bSJunchao Zhang     PetscCall(ISRestoreIndices(isrow, &r));
522*d52a580bSJunchao Zhang     PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));
523*d52a580bSJunchao Zhang   }
524*d52a580bSJunchao Zhang   /* upper triangular indices */
525*d52a580bSJunchao Zhang   PetscCall(ISIdentity(iscol, &col_identity));
526*d52a580bSJunchao Zhang   if (!col_identity && !hipsparseTriFactors->cpermIndices) {
527*d52a580bSJunchao Zhang     const PetscInt *c;
528*d52a580bSJunchao Zhang 
529*d52a580bSJunchao Zhang     PetscCall(ISGetIndices(iscol, &c));
530*d52a580bSJunchao Zhang     hipsparseTriFactors->cpermIndices = new THRUSTINTARRAY(n);
531*d52a580bSJunchao Zhang     hipsparseTriFactors->cpermIndices->assign(c, c + n);
532*d52a580bSJunchao Zhang     PetscCall(ISRestoreIndices(iscol, &c));
533*d52a580bSJunchao Zhang     PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));
534*d52a580bSJunchao Zhang   }
535*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
536*d52a580bSJunchao Zhang }
537*d52a580bSJunchao Zhang 
538*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEBuildICCTriMatrices(Mat A)
539*d52a580bSJunchao Zhang {
540*d52a580bSJunchao Zhang   Mat_SeqAIJ                         *a                   = (Mat_SeqAIJ *)A->data;
541*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors      *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
542*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr;
543*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr;
544*d52a580bSJunchao Zhang   PetscInt                           *AiUp, *AjUp;
545*d52a580bSJunchao Zhang   PetscScalar                        *AAUp;
546*d52a580bSJunchao Zhang   PetscScalar                        *AALo;
547*d52a580bSJunchao Zhang   PetscInt                            nzUpper = a->nz, n = A->rmap->n, i, offset, nz, j;
548*d52a580bSJunchao Zhang   Mat_SeqSBAIJ                       *b  = (Mat_SeqSBAIJ *)A->data;
549*d52a580bSJunchao Zhang   const PetscInt                     *ai = b->i, *aj = b->j, *vj;
550*d52a580bSJunchao Zhang   const MatScalar                    *aa = b->a, *v;
551*d52a580bSJunchao Zhang 
552*d52a580bSJunchao Zhang   PetscFunctionBegin;
553*d52a580bSJunchao Zhang   if (!n) PetscFunctionReturn(PETSC_SUCCESS);
554*d52a580bSJunchao Zhang   if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) {
555*d52a580bSJunchao Zhang     try {
556*d52a580bSJunchao Zhang       PetscCallHIP(hipHostMalloc((void **)&AAUp, nzUpper * sizeof(PetscScalar)));
557*d52a580bSJunchao Zhang       PetscCallHIP(hipHostMalloc((void **)&AALo, nzUpper * sizeof(PetscScalar)));
558*d52a580bSJunchao Zhang       if (!upTriFactor && !loTriFactor) {
559*d52a580bSJunchao Zhang         /* Allocate Space for the upper triangular matrix */
560*d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AiUp, (n + 1) * sizeof(PetscInt)));
561*d52a580bSJunchao Zhang         PetscCallHIP(hipHostMalloc((void **)&AjUp, nzUpper * sizeof(PetscInt)));
562*d52a580bSJunchao Zhang 
563*d52a580bSJunchao Zhang         /* Fill the upper triangular matrix */
564*d52a580bSJunchao Zhang         AiUp[0] = (PetscInt)0;
565*d52a580bSJunchao Zhang         AiUp[n] = nzUpper;
566*d52a580bSJunchao Zhang         offset  = 0;
567*d52a580bSJunchao Zhang         for (i = 0; i < n; i++) {
568*d52a580bSJunchao Zhang           /* set the pointers */
569*d52a580bSJunchao Zhang           v  = aa + ai[i];
570*d52a580bSJunchao Zhang           vj = aj + ai[i];
571*d52a580bSJunchao Zhang           nz = ai[i + 1] - ai[i] - 1; /* exclude diag[i] */
572*d52a580bSJunchao Zhang 
573*d52a580bSJunchao Zhang           /* first, set the diagonal elements */
574*d52a580bSJunchao Zhang           AjUp[offset] = (PetscInt)i;
575*d52a580bSJunchao Zhang           AAUp[offset] = (MatScalar)1.0 / v[nz];
576*d52a580bSJunchao Zhang           AiUp[i]      = offset;
577*d52a580bSJunchao Zhang           AALo[offset] = (MatScalar)1.0 / v[nz];
578*d52a580bSJunchao Zhang 
579*d52a580bSJunchao Zhang           offset += 1;
580*d52a580bSJunchao Zhang           if (nz > 0) {
581*d52a580bSJunchao Zhang             PetscCall(PetscArraycpy(&AjUp[offset], vj, nz));
582*d52a580bSJunchao Zhang             PetscCall(PetscArraycpy(&AAUp[offset], v, nz));
583*d52a580bSJunchao Zhang             for (j = offset; j < offset + nz; j++) {
584*d52a580bSJunchao Zhang               AAUp[j] = -AAUp[j];
585*d52a580bSJunchao Zhang               AALo[j] = AAUp[j] / v[nz];
586*d52a580bSJunchao Zhang             }
587*d52a580bSJunchao Zhang             offset += nz;
588*d52a580bSJunchao Zhang           }
589*d52a580bSJunchao Zhang         }
590*d52a580bSJunchao Zhang 
591*d52a580bSJunchao Zhang         /* allocate space for the triangular factor information */
592*d52a580bSJunchao Zhang         PetscCall(PetscNew(&upTriFactor));
593*d52a580bSJunchao Zhang         upTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
594*d52a580bSJunchao Zhang 
595*d52a580bSJunchao Zhang         /* Create the matrix description */
596*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateMatDescr(&upTriFactor->descr));
597*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatIndexBase(upTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO));
598*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatType(upTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
599*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatFillMode(upTriFactor->descr, HIPSPARSE_FILL_MODE_UPPER));
600*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatDiagType(upTriFactor->descr, HIPSPARSE_DIAG_TYPE_UNIT));
601*d52a580bSJunchao Zhang 
602*d52a580bSJunchao Zhang         /* set the matrix */
603*d52a580bSJunchao Zhang         upTriFactor->csrMat                 = new CsrMatrix;
604*d52a580bSJunchao Zhang         upTriFactor->csrMat->num_rows       = A->rmap->n;
605*d52a580bSJunchao Zhang         upTriFactor->csrMat->num_cols       = A->cmap->n;
606*d52a580bSJunchao Zhang         upTriFactor->csrMat->num_entries    = a->nz;
607*d52a580bSJunchao Zhang         upTriFactor->csrMat->row_offsets    = new THRUSTINTARRAY32(A->rmap->n + 1);
608*d52a580bSJunchao Zhang         upTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(a->nz);
609*d52a580bSJunchao Zhang         upTriFactor->csrMat->values         = new THRUSTARRAY(a->nz);
610*d52a580bSJunchao Zhang         upTriFactor->csrMat->row_offsets->assign(AiUp, AiUp + A->rmap->n + 1);
611*d52a580bSJunchao Zhang         upTriFactor->csrMat->column_indices->assign(AjUp, AjUp + a->nz);
612*d52a580bSJunchao Zhang         upTriFactor->csrMat->values->assign(AAUp, AAUp + a->nz);
613*d52a580bSJunchao Zhang 
614*d52a580bSJunchao Zhang         /* set the operation */
615*d52a580bSJunchao Zhang         upTriFactor->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE;
616*d52a580bSJunchao Zhang 
617*d52a580bSJunchao Zhang         /* Create the solve analysis information */
618*d52a580bSJunchao Zhang         PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
619*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&upTriFactor->solveInfo));
620*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
621*d52a580bSJunchao Zhang                                                     upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, &upTriFactor->solveBufferSize));
622*d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc(&upTriFactor->solveBuffer, upTriFactor->solveBufferSize));
623*d52a580bSJunchao Zhang 
624*d52a580bSJunchao Zhang         /* perform the solve analysis */
625*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
626*d52a580bSJunchao Zhang                                                     upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, upTriFactor->solvePolicy, upTriFactor->solveBuffer));
627*d52a580bSJunchao Zhang 
628*d52a580bSJunchao Zhang         PetscCallHIP(WaitForHIP());
629*d52a580bSJunchao Zhang         PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
630*d52a580bSJunchao Zhang 
631*d52a580bSJunchao Zhang         /* assign the pointer */
632*d52a580bSJunchao Zhang         ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->upTriFactorPtr = upTriFactor;
633*d52a580bSJunchao Zhang 
634*d52a580bSJunchao Zhang         /* allocate space for the triangular factor information */
635*d52a580bSJunchao Zhang         PetscCall(PetscNew(&loTriFactor));
636*d52a580bSJunchao Zhang         loTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
637*d52a580bSJunchao Zhang 
638*d52a580bSJunchao Zhang         /* Create the matrix description */
639*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateMatDescr(&loTriFactor->descr));
640*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatIndexBase(loTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO));
641*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatType(loTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
642*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatFillMode(loTriFactor->descr, HIPSPARSE_FILL_MODE_UPPER));
643*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatDiagType(loTriFactor->descr, HIPSPARSE_DIAG_TYPE_NON_UNIT));
644*d52a580bSJunchao Zhang 
645*d52a580bSJunchao Zhang         /* set the operation */
646*d52a580bSJunchao Zhang         loTriFactor->solveOp = HIPSPARSE_OPERATION_TRANSPOSE;
647*d52a580bSJunchao Zhang 
648*d52a580bSJunchao Zhang         /* set the matrix */
649*d52a580bSJunchao Zhang         loTriFactor->csrMat                 = new CsrMatrix;
650*d52a580bSJunchao Zhang         loTriFactor->csrMat->num_rows       = A->rmap->n;
651*d52a580bSJunchao Zhang         loTriFactor->csrMat->num_cols       = A->cmap->n;
652*d52a580bSJunchao Zhang         loTriFactor->csrMat->num_entries    = a->nz;
653*d52a580bSJunchao Zhang         loTriFactor->csrMat->row_offsets    = new THRUSTINTARRAY32(A->rmap->n + 1);
654*d52a580bSJunchao Zhang         loTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(a->nz);
655*d52a580bSJunchao Zhang         loTriFactor->csrMat->values         = new THRUSTARRAY(a->nz);
656*d52a580bSJunchao Zhang         loTriFactor->csrMat->row_offsets->assign(AiUp, AiUp + A->rmap->n + 1);
657*d52a580bSJunchao Zhang         loTriFactor->csrMat->column_indices->assign(AjUp, AjUp + a->nz);
658*d52a580bSJunchao Zhang         loTriFactor->csrMat->values->assign(AALo, AALo + a->nz);
659*d52a580bSJunchao Zhang 
660*d52a580bSJunchao Zhang         /* Create the solve analysis information */
661*d52a580bSJunchao Zhang         PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
662*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&loTriFactor->solveInfo));
663*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
664*d52a580bSJunchao Zhang                                                     loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, &loTriFactor->solveBufferSize));
665*d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc(&loTriFactor->solveBuffer, loTriFactor->solveBufferSize));
666*d52a580bSJunchao Zhang 
667*d52a580bSJunchao Zhang         /* perform the solve analysis */
668*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
669*d52a580bSJunchao Zhang                                                     loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, loTriFactor->solvePolicy, loTriFactor->solveBuffer));
670*d52a580bSJunchao Zhang 
671*d52a580bSJunchao Zhang         PetscCallHIP(WaitForHIP());
672*d52a580bSJunchao Zhang         PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
673*d52a580bSJunchao Zhang 
674*d52a580bSJunchao Zhang         /* assign the pointer */
675*d52a580bSJunchao Zhang         ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->loTriFactorPtr = loTriFactor;
676*d52a580bSJunchao Zhang 
677*d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu(2 * (((A->rmap->n + 1) + (a->nz)) * sizeof(int) + (a->nz) * sizeof(PetscScalar))));
678*d52a580bSJunchao Zhang         PetscCallHIP(hipHostFree(AiUp));
679*d52a580bSJunchao Zhang         PetscCallHIP(hipHostFree(AjUp));
680*d52a580bSJunchao Zhang       } else {
681*d52a580bSJunchao Zhang         /* Fill the upper triangular matrix */
682*d52a580bSJunchao Zhang         offset = 0;
683*d52a580bSJunchao Zhang         for (i = 0; i < n; i++) {
684*d52a580bSJunchao Zhang           /* set the pointers */
685*d52a580bSJunchao Zhang           v  = aa + ai[i];
686*d52a580bSJunchao Zhang           nz = ai[i + 1] - ai[i] - 1; /* exclude diag[i] */
687*d52a580bSJunchao Zhang 
688*d52a580bSJunchao Zhang           /* first, set the diagonal elements */
689*d52a580bSJunchao Zhang           AAUp[offset] = 1.0 / v[nz];
690*d52a580bSJunchao Zhang           AALo[offset] = 1.0 / v[nz];
691*d52a580bSJunchao Zhang 
692*d52a580bSJunchao Zhang           offset += 1;
693*d52a580bSJunchao Zhang           if (nz > 0) {
694*d52a580bSJunchao Zhang             PetscCall(PetscArraycpy(&AAUp[offset], v, nz));
695*d52a580bSJunchao Zhang             for (j = offset; j < offset + nz; j++) {
696*d52a580bSJunchao Zhang               AAUp[j] = -AAUp[j];
697*d52a580bSJunchao Zhang               AALo[j] = AAUp[j] / v[nz];
698*d52a580bSJunchao Zhang             }
699*d52a580bSJunchao Zhang             offset += nz;
700*d52a580bSJunchao Zhang           }
701*d52a580bSJunchao Zhang         }
702*d52a580bSJunchao Zhang         PetscCheck(upTriFactor, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors");
703*d52a580bSJunchao Zhang         PetscCheck(loTriFactor, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors");
704*d52a580bSJunchao Zhang         upTriFactor->csrMat->values->assign(AAUp, AAUp + a->nz);
705*d52a580bSJunchao Zhang         loTriFactor->csrMat->values->assign(AALo, AALo + a->nz);
706*d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu(2 * (a->nz) * sizeof(PetscScalar)));
707*d52a580bSJunchao Zhang       }
708*d52a580bSJunchao Zhang       PetscCallHIP(hipHostFree(AAUp));
709*d52a580bSJunchao Zhang       PetscCallHIP(hipHostFree(AALo));
710*d52a580bSJunchao Zhang     } catch (char *ex) {
711*d52a580bSJunchao Zhang       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex);
712*d52a580bSJunchao Zhang     }
713*d52a580bSJunchao Zhang   }
714*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
715*d52a580bSJunchao Zhang }
716*d52a580bSJunchao Zhang 
717*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEICCAnalysisAndCopyToGPU(Mat A)
718*d52a580bSJunchao Zhang {
719*d52a580bSJunchao Zhang   PetscBool                      perm_identity;
720*d52a580bSJunchao Zhang   Mat_SeqAIJ                    *a                   = (Mat_SeqAIJ *)A->data;
721*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
722*d52a580bSJunchao Zhang   IS                             ip                  = a->row;
723*d52a580bSJunchao Zhang   PetscInt                       n                   = A->rmap->n;
724*d52a580bSJunchao Zhang 
725*d52a580bSJunchao Zhang   PetscFunctionBegin;
726*d52a580bSJunchao Zhang   PetscCheck(hipsparseTriFactors, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors");
727*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEBuildICCTriMatrices(A));
728*d52a580bSJunchao Zhang   if (!hipsparseTriFactors->workVector) hipsparseTriFactors->workVector = new THRUSTARRAY(n);
729*d52a580bSJunchao Zhang   hipsparseTriFactors->nnz = (a->nz - n) * 2 + n;
730*d52a580bSJunchao Zhang 
731*d52a580bSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_BOTH;
732*d52a580bSJunchao Zhang   /* lower triangular indices */
733*d52a580bSJunchao Zhang   PetscCall(ISIdentity(ip, &perm_identity));
734*d52a580bSJunchao Zhang   if (!perm_identity) {
735*d52a580bSJunchao Zhang     IS              iip;
736*d52a580bSJunchao Zhang     const PetscInt *irip, *rip;
737*d52a580bSJunchao Zhang 
738*d52a580bSJunchao Zhang     PetscCall(ISInvertPermutation(ip, PETSC_DECIDE, &iip));
739*d52a580bSJunchao Zhang     PetscCall(ISGetIndices(iip, &irip));
740*d52a580bSJunchao Zhang     PetscCall(ISGetIndices(ip, &rip));
741*d52a580bSJunchao Zhang     hipsparseTriFactors->rpermIndices = new THRUSTINTARRAY(n);
742*d52a580bSJunchao Zhang     hipsparseTriFactors->cpermIndices = new THRUSTINTARRAY(n);
743*d52a580bSJunchao Zhang     hipsparseTriFactors->rpermIndices->assign(rip, rip + n);
744*d52a580bSJunchao Zhang     hipsparseTriFactors->cpermIndices->assign(irip, irip + n);
745*d52a580bSJunchao Zhang     PetscCall(ISRestoreIndices(iip, &irip));
746*d52a580bSJunchao Zhang     PetscCall(ISDestroy(&iip));
747*d52a580bSJunchao Zhang     PetscCall(ISRestoreIndices(ip, &rip));
748*d52a580bSJunchao Zhang     PetscCall(PetscLogCpuToGpu(2. * n * sizeof(PetscInt)));
749*d52a580bSJunchao Zhang   }
750*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
751*d52a580bSJunchao Zhang }
752*d52a580bSJunchao Zhang 
753*d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorNumeric_SeqAIJHIPSPARSE(Mat B, Mat A, const MatFactorInfo *info)
754*d52a580bSJunchao Zhang {
755*d52a580bSJunchao Zhang   PetscBool   perm_identity;
756*d52a580bSJunchao Zhang   Mat_SeqAIJ *b  = (Mat_SeqAIJ *)B->data;
757*d52a580bSJunchao Zhang   IS          ip = b->row;
758*d52a580bSJunchao Zhang 
759*d52a580bSJunchao Zhang   PetscFunctionBegin;
760*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A));
761*d52a580bSJunchao Zhang   PetscCall(MatCholeskyFactorNumeric_SeqAIJ(B, A, info));
762*d52a580bSJunchao Zhang   B->offloadmask = PETSC_OFFLOAD_CPU;
763*d52a580bSJunchao Zhang   /* determine which version of MatSolve needs to be used. */
764*d52a580bSJunchao Zhang   PetscCall(ISIdentity(ip, &perm_identity));
765*d52a580bSJunchao Zhang   if (perm_identity) {
766*d52a580bSJunchao Zhang     B->ops->solve             = MatSolve_SeqAIJHIPSPARSE_NaturalOrdering;
767*d52a580bSJunchao Zhang     B->ops->solvetranspose    = MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering;
768*d52a580bSJunchao Zhang     B->ops->matsolve          = NULL;
769*d52a580bSJunchao Zhang     B->ops->matsolvetranspose = NULL;
770*d52a580bSJunchao Zhang   } else {
771*d52a580bSJunchao Zhang     B->ops->solve             = MatSolve_SeqAIJHIPSPARSE;
772*d52a580bSJunchao Zhang     B->ops->solvetranspose    = MatSolveTranspose_SeqAIJHIPSPARSE;
773*d52a580bSJunchao Zhang     B->ops->matsolve          = NULL;
774*d52a580bSJunchao Zhang     B->ops->matsolvetranspose = NULL;
775*d52a580bSJunchao Zhang   }
776*d52a580bSJunchao Zhang 
777*d52a580bSJunchao Zhang   /* get the triangular factors */
778*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEICCAnalysisAndCopyToGPU(B));
779*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
780*d52a580bSJunchao Zhang }
781*d52a580bSJunchao Zhang 
782*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEAnalyzeTransposeForSolve(Mat A)
783*d52a580bSJunchao Zhang {
784*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors      *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
785*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr;
786*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr;
787*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorT;
788*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorT;
789*d52a580bSJunchao Zhang   hipsparseIndexBase_t                indexBase;
790*d52a580bSJunchao Zhang   hipsparseMatrixType_t               matrixType;
791*d52a580bSJunchao Zhang   hipsparseFillMode_t                 fillMode;
792*d52a580bSJunchao Zhang   hipsparseDiagType_t                 diagType;
793*d52a580bSJunchao Zhang 
794*d52a580bSJunchao Zhang   PetscFunctionBegin;
795*d52a580bSJunchao Zhang   /* allocate space for the transpose of the lower triangular factor */
796*d52a580bSJunchao Zhang   PetscCall(PetscNew(&loTriFactorT));
797*d52a580bSJunchao Zhang   loTriFactorT->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
798*d52a580bSJunchao Zhang 
799*d52a580bSJunchao Zhang   /* set the matrix descriptors of the lower triangular factor */
800*d52a580bSJunchao Zhang   matrixType = hipsparseGetMatType(loTriFactor->descr);
801*d52a580bSJunchao Zhang   indexBase  = hipsparseGetMatIndexBase(loTriFactor->descr);
802*d52a580bSJunchao Zhang   fillMode   = hipsparseGetMatFillMode(loTriFactor->descr) == HIPSPARSE_FILL_MODE_UPPER ? HIPSPARSE_FILL_MODE_LOWER : HIPSPARSE_FILL_MODE_UPPER;
803*d52a580bSJunchao Zhang   diagType   = hipsparseGetMatDiagType(loTriFactor->descr);
804*d52a580bSJunchao Zhang 
805*d52a580bSJunchao Zhang   /* Create the matrix description */
806*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateMatDescr(&loTriFactorT->descr));
807*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatIndexBase(loTriFactorT->descr, indexBase));
808*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatType(loTriFactorT->descr, matrixType));
809*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatFillMode(loTriFactorT->descr, fillMode));
810*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatDiagType(loTriFactorT->descr, diagType));
811*d52a580bSJunchao Zhang 
812*d52a580bSJunchao Zhang   /* set the operation */
813*d52a580bSJunchao Zhang   loTriFactorT->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE;
814*d52a580bSJunchao Zhang 
815*d52a580bSJunchao Zhang   /* allocate GPU space for the CSC of the lower triangular factor*/
816*d52a580bSJunchao Zhang   loTriFactorT->csrMat                 = new CsrMatrix;
817*d52a580bSJunchao Zhang   loTriFactorT->csrMat->num_rows       = loTriFactor->csrMat->num_cols;
818*d52a580bSJunchao Zhang   loTriFactorT->csrMat->num_cols       = loTriFactor->csrMat->num_rows;
819*d52a580bSJunchao Zhang   loTriFactorT->csrMat->num_entries    = loTriFactor->csrMat->num_entries;
820*d52a580bSJunchao Zhang   loTriFactorT->csrMat->row_offsets    = new THRUSTINTARRAY32(loTriFactorT->csrMat->num_rows + 1);
821*d52a580bSJunchao Zhang   loTriFactorT->csrMat->column_indices = new THRUSTINTARRAY32(loTriFactorT->csrMat->num_entries);
822*d52a580bSJunchao Zhang   loTriFactorT->csrMat->values         = new THRUSTARRAY(loTriFactorT->csrMat->num_entries);
823*d52a580bSJunchao Zhang 
824*d52a580bSJunchao Zhang   /* compute the transpose of the lower triangular factor, i.e. the CSC */
825*d52a580bSJunchao Zhang   /* Csr2cscEx2 is not implemented in ROCm-5.2.0 and is planned for implementation in hipsparse with future releases of ROCm
826*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 2, 0)
827*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCsr2cscEx2_bufferSize(hipsparseTriFactors->handle, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_cols, loTriFactor->csrMat->num_entries, loTriFactor->csrMat->values->data().get(),
828*d52a580bSJunchao Zhang                                                   loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactorT->csrMat->values->data().get(), loTriFactorT->csrMat->row_offsets->data().get(),
829*d52a580bSJunchao Zhang                                                   loTriFactorT->csrMat->column_indices->data().get(), hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, &loTriFactor->csr2cscBufferSize));
830*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc(&loTriFactor->csr2cscBuffer, loTriFactor->csr2cscBufferSize));
831*d52a580bSJunchao Zhang #endif
832*d52a580bSJunchao Zhang */
833*d52a580bSJunchao Zhang   PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
834*d52a580bSJunchao Zhang 
835*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparseTriFactors->handle, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_cols, loTriFactor->csrMat->num_entries, loTriFactor->csrMat->values->data().get(), loTriFactor->csrMat->row_offsets->data().get(),
836*d52a580bSJunchao Zhang                                        loTriFactor->csrMat->column_indices->data().get(), loTriFactorT->csrMat->values->data().get(),
837*d52a580bSJunchao Zhang #if 0 /* when Csr2cscEx2 is implemented in hipSparse PETSC_PKG_HIP_VERSION_GE(5, 2, 0)*/
838*d52a580bSJunchao Zhang                           loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(),
839*d52a580bSJunchao Zhang                           hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, loTriFactor->csr2cscBuffer));
840*d52a580bSJunchao Zhang #else
841*d52a580bSJunchao Zhang                                        loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->csrMat->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase));
842*d52a580bSJunchao Zhang #endif
843*d52a580bSJunchao Zhang 
844*d52a580bSJunchao Zhang   PetscCallHIP(WaitForHIP());
845*d52a580bSJunchao Zhang   PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
846*d52a580bSJunchao Zhang 
847*d52a580bSJunchao Zhang   /* Create the solve analysis information */
848*d52a580bSJunchao Zhang   PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
849*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&loTriFactorT->solveInfo));
850*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(),
851*d52a580bSJunchao Zhang                                               loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, &loTriFactorT->solveBufferSize));
852*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc(&loTriFactorT->solveBuffer, loTriFactorT->solveBufferSize));
853*d52a580bSJunchao Zhang 
854*d52a580bSJunchao Zhang   /* perform the solve analysis */
855*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(),
856*d52a580bSJunchao Zhang                                               loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, loTriFactorT->solvePolicy, loTriFactorT->solveBuffer));
857*d52a580bSJunchao Zhang 
858*d52a580bSJunchao Zhang   PetscCallHIP(WaitForHIP());
859*d52a580bSJunchao Zhang   PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
860*d52a580bSJunchao Zhang 
861*d52a580bSJunchao Zhang   /* assign the pointer */
862*d52a580bSJunchao Zhang   ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->loTriFactorPtrTranspose = loTriFactorT;
863*d52a580bSJunchao Zhang 
864*d52a580bSJunchao Zhang   /*********************************************/
865*d52a580bSJunchao Zhang   /* Now the Transpose of the Upper Tri Factor */
866*d52a580bSJunchao Zhang   /*********************************************/
867*d52a580bSJunchao Zhang 
868*d52a580bSJunchao Zhang   /* allocate space for the transpose of the upper triangular factor */
869*d52a580bSJunchao Zhang   PetscCall(PetscNew(&upTriFactorT));
870*d52a580bSJunchao Zhang   upTriFactorT->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
871*d52a580bSJunchao Zhang 
872*d52a580bSJunchao Zhang   /* set the matrix descriptors of the upper triangular factor */
873*d52a580bSJunchao Zhang   matrixType = hipsparseGetMatType(upTriFactor->descr);
874*d52a580bSJunchao Zhang   indexBase  = hipsparseGetMatIndexBase(upTriFactor->descr);
875*d52a580bSJunchao Zhang   fillMode   = hipsparseGetMatFillMode(upTriFactor->descr) == HIPSPARSE_FILL_MODE_UPPER ? HIPSPARSE_FILL_MODE_LOWER : HIPSPARSE_FILL_MODE_UPPER;
876*d52a580bSJunchao Zhang   diagType   = hipsparseGetMatDiagType(upTriFactor->descr);
877*d52a580bSJunchao Zhang 
878*d52a580bSJunchao Zhang   /* Create the matrix description */
879*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateMatDescr(&upTriFactorT->descr));
880*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatIndexBase(upTriFactorT->descr, indexBase));
881*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatType(upTriFactorT->descr, matrixType));
882*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatFillMode(upTriFactorT->descr, fillMode));
883*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatDiagType(upTriFactorT->descr, diagType));
884*d52a580bSJunchao Zhang 
885*d52a580bSJunchao Zhang   /* set the operation */
886*d52a580bSJunchao Zhang   upTriFactorT->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE;
887*d52a580bSJunchao Zhang 
888*d52a580bSJunchao Zhang   /* allocate GPU space for the CSC of the upper triangular factor*/
889*d52a580bSJunchao Zhang   upTriFactorT->csrMat                 = new CsrMatrix;
890*d52a580bSJunchao Zhang   upTriFactorT->csrMat->num_rows       = upTriFactor->csrMat->num_cols;
891*d52a580bSJunchao Zhang   upTriFactorT->csrMat->num_cols       = upTriFactor->csrMat->num_rows;
892*d52a580bSJunchao Zhang   upTriFactorT->csrMat->num_entries    = upTriFactor->csrMat->num_entries;
893*d52a580bSJunchao Zhang   upTriFactorT->csrMat->row_offsets    = new THRUSTINTARRAY32(upTriFactorT->csrMat->num_rows + 1);
894*d52a580bSJunchao Zhang   upTriFactorT->csrMat->column_indices = new THRUSTINTARRAY32(upTriFactorT->csrMat->num_entries);
895*d52a580bSJunchao Zhang   upTriFactorT->csrMat->values         = new THRUSTARRAY(upTriFactorT->csrMat->num_entries);
896*d52a580bSJunchao Zhang 
897*d52a580bSJunchao Zhang   /* compute the transpose of the upper triangular factor, i.e. the CSC */
898*d52a580bSJunchao Zhang   /* Csr2cscEx2 is not implemented in ROCm-5.2.0 and is planned for implementation in hipsparse with future releases of ROCm
899*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 2, 0)
900*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCsr2cscEx2_bufferSize(hipsparseTriFactors->handle, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_cols, upTriFactor->csrMat->num_entries, upTriFactor->csrMat->values->data().get(),
901*d52a580bSJunchao Zhang                                                   upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactorT->csrMat->values->data().get(), upTriFactorT->csrMat->row_offsets->data().get(),
902*d52a580bSJunchao Zhang                                                   upTriFactorT->csrMat->column_indices->data().get(), hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, &upTriFactor->csr2cscBufferSize));
903*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc(&upTriFactor->csr2cscBuffer, upTriFactor->csr2cscBufferSize));
904*d52a580bSJunchao Zhang #endif
905*d52a580bSJunchao Zhang */
906*d52a580bSJunchao Zhang   PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
907*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparseTriFactors->handle, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_cols, upTriFactor->csrMat->num_entries, upTriFactor->csrMat->values->data().get(), upTriFactor->csrMat->row_offsets->data().get(),
908*d52a580bSJunchao Zhang                                        upTriFactor->csrMat->column_indices->data().get(), upTriFactorT->csrMat->values->data().get(),
909*d52a580bSJunchao Zhang #if 0 /* when Csr2cscEx2 is implemented in hipSparse PETSC_PKG_HIP_VERSION_GE(5, 2, 0)*/
910*d52a580bSJunchao Zhang                           upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(),
911*d52a580bSJunchao Zhang                           hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, upTriFactor->csr2cscBuffer));
912*d52a580bSJunchao Zhang #else
913*d52a580bSJunchao Zhang                                        upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->csrMat->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase));
914*d52a580bSJunchao Zhang #endif
915*d52a580bSJunchao Zhang 
916*d52a580bSJunchao Zhang   PetscCallHIP(WaitForHIP());
917*d52a580bSJunchao Zhang   PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
918*d52a580bSJunchao Zhang 
919*d52a580bSJunchao Zhang   /* Create the solve analysis information */
920*d52a580bSJunchao Zhang   PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
921*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&upTriFactorT->solveInfo));
922*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(),
923*d52a580bSJunchao Zhang                                               upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, &upTriFactorT->solveBufferSize));
924*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc(&upTriFactorT->solveBuffer, upTriFactorT->solveBufferSize));
925*d52a580bSJunchao Zhang 
926*d52a580bSJunchao Zhang   /* perform the solve analysis */
927*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(),
928*d52a580bSJunchao Zhang                                               upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, upTriFactorT->solvePolicy, upTriFactorT->solveBuffer));
929*d52a580bSJunchao Zhang 
930*d52a580bSJunchao Zhang   PetscCallHIP(WaitForHIP());
931*d52a580bSJunchao Zhang   PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0));
932*d52a580bSJunchao Zhang 
933*d52a580bSJunchao Zhang   /* assign the pointer */
934*d52a580bSJunchao Zhang   ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->upTriFactorPtrTranspose = upTriFactorT;
935*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
936*d52a580bSJunchao Zhang }
937*d52a580bSJunchao Zhang 
938*d52a580bSJunchao Zhang struct PetscScalarToPetscInt {
939*d52a580bSJunchao Zhang   __host__ __device__ PetscInt operator()(PetscScalar s) { return (PetscInt)PetscRealPart(s); }
940*d52a580bSJunchao Zhang };
941*d52a580bSJunchao Zhang 
942*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEFormExplicitTranspose(Mat A)
943*d52a580bSJunchao Zhang {
944*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
945*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSEMultStruct *matstruct, *matstructT;
946*d52a580bSJunchao Zhang   Mat_SeqAIJ                    *a = (Mat_SeqAIJ *)A->data;
947*d52a580bSJunchao Zhang   hipsparseIndexBase_t           indexBase;
948*d52a580bSJunchao Zhang 
949*d52a580bSJunchao Zhang   PetscFunctionBegin;
950*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
951*d52a580bSJunchao Zhang   matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->mat;
952*d52a580bSJunchao Zhang   PetscCheck(matstruct, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing mat struct");
953*d52a580bSJunchao Zhang   matstructT = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->matTranspose;
954*d52a580bSJunchao Zhang   PetscCheck(!A->transupdated || matstructT, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing matTranspose struct");
955*d52a580bSJunchao Zhang   if (A->transupdated) PetscFunctionReturn(PETSC_SUCCESS);
956*d52a580bSJunchao Zhang   PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
957*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
958*d52a580bSJunchao Zhang   if (hipsparsestruct->format != MAT_HIPSPARSE_CSR) PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_TRUE));
959*d52a580bSJunchao Zhang   if (!hipsparsestruct->matTranspose) { /* create hipsparse matrix */
960*d52a580bSJunchao Zhang     matstructT = new Mat_SeqAIJHIPSPARSEMultStruct;
961*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseCreateMatDescr(&matstructT->descr));
962*d52a580bSJunchao Zhang     indexBase = hipsparseGetMatIndexBase(matstruct->descr);
963*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSetMatIndexBase(matstructT->descr, indexBase));
964*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSetMatType(matstructT->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
965*d52a580bSJunchao Zhang 
966*d52a580bSJunchao Zhang     /* set alpha and beta */
967*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&matstructT->alpha_one, sizeof(PetscScalar)));
968*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&matstructT->beta_zero, sizeof(PetscScalar)));
969*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&matstructT->beta_one, sizeof(PetscScalar)));
970*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(matstructT->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
971*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(matstructT->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice));
972*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(matstructT->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
973*d52a580bSJunchao Zhang 
974*d52a580bSJunchao Zhang     if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) {
975*d52a580bSJunchao Zhang       CsrMatrix *matrixT      = new CsrMatrix;
976*d52a580bSJunchao Zhang       matstructT->mat         = matrixT;
977*d52a580bSJunchao Zhang       matrixT->num_rows       = A->cmap->n;
978*d52a580bSJunchao Zhang       matrixT->num_cols       = A->rmap->n;
979*d52a580bSJunchao Zhang       matrixT->num_entries    = a->nz;
980*d52a580bSJunchao Zhang       matrixT->row_offsets    = new THRUSTINTARRAY32(matrixT->num_rows + 1);
981*d52a580bSJunchao Zhang       matrixT->column_indices = new THRUSTINTARRAY32(a->nz);
982*d52a580bSJunchao Zhang       matrixT->values         = new THRUSTARRAY(a->nz);
983*d52a580bSJunchao Zhang 
984*d52a580bSJunchao Zhang       if (!hipsparsestruct->rowoffsets_gpu) hipsparsestruct->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1);
985*d52a580bSJunchao Zhang       hipsparsestruct->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1);
986*d52a580bSJunchao Zhang 
987*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreateCsr(&matstructT->matDescr, matrixT->num_rows, matrixT->num_cols, matrixT->num_entries, matrixT->row_offsets->data().get(), matrixT->column_indices->data().get(), matrixT->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, /* row offset, col idx type due to THRUSTINTARRAY32 */
988*d52a580bSJunchao Zhang                                             indexBase, hipsparse_scalartype));
989*d52a580bSJunchao Zhang     } else if (hipsparsestruct->format == MAT_HIPSPARSE_ELL || hipsparsestruct->format == MAT_HIPSPARSE_HYB) {
990*d52a580bSJunchao Zhang       CsrMatrix *temp  = new CsrMatrix;
991*d52a580bSJunchao Zhang       CsrMatrix *tempT = new CsrMatrix;
992*d52a580bSJunchao Zhang       /* First convert HYB to CSR */
993*d52a580bSJunchao Zhang       temp->num_rows       = A->rmap->n;
994*d52a580bSJunchao Zhang       temp->num_cols       = A->cmap->n;
995*d52a580bSJunchao Zhang       temp->num_entries    = a->nz;
996*d52a580bSJunchao Zhang       temp->row_offsets    = new THRUSTINTARRAY32(A->rmap->n + 1);
997*d52a580bSJunchao Zhang       temp->column_indices = new THRUSTINTARRAY32(a->nz);
998*d52a580bSJunchao Zhang       temp->values         = new THRUSTARRAY(a->nz);
999*d52a580bSJunchao Zhang 
1000*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparse_hyb2csr(hipsparsestruct->handle, matstruct->descr, (hipsparseHybMat_t)matstruct->mat, temp->values->data().get(), temp->row_offsets->data().get(), temp->column_indices->data().get()));
1001*d52a580bSJunchao Zhang 
1002*d52a580bSJunchao Zhang       /* Next, convert CSR to CSC (i.e. the matrix transpose) */
1003*d52a580bSJunchao Zhang       tempT->num_rows       = A->rmap->n;
1004*d52a580bSJunchao Zhang       tempT->num_cols       = A->cmap->n;
1005*d52a580bSJunchao Zhang       tempT->num_entries    = a->nz;
1006*d52a580bSJunchao Zhang       tempT->row_offsets    = new THRUSTINTARRAY32(A->rmap->n + 1);
1007*d52a580bSJunchao Zhang       tempT->column_indices = new THRUSTINTARRAY32(a->nz);
1008*d52a580bSJunchao Zhang       tempT->values         = new THRUSTARRAY(a->nz);
1009*d52a580bSJunchao Zhang 
1010*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparsestruct->handle, temp->num_rows, temp->num_cols, temp->num_entries, temp->values->data().get(), temp->row_offsets->data().get(), temp->column_indices->data().get(), tempT->values->data().get(),
1011*d52a580bSJunchao Zhang                                            tempT->column_indices->data().get(), tempT->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase));
1012*d52a580bSJunchao Zhang 
1013*d52a580bSJunchao Zhang       /* Last, convert CSC to HYB */
1014*d52a580bSJunchao Zhang       hipsparseHybMat_t hybMat;
1015*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreateHybMat(&hybMat));
1016*d52a580bSJunchao Zhang       hipsparseHybPartition_t partition = hipsparsestruct->format == MAT_HIPSPARSE_ELL ? HIPSPARSE_HYB_PARTITION_MAX : HIPSPARSE_HYB_PARTITION_AUTO;
1017*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparse_csr2hyb(hipsparsestruct->handle, A->rmap->n, A->cmap->n, matstructT->descr, tempT->values->data().get(), tempT->row_offsets->data().get(), tempT->column_indices->data().get(), hybMat, 0, partition));
1018*d52a580bSJunchao Zhang 
1019*d52a580bSJunchao Zhang       /* assign the pointer */
1020*d52a580bSJunchao Zhang       matstructT->mat = hybMat;
1021*d52a580bSJunchao Zhang       A->transupdated = PETSC_TRUE;
1022*d52a580bSJunchao Zhang       /* delete temporaries */
1023*d52a580bSJunchao Zhang       if (tempT) {
1024*d52a580bSJunchao Zhang         if (tempT->values) delete (THRUSTARRAY *)tempT->values;
1025*d52a580bSJunchao Zhang         if (tempT->column_indices) delete (THRUSTINTARRAY32 *)tempT->column_indices;
1026*d52a580bSJunchao Zhang         if (tempT->row_offsets) delete (THRUSTINTARRAY32 *)tempT->row_offsets;
1027*d52a580bSJunchao Zhang         delete (CsrMatrix *)tempT;
1028*d52a580bSJunchao Zhang       }
1029*d52a580bSJunchao Zhang       if (temp) {
1030*d52a580bSJunchao Zhang         if (temp->values) delete (THRUSTARRAY *)temp->values;
1031*d52a580bSJunchao Zhang         if (temp->column_indices) delete (THRUSTINTARRAY32 *)temp->column_indices;
1032*d52a580bSJunchao Zhang         if (temp->row_offsets) delete (THRUSTINTARRAY32 *)temp->row_offsets;
1033*d52a580bSJunchao Zhang         delete (CsrMatrix *)temp;
1034*d52a580bSJunchao Zhang       }
1035*d52a580bSJunchao Zhang     }
1036*d52a580bSJunchao Zhang   }
1037*d52a580bSJunchao Zhang   if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) { /* transpose mat struct may be already present, update data */
1038*d52a580bSJunchao Zhang     CsrMatrix *matrix  = (CsrMatrix *)matstruct->mat;
1039*d52a580bSJunchao Zhang     CsrMatrix *matrixT = (CsrMatrix *)matstructT->mat;
1040*d52a580bSJunchao Zhang     PetscCheck(matrix, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix");
1041*d52a580bSJunchao Zhang     PetscCheck(matrix->row_offsets, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix rows");
1042*d52a580bSJunchao Zhang     PetscCheck(matrix->column_indices, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix cols");
1043*d52a580bSJunchao Zhang     PetscCheck(matrix->values, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix values");
1044*d52a580bSJunchao Zhang     PetscCheck(matrixT, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT");
1045*d52a580bSJunchao Zhang     PetscCheck(matrixT->row_offsets, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT rows");
1046*d52a580bSJunchao Zhang     PetscCheck(matrixT->column_indices, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT cols");
1047*d52a580bSJunchao Zhang     PetscCheck(matrixT->values, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT values");
1048*d52a580bSJunchao Zhang     if (!hipsparsestruct->rowoffsets_gpu) { /* this may be absent when we did not construct the transpose with csr2csc */
1049*d52a580bSJunchao Zhang       hipsparsestruct->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1);
1050*d52a580bSJunchao Zhang       hipsparsestruct->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1);
1051*d52a580bSJunchao Zhang       PetscCall(PetscLogCpuToGpu((A->rmap->n + 1) * sizeof(PetscInt)));
1052*d52a580bSJunchao Zhang     }
1053*d52a580bSJunchao Zhang     if (!hipsparsestruct->csr2csc_i) {
1054*d52a580bSJunchao Zhang       THRUSTARRAY csr2csc_a(matrix->num_entries);
1055*d52a580bSJunchao Zhang       PetscCallThrust(thrust::sequence(thrust::device, csr2csc_a.begin(), csr2csc_a.end(), 0.0));
1056*d52a580bSJunchao Zhang 
1057*d52a580bSJunchao Zhang       indexBase = hipsparseGetMatIndexBase(matstruct->descr);
1058*d52a580bSJunchao Zhang       if (matrix->num_entries) {
1059*d52a580bSJunchao Zhang         /* This routine is known to give errors with CUDA-11, but works fine with CUDA-10
1060*d52a580bSJunchao Zhang            Need to verify this for ROCm.
1061*d52a580bSJunchao Zhang         */
1062*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparsestruct->handle, A->rmap->n, A->cmap->n, matrix->num_entries, csr2csc_a.data().get(), hipsparsestruct->rowoffsets_gpu->data().get(), matrix->column_indices->data().get(), matrixT->values->data().get(),
1063*d52a580bSJunchao Zhang                                              matrixT->column_indices->data().get(), matrixT->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase));
1064*d52a580bSJunchao Zhang       } else {
1065*d52a580bSJunchao Zhang         matrixT->row_offsets->assign(matrixT->row_offsets->size(), indexBase);
1066*d52a580bSJunchao Zhang       }
1067*d52a580bSJunchao Zhang 
1068*d52a580bSJunchao Zhang       hipsparsestruct->csr2csc_i = new THRUSTINTARRAY(matrix->num_entries);
1069*d52a580bSJunchao Zhang       PetscCallThrust(thrust::transform(thrust::device, matrixT->values->begin(), matrixT->values->end(), hipsparsestruct->csr2csc_i->begin(), PetscScalarToPetscInt()));
1070*d52a580bSJunchao Zhang     }
1071*d52a580bSJunchao Zhang     PetscCallThrust(
1072*d52a580bSJunchao Zhang       thrust::copy(thrust::device, thrust::make_permutation_iterator(matrix->values->begin(), hipsparsestruct->csr2csc_i->begin()), thrust::make_permutation_iterator(matrix->values->begin(), hipsparsestruct->csr2csc_i->end()), matrixT->values->begin()));
1073*d52a580bSJunchao Zhang   }
1074*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1075*d52a580bSJunchao Zhang   PetscCall(PetscLogEventEnd(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0));
1076*d52a580bSJunchao Zhang   /* the compressed row indices is not used for matTranspose */
1077*d52a580bSJunchao Zhang   matstructT->cprowIndices = NULL;
1078*d52a580bSJunchao Zhang   /* assign the pointer */
1079*d52a580bSJunchao Zhang   ((Mat_SeqAIJHIPSPARSE *)A->spptr)->matTranspose = matstructT;
1080*d52a580bSJunchao Zhang   A->transupdated                                 = PETSC_TRUE;
1081*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1082*d52a580bSJunchao Zhang }
1083*d52a580bSJunchao Zhang 
1084*d52a580bSJunchao Zhang /* Why do we need to analyze the transposed matrix again? Can't we just use op(A) = HIPSPARSE_OPERATION_TRANSPOSE in MatSolve_SeqAIJHIPSPARSE? */
1085*d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE(Mat A, Vec bb, Vec xx)
1086*d52a580bSJunchao Zhang {
1087*d52a580bSJunchao Zhang   PetscInt                              n = xx->map->n;
1088*d52a580bSJunchao Zhang   const PetscScalar                    *barray;
1089*d52a580bSJunchao Zhang   PetscScalar                          *xarray;
1090*d52a580bSJunchao Zhang   thrust::device_ptr<const PetscScalar> bGPU;
1091*d52a580bSJunchao Zhang   thrust::device_ptr<PetscScalar>       xGPU;
1092*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors        *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
1093*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct   *loTriFactorT        = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose;
1094*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct   *upTriFactorT        = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose;
1095*d52a580bSJunchao Zhang   THRUSTARRAY                          *tempGPU             = (THRUSTARRAY *)hipsparseTriFactors->workVector;
1096*d52a580bSJunchao Zhang 
1097*d52a580bSJunchao Zhang   PetscFunctionBegin;
1098*d52a580bSJunchao Zhang   /* Analyze the matrix and create the transpose ... on the fly */
1099*d52a580bSJunchao Zhang   if (!loTriFactorT && !upTriFactorT) {
1100*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEAnalyzeTransposeForSolve(A));
1101*d52a580bSJunchao Zhang     loTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose;
1102*d52a580bSJunchao Zhang     upTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose;
1103*d52a580bSJunchao Zhang   }
1104*d52a580bSJunchao Zhang 
1105*d52a580bSJunchao Zhang   /* Get the GPU pointers */
1106*d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayWrite(xx, &xarray));
1107*d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayRead(bb, &barray));
1108*d52a580bSJunchao Zhang   xGPU = thrust::device_pointer_cast(xarray);
1109*d52a580bSJunchao Zhang   bGPU = thrust::device_pointer_cast(barray);
1110*d52a580bSJunchao Zhang 
1111*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1112*d52a580bSJunchao Zhang   /* First, reorder with the row permutation */
1113*d52a580bSJunchao Zhang   thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(bGPU, hipsparseTriFactors->rpermIndices->begin()), thrust::make_permutation_iterator(bGPU + n, hipsparseTriFactors->rpermIndices->end()), xGPU);
1114*d52a580bSJunchao Zhang 
1115*d52a580bSJunchao Zhang   /* First, solve U */
1116*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(),
1117*d52a580bSJunchao Zhang                                            upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, xarray, tempGPU->data().get(), upTriFactorT->solvePolicy, upTriFactorT->solveBuffer));
1118*d52a580bSJunchao Zhang 
1119*d52a580bSJunchao Zhang   /* Then, solve L */
1120*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(),
1121*d52a580bSJunchao Zhang                                            loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, tempGPU->data().get(), xarray, loTriFactorT->solvePolicy, loTriFactorT->solveBuffer));
1122*d52a580bSJunchao Zhang 
1123*d52a580bSJunchao Zhang   /* Last, copy the solution, xGPU, into a temporary with the column permutation ... can't be done in place. */
1124*d52a580bSJunchao Zhang   thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(xGPU, hipsparseTriFactors->cpermIndices->begin()), thrust::make_permutation_iterator(xGPU + n, hipsparseTriFactors->cpermIndices->end()), tempGPU->begin());
1125*d52a580bSJunchao Zhang 
1126*d52a580bSJunchao Zhang   /* Copy the temporary to the full solution. */
1127*d52a580bSJunchao Zhang   thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), tempGPU->begin(), tempGPU->end(), xGPU);
1128*d52a580bSJunchao Zhang 
1129*d52a580bSJunchao Zhang   /* restore */
1130*d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayRead(bb, &barray));
1131*d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayWrite(xx, &xarray));
1132*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1133*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n));
1134*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1135*d52a580bSJunchao Zhang }
1136*d52a580bSJunchao Zhang 
1137*d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering(Mat A, Vec bb, Vec xx)
1138*d52a580bSJunchao Zhang {
1139*d52a580bSJunchao Zhang   const PetscScalar                  *barray;
1140*d52a580bSJunchao Zhang   PetscScalar                        *xarray;
1141*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors      *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
1142*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorT        = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose;
1143*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorT        = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose;
1144*d52a580bSJunchao Zhang   THRUSTARRAY                        *tempGPU             = (THRUSTARRAY *)hipsparseTriFactors->workVector;
1145*d52a580bSJunchao Zhang 
1146*d52a580bSJunchao Zhang   PetscFunctionBegin;
1147*d52a580bSJunchao Zhang   /* Analyze the matrix and create the transpose ... on the fly */
1148*d52a580bSJunchao Zhang   if (!loTriFactorT && !upTriFactorT) {
1149*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEAnalyzeTransposeForSolve(A));
1150*d52a580bSJunchao Zhang     loTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose;
1151*d52a580bSJunchao Zhang     upTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose;
1152*d52a580bSJunchao Zhang   }
1153*d52a580bSJunchao Zhang 
1154*d52a580bSJunchao Zhang   /* Get the GPU pointers */
1155*d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayWrite(xx, &xarray));
1156*d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayRead(bb, &barray));
1157*d52a580bSJunchao Zhang 
1158*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1159*d52a580bSJunchao Zhang   /* First, solve U */
1160*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(),
1161*d52a580bSJunchao Zhang                                            upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, barray, tempGPU->data().get(), upTriFactorT->solvePolicy, upTriFactorT->solveBuffer));
1162*d52a580bSJunchao Zhang 
1163*d52a580bSJunchao Zhang   /* Then, solve L */
1164*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(),
1165*d52a580bSJunchao Zhang                                            loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, tempGPU->data().get(), xarray, loTriFactorT->solvePolicy, loTriFactorT->solveBuffer));
1166*d52a580bSJunchao Zhang 
1167*d52a580bSJunchao Zhang   /* restore */
1168*d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayRead(bb, &barray));
1169*d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayWrite(xx, &xarray));
1170*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1171*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n));
1172*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1173*d52a580bSJunchao Zhang }
1174*d52a580bSJunchao Zhang 
1175*d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE(Mat A, Vec bb, Vec xx)
1176*d52a580bSJunchao Zhang {
1177*d52a580bSJunchao Zhang   const PetscScalar                    *barray;
1178*d52a580bSJunchao Zhang   PetscScalar                          *xarray;
1179*d52a580bSJunchao Zhang   thrust::device_ptr<const PetscScalar> bGPU;
1180*d52a580bSJunchao Zhang   thrust::device_ptr<PetscScalar>       xGPU;
1181*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors        *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
1182*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct   *loTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr;
1183*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct   *upTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr;
1184*d52a580bSJunchao Zhang   THRUSTARRAY                          *tempGPU             = (THRUSTARRAY *)hipsparseTriFactors->workVector;
1185*d52a580bSJunchao Zhang 
1186*d52a580bSJunchao Zhang   PetscFunctionBegin;
1187*d52a580bSJunchao Zhang   /* Get the GPU pointers */
1188*d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayWrite(xx, &xarray));
1189*d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayRead(bb, &barray));
1190*d52a580bSJunchao Zhang   xGPU = thrust::device_pointer_cast(xarray);
1191*d52a580bSJunchao Zhang   bGPU = thrust::device_pointer_cast(barray);
1192*d52a580bSJunchao Zhang 
1193*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1194*d52a580bSJunchao Zhang   /* First, reorder with the row permutation */
1195*d52a580bSJunchao Zhang   thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(bGPU, hipsparseTriFactors->rpermIndices->begin()), thrust::make_permutation_iterator(bGPU, hipsparseTriFactors->rpermIndices->end()), tempGPU->begin());
1196*d52a580bSJunchao Zhang 
1197*d52a580bSJunchao Zhang   /* Next, solve L */
1198*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
1199*d52a580bSJunchao Zhang                                            loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, tempGPU->data().get(), xarray, loTriFactor->solvePolicy, loTriFactor->solveBuffer));
1200*d52a580bSJunchao Zhang 
1201*d52a580bSJunchao Zhang   /* Then, solve U */
1202*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
1203*d52a580bSJunchao Zhang                                            upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, xarray, tempGPU->data().get(), upTriFactor->solvePolicy, upTriFactor->solveBuffer));
1204*d52a580bSJunchao Zhang 
1205*d52a580bSJunchao Zhang   /* Last, reorder with the column permutation */
1206*d52a580bSJunchao Zhang   thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(tempGPU->begin(), hipsparseTriFactors->cpermIndices->begin()), thrust::make_permutation_iterator(tempGPU->begin(), hipsparseTriFactors->cpermIndices->end()), xGPU);
1207*d52a580bSJunchao Zhang 
1208*d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayRead(bb, &barray));
1209*d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayWrite(xx, &xarray));
1210*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1211*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n));
1212*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1213*d52a580bSJunchao Zhang }
1214*d52a580bSJunchao Zhang 
1215*d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_NaturalOrdering(Mat A, Vec bb, Vec xx)
1216*d52a580bSJunchao Zhang {
1217*d52a580bSJunchao Zhang   const PetscScalar                  *barray;
1218*d52a580bSJunchao Zhang   PetscScalar                        *xarray;
1219*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors      *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
1220*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr;
1221*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor         = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr;
1222*d52a580bSJunchao Zhang   THRUSTARRAY                        *tempGPU             = (THRUSTARRAY *)hipsparseTriFactors->workVector;
1223*d52a580bSJunchao Zhang 
1224*d52a580bSJunchao Zhang   PetscFunctionBegin;
1225*d52a580bSJunchao Zhang   /* Get the GPU pointers */
1226*d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayWrite(xx, &xarray));
1227*d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayRead(bb, &barray));
1228*d52a580bSJunchao Zhang 
1229*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1230*d52a580bSJunchao Zhang   /* First, solve L */
1231*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactor->descr, loTriFactor->csrMat->values->data().get(),
1232*d52a580bSJunchao Zhang                                            loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, barray, tempGPU->data().get(), loTriFactor->solvePolicy, loTriFactor->solveBuffer));
1233*d52a580bSJunchao Zhang 
1234*d52a580bSJunchao Zhang   /* Next, solve U */
1235*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactor->descr, upTriFactor->csrMat->values->data().get(),
1236*d52a580bSJunchao Zhang                                            upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, tempGPU->data().get(), xarray, upTriFactor->solvePolicy, upTriFactor->solveBuffer));
1237*d52a580bSJunchao Zhang 
1238*d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayRead(bb, &barray));
1239*d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayWrite(xx, &xarray));
1240*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1241*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n));
1242*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1243*d52a580bSJunchao Zhang }
1244*d52a580bSJunchao Zhang 
1245*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
1246*d52a580bSJunchao Zhang /* hipsparseSpSV_solve() and related functions first appeared in ROCm-4.5.0*/
1247*d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_ILU0(Mat fact, Vec b, Vec x)
1248*d52a580bSJunchao Zhang {
1249*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs  = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1250*d52a580bSJunchao Zhang   Mat_SeqAIJ                    *aij = (Mat_SeqAIJ *)fact->data;
1251*d52a580bSJunchao Zhang   const PetscScalar             *barray;
1252*d52a580bSJunchao Zhang   PetscScalar                   *xarray;
1253*d52a580bSJunchao Zhang 
1254*d52a580bSJunchao Zhang   PetscFunctionBegin;
1255*d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayWrite(x, &xarray));
1256*d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayRead(b, &barray));
1257*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1258*d52a580bSJunchao Zhang 
1259*d52a580bSJunchao Zhang   /* Solve L*y = b */
1260*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray));
1261*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y));
1262*d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1263*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L,                   /* L Y = X */
1264*d52a580bSJunchao Zhang                                          fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L)); // hipsparseSpSV_solve() secretely uses the external buffer used in hipsparseSpSV_analysis()!
1265*d52a580bSJunchao Zhang   #else
1266*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L,                                     /* L Y = X */
1267*d52a580bSJunchao Zhang                                          fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L)); // hipsparseSpSV_solve() secretely uses the external buffer used in hipsparseSpSV_analysis()!
1268*d52a580bSJunchao Zhang   #endif
1269*d52a580bSJunchao Zhang   /* Solve U*x = y */
1270*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray));
1271*d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1272*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* U X = Y */
1273*d52a580bSJunchao Zhang                                          fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U));
1274*d52a580bSJunchao Zhang   #else
1275*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* U X = Y */
1276*d52a580bSJunchao Zhang                                          fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U, fs->spsvBuffer_U));
1277*d52a580bSJunchao Zhang   #endif
1278*d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayRead(b, &barray));
1279*d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayWrite(x, &xarray));
1280*d52a580bSJunchao Zhang 
1281*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1282*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * aij->nz - fact->rmap->n));
1283*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1284*d52a580bSJunchao Zhang }
1285*d52a580bSJunchao Zhang 
1286*d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE_ILU0(Mat fact, Vec b, Vec x)
1287*d52a580bSJunchao Zhang {
1288*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs  = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1289*d52a580bSJunchao Zhang   Mat_SeqAIJ                    *aij = (Mat_SeqAIJ *)fact->data;
1290*d52a580bSJunchao Zhang   const PetscScalar             *barray;
1291*d52a580bSJunchao Zhang   PetscScalar                   *xarray;
1292*d52a580bSJunchao Zhang 
1293*d52a580bSJunchao Zhang   PetscFunctionBegin;
1294*d52a580bSJunchao Zhang   if (!fs->createdTransposeSpSVDescr) { /* Call MatSolveTranspose() for the first time */
1295*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_Lt));
1296*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* The matrix is still L. We only do transpose solve with it */
1297*d52a580bSJunchao Zhang                                                 fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, &fs->spsvBufferSize_Lt));
1298*d52a580bSJunchao Zhang 
1299*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_Ut));
1300*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut, &fs->spsvBufferSize_Ut));
1301*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_Lt, fs->spsvBufferSize_Lt));
1302*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_Ut, fs->spsvBufferSize_Ut));
1303*d52a580bSJunchao Zhang     fs->createdTransposeSpSVDescr = PETSC_TRUE;
1304*d52a580bSJunchao Zhang   }
1305*d52a580bSJunchao Zhang 
1306*d52a580bSJunchao Zhang   if (!fs->updatedTransposeSpSVAnalysis) {
1307*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
1308*d52a580bSJunchao Zhang 
1309*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut, fs->spsvBuffer_Ut));
1310*d52a580bSJunchao Zhang     fs->updatedTransposeSpSVAnalysis = PETSC_TRUE;
1311*d52a580bSJunchao Zhang   }
1312*d52a580bSJunchao Zhang 
1313*d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayWrite(x, &xarray));
1314*d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayRead(b, &barray));
1315*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1316*d52a580bSJunchao Zhang 
1317*d52a580bSJunchao Zhang   /* Solve Ut*y = b */
1318*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray));
1319*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y));
1320*d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1321*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* Ut Y = X */
1322*d52a580bSJunchao Zhang                                          fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut));
1323*d52a580bSJunchao Zhang   #else
1324*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* Ut Y = X */
1325*d52a580bSJunchao Zhang                                          fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut, fs->spsvBuffer_Ut));
1326*d52a580bSJunchao Zhang   #endif
1327*d52a580bSJunchao Zhang   /* Solve Lt*x = y */
1328*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray));
1329*d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1330*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
1331*d52a580bSJunchao Zhang                                          fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt));
1332*d52a580bSJunchao Zhang   #else
1333*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
1334*d52a580bSJunchao Zhang                                          fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
1335*d52a580bSJunchao Zhang   #endif
1336*d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayRead(b, &barray));
1337*d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayWrite(x, &xarray));
1338*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1339*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * aij->nz - fact->rmap->n));
1340*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1341*d52a580bSJunchao Zhang }
1342*d52a580bSJunchao Zhang 
1343*d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorNumeric_SeqAIJHIPSPARSE_ILU0(Mat fact, Mat A, const MatFactorInfo *info)
1344*d52a580bSJunchao Zhang {
1345*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs    = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1346*d52a580bSJunchao Zhang   Mat_SeqAIJ                    *aij   = (Mat_SeqAIJ *)fact->data;
1347*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
1348*d52a580bSJunchao Zhang   CsrMatrix                     *Acsr;
1349*d52a580bSJunchao Zhang   PetscInt                       m, nz;
1350*d52a580bSJunchao Zhang   PetscBool                      flg;
1351*d52a580bSJunchao Zhang 
1352*d52a580bSJunchao Zhang   PetscFunctionBegin;
1353*d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1354*d52a580bSJunchao Zhang     PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
1355*d52a580bSJunchao Zhang     PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name);
1356*d52a580bSJunchao Zhang   }
1357*d52a580bSJunchao Zhang 
1358*d52a580bSJunchao Zhang   /* Copy A's value to fact */
1359*d52a580bSJunchao Zhang   m  = fact->rmap->n;
1360*d52a580bSJunchao Zhang   nz = aij->nz;
1361*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
1362*d52a580bSJunchao Zhang   Acsr = (CsrMatrix *)Acusp->mat->mat;
1363*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpyAsync(fs->csrVal, Acsr->values->data().get(), sizeof(PetscScalar) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1364*d52a580bSJunchao Zhang 
1365*d52a580bSJunchao Zhang   /* Factorize fact inplace */
1366*d52a580bSJunchao Zhang   if (m)
1367*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseXcsrilu02(fs->handle, m, nz, /* hipsparseXcsrilu02 errors out with empty matrices (m=0) */
1368*d52a580bSJunchao Zhang                                           fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ilu0Info_M, fs->policy_M, fs->factBuffer_M));
1369*d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1370*d52a580bSJunchao Zhang     int               numerical_zero;
1371*d52a580bSJunchao Zhang     hipsparseStatus_t status;
1372*d52a580bSJunchao Zhang     status = hipsparseXcsrilu02_zeroPivot(fs->handle, fs->ilu0Info_M, &numerical_zero);
1373*d52a580bSJunchao Zhang     PetscAssert(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Numerical zero pivot detected in csrilu02: A(%d,%d) is zero", numerical_zero, numerical_zero);
1374*d52a580bSJunchao Zhang   }
1375*d52a580bSJunchao Zhang 
1376*d52a580bSJunchao Zhang   /* hipsparseSpSV_analysis() is numeric, i.e., it requires valid matrix values, therefore, we do it after hipsparseXcsrilu02() */
1377*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L));
1378*d52a580bSJunchao Zhang 
1379*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U, fs->spsvBuffer_U));
1380*d52a580bSJunchao Zhang 
1381*d52a580bSJunchao Zhang   /* L, U values have changed, reset the flag to indicate we need to redo hipsparseSpSV_analysis() for transpose solve */
1382*d52a580bSJunchao Zhang   fs->updatedTransposeSpSVAnalysis = PETSC_FALSE;
1383*d52a580bSJunchao Zhang 
1384*d52a580bSJunchao Zhang   fact->offloadmask            = PETSC_OFFLOAD_GPU;
1385*d52a580bSJunchao Zhang   fact->ops->solve             = MatSolve_SeqAIJHIPSPARSE_ILU0;
1386*d52a580bSJunchao Zhang   fact->ops->solvetranspose    = MatSolveTranspose_SeqAIJHIPSPARSE_ILU0;
1387*d52a580bSJunchao Zhang   fact->ops->matsolve          = NULL;
1388*d52a580bSJunchao Zhang   fact->ops->matsolvetranspose = NULL;
1389*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(fs->numericFactFlops));
1390*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1391*d52a580bSJunchao Zhang }
1392*d52a580bSJunchao Zhang 
1393*d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJHIPSPARSE_ILU0(Mat fact, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1394*d52a580bSJunchao Zhang {
1395*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs  = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1396*d52a580bSJunchao Zhang   Mat_SeqAIJ                    *aij = (Mat_SeqAIJ *)fact->data;
1397*d52a580bSJunchao Zhang   PetscInt                       m, nz;
1398*d52a580bSJunchao Zhang 
1399*d52a580bSJunchao Zhang   PetscFunctionBegin;
1400*d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1401*d52a580bSJunchao Zhang     PetscBool flg, diagDense;
1402*d52a580bSJunchao Zhang 
1403*d52a580bSJunchao Zhang     PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
1404*d52a580bSJunchao Zhang     PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name);
1405*d52a580bSJunchao Zhang     PetscCheck(A->rmap->n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must be square matrix, rows %" PetscInt_FMT " columns %" PetscInt_FMT, A->rmap->n, A->cmap->n);
1406*d52a580bSJunchao Zhang     PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, NULL, &diagDense));
1407*d52a580bSJunchao Zhang     PetscCheck(diagDense, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Matrix is missing diagonal entries");
1408*d52a580bSJunchao Zhang   }
1409*d52a580bSJunchao Zhang 
1410*d52a580bSJunchao Zhang   /* Free the old stale stuff */
1411*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&fs));
1412*d52a580bSJunchao Zhang 
1413*d52a580bSJunchao Zhang   /* Copy over A's meta data to fact. Note that we also allocated fact's i,j,a on host,
1414*d52a580bSJunchao Zhang      but they will not be used. Allocate them just for easy debugging.
1415*d52a580bSJunchao Zhang    */
1416*d52a580bSJunchao Zhang   PetscCall(MatDuplicateNoCreate_SeqAIJ(fact, A, MAT_DO_NOT_COPY_VALUES, PETSC_TRUE /*malloc*/));
1417*d52a580bSJunchao Zhang 
1418*d52a580bSJunchao Zhang   fact->offloadmask            = PETSC_OFFLOAD_BOTH;
1419*d52a580bSJunchao Zhang   fact->factortype             = MAT_FACTOR_ILU;
1420*d52a580bSJunchao Zhang   fact->info.factor_mallocs    = 0;
1421*d52a580bSJunchao Zhang   fact->info.fill_ratio_given  = info->fill;
1422*d52a580bSJunchao Zhang   fact->info.fill_ratio_needed = 1.0;
1423*d52a580bSJunchao Zhang 
1424*d52a580bSJunchao Zhang   aij->row = NULL;
1425*d52a580bSJunchao Zhang   aij->col = NULL;
1426*d52a580bSJunchao Zhang 
1427*d52a580bSJunchao Zhang   /* ====================================================================== */
1428*d52a580bSJunchao Zhang   /* Copy A's i, j to fact and also allocate the value array of fact.       */
1429*d52a580bSJunchao Zhang   /* We'll do in-place factorization on fact                                */
1430*d52a580bSJunchao Zhang   /* ====================================================================== */
1431*d52a580bSJunchao Zhang   const int *Ai, *Aj;
1432*d52a580bSJunchao Zhang 
1433*d52a580bSJunchao Zhang   m  = fact->rmap->n;
1434*d52a580bSJunchao Zhang   nz = aij->nz;
1435*d52a580bSJunchao Zhang 
1436*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->csrRowPtr, sizeof(int) * (m + 1)));
1437*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->csrColIdx, sizeof(int) * nz));
1438*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->csrVal, sizeof(PetscScalar) * nz));
1439*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEGetIJ(A, PETSC_FALSE, &Ai, &Aj)); /* Do not use compressed Ai */
1440*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpyAsync(fs->csrRowPtr, Ai, sizeof(int) * (m + 1), hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1441*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpyAsync(fs->csrColIdx, Aj, sizeof(int) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1442*d52a580bSJunchao Zhang 
1443*d52a580bSJunchao Zhang   /* ====================================================================== */
1444*d52a580bSJunchao Zhang   /* Create descriptors for M, L, U                                         */
1445*d52a580bSJunchao Zhang   /* ====================================================================== */
1446*d52a580bSJunchao Zhang   hipsparseFillMode_t fillMode;
1447*d52a580bSJunchao Zhang   hipsparseDiagType_t diagType;
1448*d52a580bSJunchao Zhang 
1449*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateMatDescr(&fs->matDescr_M));
1450*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatIndexBase(fs->matDescr_M, HIPSPARSE_INDEX_BASE_ZERO));
1451*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatType(fs->matDescr_M, HIPSPARSE_MATRIX_TYPE_GENERAL));
1452*d52a580bSJunchao Zhang 
1453*d52a580bSJunchao Zhang   /* https://docs.amd.com/bundle/hipSPARSE-Documentation---hipSPARSE-documentation/page/usermanual.html/#hipsparse_8h_1a79e036b6c0680cb37e2aa53d3542a054
1454*d52a580bSJunchao Zhang     hipsparseDiagType_t: This type indicates if the matrix diagonal entries are unity. The diagonal elements are always
1455*d52a580bSJunchao Zhang     assumed to be present, but if HIPSPARSE_DIAG_TYPE_UNIT is passed to an API routine, then the routine assumes that
1456*d52a580bSJunchao Zhang     all diagonal entries are unity and will not read or modify those entries. Note that in this case the routine
1457*d52a580bSJunchao Zhang     assumes the diagonal entries are equal to one, regardless of what those entries are actually set to in memory.
1458*d52a580bSJunchao Zhang   */
1459*d52a580bSJunchao Zhang   fillMode = HIPSPARSE_FILL_MODE_LOWER;
1460*d52a580bSJunchao Zhang   diagType = HIPSPARSE_DIAG_TYPE_UNIT;
1461*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsr(&fs->spMatDescr_L, m, m, nz, fs->csrRowPtr, fs->csrColIdx, fs->csrVal, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
1462*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_FILL_MODE, &fillMode, sizeof(fillMode)));
1463*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_DIAG_TYPE, &diagType, sizeof(diagType)));
1464*d52a580bSJunchao Zhang 
1465*d52a580bSJunchao Zhang   fillMode = HIPSPARSE_FILL_MODE_UPPER;
1466*d52a580bSJunchao Zhang   diagType = HIPSPARSE_DIAG_TYPE_NON_UNIT;
1467*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsr(&fs->spMatDescr_U, m, m, nz, fs->csrRowPtr, fs->csrColIdx, fs->csrVal, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
1468*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_U, HIPSPARSE_SPMAT_FILL_MODE, &fillMode, sizeof(fillMode)));
1469*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_U, HIPSPARSE_SPMAT_DIAG_TYPE, &diagType, sizeof(diagType)));
1470*d52a580bSJunchao Zhang 
1471*d52a580bSJunchao Zhang   /* ========================================================================= */
1472*d52a580bSJunchao Zhang   /* Query buffer sizes for csrilu0, SpSV and allocate buffers                 */
1473*d52a580bSJunchao Zhang   /* ========================================================================= */
1474*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsrilu02Info(&fs->ilu0Info_M));
1475*d52a580bSJunchao Zhang   if (m)
1476*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseXcsrilu02_bufferSize(fs->handle, m, nz, /* hipsparseXcsrilu02 errors out with empty matrices (m=0) */
1477*d52a580bSJunchao Zhang                                                      fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ilu0Info_M, &fs->factBufferSize_M));
1478*d52a580bSJunchao Zhang 
1479*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->X, sizeof(PetscScalar) * m));
1480*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->Y, sizeof(PetscScalar) * m));
1481*d52a580bSJunchao Zhang 
1482*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_X, m, fs->X, hipsparse_scalartype));
1483*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_Y, m, fs->Y, hipsparse_scalartype));
1484*d52a580bSJunchao Zhang 
1485*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_L));
1486*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, &fs->spsvBufferSize_L));
1487*d52a580bSJunchao Zhang 
1488*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_U));
1489*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U, &fs->spsvBufferSize_U));
1490*d52a580bSJunchao Zhang 
1491*d52a580bSJunchao Zhang   /* It appears spsvBuffer_L/U can not be shared (i.e., the same) for our case, but factBuffer_M can share with either of spsvBuffer_L/U.
1492*d52a580bSJunchao Zhang      To save memory, we make factBuffer_M share with the bigger of spsvBuffer_L/U.
1493*d52a580bSJunchao Zhang    */
1494*d52a580bSJunchao Zhang   if (fs->spsvBufferSize_L > fs->spsvBufferSize_U) {
1495*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_L, (size_t)fs->factBufferSize_M)));
1496*d52a580bSJunchao Zhang     fs->spsvBuffer_L = fs->factBuffer_M;
1497*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_U, fs->spsvBufferSize_U));
1498*d52a580bSJunchao Zhang   } else {
1499*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_U, (size_t)fs->factBufferSize_M)));
1500*d52a580bSJunchao Zhang     fs->spsvBuffer_U = fs->factBuffer_M;
1501*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_L, fs->spsvBufferSize_L));
1502*d52a580bSJunchao Zhang   }
1503*d52a580bSJunchao Zhang 
1504*d52a580bSJunchao Zhang   /* ========================================================================== */
1505*d52a580bSJunchao Zhang   /* Perform analysis of ilu0 on M, SpSv on L and U                             */
1506*d52a580bSJunchao Zhang   /* The lower(upper) triangular part of M has the same sparsity pattern as L(U)*/
1507*d52a580bSJunchao Zhang   /* ========================================================================== */
1508*d52a580bSJunchao Zhang   int structural_zero;
1509*d52a580bSJunchao Zhang 
1510*d52a580bSJunchao Zhang   fs->policy_M = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
1511*d52a580bSJunchao Zhang   if (m)
1512*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseXcsrilu02_analysis(fs->handle, m, nz, /* hipsparseXcsrilu02 errors out with empty matrices (m=0) */
1513*d52a580bSJunchao Zhang                                                    fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ilu0Info_M, fs->policy_M, fs->factBuffer_M));
1514*d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1515*d52a580bSJunchao Zhang     /* Function hipsparseXcsrilu02_zeroPivot() is a blocking call. It calls hipDeviceSynchronize() to make sure all previous kernels are done. */
1516*d52a580bSJunchao Zhang     hipsparseStatus_t status;
1517*d52a580bSJunchao Zhang     status = hipsparseXcsrilu02_zeroPivot(fs->handle, fs->ilu0Info_M, &structural_zero);
1518*d52a580bSJunchao Zhang     PetscCheck(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Structural zero pivot detected in csrilu02: A(%d,%d) is missing", structural_zero, structural_zero);
1519*d52a580bSJunchao Zhang   }
1520*d52a580bSJunchao Zhang 
1521*d52a580bSJunchao Zhang   /* Estimate FLOPs of the numeric factorization */
1522*d52a580bSJunchao Zhang   {
1523*d52a580bSJunchao Zhang     Mat_SeqAIJ     *Aseq = (Mat_SeqAIJ *)A->data;
1524*d52a580bSJunchao Zhang     PetscInt       *Ai, nzRow, nzLeft;
1525*d52a580bSJunchao Zhang     PetscLogDouble  flops = 0.0;
1526*d52a580bSJunchao Zhang     const PetscInt *Adiag;
1527*d52a580bSJunchao Zhang 
1528*d52a580bSJunchao Zhang     PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, &Adiag, NULL));
1529*d52a580bSJunchao Zhang     Ai = Aseq->i;
1530*d52a580bSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
1531*d52a580bSJunchao Zhang       if (Ai[i] < Adiag[i] && Adiag[i] < Ai[i + 1]) { /* There are nonzeros left to the diagonal of row i */
1532*d52a580bSJunchao Zhang         nzRow  = Ai[i + 1] - Ai[i];
1533*d52a580bSJunchao Zhang         nzLeft = Adiag[i] - Ai[i];
1534*d52a580bSJunchao Zhang         /* We want to eliminate nonzeros left to the diagonal one by one. Assume each time, nonzeros right
1535*d52a580bSJunchao Zhang           and include the eliminated one will be updated, which incurs a multiplication and an addition.
1536*d52a580bSJunchao Zhang         */
1537*d52a580bSJunchao Zhang         nzLeft = (nzRow - 1) / 2;
1538*d52a580bSJunchao Zhang         flops += nzLeft * (2.0 * nzRow - nzLeft + 1);
1539*d52a580bSJunchao Zhang       }
1540*d52a580bSJunchao Zhang     }
1541*d52a580bSJunchao Zhang     fs->numericFactFlops = flops;
1542*d52a580bSJunchao Zhang   }
1543*d52a580bSJunchao Zhang   fact->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJHIPSPARSE_ILU0;
1544*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1545*d52a580bSJunchao Zhang }
1546*d52a580bSJunchao Zhang 
1547*d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_ICC0(Mat fact, Vec b, Vec x)
1548*d52a580bSJunchao Zhang {
1549*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs  = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1550*d52a580bSJunchao Zhang   Mat_SeqAIJ                    *aij = (Mat_SeqAIJ *)fact->data;
1551*d52a580bSJunchao Zhang   const PetscScalar             *barray;
1552*d52a580bSJunchao Zhang   PetscScalar                   *xarray;
1553*d52a580bSJunchao Zhang 
1554*d52a580bSJunchao Zhang   PetscFunctionBegin;
1555*d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayWrite(x, &xarray));
1556*d52a580bSJunchao Zhang   PetscCall(VecHIPGetArrayRead(b, &barray));
1557*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
1558*d52a580bSJunchao Zhang 
1559*d52a580bSJunchao Zhang   /* Solve L*y = b */
1560*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray));
1561*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y));
1562*d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1563*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */
1564*d52a580bSJunchao Zhang                                          fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L));
1565*d52a580bSJunchao Zhang   #else
1566*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */
1567*d52a580bSJunchao Zhang                                          fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L));
1568*d52a580bSJunchao Zhang   #endif
1569*d52a580bSJunchao Zhang   /* Solve Lt*x = y */
1570*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray));
1571*d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0)
1572*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
1573*d52a580bSJunchao Zhang                                          fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt));
1574*d52a580bSJunchao Zhang   #else
1575*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
1576*d52a580bSJunchao Zhang                                          fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
1577*d52a580bSJunchao Zhang   #endif
1578*d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayRead(b, &barray));
1579*d52a580bSJunchao Zhang   PetscCall(VecHIPRestoreArrayWrite(x, &xarray));
1580*d52a580bSJunchao Zhang 
1581*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
1582*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(2.0 * aij->nz - fact->rmap->n));
1583*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1584*d52a580bSJunchao Zhang }
1585*d52a580bSJunchao Zhang 
1586*d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorNumeric_SeqAIJHIPSPARSE_ICC0(Mat fact, Mat A, const MatFactorInfo *info)
1587*d52a580bSJunchao Zhang {
1588*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs    = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1589*d52a580bSJunchao Zhang   Mat_SeqAIJ                    *aij   = (Mat_SeqAIJ *)fact->data;
1590*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
1591*d52a580bSJunchao Zhang   CsrMatrix                     *Acsr;
1592*d52a580bSJunchao Zhang   PetscInt                       m, nz;
1593*d52a580bSJunchao Zhang   PetscBool                      flg;
1594*d52a580bSJunchao Zhang 
1595*d52a580bSJunchao Zhang   PetscFunctionBegin;
1596*d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1597*d52a580bSJunchao Zhang     PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
1598*d52a580bSJunchao Zhang     PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name);
1599*d52a580bSJunchao Zhang   }
1600*d52a580bSJunchao Zhang 
1601*d52a580bSJunchao Zhang   /* Copy A's value to fact */
1602*d52a580bSJunchao Zhang   m  = fact->rmap->n;
1603*d52a580bSJunchao Zhang   nz = aij->nz;
1604*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
1605*d52a580bSJunchao Zhang   Acsr = (CsrMatrix *)Acusp->mat->mat;
1606*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpyAsync(fs->csrVal, Acsr->values->data().get(), sizeof(PetscScalar) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1607*d52a580bSJunchao Zhang 
1608*d52a580bSJunchao Zhang   /* Factorize fact inplace */
1609*d52a580bSJunchao Zhang   /* Function csric02() only takes the lower triangular part of matrix A to perform factorization.
1610*d52a580bSJunchao Zhang      The matrix type must be HIPSPARSE_MATRIX_TYPE_GENERAL, the fill mode and diagonal type are ignored,
1611*d52a580bSJunchao Zhang      and the strictly upper triangular part is ignored and never touched. It does not matter if A is Hermitian or not.
1612*d52a580bSJunchao Zhang      In other words, from the point of view of csric02() A is Hermitian and only the lower triangular part is provided.
1613*d52a580bSJunchao Zhang    */
1614*d52a580bSJunchao Zhang   if (m) PetscCallHIPSPARSE(hipsparseXcsric02(fs->handle, m, nz, fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ic0Info_M, fs->policy_M, fs->factBuffer_M));
1615*d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1616*d52a580bSJunchao Zhang     int               numerical_zero;
1617*d52a580bSJunchao Zhang     hipsparseStatus_t status;
1618*d52a580bSJunchao Zhang     status = hipsparseXcsric02_zeroPivot(fs->handle, fs->ic0Info_M, &numerical_zero);
1619*d52a580bSJunchao Zhang     PetscAssert(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Numerical zero pivot detected in csric02: A(%d,%d) is zero", numerical_zero, numerical_zero);
1620*d52a580bSJunchao Zhang   }
1621*d52a580bSJunchao Zhang 
1622*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L));
1623*d52a580bSJunchao Zhang 
1624*d52a580bSJunchao Zhang   /* Note that hipsparse reports this error if we use double and HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE
1625*d52a580bSJunchao Zhang     ** On entry to hipsparseSpSV_analysis(): conjugate transpose (opA) is not supported for matA data type, current -> CUDA_R_64F
1626*d52a580bSJunchao Zhang   */
1627*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
1628*d52a580bSJunchao Zhang 
1629*d52a580bSJunchao Zhang   fact->offloadmask            = PETSC_OFFLOAD_GPU;
1630*d52a580bSJunchao Zhang   fact->ops->solve             = MatSolve_SeqAIJHIPSPARSE_ICC0;
1631*d52a580bSJunchao Zhang   fact->ops->solvetranspose    = MatSolve_SeqAIJHIPSPARSE_ICC0;
1632*d52a580bSJunchao Zhang   fact->ops->matsolve          = NULL;
1633*d52a580bSJunchao Zhang   fact->ops->matsolvetranspose = NULL;
1634*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(fs->numericFactFlops));
1635*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1636*d52a580bSJunchao Zhang }
1637*d52a580bSJunchao Zhang 
1638*d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJHIPSPARSE_ICC0(Mat fact, Mat A, IS perm, const MatFactorInfo *info)
1639*d52a580bSJunchao Zhang {
1640*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs  = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr;
1641*d52a580bSJunchao Zhang   Mat_SeqAIJ                    *aij = (Mat_SeqAIJ *)fact->data;
1642*d52a580bSJunchao Zhang   PetscInt                       m, nz;
1643*d52a580bSJunchao Zhang 
1644*d52a580bSJunchao Zhang   PetscFunctionBegin;
1645*d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1646*d52a580bSJunchao Zhang     PetscBool flg, diagDense;
1647*d52a580bSJunchao Zhang 
1648*d52a580bSJunchao Zhang     PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
1649*d52a580bSJunchao Zhang     PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name);
1650*d52a580bSJunchao Zhang     PetscCheck(A->rmap->n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must be square matrix, rows %" PetscInt_FMT " columns %" PetscInt_FMT, A->rmap->n, A->cmap->n);
1651*d52a580bSJunchao Zhang     PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, NULL, &diagDense));
1652*d52a580bSJunchao Zhang     PetscCheck(diagDense, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Matrix is missing diagonal entries");
1653*d52a580bSJunchao Zhang   }
1654*d52a580bSJunchao Zhang 
1655*d52a580bSJunchao Zhang   /* Free the old stale stuff */
1656*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&fs));
1657*d52a580bSJunchao Zhang 
1658*d52a580bSJunchao Zhang   /* Copy over A's meta data to fact. Note that we also allocated fact's i,j,a on host,
1659*d52a580bSJunchao Zhang      but they will not be used. Allocate them just for easy debugging.
1660*d52a580bSJunchao Zhang    */
1661*d52a580bSJunchao Zhang   PetscCall(MatDuplicateNoCreate_SeqAIJ(fact, A, MAT_DO_NOT_COPY_VALUES, PETSC_TRUE /*malloc*/));
1662*d52a580bSJunchao Zhang 
1663*d52a580bSJunchao Zhang   fact->offloadmask            = PETSC_OFFLOAD_BOTH;
1664*d52a580bSJunchao Zhang   fact->factortype             = MAT_FACTOR_ICC;
1665*d52a580bSJunchao Zhang   fact->info.factor_mallocs    = 0;
1666*d52a580bSJunchao Zhang   fact->info.fill_ratio_given  = info->fill;
1667*d52a580bSJunchao Zhang   fact->info.fill_ratio_needed = 1.0;
1668*d52a580bSJunchao Zhang 
1669*d52a580bSJunchao Zhang   aij->row = NULL;
1670*d52a580bSJunchao Zhang   aij->col = NULL;
1671*d52a580bSJunchao Zhang 
1672*d52a580bSJunchao Zhang   /* ====================================================================== */
1673*d52a580bSJunchao Zhang   /* Copy A's i, j to fact and also allocate the value array of fact.       */
1674*d52a580bSJunchao Zhang   /* We'll do in-place factorization on fact                                */
1675*d52a580bSJunchao Zhang   /* ====================================================================== */
1676*d52a580bSJunchao Zhang   const int *Ai, *Aj;
1677*d52a580bSJunchao Zhang 
1678*d52a580bSJunchao Zhang   m  = fact->rmap->n;
1679*d52a580bSJunchao Zhang   nz = aij->nz;
1680*d52a580bSJunchao Zhang 
1681*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->csrRowPtr, sizeof(int) * (m + 1)));
1682*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->csrColIdx, sizeof(int) * nz));
1683*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->csrVal, sizeof(PetscScalar) * nz));
1684*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEGetIJ(A, PETSC_FALSE, &Ai, &Aj)); /* Do not use compressed Ai */
1685*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpyAsync(fs->csrRowPtr, Ai, sizeof(int) * (m + 1), hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1686*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpyAsync(fs->csrColIdx, Aj, sizeof(int) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream));
1687*d52a580bSJunchao Zhang 
1688*d52a580bSJunchao Zhang   /* ====================================================================== */
1689*d52a580bSJunchao Zhang   /* Create mat descriptors for M, L                                        */
1690*d52a580bSJunchao Zhang   /* ====================================================================== */
1691*d52a580bSJunchao Zhang   hipsparseFillMode_t fillMode;
1692*d52a580bSJunchao Zhang   hipsparseDiagType_t diagType;
1693*d52a580bSJunchao Zhang 
1694*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateMatDescr(&fs->matDescr_M));
1695*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatIndexBase(fs->matDescr_M, HIPSPARSE_INDEX_BASE_ZERO));
1696*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatType(fs->matDescr_M, HIPSPARSE_MATRIX_TYPE_GENERAL));
1697*d52a580bSJunchao Zhang 
1698*d52a580bSJunchao Zhang   /* https://docs.amd.com/bundle/hipSPARSE-Documentation---hipSPARSE-documentation/page/usermanual.html/#hipsparse_8h_1a79e036b6c0680cb37e2aa53d3542a054
1699*d52a580bSJunchao Zhang     hipsparseDiagType_t: This type indicates if the matrix diagonal entries are unity. The diagonal elements are always
1700*d52a580bSJunchao Zhang     assumed to be present, but if HIPSPARSE_DIAG_TYPE_UNIT is passed to an API routine, then the routine assumes that
1701*d52a580bSJunchao Zhang     all diagonal entries are unity and will not read or modify those entries. Note that in this case the routine
1702*d52a580bSJunchao Zhang     assumes the diagonal entries are equal to one, regardless of what those entries are actually set to in memory.
1703*d52a580bSJunchao Zhang   */
1704*d52a580bSJunchao Zhang   fillMode = HIPSPARSE_FILL_MODE_LOWER;
1705*d52a580bSJunchao Zhang   diagType = HIPSPARSE_DIAG_TYPE_NON_UNIT;
1706*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsr(&fs->spMatDescr_L, m, m, nz, fs->csrRowPtr, fs->csrColIdx, fs->csrVal, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
1707*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_FILL_MODE, &fillMode, sizeof(fillMode)));
1708*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_DIAG_TYPE, &diagType, sizeof(diagType)));
1709*d52a580bSJunchao Zhang 
1710*d52a580bSJunchao Zhang   /* ========================================================================= */
1711*d52a580bSJunchao Zhang   /* Query buffer sizes for csric0, SpSV of L and Lt, and allocate buffers     */
1712*d52a580bSJunchao Zhang   /* ========================================================================= */
1713*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsric02Info(&fs->ic0Info_M));
1714*d52a580bSJunchao Zhang   if (m) PetscCallHIPSPARSE(hipsparseXcsric02_bufferSize(fs->handle, m, nz, fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ic0Info_M, &fs->factBufferSize_M));
1715*d52a580bSJunchao Zhang 
1716*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->X, sizeof(PetscScalar) * m));
1717*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&fs->Y, sizeof(PetscScalar) * m));
1718*d52a580bSJunchao Zhang 
1719*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_X, m, fs->X, hipsparse_scalartype));
1720*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_Y, m, fs->Y, hipsparse_scalartype));
1721*d52a580bSJunchao Zhang 
1722*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_L));
1723*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, &fs->spsvBufferSize_L));
1724*d52a580bSJunchao Zhang 
1725*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_Lt));
1726*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, &fs->spsvBufferSize_Lt));
1727*d52a580bSJunchao Zhang 
1728*d52a580bSJunchao Zhang   /* To save device memory, we make the factorization buffer share with one of the solver buffer.
1729*d52a580bSJunchao Zhang      See also comments in `MatILUFactorSymbolic_SeqAIJHIPSPARSE_ILU0()`.
1730*d52a580bSJunchao Zhang    */
1731*d52a580bSJunchao Zhang   if (fs->spsvBufferSize_L > fs->spsvBufferSize_Lt) {
1732*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_L, (size_t)fs->factBufferSize_M)));
1733*d52a580bSJunchao Zhang     fs->spsvBuffer_L = fs->factBuffer_M;
1734*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_Lt, fs->spsvBufferSize_Lt));
1735*d52a580bSJunchao Zhang   } else {
1736*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_Lt, (size_t)fs->factBufferSize_M)));
1737*d52a580bSJunchao Zhang     fs->spsvBuffer_Lt = fs->factBuffer_M;
1738*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_L, fs->spsvBufferSize_L));
1739*d52a580bSJunchao Zhang   }
1740*d52a580bSJunchao Zhang 
1741*d52a580bSJunchao Zhang   /* ========================================================================== */
1742*d52a580bSJunchao Zhang   /* Perform analysis of ic0 on M                                               */
1743*d52a580bSJunchao Zhang   /* The lower triangular part of M has the same sparsity pattern as L          */
1744*d52a580bSJunchao Zhang   /* ========================================================================== */
1745*d52a580bSJunchao Zhang   int structural_zero;
1746*d52a580bSJunchao Zhang 
1747*d52a580bSJunchao Zhang   fs->policy_M = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
1748*d52a580bSJunchao Zhang   if (m) PetscCallHIPSPARSE(hipsparseXcsric02_analysis(fs->handle, m, nz, fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ic0Info_M, fs->policy_M, fs->factBuffer_M));
1749*d52a580bSJunchao Zhang   if (PetscDefined(USE_DEBUG)) {
1750*d52a580bSJunchao Zhang     hipsparseStatus_t status;
1751*d52a580bSJunchao Zhang     /* Function hipsparseXcsric02_zeroPivot() is a blocking call. It calls hipDeviceSynchronize() to make sure all previous kernels are done. */
1752*d52a580bSJunchao Zhang     status = hipsparseXcsric02_zeroPivot(fs->handle, fs->ic0Info_M, &structural_zero);
1753*d52a580bSJunchao Zhang     PetscCheck(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Structural zero pivot detected in csric02: A(%d,%d) is missing", structural_zero, structural_zero);
1754*d52a580bSJunchao Zhang   }
1755*d52a580bSJunchao Zhang 
1756*d52a580bSJunchao Zhang   /* Estimate FLOPs of the numeric factorization */
1757*d52a580bSJunchao Zhang   {
1758*d52a580bSJunchao Zhang     Mat_SeqAIJ    *Aseq = (Mat_SeqAIJ *)A->data;
1759*d52a580bSJunchao Zhang     PetscInt      *Ai, nzRow, nzLeft;
1760*d52a580bSJunchao Zhang     PetscLogDouble flops = 0.0;
1761*d52a580bSJunchao Zhang 
1762*d52a580bSJunchao Zhang     Ai = Aseq->i;
1763*d52a580bSJunchao Zhang     for (PetscInt i = 0; i < m; i++) {
1764*d52a580bSJunchao Zhang       nzRow = Ai[i + 1] - Ai[i];
1765*d52a580bSJunchao Zhang       if (nzRow > 1) {
1766*d52a580bSJunchao Zhang         /* We want to eliminate nonzeros left to the diagonal one by one. Assume each time, nonzeros right
1767*d52a580bSJunchao Zhang           and include the eliminated one will be updated, which incurs a multiplication and an addition.
1768*d52a580bSJunchao Zhang         */
1769*d52a580bSJunchao Zhang         nzLeft = (nzRow - 1) / 2;
1770*d52a580bSJunchao Zhang         flops += nzLeft * (2.0 * nzRow - nzLeft + 1);
1771*d52a580bSJunchao Zhang       }
1772*d52a580bSJunchao Zhang     }
1773*d52a580bSJunchao Zhang     fs->numericFactFlops = flops;
1774*d52a580bSJunchao Zhang   }
1775*d52a580bSJunchao Zhang   fact->ops->choleskyfactornumeric = MatICCFactorNumeric_SeqAIJHIPSPARSE_ICC0;
1776*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1777*d52a580bSJunchao Zhang }
1778*d52a580bSJunchao Zhang #endif
1779*d52a580bSJunchao Zhang 
1780*d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1781*d52a580bSJunchao Zhang {
1782*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr;
1783*d52a580bSJunchao Zhang 
1784*d52a580bSJunchao Zhang   PetscFunctionBegin;
1785*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
1786*d52a580bSJunchao Zhang   PetscBool row_identity = PETSC_FALSE, col_identity = PETSC_FALSE;
1787*d52a580bSJunchao Zhang   if (!info->factoronhost) {
1788*d52a580bSJunchao Zhang     PetscCall(ISIdentity(isrow, &row_identity));
1789*d52a580bSJunchao Zhang     PetscCall(ISIdentity(iscol, &col_identity));
1790*d52a580bSJunchao Zhang   }
1791*d52a580bSJunchao Zhang   if (!info->levels && row_identity && col_identity) PetscCall(MatILUFactorSymbolic_SeqAIJHIPSPARSE_ILU0(B, A, isrow, iscol, info));
1792*d52a580bSJunchao Zhang   else
1793*d52a580bSJunchao Zhang #endif
1794*d52a580bSJunchao Zhang   {
1795*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors));
1796*d52a580bSJunchao Zhang     PetscCall(MatILUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
1797*d52a580bSJunchao Zhang     B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJHIPSPARSE;
1798*d52a580bSJunchao Zhang   }
1799*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1800*d52a580bSJunchao Zhang }
1801*d52a580bSJunchao Zhang 
1802*d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info)
1803*d52a580bSJunchao Zhang {
1804*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr;
1805*d52a580bSJunchao Zhang 
1806*d52a580bSJunchao Zhang   PetscFunctionBegin;
1807*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors));
1808*d52a580bSJunchao Zhang   PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info));
1809*d52a580bSJunchao Zhang   B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJHIPSPARSE;
1810*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1811*d52a580bSJunchao Zhang }
1812*d52a580bSJunchao Zhang 
1813*d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS perm, const MatFactorInfo *info)
1814*d52a580bSJunchao Zhang {
1815*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr;
1816*d52a580bSJunchao Zhang 
1817*d52a580bSJunchao Zhang   PetscFunctionBegin;
1818*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
1819*d52a580bSJunchao Zhang   PetscBool perm_identity = PETSC_FALSE;
1820*d52a580bSJunchao Zhang   if (!info->factoronhost) PetscCall(ISIdentity(perm, &perm_identity));
1821*d52a580bSJunchao Zhang   if (!info->levels && perm_identity) PetscCall(MatICCFactorSymbolic_SeqAIJHIPSPARSE_ICC0(B, A, perm, info));
1822*d52a580bSJunchao Zhang   else
1823*d52a580bSJunchao Zhang #endif
1824*d52a580bSJunchao Zhang   {
1825*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors));
1826*d52a580bSJunchao Zhang     PetscCall(MatICCFactorSymbolic_SeqAIJ(B, A, perm, info));
1827*d52a580bSJunchao Zhang     B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJHIPSPARSE;
1828*d52a580bSJunchao Zhang   }
1829*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1830*d52a580bSJunchao Zhang }
1831*d52a580bSJunchao Zhang 
1832*d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS perm, const MatFactorInfo *info)
1833*d52a580bSJunchao Zhang {
1834*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr;
1835*d52a580bSJunchao Zhang 
1836*d52a580bSJunchao Zhang   PetscFunctionBegin;
1837*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors));
1838*d52a580bSJunchao Zhang   PetscCall(MatCholeskyFactorSymbolic_SeqAIJ(B, A, perm, info));
1839*d52a580bSJunchao Zhang   B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJHIPSPARSE;
1840*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1841*d52a580bSJunchao Zhang }
1842*d52a580bSJunchao Zhang 
1843*d52a580bSJunchao Zhang static PetscErrorCode MatFactorGetSolverType_seqaij_hipsparse(Mat A, MatSolverType *type)
1844*d52a580bSJunchao Zhang {
1845*d52a580bSJunchao Zhang   PetscFunctionBegin;
1846*d52a580bSJunchao Zhang   *type = MATSOLVERHIPSPARSE;
1847*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1848*d52a580bSJunchao Zhang }
1849*d52a580bSJunchao Zhang 
1850*d52a580bSJunchao Zhang /*MC
1851*d52a580bSJunchao Zhang   MATSOLVERHIPSPARSE = "hipsparse" - A matrix type providing triangular solvers for sequential matrices
1852*d52a580bSJunchao Zhang   on a single GPU of type, `MATSEQAIJHIPSPARSE`. Currently supported
1853*d52a580bSJunchao Zhang   algorithms are ILU(k) and ICC(k). Typically, deeper factorizations (larger k) results in poorer
1854*d52a580bSJunchao Zhang   performance in the triangular solves. Full LU, and Cholesky decompositions can be solved through the
1855*d52a580bSJunchao Zhang   HipSPARSE triangular solve algorithm. However, the performance can be quite poor and thus these
1856*d52a580bSJunchao Zhang   algorithms are not recommended. This class does NOT support direct solver operations.
1857*d52a580bSJunchao Zhang 
1858*d52a580bSJunchao Zhang   Level: beginner
1859*d52a580bSJunchao Zhang 
1860*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MATSEQAIJHIPSPARSE`, `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJHIPSPARSE()`, `MATAIJHIPSPARSE`, `MatCreateAIJHIPSPARSE()`, `MatHIPSPARSESetFormat()`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation`
1861*d52a580bSJunchao Zhang M*/
1862*d52a580bSJunchao Zhang 
1863*d52a580bSJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_seqaijhipsparse_hipsparse(Mat A, MatFactorType ftype, Mat *B)
1864*d52a580bSJunchao Zhang {
1865*d52a580bSJunchao Zhang   PetscInt n = A->rmap->n;
1866*d52a580bSJunchao Zhang 
1867*d52a580bSJunchao Zhang   PetscFunctionBegin;
1868*d52a580bSJunchao Zhang   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B));
1869*d52a580bSJunchao Zhang   PetscCall(MatSetSizes(*B, n, n, n, n));
1870*d52a580bSJunchao Zhang   (*B)->factortype = ftype;
1871*d52a580bSJunchao Zhang   PetscCall(MatSetType(*B, MATSEQAIJHIPSPARSE));
1872*d52a580bSJunchao Zhang 
1873*d52a580bSJunchao Zhang   if (A->boundtocpu && A->bindingpropagates) PetscCall(MatBindToCPU(*B, PETSC_TRUE));
1874*d52a580bSJunchao Zhang   if (ftype == MAT_FACTOR_LU || ftype == MAT_FACTOR_ILU || ftype == MAT_FACTOR_ILUDT) {
1875*d52a580bSJunchao Zhang     PetscCall(MatSetBlockSizesFromMats(*B, A, A));
1876*d52a580bSJunchao Zhang     if (!A->boundtocpu) {
1877*d52a580bSJunchao Zhang       (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJHIPSPARSE;
1878*d52a580bSJunchao Zhang       (*B)->ops->lufactorsymbolic  = MatLUFactorSymbolic_SeqAIJHIPSPARSE;
1879*d52a580bSJunchao Zhang     } else {
1880*d52a580bSJunchao Zhang       (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJ;
1881*d52a580bSJunchao Zhang       (*B)->ops->lufactorsymbolic  = MatLUFactorSymbolic_SeqAIJ;
1882*d52a580bSJunchao Zhang     }
1883*d52a580bSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU]));
1884*d52a580bSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILU]));
1885*d52a580bSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILUDT]));
1886*d52a580bSJunchao Zhang   } else if (ftype == MAT_FACTOR_CHOLESKY || ftype == MAT_FACTOR_ICC) {
1887*d52a580bSJunchao Zhang     if (!A->boundtocpu) {
1888*d52a580bSJunchao Zhang       (*B)->ops->iccfactorsymbolic      = MatICCFactorSymbolic_SeqAIJHIPSPARSE;
1889*d52a580bSJunchao Zhang       (*B)->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqAIJHIPSPARSE;
1890*d52a580bSJunchao Zhang     } else {
1891*d52a580bSJunchao Zhang       (*B)->ops->iccfactorsymbolic      = MatICCFactorSymbolic_SeqAIJ;
1892*d52a580bSJunchao Zhang       (*B)->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqAIJ;
1893*d52a580bSJunchao Zhang     }
1894*d52a580bSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_CHOLESKY]));
1895*d52a580bSJunchao Zhang     PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ICC]));
1896*d52a580bSJunchao Zhang   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Factor type not supported for HIPSPARSE Matrix Types");
1897*d52a580bSJunchao Zhang 
1898*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL));
1899*d52a580bSJunchao Zhang   (*B)->canuseordering = PETSC_TRUE;
1900*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)*B, "MatFactorGetSolverType_C", MatFactorGetSolverType_seqaij_hipsparse));
1901*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1902*d52a580bSJunchao Zhang }
1903*d52a580bSJunchao Zhang 
1904*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSECopyFromGPU(Mat A)
1905*d52a580bSJunchao Zhang {
1906*d52a580bSJunchao Zhang   Mat_SeqAIJ          *a    = (Mat_SeqAIJ *)A->data;
1907*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
1908*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
1909*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr;
1910*d52a580bSJunchao Zhang #endif
1911*d52a580bSJunchao Zhang 
1912*d52a580bSJunchao Zhang   PetscFunctionBegin;
1913*d52a580bSJunchao Zhang   if (A->offloadmask == PETSC_OFFLOAD_GPU) {
1914*d52a580bSJunchao Zhang     PetscCall(PetscLogEventBegin(MAT_HIPSPARSECopyFromGPU, A, 0, 0, 0));
1915*d52a580bSJunchao Zhang     if (A->factortype == MAT_FACTOR_NONE) {
1916*d52a580bSJunchao Zhang       CsrMatrix *matrix = (CsrMatrix *)cusp->mat->mat;
1917*d52a580bSJunchao Zhang       PetscCallHIP(hipMemcpy(a->a, matrix->values->data().get(), a->nz * sizeof(PetscScalar), hipMemcpyDeviceToHost));
1918*d52a580bSJunchao Zhang     }
1919*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
1920*d52a580bSJunchao Zhang     else if (fs->csrVal) {
1921*d52a580bSJunchao Zhang       /* We have a factorized matrix on device and are able to copy it to host */
1922*d52a580bSJunchao Zhang       PetscCallHIP(hipMemcpy(a->a, fs->csrVal, a->nz * sizeof(PetscScalar), hipMemcpyDeviceToHost));
1923*d52a580bSJunchao Zhang     }
1924*d52a580bSJunchao Zhang #endif
1925*d52a580bSJunchao Zhang     else
1926*d52a580bSJunchao Zhang       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "No support for copying this type of factorized matrix from device to host");
1927*d52a580bSJunchao Zhang     PetscCall(PetscLogGpuToCpu(a->nz * sizeof(PetscScalar)));
1928*d52a580bSJunchao Zhang     PetscCall(PetscLogEventEnd(MAT_HIPSPARSECopyFromGPU, A, 0, 0, 0));
1929*d52a580bSJunchao Zhang     A->offloadmask = PETSC_OFFLOAD_BOTH;
1930*d52a580bSJunchao Zhang   }
1931*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1932*d52a580bSJunchao Zhang }
1933*d52a580bSJunchao Zhang 
1934*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetArray_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[])
1935*d52a580bSJunchao Zhang {
1936*d52a580bSJunchao Zhang   PetscFunctionBegin;
1937*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A));
1938*d52a580bSJunchao Zhang   *array = ((Mat_SeqAIJ *)A->data)->a;
1939*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1940*d52a580bSJunchao Zhang }
1941*d52a580bSJunchao Zhang 
1942*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[])
1943*d52a580bSJunchao Zhang {
1944*d52a580bSJunchao Zhang   PetscFunctionBegin;
1945*d52a580bSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_CPU;
1946*d52a580bSJunchao Zhang   *array         = NULL;
1947*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1948*d52a580bSJunchao Zhang }
1949*d52a580bSJunchao Zhang 
1950*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJHIPSPARSE(Mat A, const PetscScalar *array[])
1951*d52a580bSJunchao Zhang {
1952*d52a580bSJunchao Zhang   PetscFunctionBegin;
1953*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A));
1954*d52a580bSJunchao Zhang   *array = ((Mat_SeqAIJ *)A->data)->a;
1955*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1956*d52a580bSJunchao Zhang }
1957*d52a580bSJunchao Zhang 
1958*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJHIPSPARSE(Mat A, const PetscScalar *array[])
1959*d52a580bSJunchao Zhang {
1960*d52a580bSJunchao Zhang   PetscFunctionBegin;
1961*d52a580bSJunchao Zhang   *array = NULL;
1962*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1963*d52a580bSJunchao Zhang }
1964*d52a580bSJunchao Zhang 
1965*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[])
1966*d52a580bSJunchao Zhang {
1967*d52a580bSJunchao Zhang   PetscFunctionBegin;
1968*d52a580bSJunchao Zhang   *array = ((Mat_SeqAIJ *)A->data)->a;
1969*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1970*d52a580bSJunchao Zhang }
1971*d52a580bSJunchao Zhang 
1972*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[])
1973*d52a580bSJunchao Zhang {
1974*d52a580bSJunchao Zhang   PetscFunctionBegin;
1975*d52a580bSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_CPU;
1976*d52a580bSJunchao Zhang   *array         = NULL;
1977*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
1978*d52a580bSJunchao Zhang }
1979*d52a580bSJunchao Zhang 
1980*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJHIPSPARSE(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype)
1981*d52a580bSJunchao Zhang {
1982*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp;
1983*d52a580bSJunchao Zhang   CsrMatrix           *matrix;
1984*d52a580bSJunchao Zhang 
1985*d52a580bSJunchao Zhang   PetscFunctionBegin;
1986*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
1987*d52a580bSJunchao Zhang   PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "Not for factored matrix");
1988*d52a580bSJunchao Zhang   cusp = static_cast<Mat_SeqAIJHIPSPARSE *>(A->spptr);
1989*d52a580bSJunchao Zhang   PetscCheck(cusp != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "cusp is NULL");
1990*d52a580bSJunchao Zhang   matrix = (CsrMatrix *)cusp->mat->mat;
1991*d52a580bSJunchao Zhang 
1992*d52a580bSJunchao Zhang   if (i) {
1993*d52a580bSJunchao Zhang #if !defined(PETSC_USE_64BIT_INDICES)
1994*d52a580bSJunchao Zhang     *i = matrix->row_offsets->data().get();
1995*d52a580bSJunchao Zhang #else
1996*d52a580bSJunchao Zhang     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSparse does not supported 64-bit indices");
1997*d52a580bSJunchao Zhang #endif
1998*d52a580bSJunchao Zhang   }
1999*d52a580bSJunchao Zhang   if (j) {
2000*d52a580bSJunchao Zhang #if !defined(PETSC_USE_64BIT_INDICES)
2001*d52a580bSJunchao Zhang     *j = matrix->column_indices->data().get();
2002*d52a580bSJunchao Zhang #else
2003*d52a580bSJunchao Zhang     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSparse does not supported 64-bit indices");
2004*d52a580bSJunchao Zhang #endif
2005*d52a580bSJunchao Zhang   }
2006*d52a580bSJunchao Zhang   if (a) *a = matrix->values->data().get();
2007*d52a580bSJunchao Zhang   if (mtype) *mtype = PETSC_MEMTYPE_HIP;
2008*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2009*d52a580bSJunchao Zhang }
2010*d52a580bSJunchao Zhang 
2011*d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJHIPSPARSECopyToGPU(Mat A)
2012*d52a580bSJunchao Zhang {
2013*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
2014*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSEMultStruct *matstruct       = hipsparsestruct->mat;
2015*d52a580bSJunchao Zhang   Mat_SeqAIJ                    *a               = (Mat_SeqAIJ *)A->data;
2016*d52a580bSJunchao Zhang   PetscBool                      both            = PETSC_TRUE;
2017*d52a580bSJunchao Zhang   PetscInt                       m               = A->rmap->n, *ii, *ridx, tmp;
2018*d52a580bSJunchao Zhang 
2019*d52a580bSJunchao Zhang   PetscFunctionBegin;
2020*d52a580bSJunchao Zhang   PetscCheck(!A->boundtocpu, PETSC_COMM_SELF, PETSC_ERR_GPU, "Cannot copy to GPU");
2021*d52a580bSJunchao Zhang   if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) {
2022*d52a580bSJunchao Zhang     if (A->nonzerostate == hipsparsestruct->nonzerostate && hipsparsestruct->format == MAT_HIPSPARSE_CSR) { /* Copy values only */
2023*d52a580bSJunchao Zhang       CsrMatrix *matrix;
2024*d52a580bSJunchao Zhang       matrix = (CsrMatrix *)hipsparsestruct->mat->mat;
2025*d52a580bSJunchao Zhang 
2026*d52a580bSJunchao Zhang       PetscCheck(!a->nz || a->a, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CSR values");
2027*d52a580bSJunchao Zhang       PetscCall(PetscLogEventBegin(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0));
2028*d52a580bSJunchao Zhang       matrix->values->assign(a->a, a->a + a->nz);
2029*d52a580bSJunchao Zhang       PetscCallHIP(WaitForHIP());
2030*d52a580bSJunchao Zhang       PetscCall(PetscLogCpuToGpu(a->nz * sizeof(PetscScalar)));
2031*d52a580bSJunchao Zhang       PetscCall(PetscLogEventEnd(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0));
2032*d52a580bSJunchao Zhang       PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_FALSE));
2033*d52a580bSJunchao Zhang     } else {
2034*d52a580bSJunchao Zhang       PetscInt nnz;
2035*d52a580bSJunchao Zhang       PetscCall(PetscLogEventBegin(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0));
2036*d52a580bSJunchao Zhang       PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&hipsparsestruct->mat, hipsparsestruct->format));
2037*d52a580bSJunchao Zhang       PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_TRUE));
2038*d52a580bSJunchao Zhang       delete hipsparsestruct->workVector;
2039*d52a580bSJunchao Zhang       delete hipsparsestruct->rowoffsets_gpu;
2040*d52a580bSJunchao Zhang       hipsparsestruct->workVector     = NULL;
2041*d52a580bSJunchao Zhang       hipsparsestruct->rowoffsets_gpu = NULL;
2042*d52a580bSJunchao Zhang       try {
2043*d52a580bSJunchao Zhang         if (a->compressedrow.use) {
2044*d52a580bSJunchao Zhang           m    = a->compressedrow.nrows;
2045*d52a580bSJunchao Zhang           ii   = a->compressedrow.i;
2046*d52a580bSJunchao Zhang           ridx = a->compressedrow.rindex;
2047*d52a580bSJunchao Zhang         } else {
2048*d52a580bSJunchao Zhang           m    = A->rmap->n;
2049*d52a580bSJunchao Zhang           ii   = a->i;
2050*d52a580bSJunchao Zhang           ridx = NULL;
2051*d52a580bSJunchao Zhang         }
2052*d52a580bSJunchao Zhang         PetscCheck(ii, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CSR row data");
2053*d52a580bSJunchao Zhang         if (!a->a) {
2054*d52a580bSJunchao Zhang           nnz  = ii[m];
2055*d52a580bSJunchao Zhang           both = PETSC_FALSE;
2056*d52a580bSJunchao Zhang         } else nnz = a->nz;
2057*d52a580bSJunchao Zhang         PetscCheck(!nnz || a->j, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CSR column data");
2058*d52a580bSJunchao Zhang 
2059*d52a580bSJunchao Zhang         /* create hipsparse matrix */
2060*d52a580bSJunchao Zhang         hipsparsestruct->nrows = m;
2061*d52a580bSJunchao Zhang         matstruct              = new Mat_SeqAIJHIPSPARSEMultStruct;
2062*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateMatDescr(&matstruct->descr));
2063*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatIndexBase(matstruct->descr, HIPSPARSE_INDEX_BASE_ZERO));
2064*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatType(matstruct->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
2065*d52a580bSJunchao Zhang 
2066*d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc((void **)&matstruct->alpha_one, sizeof(PetscScalar)));
2067*d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc((void **)&matstruct->beta_zero, sizeof(PetscScalar)));
2068*d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc((void **)&matstruct->beta_one, sizeof(PetscScalar)));
2069*d52a580bSJunchao Zhang         PetscCallHIP(hipMemcpy(matstruct->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
2070*d52a580bSJunchao Zhang         PetscCallHIP(hipMemcpy(matstruct->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice));
2071*d52a580bSJunchao Zhang         PetscCallHIP(hipMemcpy(matstruct->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
2072*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetPointerMode(hipsparsestruct->handle, HIPSPARSE_POINTER_MODE_DEVICE));
2073*d52a580bSJunchao Zhang 
2074*d52a580bSJunchao Zhang         /* Build a hybrid/ellpack matrix if this option is chosen for the storage */
2075*d52a580bSJunchao Zhang         if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) {
2076*d52a580bSJunchao Zhang           /* set the matrix */
2077*d52a580bSJunchao Zhang           CsrMatrix *mat      = new CsrMatrix;
2078*d52a580bSJunchao Zhang           mat->num_rows       = m;
2079*d52a580bSJunchao Zhang           mat->num_cols       = A->cmap->n;
2080*d52a580bSJunchao Zhang           mat->num_entries    = nnz;
2081*d52a580bSJunchao Zhang           mat->row_offsets    = new THRUSTINTARRAY32(m + 1);
2082*d52a580bSJunchao Zhang           mat->column_indices = new THRUSTINTARRAY32(nnz);
2083*d52a580bSJunchao Zhang           mat->values         = new THRUSTARRAY(nnz);
2084*d52a580bSJunchao Zhang           mat->row_offsets->assign(ii, ii + m + 1);
2085*d52a580bSJunchao Zhang           mat->column_indices->assign(a->j, a->j + nnz);
2086*d52a580bSJunchao Zhang           if (a->a) mat->values->assign(a->a, a->a + nnz);
2087*d52a580bSJunchao Zhang 
2088*d52a580bSJunchao Zhang           /* assign the pointer */
2089*d52a580bSJunchao Zhang           matstruct->mat = mat;
2090*d52a580bSJunchao Zhang           if (mat->num_rows) { /* hipsparse errors on empty matrices! */
2091*d52a580bSJunchao Zhang             PetscCallHIPSPARSE(hipsparseCreateCsr(&matstruct->matDescr, mat->num_rows, mat->num_cols, mat->num_entries, mat->row_offsets->data().get(), mat->column_indices->data().get(), mat->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, /* row offset, col idx types due to THRUSTINTARRAY32 */
2092*d52a580bSJunchao Zhang                                                   HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
2093*d52a580bSJunchao Zhang           }
2094*d52a580bSJunchao Zhang         } else if (hipsparsestruct->format == MAT_HIPSPARSE_ELL || hipsparsestruct->format == MAT_HIPSPARSE_HYB) {
2095*d52a580bSJunchao Zhang           CsrMatrix *mat      = new CsrMatrix;
2096*d52a580bSJunchao Zhang           mat->num_rows       = m;
2097*d52a580bSJunchao Zhang           mat->num_cols       = A->cmap->n;
2098*d52a580bSJunchao Zhang           mat->num_entries    = nnz;
2099*d52a580bSJunchao Zhang           mat->row_offsets    = new THRUSTINTARRAY32(m + 1);
2100*d52a580bSJunchao Zhang           mat->column_indices = new THRUSTINTARRAY32(nnz);
2101*d52a580bSJunchao Zhang           mat->values         = new THRUSTARRAY(nnz);
2102*d52a580bSJunchao Zhang           mat->row_offsets->assign(ii, ii + m + 1);
2103*d52a580bSJunchao Zhang           mat->column_indices->assign(a->j, a->j + nnz);
2104*d52a580bSJunchao Zhang           if (a->a) mat->values->assign(a->a, a->a + nnz);
2105*d52a580bSJunchao Zhang 
2106*d52a580bSJunchao Zhang           hipsparseHybMat_t hybMat;
2107*d52a580bSJunchao Zhang           PetscCallHIPSPARSE(hipsparseCreateHybMat(&hybMat));
2108*d52a580bSJunchao Zhang           hipsparseHybPartition_t partition = hipsparsestruct->format == MAT_HIPSPARSE_ELL ? HIPSPARSE_HYB_PARTITION_MAX : HIPSPARSE_HYB_PARTITION_AUTO;
2109*d52a580bSJunchao Zhang           PetscCallHIPSPARSE(hipsparse_csr2hyb(hipsparsestruct->handle, mat->num_rows, mat->num_cols, matstruct->descr, mat->values->data().get(), mat->row_offsets->data().get(), mat->column_indices->data().get(), hybMat, 0, partition));
2110*d52a580bSJunchao Zhang           /* assign the pointer */
2111*d52a580bSJunchao Zhang           matstruct->mat = hybMat;
2112*d52a580bSJunchao Zhang 
2113*d52a580bSJunchao Zhang           if (mat) {
2114*d52a580bSJunchao Zhang             if (mat->values) delete (THRUSTARRAY *)mat->values;
2115*d52a580bSJunchao Zhang             if (mat->column_indices) delete (THRUSTINTARRAY32 *)mat->column_indices;
2116*d52a580bSJunchao Zhang             if (mat->row_offsets) delete (THRUSTINTARRAY32 *)mat->row_offsets;
2117*d52a580bSJunchao Zhang             delete (CsrMatrix *)mat;
2118*d52a580bSJunchao Zhang           }
2119*d52a580bSJunchao Zhang         }
2120*d52a580bSJunchao Zhang 
2121*d52a580bSJunchao Zhang         /* assign the compressed row indices */
2122*d52a580bSJunchao Zhang         if (a->compressedrow.use) {
2123*d52a580bSJunchao Zhang           hipsparsestruct->workVector = new THRUSTARRAY(m);
2124*d52a580bSJunchao Zhang           matstruct->cprowIndices     = new THRUSTINTARRAY(m);
2125*d52a580bSJunchao Zhang           matstruct->cprowIndices->assign(ridx, ridx + m);
2126*d52a580bSJunchao Zhang           tmp = m;
2127*d52a580bSJunchao Zhang         } else {
2128*d52a580bSJunchao Zhang           hipsparsestruct->workVector = NULL;
2129*d52a580bSJunchao Zhang           matstruct->cprowIndices     = NULL;
2130*d52a580bSJunchao Zhang           tmp                         = 0;
2131*d52a580bSJunchao Zhang         }
2132*d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu(((m + 1) + (a->nz)) * sizeof(int) + tmp * sizeof(PetscInt) + (3 + (a->nz)) * sizeof(PetscScalar)));
2133*d52a580bSJunchao Zhang 
2134*d52a580bSJunchao Zhang         /* assign the pointer */
2135*d52a580bSJunchao Zhang         hipsparsestruct->mat = matstruct;
2136*d52a580bSJunchao Zhang       } catch (char *ex) {
2137*d52a580bSJunchao Zhang         SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex);
2138*d52a580bSJunchao Zhang       }
2139*d52a580bSJunchao Zhang       PetscCallHIP(WaitForHIP());
2140*d52a580bSJunchao Zhang       PetscCall(PetscLogEventEnd(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0));
2141*d52a580bSJunchao Zhang       hipsparsestruct->nonzerostate = A->nonzerostate;
2142*d52a580bSJunchao Zhang     }
2143*d52a580bSJunchao Zhang     if (both) A->offloadmask = PETSC_OFFLOAD_BOTH;
2144*d52a580bSJunchao Zhang   }
2145*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2146*d52a580bSJunchao Zhang }
2147*d52a580bSJunchao Zhang 
2148*d52a580bSJunchao Zhang struct VecHIPPlusEquals {
2149*d52a580bSJunchao Zhang   template <typename Tuple>
2150*d52a580bSJunchao Zhang   __host__ __device__ void operator()(Tuple t)
2151*d52a580bSJunchao Zhang   {
2152*d52a580bSJunchao Zhang     thrust::get<1>(t) = thrust::get<1>(t) + thrust::get<0>(t);
2153*d52a580bSJunchao Zhang   }
2154*d52a580bSJunchao Zhang };
2155*d52a580bSJunchao Zhang 
2156*d52a580bSJunchao Zhang struct VecHIPEquals {
2157*d52a580bSJunchao Zhang   template <typename Tuple>
2158*d52a580bSJunchao Zhang   __host__ __device__ void operator()(Tuple t)
2159*d52a580bSJunchao Zhang   {
2160*d52a580bSJunchao Zhang     thrust::get<1>(t) = thrust::get<0>(t);
2161*d52a580bSJunchao Zhang   }
2162*d52a580bSJunchao Zhang };
2163*d52a580bSJunchao Zhang 
2164*d52a580bSJunchao Zhang struct VecHIPEqualsReverse {
2165*d52a580bSJunchao Zhang   template <typename Tuple>
2166*d52a580bSJunchao Zhang   __host__ __device__ void operator()(Tuple t)
2167*d52a580bSJunchao Zhang   {
2168*d52a580bSJunchao Zhang     thrust::get<0>(t) = thrust::get<1>(t);
2169*d52a580bSJunchao Zhang   }
2170*d52a580bSJunchao Zhang };
2171*d52a580bSJunchao Zhang 
2172*d52a580bSJunchao Zhang struct MatProductCtx_MatMatHipsparse {
2173*d52a580bSJunchao Zhang   PetscBool             cisdense;
2174*d52a580bSJunchao Zhang   PetscScalar          *Bt;
2175*d52a580bSJunchao Zhang   Mat                   X;
2176*d52a580bSJunchao Zhang   PetscBool             reusesym; /* Hipsparse does not have split symbolic and numeric phases for sparse matmat operations */
2177*d52a580bSJunchao Zhang   PetscLogDouble        flops;
2178*d52a580bSJunchao Zhang   CsrMatrix            *Bcsr;
2179*d52a580bSJunchao Zhang   hipsparseSpMatDescr_t matSpBDescr;
2180*d52a580bSJunchao Zhang   PetscBool             initialized; /* C = alpha op(A) op(B) + beta C */
2181*d52a580bSJunchao Zhang   hipsparseDnMatDescr_t matBDescr;
2182*d52a580bSJunchao Zhang   hipsparseDnMatDescr_t matCDescr;
2183*d52a580bSJunchao Zhang   PetscInt              Blda, Clda; /* Record leading dimensions of B and C here to detect changes*/
2184*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0)
2185*d52a580bSJunchao Zhang   void *dBuffer4, *dBuffer5;
2186*d52a580bSJunchao Zhang #endif
2187*d52a580bSJunchao Zhang   size_t                 mmBufferSize;
2188*d52a580bSJunchao Zhang   void                  *mmBuffer, *mmBuffer2; /* SpGEMM WorkEstimation buffer */
2189*d52a580bSJunchao Zhang   hipsparseSpGEMMDescr_t spgemmDesc;
2190*d52a580bSJunchao Zhang };
2191*d52a580bSJunchao Zhang 
2192*d52a580bSJunchao Zhang static PetscErrorCode MatProductCtxDestroy_MatMatHipsparse(PetscCtxRt data)
2193*d52a580bSJunchao Zhang {
2194*d52a580bSJunchao Zhang   MatProductCtx_MatMatHipsparse *mmdata = *(MatProductCtx_MatMatHipsparse **)data;
2195*d52a580bSJunchao Zhang 
2196*d52a580bSJunchao Zhang   PetscFunctionBegin;
2197*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(mmdata->Bt));
2198*d52a580bSJunchao Zhang   delete mmdata->Bcsr;
2199*d52a580bSJunchao Zhang   if (mmdata->matSpBDescr) PetscCallHIPSPARSE(hipsparseDestroySpMat(mmdata->matSpBDescr));
2200*d52a580bSJunchao Zhang   if (mmdata->matBDescr) PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matBDescr));
2201*d52a580bSJunchao Zhang   if (mmdata->matCDescr) PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matCDescr));
2202*d52a580bSJunchao Zhang   if (mmdata->spgemmDesc) PetscCallHIPSPARSE(hipsparseSpGEMM_destroyDescr(mmdata->spgemmDesc));
2203*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0)
2204*d52a580bSJunchao Zhang   if (mmdata->dBuffer4) PetscCallHIP(hipFree(mmdata->dBuffer4));
2205*d52a580bSJunchao Zhang   if (mmdata->dBuffer5) PetscCallHIP(hipFree(mmdata->dBuffer5));
2206*d52a580bSJunchao Zhang #endif
2207*d52a580bSJunchao Zhang   if (mmdata->mmBuffer) PetscCallHIP(hipFree(mmdata->mmBuffer));
2208*d52a580bSJunchao Zhang   if (mmdata->mmBuffer2) PetscCallHIP(hipFree(mmdata->mmBuffer2));
2209*d52a580bSJunchao Zhang   PetscCall(MatDestroy(&mmdata->X));
2210*d52a580bSJunchao Zhang   PetscCall(PetscFree(*(void **)data));
2211*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2212*d52a580bSJunchao Zhang }
2213*d52a580bSJunchao Zhang 
2214*d52a580bSJunchao Zhang static PetscErrorCode MatProductNumeric_SeqAIJHIPSPARSE_SeqDENSEHIP(Mat C)
2215*d52a580bSJunchao Zhang {
2216*d52a580bSJunchao Zhang   Mat_Product                   *product = C->product;
2217*d52a580bSJunchao Zhang   Mat                            A, B;
2218*d52a580bSJunchao Zhang   PetscInt                       m, n, blda, clda;
2219*d52a580bSJunchao Zhang   PetscBool                      flg, biship;
2220*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *cusp;
2221*d52a580bSJunchao Zhang   hipsparseOperation_t           opA;
2222*d52a580bSJunchao Zhang   const PetscScalar             *barray;
2223*d52a580bSJunchao Zhang   PetscScalar                   *carray;
2224*d52a580bSJunchao Zhang   MatProductCtx_MatMatHipsparse *mmdata;
2225*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSEMultStruct *mat;
2226*d52a580bSJunchao Zhang   CsrMatrix                     *csrmat;
2227*d52a580bSJunchao Zhang 
2228*d52a580bSJunchao Zhang   PetscFunctionBegin;
2229*d52a580bSJunchao Zhang   MatCheckProduct(C, 1);
2230*d52a580bSJunchao Zhang   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data empty");
2231*d52a580bSJunchao Zhang   mmdata = (MatProductCtx_MatMatHipsparse *)product->data;
2232*d52a580bSJunchao Zhang   A      = product->A;
2233*d52a580bSJunchao Zhang   B      = product->B;
2234*d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
2235*d52a580bSJunchao Zhang   PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name);
2236*d52a580bSJunchao Zhang   /* currently CopyToGpu does not copy if the matrix is bound to CPU
2237*d52a580bSJunchao Zhang      Instead of silently accepting the wrong answer, I prefer to raise the error */
2238*d52a580bSJunchao Zhang   PetscCheck(!A->boundtocpu, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Cannot bind to CPU a HIPSPARSE matrix between MatProductSymbolic and MatProductNumeric phases");
2239*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
2240*d52a580bSJunchao Zhang   cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
2241*d52a580bSJunchao Zhang   switch (product->type) {
2242*d52a580bSJunchao Zhang   case MATPRODUCT_AB:
2243*d52a580bSJunchao Zhang   case MATPRODUCT_PtAP:
2244*d52a580bSJunchao Zhang     mat = cusp->mat;
2245*d52a580bSJunchao Zhang     opA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
2246*d52a580bSJunchao Zhang     m   = A->rmap->n;
2247*d52a580bSJunchao Zhang     n   = B->cmap->n;
2248*d52a580bSJunchao Zhang     break;
2249*d52a580bSJunchao Zhang   case MATPRODUCT_AtB:
2250*d52a580bSJunchao Zhang     if (!A->form_explicit_transpose) {
2251*d52a580bSJunchao Zhang       mat = cusp->mat;
2252*d52a580bSJunchao Zhang       opA = HIPSPARSE_OPERATION_TRANSPOSE;
2253*d52a580bSJunchao Zhang     } else {
2254*d52a580bSJunchao Zhang       PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A));
2255*d52a580bSJunchao Zhang       mat = cusp->matTranspose;
2256*d52a580bSJunchao Zhang       opA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
2257*d52a580bSJunchao Zhang     }
2258*d52a580bSJunchao Zhang     m = A->cmap->n;
2259*d52a580bSJunchao Zhang     n = B->cmap->n;
2260*d52a580bSJunchao Zhang     break;
2261*d52a580bSJunchao Zhang   case MATPRODUCT_ABt:
2262*d52a580bSJunchao Zhang   case MATPRODUCT_RARt:
2263*d52a580bSJunchao Zhang     mat = cusp->mat;
2264*d52a580bSJunchao Zhang     opA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
2265*d52a580bSJunchao Zhang     m   = A->rmap->n;
2266*d52a580bSJunchao Zhang     n   = B->rmap->n;
2267*d52a580bSJunchao Zhang     break;
2268*d52a580bSJunchao Zhang   default:
2269*d52a580bSJunchao Zhang     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]);
2270*d52a580bSJunchao Zhang   }
2271*d52a580bSJunchao Zhang   PetscCheck(mat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
2272*d52a580bSJunchao Zhang   csrmat = (CsrMatrix *)mat->mat;
2273*d52a580bSJunchao Zhang   /* if the user passed a CPU matrix, copy the data to the GPU */
2274*d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATSEQDENSEHIP, &biship));
2275*d52a580bSJunchao Zhang   if (!biship) PetscCall(MatConvert(B, MATSEQDENSEHIP, MAT_INPLACE_MATRIX, &B));
2276*d52a580bSJunchao Zhang   PetscCall(MatDenseGetArrayReadAndMemType(B, &barray, nullptr));
2277*d52a580bSJunchao Zhang   PetscCall(MatDenseGetLDA(B, &blda));
2278*d52a580bSJunchao Zhang   if (product->type == MATPRODUCT_RARt || product->type == MATPRODUCT_PtAP) {
2279*d52a580bSJunchao Zhang     PetscCall(MatDenseGetArrayWriteAndMemType(mmdata->X, &carray, nullptr));
2280*d52a580bSJunchao Zhang     PetscCall(MatDenseGetLDA(mmdata->X, &clda));
2281*d52a580bSJunchao Zhang   } else {
2282*d52a580bSJunchao Zhang     PetscCall(MatDenseGetArrayWriteAndMemType(C, &carray, nullptr));
2283*d52a580bSJunchao Zhang     PetscCall(MatDenseGetLDA(C, &clda));
2284*d52a580bSJunchao Zhang   }
2285*d52a580bSJunchao Zhang 
2286*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
2287*d52a580bSJunchao Zhang   hipsparseOperation_t opB = (product->type == MATPRODUCT_ABt || product->type == MATPRODUCT_RARt) ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE;
2288*d52a580bSJunchao Zhang   /* (re)allocate mmBuffer if not initialized or LDAs are different */
2289*d52a580bSJunchao Zhang   if (!mmdata->initialized || mmdata->Blda != blda || mmdata->Clda != clda) {
2290*d52a580bSJunchao Zhang     size_t mmBufferSize;
2291*d52a580bSJunchao Zhang     if (mmdata->initialized && mmdata->Blda != blda) {
2292*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matBDescr));
2293*d52a580bSJunchao Zhang       mmdata->matBDescr = NULL;
2294*d52a580bSJunchao Zhang     }
2295*d52a580bSJunchao Zhang     if (!mmdata->matBDescr) {
2296*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreateDnMat(&mmdata->matBDescr, B->rmap->n, B->cmap->n, blda, (void *)barray, hipsparse_scalartype, HIPSPARSE_ORDER_COL));
2297*d52a580bSJunchao Zhang       mmdata->Blda = blda;
2298*d52a580bSJunchao Zhang     }
2299*d52a580bSJunchao Zhang     if (mmdata->initialized && mmdata->Clda != clda) {
2300*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matCDescr));
2301*d52a580bSJunchao Zhang       mmdata->matCDescr = NULL;
2302*d52a580bSJunchao Zhang     }
2303*d52a580bSJunchao Zhang     if (!mmdata->matCDescr) { /* matCDescr is for C or mmdata->X */
2304*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreateDnMat(&mmdata->matCDescr, m, n, clda, (void *)carray, hipsparse_scalartype, HIPSPARSE_ORDER_COL));
2305*d52a580bSJunchao Zhang       mmdata->Clda = clda;
2306*d52a580bSJunchao Zhang     }
2307*d52a580bSJunchao Zhang     if (!mat->matDescr) {
2308*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreateCsr(&mat->matDescr, csrmat->num_rows, csrmat->num_cols, csrmat->num_entries, csrmat->row_offsets->data().get(), csrmat->column_indices->data().get(), csrmat->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, /* row offset, col idx types due to THRUSTINTARRAY32 */
2309*d52a580bSJunchao Zhang                                             HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
2310*d52a580bSJunchao Zhang     }
2311*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpMM_bufferSize(cusp->handle, opA, opB, mat->alpha_one, mat->matDescr, mmdata->matBDescr, mat->beta_zero, mmdata->matCDescr, hipsparse_scalartype, cusp->spmmAlg, &mmBufferSize));
2312*d52a580bSJunchao Zhang     if ((mmdata->mmBuffer && mmdata->mmBufferSize < mmBufferSize) || !mmdata->mmBuffer) {
2313*d52a580bSJunchao Zhang       PetscCallHIP(hipFree(mmdata->mmBuffer));
2314*d52a580bSJunchao Zhang       PetscCallHIP(hipMalloc(&mmdata->mmBuffer, mmBufferSize));
2315*d52a580bSJunchao Zhang       mmdata->mmBufferSize = mmBufferSize;
2316*d52a580bSJunchao Zhang     }
2317*d52a580bSJunchao Zhang     mmdata->initialized = PETSC_TRUE;
2318*d52a580bSJunchao Zhang   } else {
2319*d52a580bSJunchao Zhang     /* to be safe, always update pointers of the mats */
2320*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpMatSetValues(mat->matDescr, csrmat->values->data().get()));
2321*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseDnMatSetValues(mmdata->matBDescr, (void *)barray));
2322*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseDnMatSetValues(mmdata->matCDescr, (void *)carray));
2323*d52a580bSJunchao Zhang   }
2324*d52a580bSJunchao Zhang 
2325*d52a580bSJunchao Zhang   /* do hipsparseSpMM, which supports transpose on B */
2326*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMM(cusp->handle, opA, opB, mat->alpha_one, mat->matDescr, mmdata->matBDescr, mat->beta_zero, mmdata->matCDescr, hipsparse_scalartype, cusp->spmmAlg, mmdata->mmBuffer));
2327*d52a580bSJunchao Zhang 
2328*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
2329*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(n * 2.0 * csrmat->num_entries));
2330*d52a580bSJunchao Zhang   PetscCall(MatDenseRestoreArrayReadAndMemType(B, &barray));
2331*d52a580bSJunchao Zhang   if (product->type == MATPRODUCT_RARt) {
2332*d52a580bSJunchao Zhang     PetscCall(MatDenseRestoreArrayWriteAndMemType(mmdata->X, &carray));
2333*d52a580bSJunchao Zhang     PetscCall(MatMatMultNumeric_SeqDenseHIP_SeqDenseHIP_Internal(B, mmdata->X, C, PETSC_FALSE, PETSC_FALSE));
2334*d52a580bSJunchao Zhang   } else if (product->type == MATPRODUCT_PtAP) {
2335*d52a580bSJunchao Zhang     PetscCall(MatDenseRestoreArrayWriteAndMemType(mmdata->X, &carray));
2336*d52a580bSJunchao Zhang     PetscCall(MatMatMultNumeric_SeqDenseHIP_SeqDenseHIP_Internal(B, mmdata->X, C, PETSC_TRUE, PETSC_FALSE));
2337*d52a580bSJunchao Zhang   } else PetscCall(MatDenseRestoreArrayWriteAndMemType(C, &carray));
2338*d52a580bSJunchao Zhang   if (mmdata->cisdense) PetscCall(MatConvert(C, MATSEQDENSE, MAT_INPLACE_MATRIX, &C));
2339*d52a580bSJunchao Zhang   if (!biship) PetscCall(MatConvert(B, MATSEQDENSE, MAT_INPLACE_MATRIX, &B));
2340*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2341*d52a580bSJunchao Zhang }
2342*d52a580bSJunchao Zhang 
2343*d52a580bSJunchao Zhang static PetscErrorCode MatProductSymbolic_SeqAIJHIPSPARSE_SeqDENSEHIP(Mat C)
2344*d52a580bSJunchao Zhang {
2345*d52a580bSJunchao Zhang   Mat_Product                   *product = C->product;
2346*d52a580bSJunchao Zhang   Mat                            A, B;
2347*d52a580bSJunchao Zhang   PetscInt                       m, n;
2348*d52a580bSJunchao Zhang   PetscBool                      cisdense, flg;
2349*d52a580bSJunchao Zhang   MatProductCtx_MatMatHipsparse *mmdata;
2350*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *cusp;
2351*d52a580bSJunchao Zhang 
2352*d52a580bSJunchao Zhang   PetscFunctionBegin;
2353*d52a580bSJunchao Zhang   MatCheckProduct(C, 1);
2354*d52a580bSJunchao Zhang   PetscCheck(!C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data not empty");
2355*d52a580bSJunchao Zhang   A = product->A;
2356*d52a580bSJunchao Zhang   B = product->B;
2357*d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
2358*d52a580bSJunchao Zhang   PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name);
2359*d52a580bSJunchao Zhang   cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
2360*d52a580bSJunchao Zhang   PetscCheck(cusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2361*d52a580bSJunchao Zhang   switch (product->type) {
2362*d52a580bSJunchao Zhang   case MATPRODUCT_AB:
2363*d52a580bSJunchao Zhang     m = A->rmap->n;
2364*d52a580bSJunchao Zhang     n = B->cmap->n;
2365*d52a580bSJunchao Zhang     break;
2366*d52a580bSJunchao Zhang   case MATPRODUCT_AtB:
2367*d52a580bSJunchao Zhang     m = A->cmap->n;
2368*d52a580bSJunchao Zhang     n = B->cmap->n;
2369*d52a580bSJunchao Zhang     break;
2370*d52a580bSJunchao Zhang   case MATPRODUCT_ABt:
2371*d52a580bSJunchao Zhang     m = A->rmap->n;
2372*d52a580bSJunchao Zhang     n = B->rmap->n;
2373*d52a580bSJunchao Zhang     break;
2374*d52a580bSJunchao Zhang   case MATPRODUCT_PtAP:
2375*d52a580bSJunchao Zhang     m = B->cmap->n;
2376*d52a580bSJunchao Zhang     n = B->cmap->n;
2377*d52a580bSJunchao Zhang     break;
2378*d52a580bSJunchao Zhang   case MATPRODUCT_RARt:
2379*d52a580bSJunchao Zhang     m = B->rmap->n;
2380*d52a580bSJunchao Zhang     n = B->rmap->n;
2381*d52a580bSJunchao Zhang     break;
2382*d52a580bSJunchao Zhang   default:
2383*d52a580bSJunchao Zhang     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]);
2384*d52a580bSJunchao Zhang   }
2385*d52a580bSJunchao Zhang   PetscCall(MatSetSizes(C, m, n, m, n));
2386*d52a580bSJunchao Zhang   /* if C is of type MATSEQDENSE (CPU), perform the operation on the GPU and then copy on the CPU */
2387*d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)C, MATSEQDENSE, &cisdense));
2388*d52a580bSJunchao Zhang   PetscCall(MatSetType(C, MATSEQDENSEHIP));
2389*d52a580bSJunchao Zhang 
2390*d52a580bSJunchao Zhang   /* product data */
2391*d52a580bSJunchao Zhang   PetscCall(PetscNew(&mmdata));
2392*d52a580bSJunchao Zhang   mmdata->cisdense = cisdense;
2393*d52a580bSJunchao Zhang   /* for these products we need intermediate storage */
2394*d52a580bSJunchao Zhang   if (product->type == MATPRODUCT_RARt || product->type == MATPRODUCT_PtAP) {
2395*d52a580bSJunchao Zhang     PetscCall(MatCreate(PetscObjectComm((PetscObject)C), &mmdata->X));
2396*d52a580bSJunchao Zhang     PetscCall(MatSetType(mmdata->X, MATSEQDENSEHIP));
2397*d52a580bSJunchao Zhang     /* do not preallocate, since the first call to MatDenseHIPGetArray will preallocate on the GPU for us */
2398*d52a580bSJunchao Zhang     if (product->type == MATPRODUCT_RARt) PetscCall(MatSetSizes(mmdata->X, A->rmap->n, B->rmap->n, A->rmap->n, B->rmap->n));
2399*d52a580bSJunchao Zhang     else PetscCall(MatSetSizes(mmdata->X, A->rmap->n, B->cmap->n, A->rmap->n, B->cmap->n));
2400*d52a580bSJunchao Zhang   }
2401*d52a580bSJunchao Zhang   C->product->data       = mmdata;
2402*d52a580bSJunchao Zhang   C->product->destroy    = MatProductCtxDestroy_MatMatHipsparse;
2403*d52a580bSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_SeqAIJHIPSPARSE_SeqDENSEHIP;
2404*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2405*d52a580bSJunchao Zhang }
2406*d52a580bSJunchao Zhang 
2407*d52a580bSJunchao Zhang static PetscErrorCode MatProductNumeric_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE(Mat C)
2408*d52a580bSJunchao Zhang {
2409*d52a580bSJunchao Zhang   Mat_Product                   *product = C->product;
2410*d52a580bSJunchao Zhang   Mat                            A, B;
2411*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *Acusp, *Bcusp, *Ccusp;
2412*d52a580bSJunchao Zhang   Mat_SeqAIJ                    *c = (Mat_SeqAIJ *)C->data;
2413*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSEMultStruct *Amat, *Bmat, *Cmat;
2414*d52a580bSJunchao Zhang   CsrMatrix                     *Acsr, *Bcsr, *Ccsr;
2415*d52a580bSJunchao Zhang   PetscBool                      flg;
2416*d52a580bSJunchao Zhang   MatProductType                 ptype;
2417*d52a580bSJunchao Zhang   MatProductCtx_MatMatHipsparse *mmdata;
2418*d52a580bSJunchao Zhang   hipsparseSpMatDescr_t          BmatSpDescr;
2419*d52a580bSJunchao Zhang   hipsparseOperation_t           opA = HIPSPARSE_OPERATION_NON_TRANSPOSE, opB = HIPSPARSE_OPERATION_NON_TRANSPOSE; /* hipSPARSE spgemm doesn't support transpose yet */
2420*d52a580bSJunchao Zhang 
2421*d52a580bSJunchao Zhang   PetscFunctionBegin;
2422*d52a580bSJunchao Zhang   MatCheckProduct(C, 1);
2423*d52a580bSJunchao Zhang   PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data empty");
2424*d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)C, MATSEQAIJHIPSPARSE, &flg));
2425*d52a580bSJunchao Zhang   PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for C of type %s", ((PetscObject)C)->type_name);
2426*d52a580bSJunchao Zhang   mmdata = (MatProductCtx_MatMatHipsparse *)C->product->data;
2427*d52a580bSJunchao Zhang   A      = product->A;
2428*d52a580bSJunchao Zhang   B      = product->B;
2429*d52a580bSJunchao Zhang   if (mmdata->reusesym) { /* this happens when api_user is true, meaning that the matrix values have been already computed in the MatProductSymbolic phase */
2430*d52a580bSJunchao Zhang     mmdata->reusesym = PETSC_FALSE;
2431*d52a580bSJunchao Zhang     Ccusp            = (Mat_SeqAIJHIPSPARSE *)C->spptr;
2432*d52a580bSJunchao Zhang     PetscCheck(Ccusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2433*d52a580bSJunchao Zhang     Cmat = Ccusp->mat;
2434*d52a580bSJunchao Zhang     PetscCheck(Cmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C mult struct for product type %s", MatProductTypes[C->product->type]);
2435*d52a580bSJunchao Zhang     Ccsr = (CsrMatrix *)Cmat->mat;
2436*d52a580bSJunchao Zhang     PetscCheck(Ccsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C CSR struct");
2437*d52a580bSJunchao Zhang     goto finalize;
2438*d52a580bSJunchao Zhang   }
2439*d52a580bSJunchao Zhang   if (!c->nz) goto finalize;
2440*d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
2441*d52a580bSJunchao Zhang   PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name);
2442*d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATSEQAIJHIPSPARSE, &flg));
2443*d52a580bSJunchao Zhang   PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for B of type %s", ((PetscObject)B)->type_name);
2444*d52a580bSJunchao Zhang   PetscCheck(!A->boundtocpu, PetscObjectComm((PetscObject)C), PETSC_ERR_ARG_WRONG, "Cannot bind to CPU a HIPSPARSE matrix between MatProductSymbolic and MatProductNumeric phases");
2445*d52a580bSJunchao Zhang   PetscCheck(!B->boundtocpu, PetscObjectComm((PetscObject)C), PETSC_ERR_ARG_WRONG, "Cannot bind to CPU a HIPSPARSE matrix between MatProductSymbolic and MatProductNumeric phases");
2446*d52a580bSJunchao Zhang   Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
2447*d52a580bSJunchao Zhang   Bcusp = (Mat_SeqAIJHIPSPARSE *)B->spptr;
2448*d52a580bSJunchao Zhang   Ccusp = (Mat_SeqAIJHIPSPARSE *)C->spptr;
2449*d52a580bSJunchao Zhang   PetscCheck(Acusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2450*d52a580bSJunchao Zhang   PetscCheck(Bcusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2451*d52a580bSJunchao Zhang   PetscCheck(Ccusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2452*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
2453*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B));
2454*d52a580bSJunchao Zhang 
2455*d52a580bSJunchao Zhang   ptype = product->type;
2456*d52a580bSJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
2457*d52a580bSJunchao Zhang     ptype = MATPRODUCT_AB;
2458*d52a580bSJunchao Zhang     PetscCheck(product->symbolic_used_the_fact_A_is_symmetric, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Symbolic should have been built using the fact that A is symmetric");
2459*d52a580bSJunchao Zhang   }
2460*d52a580bSJunchao Zhang   if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
2461*d52a580bSJunchao Zhang     ptype = MATPRODUCT_AB;
2462*d52a580bSJunchao Zhang     PetscCheck(product->symbolic_used_the_fact_B_is_symmetric, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Symbolic should have been built using the fact that B is symmetric");
2463*d52a580bSJunchao Zhang   }
2464*d52a580bSJunchao Zhang   switch (ptype) {
2465*d52a580bSJunchao Zhang   case MATPRODUCT_AB:
2466*d52a580bSJunchao Zhang     Amat = Acusp->mat;
2467*d52a580bSJunchao Zhang     Bmat = Bcusp->mat;
2468*d52a580bSJunchao Zhang     break;
2469*d52a580bSJunchao Zhang   case MATPRODUCT_AtB:
2470*d52a580bSJunchao Zhang     Amat = Acusp->matTranspose;
2471*d52a580bSJunchao Zhang     Bmat = Bcusp->mat;
2472*d52a580bSJunchao Zhang     break;
2473*d52a580bSJunchao Zhang   case MATPRODUCT_ABt:
2474*d52a580bSJunchao Zhang     Amat = Acusp->mat;
2475*d52a580bSJunchao Zhang     Bmat = Bcusp->matTranspose;
2476*d52a580bSJunchao Zhang     break;
2477*d52a580bSJunchao Zhang   default:
2478*d52a580bSJunchao Zhang     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]);
2479*d52a580bSJunchao Zhang   }
2480*d52a580bSJunchao Zhang   Cmat = Ccusp->mat;
2481*d52a580bSJunchao Zhang   PetscCheck(Amat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A mult struct for product type %s", MatProductTypes[ptype]);
2482*d52a580bSJunchao Zhang   PetscCheck(Bmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B mult struct for product type %s", MatProductTypes[ptype]);
2483*d52a580bSJunchao Zhang   PetscCheck(Cmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C mult struct for product type %s", MatProductTypes[ptype]);
2484*d52a580bSJunchao Zhang   Acsr = (CsrMatrix *)Amat->mat;
2485*d52a580bSJunchao Zhang   Bcsr = mmdata->Bcsr ? mmdata->Bcsr : (CsrMatrix *)Bmat->mat; /* B may be in compressed row storage */
2486*d52a580bSJunchao Zhang   Ccsr = (CsrMatrix *)Cmat->mat;
2487*d52a580bSJunchao Zhang   PetscCheck(Acsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A CSR struct");
2488*d52a580bSJunchao Zhang   PetscCheck(Bcsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B CSR struct");
2489*d52a580bSJunchao Zhang   PetscCheck(Ccsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C CSR struct");
2490*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
2491*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 0, 0)
2492*d52a580bSJunchao Zhang   BmatSpDescr = mmdata->Bcsr ? mmdata->matSpBDescr : Bmat->matDescr; /* B may be in compressed row storage */
2493*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_DEVICE));
2494*d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0)
2495*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMMreuse_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc));
2496*d52a580bSJunchao Zhang   #else
2497*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &mmdata->mmBufferSize, mmdata->mmBuffer));
2498*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_copy(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc));
2499*d52a580bSJunchao Zhang   #endif
2500*d52a580bSJunchao Zhang #else
2501*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparse_csr_spgemm(Ccusp->handle, opA, opB, Acsr->num_rows, Bcsr->num_cols, Acsr->num_cols, Amat->descr, Acsr->num_entries, Acsr->values->data().get(), Acsr->row_offsets->data().get(), Acsr->column_indices->data().get(), Bmat->descr,
2502*d52a580bSJunchao Zhang                                           Bcsr->num_entries, Bcsr->values->data().get(), Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Cmat->descr, Ccsr->values->data().get(), Ccsr->row_offsets->data().get(),
2503*d52a580bSJunchao Zhang                                           Ccsr->column_indices->data().get()));
2504*d52a580bSJunchao Zhang #endif
2505*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(mmdata->flops));
2506*d52a580bSJunchao Zhang   PetscCallHIP(WaitForHIP());
2507*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
2508*d52a580bSJunchao Zhang   C->offloadmask = PETSC_OFFLOAD_GPU;
2509*d52a580bSJunchao Zhang finalize:
2510*d52a580bSJunchao Zhang   /* shorter version of MatAssemblyEnd_SeqAIJ */
2511*d52a580bSJunchao Zhang   PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded, %" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz));
2512*d52a580bSJunchao Zhang   PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n"));
2513*d52a580bSJunchao Zhang   PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax));
2514*d52a580bSJunchao Zhang   c->reallocs = 0;
2515*d52a580bSJunchao Zhang   C->info.mallocs += 0;
2516*d52a580bSJunchao Zhang   C->info.nz_unneeded = 0;
2517*d52a580bSJunchao Zhang   C->assembled = C->was_assembled = PETSC_TRUE;
2518*d52a580bSJunchao Zhang   C->num_ass++;
2519*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2520*d52a580bSJunchao Zhang }
2521*d52a580bSJunchao Zhang 
2522*d52a580bSJunchao Zhang static PetscErrorCode MatProductSymbolic_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE(Mat C)
2523*d52a580bSJunchao Zhang {
2524*d52a580bSJunchao Zhang   Mat_Product                   *product = C->product;
2525*d52a580bSJunchao Zhang   Mat                            A, B;
2526*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *Acusp, *Bcusp, *Ccusp;
2527*d52a580bSJunchao Zhang   Mat_SeqAIJ                    *a, *b, *c;
2528*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSEMultStruct *Amat, *Bmat, *Cmat;
2529*d52a580bSJunchao Zhang   CsrMatrix                     *Acsr, *Bcsr, *Ccsr;
2530*d52a580bSJunchao Zhang   PetscInt                       i, j, m, n, k;
2531*d52a580bSJunchao Zhang   PetscBool                      flg;
2532*d52a580bSJunchao Zhang   MatProductType                 ptype;
2533*d52a580bSJunchao Zhang   MatProductCtx_MatMatHipsparse *mmdata;
2534*d52a580bSJunchao Zhang   PetscLogDouble                 flops;
2535*d52a580bSJunchao Zhang   PetscBool                      biscompressed, ciscompressed;
2536*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 0, 0)
2537*d52a580bSJunchao Zhang   int64_t               C_num_rows1, C_num_cols1, C_nnz1;
2538*d52a580bSJunchao Zhang   hipsparseSpMatDescr_t BmatSpDescr;
2539*d52a580bSJunchao Zhang #else
2540*d52a580bSJunchao Zhang   int cnz;
2541*d52a580bSJunchao Zhang #endif
2542*d52a580bSJunchao Zhang   hipsparseOperation_t opA = HIPSPARSE_OPERATION_NON_TRANSPOSE, opB = HIPSPARSE_OPERATION_NON_TRANSPOSE; /* hipSPARSE spgemm doesn't support transpose yet */
2543*d52a580bSJunchao Zhang 
2544*d52a580bSJunchao Zhang   PetscFunctionBegin;
2545*d52a580bSJunchao Zhang   MatCheckProduct(C, 1);
2546*d52a580bSJunchao Zhang   PetscCheck(!C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data not empty");
2547*d52a580bSJunchao Zhang   A = product->A;
2548*d52a580bSJunchao Zhang   B = product->B;
2549*d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg));
2550*d52a580bSJunchao Zhang   PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name);
2551*d52a580bSJunchao Zhang   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATSEQAIJHIPSPARSE, &flg));
2552*d52a580bSJunchao Zhang   PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for B of type %s", ((PetscObject)B)->type_name);
2553*d52a580bSJunchao Zhang   a = (Mat_SeqAIJ *)A->data;
2554*d52a580bSJunchao Zhang   b = (Mat_SeqAIJ *)B->data;
2555*d52a580bSJunchao Zhang   /* product data */
2556*d52a580bSJunchao Zhang   PetscCall(PetscNew(&mmdata));
2557*d52a580bSJunchao Zhang   C->product->data    = mmdata;
2558*d52a580bSJunchao Zhang   C->product->destroy = MatProductCtxDestroy_MatMatHipsparse;
2559*d52a580bSJunchao Zhang 
2560*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
2561*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B));
2562*d52a580bSJunchao Zhang   Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr; /* Access spptr after MatSeqAIJHIPSPARSECopyToGPU, not before */
2563*d52a580bSJunchao Zhang   Bcusp = (Mat_SeqAIJHIPSPARSE *)B->spptr;
2564*d52a580bSJunchao Zhang   PetscCheck(Acusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2565*d52a580bSJunchao Zhang   PetscCheck(Bcusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format");
2566*d52a580bSJunchao Zhang 
2567*d52a580bSJunchao Zhang   ptype = product->type;
2568*d52a580bSJunchao Zhang   if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) {
2569*d52a580bSJunchao Zhang     ptype                                          = MATPRODUCT_AB;
2570*d52a580bSJunchao Zhang     product->symbolic_used_the_fact_A_is_symmetric = PETSC_TRUE;
2571*d52a580bSJunchao Zhang   }
2572*d52a580bSJunchao Zhang   if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) {
2573*d52a580bSJunchao Zhang     ptype                                          = MATPRODUCT_AB;
2574*d52a580bSJunchao Zhang     product->symbolic_used_the_fact_B_is_symmetric = PETSC_TRUE;
2575*d52a580bSJunchao Zhang   }
2576*d52a580bSJunchao Zhang   biscompressed = PETSC_FALSE;
2577*d52a580bSJunchao Zhang   ciscompressed = PETSC_FALSE;
2578*d52a580bSJunchao Zhang   switch (ptype) {
2579*d52a580bSJunchao Zhang   case MATPRODUCT_AB:
2580*d52a580bSJunchao Zhang     m    = A->rmap->n;
2581*d52a580bSJunchao Zhang     n    = B->cmap->n;
2582*d52a580bSJunchao Zhang     k    = A->cmap->n;
2583*d52a580bSJunchao Zhang     Amat = Acusp->mat;
2584*d52a580bSJunchao Zhang     Bmat = Bcusp->mat;
2585*d52a580bSJunchao Zhang     if (a->compressedrow.use) ciscompressed = PETSC_TRUE;
2586*d52a580bSJunchao Zhang     if (b->compressedrow.use) biscompressed = PETSC_TRUE;
2587*d52a580bSJunchao Zhang     break;
2588*d52a580bSJunchao Zhang   case MATPRODUCT_AtB:
2589*d52a580bSJunchao Zhang     m = A->cmap->n;
2590*d52a580bSJunchao Zhang     n = B->cmap->n;
2591*d52a580bSJunchao Zhang     k = A->rmap->n;
2592*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A));
2593*d52a580bSJunchao Zhang     Amat = Acusp->matTranspose;
2594*d52a580bSJunchao Zhang     Bmat = Bcusp->mat;
2595*d52a580bSJunchao Zhang     if (b->compressedrow.use) biscompressed = PETSC_TRUE;
2596*d52a580bSJunchao Zhang     break;
2597*d52a580bSJunchao Zhang   case MATPRODUCT_ABt:
2598*d52a580bSJunchao Zhang     m = A->rmap->n;
2599*d52a580bSJunchao Zhang     n = B->rmap->n;
2600*d52a580bSJunchao Zhang     k = A->cmap->n;
2601*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(B));
2602*d52a580bSJunchao Zhang     Amat = Acusp->mat;
2603*d52a580bSJunchao Zhang     Bmat = Bcusp->matTranspose;
2604*d52a580bSJunchao Zhang     if (a->compressedrow.use) ciscompressed = PETSC_TRUE;
2605*d52a580bSJunchao Zhang     break;
2606*d52a580bSJunchao Zhang   default:
2607*d52a580bSJunchao Zhang     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]);
2608*d52a580bSJunchao Zhang   }
2609*d52a580bSJunchao Zhang 
2610*d52a580bSJunchao Zhang   /* create hipsparse matrix */
2611*d52a580bSJunchao Zhang   PetscCall(MatSetSizes(C, m, n, m, n));
2612*d52a580bSJunchao Zhang   PetscCall(MatSetType(C, MATSEQAIJHIPSPARSE));
2613*d52a580bSJunchao Zhang   c     = (Mat_SeqAIJ *)C->data;
2614*d52a580bSJunchao Zhang   Ccusp = (Mat_SeqAIJHIPSPARSE *)C->spptr;
2615*d52a580bSJunchao Zhang   Cmat  = new Mat_SeqAIJHIPSPARSEMultStruct;
2616*d52a580bSJunchao Zhang   Ccsr  = new CsrMatrix;
2617*d52a580bSJunchao Zhang 
2618*d52a580bSJunchao Zhang   c->compressedrow.use = ciscompressed;
2619*d52a580bSJunchao Zhang   if (c->compressedrow.use) { /* if a is in compressed row, than c will be in compressed row format */
2620*d52a580bSJunchao Zhang     c->compressedrow.nrows = a->compressedrow.nrows;
2621*d52a580bSJunchao Zhang     PetscCall(PetscMalloc2(c->compressedrow.nrows + 1, &c->compressedrow.i, c->compressedrow.nrows, &c->compressedrow.rindex));
2622*d52a580bSJunchao Zhang     PetscCall(PetscArraycpy(c->compressedrow.rindex, a->compressedrow.rindex, c->compressedrow.nrows));
2623*d52a580bSJunchao Zhang     Ccusp->workVector  = new THRUSTARRAY(c->compressedrow.nrows);
2624*d52a580bSJunchao Zhang     Cmat->cprowIndices = new THRUSTINTARRAY(c->compressedrow.nrows);
2625*d52a580bSJunchao Zhang     Cmat->cprowIndices->assign(c->compressedrow.rindex, c->compressedrow.rindex + c->compressedrow.nrows);
2626*d52a580bSJunchao Zhang   } else {
2627*d52a580bSJunchao Zhang     c->compressedrow.nrows  = 0;
2628*d52a580bSJunchao Zhang     c->compressedrow.i      = NULL;
2629*d52a580bSJunchao Zhang     c->compressedrow.rindex = NULL;
2630*d52a580bSJunchao Zhang     Ccusp->workVector       = NULL;
2631*d52a580bSJunchao Zhang     Cmat->cprowIndices      = NULL;
2632*d52a580bSJunchao Zhang   }
2633*d52a580bSJunchao Zhang   Ccusp->nrows      = ciscompressed ? c->compressedrow.nrows : m;
2634*d52a580bSJunchao Zhang   Ccusp->mat        = Cmat;
2635*d52a580bSJunchao Zhang   Ccusp->mat->mat   = Ccsr;
2636*d52a580bSJunchao Zhang   Ccsr->num_rows    = Ccusp->nrows;
2637*d52a580bSJunchao Zhang   Ccsr->num_cols    = n;
2638*d52a580bSJunchao Zhang   Ccsr->row_offsets = new THRUSTINTARRAY32(Ccusp->nrows + 1);
2639*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateMatDescr(&Cmat->descr));
2640*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatIndexBase(Cmat->descr, HIPSPARSE_INDEX_BASE_ZERO));
2641*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetMatType(Cmat->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
2642*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&Cmat->alpha_one, sizeof(PetscScalar)));
2643*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&Cmat->beta_zero, sizeof(PetscScalar)));
2644*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&Cmat->beta_one, sizeof(PetscScalar)));
2645*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(Cmat->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
2646*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(Cmat->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice));
2647*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(Cmat->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
2648*d52a580bSJunchao Zhang   if (!Ccsr->num_rows || !Ccsr->num_cols || !a->nz || !b->nz) { /* hipsparse raise errors in different calls when matrices have zero rows/columns! */
2649*d52a580bSJunchao Zhang     thrust::fill(thrust::device, Ccsr->row_offsets->begin(), Ccsr->row_offsets->end(), 0);
2650*d52a580bSJunchao Zhang     c->nz                = 0;
2651*d52a580bSJunchao Zhang     Ccsr->column_indices = new THRUSTINTARRAY32(c->nz);
2652*d52a580bSJunchao Zhang     Ccsr->values         = new THRUSTARRAY(c->nz);
2653*d52a580bSJunchao Zhang     goto finalizesym;
2654*d52a580bSJunchao Zhang   }
2655*d52a580bSJunchao Zhang 
2656*d52a580bSJunchao Zhang   PetscCheck(Amat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A mult struct for product type %s", MatProductTypes[ptype]);
2657*d52a580bSJunchao Zhang   PetscCheck(Bmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B mult struct for product type %s", MatProductTypes[ptype]);
2658*d52a580bSJunchao Zhang   Acsr = (CsrMatrix *)Amat->mat;
2659*d52a580bSJunchao Zhang   if (!biscompressed) {
2660*d52a580bSJunchao Zhang     Bcsr        = (CsrMatrix *)Bmat->mat;
2661*d52a580bSJunchao Zhang     BmatSpDescr = Bmat->matDescr;
2662*d52a580bSJunchao Zhang   } else { /* we need to use row offsets for the full matrix */
2663*d52a580bSJunchao Zhang     CsrMatrix *cBcsr     = (CsrMatrix *)Bmat->mat;
2664*d52a580bSJunchao Zhang     Bcsr                 = new CsrMatrix;
2665*d52a580bSJunchao Zhang     Bcsr->num_rows       = B->rmap->n;
2666*d52a580bSJunchao Zhang     Bcsr->num_cols       = cBcsr->num_cols;
2667*d52a580bSJunchao Zhang     Bcsr->num_entries    = cBcsr->num_entries;
2668*d52a580bSJunchao Zhang     Bcsr->column_indices = cBcsr->column_indices;
2669*d52a580bSJunchao Zhang     Bcsr->values         = cBcsr->values;
2670*d52a580bSJunchao Zhang     if (!Bcusp->rowoffsets_gpu) {
2671*d52a580bSJunchao Zhang       Bcusp->rowoffsets_gpu = new THRUSTINTARRAY32(B->rmap->n + 1);
2672*d52a580bSJunchao Zhang       Bcusp->rowoffsets_gpu->assign(b->i, b->i + B->rmap->n + 1);
2673*d52a580bSJunchao Zhang       PetscCall(PetscLogCpuToGpu((B->rmap->n + 1) * sizeof(PetscInt)));
2674*d52a580bSJunchao Zhang     }
2675*d52a580bSJunchao Zhang     Bcsr->row_offsets = Bcusp->rowoffsets_gpu;
2676*d52a580bSJunchao Zhang     mmdata->Bcsr      = Bcsr;
2677*d52a580bSJunchao Zhang     if (Bcsr->num_rows && Bcsr->num_cols) {
2678*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreateCsr(&mmdata->matSpBDescr, Bcsr->num_rows, Bcsr->num_cols, Bcsr->num_entries, Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Bcsr->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
2679*d52a580bSJunchao Zhang     }
2680*d52a580bSJunchao Zhang     BmatSpDescr = mmdata->matSpBDescr;
2681*d52a580bSJunchao Zhang   }
2682*d52a580bSJunchao Zhang   PetscCheck(Acsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A CSR struct");
2683*d52a580bSJunchao Zhang   PetscCheck(Bcsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B CSR struct");
2684*d52a580bSJunchao Zhang   /* precompute flops count */
2685*d52a580bSJunchao Zhang   if (ptype == MATPRODUCT_AB) {
2686*d52a580bSJunchao Zhang     for (i = 0, flops = 0; i < A->rmap->n; i++) {
2687*d52a580bSJunchao Zhang       const PetscInt st = a->i[i];
2688*d52a580bSJunchao Zhang       const PetscInt en = a->i[i + 1];
2689*d52a580bSJunchao Zhang       for (j = st; j < en; j++) {
2690*d52a580bSJunchao Zhang         const PetscInt brow = a->j[j];
2691*d52a580bSJunchao Zhang         flops += 2. * (b->i[brow + 1] - b->i[brow]);
2692*d52a580bSJunchao Zhang       }
2693*d52a580bSJunchao Zhang     }
2694*d52a580bSJunchao Zhang   } else if (ptype == MATPRODUCT_AtB) {
2695*d52a580bSJunchao Zhang     for (i = 0, flops = 0; i < A->rmap->n; i++) {
2696*d52a580bSJunchao Zhang       const PetscInt anzi = a->i[i + 1] - a->i[i];
2697*d52a580bSJunchao Zhang       const PetscInt bnzi = b->i[i + 1] - b->i[i];
2698*d52a580bSJunchao Zhang       flops += (2. * anzi) * bnzi;
2699*d52a580bSJunchao Zhang     }
2700*d52a580bSJunchao Zhang   } else flops = 0.; /* TODO */
2701*d52a580bSJunchao Zhang 
2702*d52a580bSJunchao Zhang   mmdata->flops = flops;
2703*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
2704*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 0, 0)
2705*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_DEVICE));
2706*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCreateCsr(&Cmat->matDescr, Ccsr->num_rows, Ccsr->num_cols, 0, Ccsr->row_offsets->data().get(), NULL, NULL, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
2707*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_createDescr(&mmdata->spgemmDesc));
2708*d52a580bSJunchao Zhang   #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0)
2709*d52a580bSJunchao Zhang   {
2710*d52a580bSJunchao Zhang     /* hipsparseSpGEMMreuse has more reasonable APIs than hipsparseSpGEMM, so we prefer to use it.
2711*d52a580bSJunchao Zhang      We follow the sample code at https://github.com/ROCmSoftwarePlatform/hipSPARSE/blob/develop/clients/include/testing_spgemmreuse_csr.hpp
2712*d52a580bSJunchao Zhang   */
2713*d52a580bSJunchao Zhang     void *dBuffer1 = NULL;
2714*d52a580bSJunchao Zhang     void *dBuffer2 = NULL;
2715*d52a580bSJunchao Zhang     void *dBuffer3 = NULL;
2716*d52a580bSJunchao Zhang     /* dBuffer4, dBuffer5 are needed by hipsparseSpGEMMreuse_compute, and therefore are stored in mmdata */
2717*d52a580bSJunchao Zhang     size_t bufferSize1 = 0;
2718*d52a580bSJunchao Zhang     size_t bufferSize2 = 0;
2719*d52a580bSJunchao Zhang     size_t bufferSize3 = 0;
2720*d52a580bSJunchao Zhang     size_t bufferSize4 = 0;
2721*d52a580bSJunchao Zhang     size_t bufferSize5 = 0;
2722*d52a580bSJunchao Zhang 
2723*d52a580bSJunchao Zhang     /* ask bufferSize1 bytes for external memory */
2724*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpGEMMreuse_workEstimation(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize1, NULL));
2725*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&dBuffer1, bufferSize1));
2726*d52a580bSJunchao Zhang     /* inspect the matrices A and B to understand the memory requirement for the next step */
2727*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpGEMMreuse_workEstimation(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize1, dBuffer1));
2728*d52a580bSJunchao Zhang 
2729*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpGEMMreuse_nnz(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize2, NULL, &bufferSize3, NULL, &bufferSize4, NULL));
2730*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&dBuffer2, bufferSize2));
2731*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&dBuffer3, bufferSize3));
2732*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&mmdata->dBuffer4, bufferSize4));
2733*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpGEMMreuse_nnz(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize2, dBuffer2, &bufferSize3, dBuffer3, &bufferSize4, mmdata->dBuffer4));
2734*d52a580bSJunchao Zhang     PetscCallHIP(hipFree(dBuffer1));
2735*d52a580bSJunchao Zhang     PetscCallHIP(hipFree(dBuffer2));
2736*d52a580bSJunchao Zhang 
2737*d52a580bSJunchao Zhang     /* get matrix C non-zero entries C_nnz1 */
2738*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpMatGetSize(Cmat->matDescr, &C_num_rows1, &C_num_cols1, &C_nnz1));
2739*d52a580bSJunchao Zhang     c->nz = (PetscInt)C_nnz1;
2740*d52a580bSJunchao Zhang     /* allocate matrix C */
2741*d52a580bSJunchao Zhang     Ccsr->column_indices = new THRUSTINTARRAY32(c->nz);
2742*d52a580bSJunchao Zhang     PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2743*d52a580bSJunchao Zhang     Ccsr->values = new THRUSTARRAY(c->nz);
2744*d52a580bSJunchao Zhang     PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2745*d52a580bSJunchao Zhang     /* update matC with the new pointers */
2746*d52a580bSJunchao Zhang     if (c->nz) { /* 5.5.1 has a bug with nz = 0, exposed by mat_tests_ex123_2_hypre */
2747*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCsrSetPointers(Cmat->matDescr, Ccsr->row_offsets->data().get(), Ccsr->column_indices->data().get(), Ccsr->values->data().get()));
2748*d52a580bSJunchao Zhang 
2749*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseSpGEMMreuse_copy(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize5, NULL));
2750*d52a580bSJunchao Zhang       PetscCallHIP(hipMalloc((void **)&mmdata->dBuffer5, bufferSize5));
2751*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseSpGEMMreuse_copy(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize5, mmdata->dBuffer5));
2752*d52a580bSJunchao Zhang       PetscCallHIP(hipFree(dBuffer3));
2753*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseSpGEMMreuse_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc));
2754*d52a580bSJunchao Zhang     }
2755*d52a580bSJunchao Zhang     PetscCall(PetscInfo(C, "Buffer sizes for type %s, result %" PetscInt_FMT " x %" PetscInt_FMT " (k %" PetscInt_FMT ", nzA %" PetscInt_FMT ", nzB %" PetscInt_FMT ", nzC %" PetscInt_FMT ") are: %ldKB %ldKB\n", MatProductTypes[ptype], m, n, k, a->nz, b->nz, c->nz, bufferSize4 / 1024, bufferSize5 / 1024));
2756*d52a580bSJunchao Zhang   }
2757*d52a580bSJunchao Zhang   #else
2758*d52a580bSJunchao Zhang   size_t bufSize2;
2759*d52a580bSJunchao Zhang   /* ask bufferSize bytes for external memory */
2760*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_workEstimation(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufSize2, NULL));
2761*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&mmdata->mmBuffer2, bufSize2));
2762*d52a580bSJunchao Zhang   /* inspect the matrices A and B to understand the memory requirement for the next step */
2763*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_workEstimation(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufSize2, mmdata->mmBuffer2));
2764*d52a580bSJunchao Zhang   /* ask bufferSize again bytes for external memory */
2765*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &mmdata->mmBufferSize, NULL));
2766*d52a580bSJunchao Zhang   /* Similar to CUSPARSE, we need both buffers to perform the operations properly!
2767*d52a580bSJunchao Zhang      mmdata->mmBuffer2 does not appear anywhere in the compute/copy API
2768*d52a580bSJunchao Zhang      it only appears for the workEstimation stuff, but it seems it is needed in compute, so probably the address
2769*d52a580bSJunchao Zhang      is stored in the descriptor! What a messy API... */
2770*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&mmdata->mmBuffer, mmdata->mmBufferSize));
2771*d52a580bSJunchao Zhang   /* compute the intermediate product of A * B */
2772*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &mmdata->mmBufferSize, mmdata->mmBuffer));
2773*d52a580bSJunchao Zhang   /* get matrix C non-zero entries C_nnz1 */
2774*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpMatGetSize(Cmat->matDescr, &C_num_rows1, &C_num_cols1, &C_nnz1));
2775*d52a580bSJunchao Zhang   c->nz = (PetscInt)C_nnz1;
2776*d52a580bSJunchao Zhang   PetscCall(PetscInfo(C, "Buffer sizes for type %s, result %" PetscInt_FMT " x %" PetscInt_FMT " (k %" PetscInt_FMT ", nzA %" PetscInt_FMT ", nzB %" PetscInt_FMT ", nzC %" PetscInt_FMT ") are: %ldKB %ldKB\n", MatProductTypes[ptype], m, n, k, a->nz, b->nz, c->nz, bufSize2 / 1024,
2777*d52a580bSJunchao Zhang                       mmdata->mmBufferSize / 1024));
2778*d52a580bSJunchao Zhang   Ccsr->column_indices = new THRUSTINTARRAY32(c->nz);
2779*d52a580bSJunchao Zhang   PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2780*d52a580bSJunchao Zhang   Ccsr->values = new THRUSTARRAY(c->nz);
2781*d52a580bSJunchao Zhang   PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2782*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseCsrSetPointers(Cmat->matDescr, Ccsr->row_offsets->data().get(), Ccsr->column_indices->data().get(), Ccsr->values->data().get()));
2783*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSpGEMM_copy(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc));
2784*d52a580bSJunchao Zhang   #endif
2785*d52a580bSJunchao Zhang #else
2786*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_HOST));
2787*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseXcsrgemmNnz(Ccusp->handle, opA, opB, Acsr->num_rows, Bcsr->num_cols, Acsr->num_cols, Amat->descr, Acsr->num_entries, Acsr->row_offsets->data().get(), Acsr->column_indices->data().get(), Bmat->descr, Bcsr->num_entries,
2788*d52a580bSJunchao Zhang                                           Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Cmat->descr, Ccsr->row_offsets->data().get(), &cnz));
2789*d52a580bSJunchao Zhang   c->nz                = cnz;
2790*d52a580bSJunchao Zhang   Ccsr->column_indices = new THRUSTINTARRAY32(c->nz);
2791*d52a580bSJunchao Zhang   PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2792*d52a580bSJunchao Zhang   Ccsr->values = new THRUSTARRAY(c->nz);
2793*d52a580bSJunchao Zhang   PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */
2794*d52a580bSJunchao Zhang 
2795*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_DEVICE));
2796*d52a580bSJunchao Zhang   /* with the old gemm interface (removed from 11.0 on) we cannot compute the symbolic factorization only.
2797*d52a580bSJunchao Zhang       I have tried using the gemm2 interface (alpha * A * B + beta * D), which allows to do symbolic by passing NULL for values, but it seems quite buggy when
2798*d52a580bSJunchao Zhang       D is NULL, despite the fact that CUSPARSE documentation claims it is supported! */
2799*d52a580bSJunchao Zhang   PetscCallHIPSPARSE(hipsparse_csr_spgemm(Ccusp->handle, opA, opB, Acsr->num_rows, Bcsr->num_cols, Acsr->num_cols, Amat->descr, Acsr->num_entries, Acsr->values->data().get(), Acsr->row_offsets->data().get(), Acsr->column_indices->data().get(), Bmat->descr,
2800*d52a580bSJunchao Zhang                                           Bcsr->num_entries, Bcsr->values->data().get(), Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Cmat->descr, Ccsr->values->data().get(), Ccsr->row_offsets->data().get(),
2801*d52a580bSJunchao Zhang                                           Ccsr->column_indices->data().get()));
2802*d52a580bSJunchao Zhang #endif
2803*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(mmdata->flops));
2804*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
2805*d52a580bSJunchao Zhang finalizesym:
2806*d52a580bSJunchao Zhang   c->free_a = PETSC_TRUE;
2807*d52a580bSJunchao Zhang   PetscCall(PetscShmgetAllocateArray(c->nz, sizeof(PetscInt), (void **)&c->j));
2808*d52a580bSJunchao Zhang   PetscCall(PetscShmgetAllocateArray(m + 1, sizeof(PetscInt), (void **)&c->i));
2809*d52a580bSJunchao Zhang   c->free_ij = PETSC_TRUE;
2810*d52a580bSJunchao Zhang   if (PetscDefined(USE_64BIT_INDICES)) { /* 32 to 64-bit conversion on the GPU and then copy to host (lazy) */
2811*d52a580bSJunchao Zhang     PetscInt      *d_i = c->i;
2812*d52a580bSJunchao Zhang     THRUSTINTARRAY ii(Ccsr->row_offsets->size());
2813*d52a580bSJunchao Zhang     THRUSTINTARRAY jj(Ccsr->column_indices->size());
2814*d52a580bSJunchao Zhang     ii = *Ccsr->row_offsets;
2815*d52a580bSJunchao Zhang     jj = *Ccsr->column_indices;
2816*d52a580bSJunchao Zhang     if (ciscompressed) d_i = c->compressedrow.i;
2817*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(d_i, ii.data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
2818*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(c->j, jj.data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
2819*d52a580bSJunchao Zhang   } else {
2820*d52a580bSJunchao Zhang     PetscInt *d_i = c->i;
2821*d52a580bSJunchao Zhang     if (ciscompressed) d_i = c->compressedrow.i;
2822*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(d_i, Ccsr->row_offsets->data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
2823*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(c->j, Ccsr->column_indices->data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
2824*d52a580bSJunchao Zhang   }
2825*d52a580bSJunchao Zhang   if (ciscompressed) { /* need to expand host row offsets */
2826*d52a580bSJunchao Zhang     PetscInt r = 0;
2827*d52a580bSJunchao Zhang     c->i[0]    = 0;
2828*d52a580bSJunchao Zhang     for (k = 0; k < c->compressedrow.nrows; k++) {
2829*d52a580bSJunchao Zhang       const PetscInt next = c->compressedrow.rindex[k];
2830*d52a580bSJunchao Zhang       const PetscInt old  = c->compressedrow.i[k];
2831*d52a580bSJunchao Zhang       for (; r < next; r++) c->i[r + 1] = old;
2832*d52a580bSJunchao Zhang     }
2833*d52a580bSJunchao Zhang     for (; r < m; r++) c->i[r + 1] = c->compressedrow.i[c->compressedrow.nrows];
2834*d52a580bSJunchao Zhang   }
2835*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuToCpu((Ccsr->column_indices->size() + Ccsr->row_offsets->size()) * sizeof(PetscInt)));
2836*d52a580bSJunchao Zhang   PetscCall(PetscMalloc1(m, &c->ilen));
2837*d52a580bSJunchao Zhang   PetscCall(PetscMalloc1(m, &c->imax));
2838*d52a580bSJunchao Zhang   c->maxnz         = c->nz;
2839*d52a580bSJunchao Zhang   c->nonzerorowcnt = 0;
2840*d52a580bSJunchao Zhang   c->rmax          = 0;
2841*d52a580bSJunchao Zhang   for (k = 0; k < m; k++) {
2842*d52a580bSJunchao Zhang     const PetscInt nn = c->i[k + 1] - c->i[k];
2843*d52a580bSJunchao Zhang     c->ilen[k] = c->imax[k] = nn;
2844*d52a580bSJunchao Zhang     c->nonzerorowcnt += (PetscInt)!!nn;
2845*d52a580bSJunchao Zhang     c->rmax = PetscMax(c->rmax, nn);
2846*d52a580bSJunchao Zhang   }
2847*d52a580bSJunchao Zhang   PetscCall(PetscMalloc1(c->nz, &c->a));
2848*d52a580bSJunchao Zhang   Ccsr->num_entries = c->nz;
2849*d52a580bSJunchao Zhang 
2850*d52a580bSJunchao Zhang   C->nonzerostate++;
2851*d52a580bSJunchao Zhang   PetscCall(PetscLayoutSetUp(C->rmap));
2852*d52a580bSJunchao Zhang   PetscCall(PetscLayoutSetUp(C->cmap));
2853*d52a580bSJunchao Zhang   Ccusp->nonzerostate = C->nonzerostate;
2854*d52a580bSJunchao Zhang   C->offloadmask      = PETSC_OFFLOAD_UNALLOCATED;
2855*d52a580bSJunchao Zhang   C->preallocated     = PETSC_TRUE;
2856*d52a580bSJunchao Zhang   C->assembled        = PETSC_FALSE;
2857*d52a580bSJunchao Zhang   C->was_assembled    = PETSC_FALSE;
2858*d52a580bSJunchao Zhang   if (product->api_user && A->offloadmask == PETSC_OFFLOAD_BOTH && B->offloadmask == PETSC_OFFLOAD_BOTH) { /* flag the matrix C values as computed, so that the numeric phase will only call MatAssembly */
2859*d52a580bSJunchao Zhang     mmdata->reusesym = PETSC_TRUE;
2860*d52a580bSJunchao Zhang     C->offloadmask   = PETSC_OFFLOAD_GPU;
2861*d52a580bSJunchao Zhang   }
2862*d52a580bSJunchao Zhang   C->ops->productnumeric = MatProductNumeric_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE;
2863*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2864*d52a580bSJunchao Zhang }
2865*d52a580bSJunchao Zhang 
2866*d52a580bSJunchao Zhang /* handles sparse or dense B */
2867*d52a580bSJunchao Zhang static PetscErrorCode MatProductSetFromOptions_SeqAIJHIPSPARSE(Mat mat)
2868*d52a580bSJunchao Zhang {
2869*d52a580bSJunchao Zhang   Mat_Product *product = mat->product;
2870*d52a580bSJunchao Zhang   PetscBool    isdense = PETSC_FALSE, Biscusp = PETSC_FALSE, Ciscusp = PETSC_TRUE;
2871*d52a580bSJunchao Zhang 
2872*d52a580bSJunchao Zhang   PetscFunctionBegin;
2873*d52a580bSJunchao Zhang   MatCheckProduct(mat, 1);
2874*d52a580bSJunchao Zhang   PetscCall(PetscObjectBaseTypeCompare((PetscObject)product->B, MATSEQDENSE, &isdense));
2875*d52a580bSJunchao Zhang   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJHIPSPARSE, &Biscusp));
2876*d52a580bSJunchao Zhang   if (product->type == MATPRODUCT_ABC) {
2877*d52a580bSJunchao Zhang     Ciscusp = PETSC_FALSE;
2878*d52a580bSJunchao Zhang     if (!product->C->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJHIPSPARSE, &Ciscusp));
2879*d52a580bSJunchao Zhang   }
2880*d52a580bSJunchao Zhang   if (Biscusp && Ciscusp) { /* we can always select the CPU backend */
2881*d52a580bSJunchao Zhang     PetscBool usecpu = PETSC_FALSE;
2882*d52a580bSJunchao Zhang     switch (product->type) {
2883*d52a580bSJunchao Zhang     case MATPRODUCT_AB:
2884*d52a580bSJunchao Zhang       if (product->api_user) {
2885*d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
2886*d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
2887*d52a580bSJunchao Zhang         PetscOptionsEnd();
2888*d52a580bSJunchao Zhang       } else {
2889*d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
2890*d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
2891*d52a580bSJunchao Zhang         PetscOptionsEnd();
2892*d52a580bSJunchao Zhang       }
2893*d52a580bSJunchao Zhang       break;
2894*d52a580bSJunchao Zhang     case MATPRODUCT_AtB:
2895*d52a580bSJunchao Zhang       if (product->api_user) {
2896*d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
2897*d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
2898*d52a580bSJunchao Zhang         PetscOptionsEnd();
2899*d52a580bSJunchao Zhang       } else {
2900*d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
2901*d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
2902*d52a580bSJunchao Zhang         PetscOptionsEnd();
2903*d52a580bSJunchao Zhang       }
2904*d52a580bSJunchao Zhang       break;
2905*d52a580bSJunchao Zhang     case MATPRODUCT_PtAP:
2906*d52a580bSJunchao Zhang       if (product->api_user) {
2907*d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
2908*d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
2909*d52a580bSJunchao Zhang         PetscOptionsEnd();
2910*d52a580bSJunchao Zhang       } else {
2911*d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
2912*d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
2913*d52a580bSJunchao Zhang         PetscOptionsEnd();
2914*d52a580bSJunchao Zhang       }
2915*d52a580bSJunchao Zhang       break;
2916*d52a580bSJunchao Zhang     case MATPRODUCT_RARt:
2917*d52a580bSJunchao Zhang       if (product->api_user) {
2918*d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatRARt", "Mat");
2919*d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-matrart_backend_cpu", "Use CPU code", "MatRARt", usecpu, &usecpu, NULL));
2920*d52a580bSJunchao Zhang         PetscOptionsEnd();
2921*d52a580bSJunchao Zhang       } else {
2922*d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_RARt", "Mat");
2923*d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatRARt", usecpu, &usecpu, NULL));
2924*d52a580bSJunchao Zhang         PetscOptionsEnd();
2925*d52a580bSJunchao Zhang       }
2926*d52a580bSJunchao Zhang       break;
2927*d52a580bSJunchao Zhang     case MATPRODUCT_ABC:
2928*d52a580bSJunchao Zhang       if (product->api_user) {
2929*d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMatMult", "Mat");
2930*d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-matmatmatmult_backend_cpu", "Use CPU code", "MatMatMatMult", usecpu, &usecpu, NULL));
2931*d52a580bSJunchao Zhang         PetscOptionsEnd();
2932*d52a580bSJunchao Zhang       } else {
2933*d52a580bSJunchao Zhang         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_ABC", "Mat");
2934*d52a580bSJunchao Zhang         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMatMult", usecpu, &usecpu, NULL));
2935*d52a580bSJunchao Zhang         PetscOptionsEnd();
2936*d52a580bSJunchao Zhang       }
2937*d52a580bSJunchao Zhang       break;
2938*d52a580bSJunchao Zhang     default:
2939*d52a580bSJunchao Zhang       break;
2940*d52a580bSJunchao Zhang     }
2941*d52a580bSJunchao Zhang     if (usecpu) Biscusp = Ciscusp = PETSC_FALSE;
2942*d52a580bSJunchao Zhang   }
2943*d52a580bSJunchao Zhang   /* dispatch */
2944*d52a580bSJunchao Zhang   if (isdense) {
2945*d52a580bSJunchao Zhang     switch (product->type) {
2946*d52a580bSJunchao Zhang     case MATPRODUCT_AB:
2947*d52a580bSJunchao Zhang     case MATPRODUCT_AtB:
2948*d52a580bSJunchao Zhang     case MATPRODUCT_ABt:
2949*d52a580bSJunchao Zhang     case MATPRODUCT_PtAP:
2950*d52a580bSJunchao Zhang     case MATPRODUCT_RARt:
2951*d52a580bSJunchao Zhang       if (product->A->boundtocpu) PetscCall(MatProductSetFromOptions_SeqAIJ_SeqDense(mat));
2952*d52a580bSJunchao Zhang       else mat->ops->productsymbolic = MatProductSymbolic_SeqAIJHIPSPARSE_SeqDENSEHIP;
2953*d52a580bSJunchao Zhang       break;
2954*d52a580bSJunchao Zhang     case MATPRODUCT_ABC:
2955*d52a580bSJunchao Zhang       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
2956*d52a580bSJunchao Zhang       break;
2957*d52a580bSJunchao Zhang     default:
2958*d52a580bSJunchao Zhang       break;
2959*d52a580bSJunchao Zhang     }
2960*d52a580bSJunchao Zhang   } else if (Biscusp && Ciscusp) {
2961*d52a580bSJunchao Zhang     switch (product->type) {
2962*d52a580bSJunchao Zhang     case MATPRODUCT_AB:
2963*d52a580bSJunchao Zhang     case MATPRODUCT_AtB:
2964*d52a580bSJunchao Zhang     case MATPRODUCT_ABt:
2965*d52a580bSJunchao Zhang       mat->ops->productsymbolic = MatProductSymbolic_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE;
2966*d52a580bSJunchao Zhang       break;
2967*d52a580bSJunchao Zhang     case MATPRODUCT_PtAP:
2968*d52a580bSJunchao Zhang     case MATPRODUCT_RARt:
2969*d52a580bSJunchao Zhang     case MATPRODUCT_ABC:
2970*d52a580bSJunchao Zhang       mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic;
2971*d52a580bSJunchao Zhang       break;
2972*d52a580bSJunchao Zhang     default:
2973*d52a580bSJunchao Zhang       break;
2974*d52a580bSJunchao Zhang     }
2975*d52a580bSJunchao Zhang   } else PetscCall(MatProductSetFromOptions_SeqAIJ(mat)); /* fallback for AIJ */
2976*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2977*d52a580bSJunchao Zhang }
2978*d52a580bSJunchao Zhang 
2979*d52a580bSJunchao Zhang static PetscErrorCode MatMult_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy)
2980*d52a580bSJunchao Zhang {
2981*d52a580bSJunchao Zhang   PetscFunctionBegin;
2982*d52a580bSJunchao Zhang   PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, NULL, yy, PETSC_FALSE, PETSC_FALSE));
2983*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2984*d52a580bSJunchao Zhang }
2985*d52a580bSJunchao Zhang 
2986*d52a580bSJunchao Zhang static PetscErrorCode MatMultAdd_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz)
2987*d52a580bSJunchao Zhang {
2988*d52a580bSJunchao Zhang   PetscFunctionBegin;
2989*d52a580bSJunchao Zhang   PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, yy, zz, PETSC_FALSE, PETSC_FALSE));
2990*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2991*d52a580bSJunchao Zhang }
2992*d52a580bSJunchao Zhang 
2993*d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTranspose_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy)
2994*d52a580bSJunchao Zhang {
2995*d52a580bSJunchao Zhang   PetscFunctionBegin;
2996*d52a580bSJunchao Zhang   PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, NULL, yy, PETSC_TRUE, PETSC_TRUE));
2997*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
2998*d52a580bSJunchao Zhang }
2999*d52a580bSJunchao Zhang 
3000*d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz)
3001*d52a580bSJunchao Zhang {
3002*d52a580bSJunchao Zhang   PetscFunctionBegin;
3003*d52a580bSJunchao Zhang   PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, yy, zz, PETSC_TRUE, PETSC_TRUE));
3004*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3005*d52a580bSJunchao Zhang }
3006*d52a580bSJunchao Zhang 
3007*d52a580bSJunchao Zhang static PetscErrorCode MatMultTranspose_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy)
3008*d52a580bSJunchao Zhang {
3009*d52a580bSJunchao Zhang   PetscFunctionBegin;
3010*d52a580bSJunchao Zhang   PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, NULL, yy, PETSC_TRUE, PETSC_FALSE));
3011*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3012*d52a580bSJunchao Zhang }
3013*d52a580bSJunchao Zhang 
3014*d52a580bSJunchao Zhang __global__ static void ScatterAdd(PetscInt n, PetscInt *idx, const PetscScalar *x, PetscScalar *y)
3015*d52a580bSJunchao Zhang {
3016*d52a580bSJunchao Zhang   int i = blockIdx.x * blockDim.x + threadIdx.x;
3017*d52a580bSJunchao Zhang   if (i < n) y[idx[i]] += x[i];
3018*d52a580bSJunchao Zhang }
3019*d52a580bSJunchao Zhang 
3020*d52a580bSJunchao Zhang /* z = op(A) x + y. If trans & !herm, op = ^T; if trans & herm, op = ^H; if !trans, op = no-op */
3021*d52a580bSJunchao Zhang static PetscErrorCode MatMultAddKernel_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz, PetscBool trans, PetscBool herm)
3022*d52a580bSJunchao Zhang {
3023*d52a580bSJunchao Zhang   Mat_SeqAIJ                    *a               = (Mat_SeqAIJ *)A->data;
3024*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3025*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSEMultStruct *matstruct;
3026*d52a580bSJunchao Zhang   PetscScalar                   *xarray, *zarray, *dptr, *beta, *xptr;
3027*d52a580bSJunchao Zhang   hipsparseOperation_t           opA = HIPSPARSE_OPERATION_NON_TRANSPOSE;
3028*d52a580bSJunchao Zhang   PetscBool                      compressed;
3029*d52a580bSJunchao Zhang   PetscInt                       nx, ny;
3030*d52a580bSJunchao Zhang 
3031*d52a580bSJunchao Zhang   PetscFunctionBegin;
3032*d52a580bSJunchao Zhang   PetscCheck(!herm || trans, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Hermitian and not transpose not supported");
3033*d52a580bSJunchao Zhang   if (!a->nz) {
3034*d52a580bSJunchao Zhang     if (yy) PetscCall(VecSeq_HIP::Copy(yy, zz));
3035*d52a580bSJunchao Zhang     else PetscCall(VecSeq_HIP::Set(zz, 0));
3036*d52a580bSJunchao Zhang     PetscFunctionReturn(PETSC_SUCCESS);
3037*d52a580bSJunchao Zhang   }
3038*d52a580bSJunchao Zhang   /* The line below is necessary due to the operations that modify the matrix on the CPU (axpy, scale, etc) */
3039*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
3040*d52a580bSJunchao Zhang   if (!trans) {
3041*d52a580bSJunchao Zhang     matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->mat;
3042*d52a580bSJunchao Zhang     PetscCheck(matstruct, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "SeqAIJHIPSPARSE does not have a 'mat' (need to fix)");
3043*d52a580bSJunchao Zhang   } else {
3044*d52a580bSJunchao Zhang     if (herm || !A->form_explicit_transpose) {
3045*d52a580bSJunchao Zhang       opA       = herm ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE : HIPSPARSE_OPERATION_TRANSPOSE;
3046*d52a580bSJunchao Zhang       matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->mat;
3047*d52a580bSJunchao Zhang     } else {
3048*d52a580bSJunchao Zhang       if (!hipsparsestruct->matTranspose) PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A));
3049*d52a580bSJunchao Zhang       matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->matTranspose;
3050*d52a580bSJunchao Zhang     }
3051*d52a580bSJunchao Zhang   }
3052*d52a580bSJunchao Zhang   /* Does the matrix use compressed rows (i.e., drop zero rows)? */
3053*d52a580bSJunchao Zhang   compressed = matstruct->cprowIndices ? PETSC_TRUE : PETSC_FALSE;
3054*d52a580bSJunchao Zhang   try {
3055*d52a580bSJunchao Zhang     PetscCall(VecHIPGetArrayRead(xx, (const PetscScalar **)&xarray));
3056*d52a580bSJunchao Zhang     if (yy == zz) PetscCall(VecHIPGetArray(zz, &zarray)); /* read & write zz, so need to get up-to-date zarray on GPU */
3057*d52a580bSJunchao Zhang     else PetscCall(VecHIPGetArrayWrite(zz, &zarray));     /* write zz, so no need to init zarray on GPU */
3058*d52a580bSJunchao Zhang 
3059*d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
3060*d52a580bSJunchao Zhang     if (opA == HIPSPARSE_OPERATION_NON_TRANSPOSE) {
3061*d52a580bSJunchao Zhang       /* z = A x + beta y.
3062*d52a580bSJunchao Zhang          If A is compressed (with less rows), then Ax is shorter than the full z, so we need a work vector to store Ax.
3063*d52a580bSJunchao Zhang          When A is non-compressed, and z = y, we can set beta=1 to compute y = Ax + y in one call.
3064*d52a580bSJunchao Zhang       */
3065*d52a580bSJunchao Zhang       xptr = xarray;
3066*d52a580bSJunchao Zhang       dptr = compressed ? hipsparsestruct->workVector->data().get() : zarray;
3067*d52a580bSJunchao Zhang       beta = (yy == zz && !compressed) ? matstruct->beta_one : matstruct->beta_zero;
3068*d52a580bSJunchao Zhang       /* Get length of x, y for y=Ax. ny might be shorter than the work vector's allocated length, since the work vector is
3069*d52a580bSJunchao Zhang           allocated to accommodate different uses. So we get the length info directly from mat.
3070*d52a580bSJunchao Zhang        */
3071*d52a580bSJunchao Zhang       if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) {
3072*d52a580bSJunchao Zhang         CsrMatrix *mat = (CsrMatrix *)matstruct->mat;
3073*d52a580bSJunchao Zhang         nx             = mat->num_cols;
3074*d52a580bSJunchao Zhang         ny             = mat->num_rows;
3075*d52a580bSJunchao Zhang       }
3076*d52a580bSJunchao Zhang     } else {
3077*d52a580bSJunchao Zhang       /* z = A^T x + beta y
3078*d52a580bSJunchao Zhang          If A is compressed, then we need a work vector as the shorter version of x to compute A^T x.
3079*d52a580bSJunchao Zhang          Note A^Tx is of full length, so we set beta to 1.0 if y exists.
3080*d52a580bSJunchao Zhang        */
3081*d52a580bSJunchao Zhang       xptr = compressed ? hipsparsestruct->workVector->data().get() : xarray;
3082*d52a580bSJunchao Zhang       dptr = zarray;
3083*d52a580bSJunchao Zhang       beta = yy ? matstruct->beta_one : matstruct->beta_zero;
3084*d52a580bSJunchao Zhang       if (compressed) { /* Scatter x to work vector */
3085*d52a580bSJunchao Zhang         thrust::device_ptr<PetscScalar> xarr = thrust::device_pointer_cast(xarray);
3086*d52a580bSJunchao Zhang         thrust::for_each(
3087*d52a580bSJunchao Zhang #if PetscDefined(HAVE_THRUST_ASYNC)
3088*d52a580bSJunchao Zhang           thrust::hip::par.on(PetscDefaultHipStream),
3089*d52a580bSJunchao Zhang #endif
3090*d52a580bSJunchao Zhang           thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(xarr, matstruct->cprowIndices->begin()))),
3091*d52a580bSJunchao Zhang           thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(xarr, matstruct->cprowIndices->begin()))) + matstruct->cprowIndices->size(), VecHIPEqualsReverse());
3092*d52a580bSJunchao Zhang       }
3093*d52a580bSJunchao Zhang       if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) {
3094*d52a580bSJunchao Zhang         CsrMatrix *mat = (CsrMatrix *)matstruct->mat;
3095*d52a580bSJunchao Zhang         nx             = mat->num_rows;
3096*d52a580bSJunchao Zhang         ny             = mat->num_cols;
3097*d52a580bSJunchao Zhang       }
3098*d52a580bSJunchao Zhang     }
3099*d52a580bSJunchao Zhang     /* csr_spmv does y = alpha op(A) x + beta y */
3100*d52a580bSJunchao Zhang     if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) {
3101*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0)
3102*d52a580bSJunchao Zhang       PetscCheck(opA >= 0 && opA <= 2, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE API on hipsparseOperation_t has changed and PETSc has not been updated accordingly");
3103*d52a580bSJunchao Zhang       if (!matstruct->hipSpMV[opA].initialized) { /* built on demand */
3104*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateDnVec(&matstruct->hipSpMV[opA].vecXDescr, nx, xptr, hipsparse_scalartype));
3105*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateDnVec(&matstruct->hipSpMV[opA].vecYDescr, ny, dptr, hipsparse_scalartype));
3106*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSpMV_bufferSize(hipsparsestruct->handle, opA, matstruct->alpha_one, matstruct->matDescr, matstruct->hipSpMV[opA].vecXDescr, beta, matstruct->hipSpMV[opA].vecYDescr, hipsparse_scalartype, hipsparsestruct->spmvAlg,
3107*d52a580bSJunchao Zhang                                                     &matstruct->hipSpMV[opA].spmvBufferSize));
3108*d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc(&matstruct->hipSpMV[opA].spmvBuffer, matstruct->hipSpMV[opA].spmvBufferSize));
3109*d52a580bSJunchao Zhang         matstruct->hipSpMV[opA].initialized = PETSC_TRUE;
3110*d52a580bSJunchao Zhang       } else {
3111*d52a580bSJunchao Zhang         /* x, y's value pointers might change between calls, but their shape is kept, so we just update pointers */
3112*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseDnVecSetValues(matstruct->hipSpMV[opA].vecXDescr, xptr));
3113*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseDnVecSetValues(matstruct->hipSpMV[opA].vecYDescr, dptr));
3114*d52a580bSJunchao Zhang       }
3115*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseSpMV(hipsparsestruct->handle, opA, matstruct->alpha_one, matstruct->matDescr, /* built in MatSeqAIJHIPSPARSECopyToGPU() or MatSeqAIJHIPSPARSEFormExplicitTranspose() */
3116*d52a580bSJunchao Zhang                                        matstruct->hipSpMV[opA].vecXDescr, beta, matstruct->hipSpMV[opA].vecYDescr, hipsparse_scalartype, hipsparsestruct->spmvAlg, matstruct->hipSpMV[opA].spmvBuffer));
3117*d52a580bSJunchao Zhang #else
3118*d52a580bSJunchao Zhang       CsrMatrix *mat = (CsrMatrix *)matstruct->mat;
3119*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparse_csr_spmv(hipsparsestruct->handle, opA, mat->num_rows, mat->num_cols, mat->num_entries, matstruct->alpha_one, matstruct->descr, mat->values->data().get(), mat->row_offsets->data().get(), mat->column_indices->data().get(), xptr, beta, dptr));
3120*d52a580bSJunchao Zhang #endif
3121*d52a580bSJunchao Zhang     } else {
3122*d52a580bSJunchao Zhang       if (hipsparsestruct->nrows) {
3123*d52a580bSJunchao Zhang         hipsparseHybMat_t hybMat = (hipsparseHybMat_t)matstruct->mat;
3124*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparse_hyb_spmv(hipsparsestruct->handle, opA, matstruct->alpha_one, matstruct->descr, hybMat, xptr, beta, dptr));
3125*d52a580bSJunchao Zhang       }
3126*d52a580bSJunchao Zhang     }
3127*d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
3128*d52a580bSJunchao Zhang 
3129*d52a580bSJunchao Zhang     if (opA == HIPSPARSE_OPERATION_NON_TRANSPOSE) {
3130*d52a580bSJunchao Zhang       if (yy) {                                     /* MatMultAdd: zz = A*xx + yy */
3131*d52a580bSJunchao Zhang         if (compressed) {                           /* A is compressed. We first copy yy to zz, then ScatterAdd the work vector to zz */
3132*d52a580bSJunchao Zhang           PetscCall(VecSeq_HIP::Copy(yy, zz));      /* zz = yy */
3133*d52a580bSJunchao Zhang         } else if (zz != yy) {                      /* A is not compressed. zz already contains A*xx, and we just need to add yy */
3134*d52a580bSJunchao Zhang           PetscCall(VecSeq_HIP::AXPY(zz, 1.0, yy)); /* zz += yy */
3135*d52a580bSJunchao Zhang         }
3136*d52a580bSJunchao Zhang       } else if (compressed) { /* MatMult: zz = A*xx. A is compressed, so we zero zz first, then ScatterAdd the work vector to zz */
3137*d52a580bSJunchao Zhang         PetscCall(VecSeq_HIP::Set(zz, 0));
3138*d52a580bSJunchao Zhang       }
3139*d52a580bSJunchao Zhang 
3140*d52a580bSJunchao Zhang       /* ScatterAdd the result from work vector into the full vector when A is compressed */
3141*d52a580bSJunchao Zhang       if (compressed) {
3142*d52a580bSJunchao Zhang         PetscCall(PetscLogGpuTimeBegin());
3143*d52a580bSJunchao Zhang         /* I wanted to make this for_each asynchronous but failed. thrust::async::for_each() returns an event (internally registered)
3144*d52a580bSJunchao Zhang            and in the destructor of the scope, it will call hipStreamSynchronize() on this stream. One has to store all events to
3145*d52a580bSJunchao Zhang            prevent that. So I just add a ScatterAdd kernel.
3146*d52a580bSJunchao Zhang          */
3147*d52a580bSJunchao Zhang #if 0
3148*d52a580bSJunchao Zhang         thrust::device_ptr<PetscScalar> zptr = thrust::device_pointer_cast(zarray);
3149*d52a580bSJunchao Zhang         thrust::async::for_each(thrust::hip::par.on(hipsparsestruct->stream),
3150*d52a580bSJunchao Zhang                          thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(zptr, matstruct->cprowIndices->begin()))),
3151*d52a580bSJunchao Zhang                          thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(zptr, matstruct->cprowIndices->begin()))) + matstruct->cprowIndices->size(),
3152*d52a580bSJunchao Zhang                          VecHIPPlusEquals());
3153*d52a580bSJunchao Zhang #else
3154*d52a580bSJunchao Zhang         PetscInt n = matstruct->cprowIndices->size();
3155*d52a580bSJunchao Zhang         hipLaunchKernelGGL(ScatterAdd, dim3((n + 255) / 256), dim3(256), 0, PetscDefaultHipStream, n, matstruct->cprowIndices->data().get(), hipsparsestruct->workVector->data().get(), zarray);
3156*d52a580bSJunchao Zhang #endif
3157*d52a580bSJunchao Zhang         PetscCall(PetscLogGpuTimeEnd());
3158*d52a580bSJunchao Zhang       }
3159*d52a580bSJunchao Zhang     } else {
3160*d52a580bSJunchao Zhang       if (yy && yy != zz) PetscCall(VecSeq_HIP::AXPY(zz, 1.0, yy)); /* zz += yy */
3161*d52a580bSJunchao Zhang     }
3162*d52a580bSJunchao Zhang     PetscCall(VecHIPRestoreArrayRead(xx, (const PetscScalar **)&xarray));
3163*d52a580bSJunchao Zhang     if (yy == zz) PetscCall(VecHIPRestoreArray(zz, &zarray));
3164*d52a580bSJunchao Zhang     else PetscCall(VecHIPRestoreArrayWrite(zz, &zarray));
3165*d52a580bSJunchao Zhang   } catch (char *ex) {
3166*d52a580bSJunchao Zhang     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex);
3167*d52a580bSJunchao Zhang   }
3168*d52a580bSJunchao Zhang   if (yy) PetscCall(PetscLogGpuFlops(2.0 * a->nz));
3169*d52a580bSJunchao Zhang   else PetscCall(PetscLogGpuFlops(2.0 * a->nz - a->nonzerorowcnt));
3170*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3171*d52a580bSJunchao Zhang }
3172*d52a580bSJunchao Zhang 
3173*d52a580bSJunchao Zhang static PetscErrorCode MatMultTransposeAdd_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz)
3174*d52a580bSJunchao Zhang {
3175*d52a580bSJunchao Zhang   PetscFunctionBegin;
3176*d52a580bSJunchao Zhang   PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, yy, zz, PETSC_TRUE, PETSC_FALSE));
3177*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3178*d52a580bSJunchao Zhang }
3179*d52a580bSJunchao Zhang 
3180*d52a580bSJunchao Zhang static PetscErrorCode MatAssemblyEnd_SeqAIJHIPSPARSE(Mat A, MatAssemblyType mode)
3181*d52a580bSJunchao Zhang {
3182*d52a580bSJunchao Zhang   PetscFunctionBegin;
3183*d52a580bSJunchao Zhang   PetscCall(MatAssemblyEnd_SeqAIJ(A, mode));
3184*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3185*d52a580bSJunchao Zhang }
3186*d52a580bSJunchao Zhang 
3187*d52a580bSJunchao Zhang /*@
3188*d52a580bSJunchao Zhang   MatCreateSeqAIJHIPSPARSE - Creates a sparse matrix in `MATAIJHIPSPARSE` (compressed row) format.
3189*d52a580bSJunchao Zhang   This matrix will ultimately pushed down to AMD GPUs and use the HIPSPARSE library for calculations.
3190*d52a580bSJunchao Zhang 
3191*d52a580bSJunchao Zhang   Collective
3192*d52a580bSJunchao Zhang 
3193*d52a580bSJunchao Zhang   Input Parameters:
3194*d52a580bSJunchao Zhang + comm - MPI communicator, set to `PETSC_COMM_SELF`
3195*d52a580bSJunchao Zhang . m    - number of rows
3196*d52a580bSJunchao Zhang . n    - number of columns
3197*d52a580bSJunchao Zhang . nz   - number of nonzeros per row (same for all rows), ignored if `nnz` is set
3198*d52a580bSJunchao Zhang - nnz  - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL`
3199*d52a580bSJunchao Zhang 
3200*d52a580bSJunchao Zhang   Output Parameter:
3201*d52a580bSJunchao Zhang . A - the matrix
3202*d52a580bSJunchao Zhang 
3203*d52a580bSJunchao Zhang   Level: intermediate
3204*d52a580bSJunchao Zhang 
3205*d52a580bSJunchao Zhang   Notes:
3206*d52a580bSJunchao Zhang   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
3207*d52a580bSJunchao Zhang   `MatXXXXSetPreallocation()` paradgm instead of this routine directly.
3208*d52a580bSJunchao Zhang   [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation`]
3209*d52a580bSJunchao Zhang 
3210*d52a580bSJunchao Zhang   The AIJ format (compressed row storage), is fully compatible with standard Fortran
3211*d52a580bSJunchao Zhang   storage.  That is, the stored row and column indices can begin at
3212*d52a580bSJunchao Zhang   either one (as in Fortran) or zero.
3213*d52a580bSJunchao Zhang 
3214*d52a580bSJunchao Zhang   Specify the preallocated storage with either `nz` or `nnz` (not both).
3215*d52a580bSJunchao Zhang   Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
3216*d52a580bSJunchao Zhang   allocation.
3217*d52a580bSJunchao Zhang 
3218*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATSEQAIJHIPSPARSE`, `MATAIJHIPSPARSE`
3219*d52a580bSJunchao Zhang @*/
3220*d52a580bSJunchao Zhang PetscErrorCode MatCreateSeqAIJHIPSPARSE(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A)
3221*d52a580bSJunchao Zhang {
3222*d52a580bSJunchao Zhang   PetscFunctionBegin;
3223*d52a580bSJunchao Zhang   PetscCall(MatCreate(comm, A));
3224*d52a580bSJunchao Zhang   PetscCall(MatSetSizes(*A, m, n, m, n));
3225*d52a580bSJunchao Zhang   PetscCall(MatSetType(*A, MATSEQAIJHIPSPARSE));
3226*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz));
3227*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3228*d52a580bSJunchao Zhang }
3229*d52a580bSJunchao Zhang 
3230*d52a580bSJunchao Zhang static PetscErrorCode MatDestroy_SeqAIJHIPSPARSE(Mat A)
3231*d52a580bSJunchao Zhang {
3232*d52a580bSJunchao Zhang   PetscFunctionBegin;
3233*d52a580bSJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) PetscCall(MatSeqAIJHIPSPARSE_Destroy(A));
3234*d52a580bSJunchao Zhang   else PetscCall(MatSeqAIJHIPSPARSETriFactors_Destroy((Mat_SeqAIJHIPSPARSETriFactors **)&A->spptr));
3235*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSeqAIJCopySubArray_C", NULL));
3236*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatHIPSPARSESetFormat_C", NULL));
3237*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatHIPSPARSESetUseCPUSolve_C", NULL));
3238*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdensehip_C", NULL));
3239*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdense_C", NULL));
3240*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqaijhipsparse_C", NULL));
3241*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL));
3242*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
3243*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
3244*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijhipsparse_hypre_C", NULL));
3245*d52a580bSJunchao Zhang   PetscCall(MatDestroy_SeqAIJ(A));
3246*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3247*d52a580bSJunchao Zhang }
3248*d52a580bSJunchao Zhang 
3249*d52a580bSJunchao Zhang static PetscErrorCode MatDuplicate_SeqAIJHIPSPARSE(Mat A, MatDuplicateOption cpvalues, Mat *B)
3250*d52a580bSJunchao Zhang {
3251*d52a580bSJunchao Zhang   PetscFunctionBegin;
3252*d52a580bSJunchao Zhang   PetscCall(MatDuplicate_SeqAIJ(A, cpvalues, B));
3253*d52a580bSJunchao Zhang   PetscCall(MatConvert_SeqAIJ_SeqAIJHIPSPARSE(*B, MATSEQAIJHIPSPARSE, MAT_INPLACE_MATRIX, B));
3254*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3255*d52a580bSJunchao Zhang }
3256*d52a580bSJunchao Zhang 
3257*d52a580bSJunchao Zhang static PetscErrorCode MatAXPY_SeqAIJHIPSPARSE(Mat Y, PetscScalar a, Mat X, MatStructure str)
3258*d52a580bSJunchao Zhang {
3259*d52a580bSJunchao Zhang   Mat_SeqAIJ          *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data;
3260*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cy;
3261*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cx;
3262*d52a580bSJunchao Zhang   PetscScalar         *ay;
3263*d52a580bSJunchao Zhang   const PetscScalar   *ax;
3264*d52a580bSJunchao Zhang   CsrMatrix           *csry, *csrx;
3265*d52a580bSJunchao Zhang 
3266*d52a580bSJunchao Zhang   PetscFunctionBegin;
3267*d52a580bSJunchao Zhang   cy = (Mat_SeqAIJHIPSPARSE *)Y->spptr;
3268*d52a580bSJunchao Zhang   cx = (Mat_SeqAIJHIPSPARSE *)X->spptr;
3269*d52a580bSJunchao Zhang   if (X->ops->axpy != Y->ops->axpy) {
3270*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(Y, PETSC_FALSE));
3271*d52a580bSJunchao Zhang     PetscCall(MatAXPY_SeqAIJ(Y, a, X, str));
3272*d52a580bSJunchao Zhang     PetscFunctionReturn(PETSC_SUCCESS);
3273*d52a580bSJunchao Zhang   }
3274*d52a580bSJunchao Zhang   /* if we are here, it means both matrices are bound to GPU */
3275*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(Y));
3276*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(X));
3277*d52a580bSJunchao Zhang   PetscCheck(cy->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)Y), PETSC_ERR_GPU, "only MAT_HIPSPARSE_CSR supported");
3278*d52a580bSJunchao Zhang   PetscCheck(cx->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)X), PETSC_ERR_GPU, "only MAT_HIPSPARSE_CSR supported");
3279*d52a580bSJunchao Zhang   csry = (CsrMatrix *)cy->mat->mat;
3280*d52a580bSJunchao Zhang   csrx = (CsrMatrix *)cx->mat->mat;
3281*d52a580bSJunchao Zhang   /* see if we can turn this into a hipblas axpy */
3282*d52a580bSJunchao Zhang   if (str != SAME_NONZERO_PATTERN && x->nz == y->nz && !x->compressedrow.use && !y->compressedrow.use) {
3283*d52a580bSJunchao Zhang     bool eq = thrust::equal(thrust::device, csry->row_offsets->begin(), csry->row_offsets->end(), csrx->row_offsets->begin());
3284*d52a580bSJunchao Zhang     if (eq) eq = thrust::equal(thrust::device, csry->column_indices->begin(), csry->column_indices->end(), csrx->column_indices->begin());
3285*d52a580bSJunchao Zhang     if (eq) str = SAME_NONZERO_PATTERN;
3286*d52a580bSJunchao Zhang   }
3287*d52a580bSJunchao Zhang   /* spgeam is buggy with one column */
3288*d52a580bSJunchao Zhang   if (Y->cmap->n == 1 && str != SAME_NONZERO_PATTERN) str = DIFFERENT_NONZERO_PATTERN;
3289*d52a580bSJunchao Zhang   if (str == SUBSET_NONZERO_PATTERN) {
3290*d52a580bSJunchao Zhang     PetscScalar b = 1.0;
3291*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
3292*d52a580bSJunchao Zhang     size_t bufferSize;
3293*d52a580bSJunchao Zhang     void  *buffer;
3294*d52a580bSJunchao Zhang #endif
3295*d52a580bSJunchao Zhang 
3296*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEGetArrayRead(X, &ax));
3297*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEGetArray(Y, &ay));
3298*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSetPointerMode(cy->handle, HIPSPARSE_POINTER_MODE_HOST));
3299*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
3300*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparse_csr_spgeam_bufferSize(cy->handle, Y->rmap->n, Y->cmap->n, &a, cx->mat->descr, x->nz, ax, csrx->row_offsets->data().get(), csrx->column_indices->data().get(), &b, cy->mat->descr, y->nz, ay, csry->row_offsets->data().get(),
3301*d52a580bSJunchao Zhang                                                        csry->column_indices->data().get(), cy->mat->descr, ay, csry->row_offsets->data().get(), csry->column_indices->data().get(), &bufferSize));
3302*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc(&buffer, bufferSize));
3303*d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
3304*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparse_csr_spgeam(cy->handle, Y->rmap->n, Y->cmap->n, &a, cx->mat->descr, x->nz, ax, csrx->row_offsets->data().get(), csrx->column_indices->data().get(), &b, cy->mat->descr, y->nz, ay, csry->row_offsets->data().get(),
3305*d52a580bSJunchao Zhang                                             csry->column_indices->data().get(), cy->mat->descr, ay, csry->row_offsets->data().get(), csry->column_indices->data().get(), buffer));
3306*d52a580bSJunchao Zhang     PetscCall(PetscLogGpuFlops(x->nz + y->nz));
3307*d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
3308*d52a580bSJunchao Zhang     PetscCallHIP(hipFree(buffer));
3309*d52a580bSJunchao Zhang #else
3310*d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
3311*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparse_csr_spgeam(cy->handle, Y->rmap->n, Y->cmap->n, &a, cx->mat->descr, x->nz, ax, csrx->row_offsets->data().get(), csrx->column_indices->data().get(), &b, cy->mat->descr, y->nz, ay, csry->row_offsets->data().get(),
3312*d52a580bSJunchao Zhang                                             csry->column_indices->data().get(), cy->mat->descr, ay, csry->row_offsets->data().get(), csry->column_indices->data().get()));
3313*d52a580bSJunchao Zhang     PetscCall(PetscLogGpuFlops(x->nz + y->nz));
3314*d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
3315*d52a580bSJunchao Zhang #endif
3316*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSetPointerMode(cy->handle, HIPSPARSE_POINTER_MODE_DEVICE));
3317*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSERestoreArrayRead(X, &ax));
3318*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSERestoreArray(Y, &ay));
3319*d52a580bSJunchao Zhang   } else if (str == SAME_NONZERO_PATTERN) {
3320*d52a580bSJunchao Zhang     hipblasHandle_t hipblasv2handle;
3321*d52a580bSJunchao Zhang     PetscBLASInt    one = 1, bnz = 1;
3322*d52a580bSJunchao Zhang 
3323*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEGetArrayRead(X, &ax));
3324*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEGetArray(Y, &ay));
3325*d52a580bSJunchao Zhang     PetscCall(PetscHIPBLASGetHandle(&hipblasv2handle));
3326*d52a580bSJunchao Zhang     PetscCall(PetscBLASIntCast(x->nz, &bnz));
3327*d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeBegin());
3328*d52a580bSJunchao Zhang     PetscCallHIPBLAS(hipblasXaxpy(hipblasv2handle, bnz, &a, ax, one, ay, one));
3329*d52a580bSJunchao Zhang     PetscCall(PetscLogGpuFlops(2.0 * bnz));
3330*d52a580bSJunchao Zhang     PetscCall(PetscLogGpuTimeEnd());
3331*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSERestoreArrayRead(X, &ax));
3332*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSERestoreArray(Y, &ay));
3333*d52a580bSJunchao Zhang   } else {
3334*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(Y, PETSC_FALSE));
3335*d52a580bSJunchao Zhang     PetscCall(MatAXPY_SeqAIJ(Y, a, X, str));
3336*d52a580bSJunchao Zhang   }
3337*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3338*d52a580bSJunchao Zhang }
3339*d52a580bSJunchao Zhang 
3340*d52a580bSJunchao Zhang static PetscErrorCode MatScale_SeqAIJHIPSPARSE(Mat Y, PetscScalar a)
3341*d52a580bSJunchao Zhang {
3342*d52a580bSJunchao Zhang   Mat_SeqAIJ     *y = (Mat_SeqAIJ *)Y->data;
3343*d52a580bSJunchao Zhang   PetscScalar    *ay;
3344*d52a580bSJunchao Zhang   hipblasHandle_t hipblasv2handle;
3345*d52a580bSJunchao Zhang   PetscBLASInt    one = 1, bnz = 1;
3346*d52a580bSJunchao Zhang 
3347*d52a580bSJunchao Zhang   PetscFunctionBegin;
3348*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEGetArray(Y, &ay));
3349*d52a580bSJunchao Zhang   PetscCall(PetscHIPBLASGetHandle(&hipblasv2handle));
3350*d52a580bSJunchao Zhang   PetscCall(PetscBLASIntCast(y->nz, &bnz));
3351*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
3352*d52a580bSJunchao Zhang   PetscCallHIPBLAS(hipblasXscal(hipblasv2handle, bnz, &a, ay, one));
3353*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuFlops(bnz));
3354*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
3355*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSERestoreArray(Y, &ay));
3356*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3357*d52a580bSJunchao Zhang }
3358*d52a580bSJunchao Zhang 
3359*d52a580bSJunchao Zhang static PetscErrorCode MatZeroEntries_SeqAIJHIPSPARSE(Mat A)
3360*d52a580bSJunchao Zhang {
3361*d52a580bSJunchao Zhang   PetscBool   both = PETSC_FALSE;
3362*d52a580bSJunchao Zhang   Mat_SeqAIJ *a    = (Mat_SeqAIJ *)A->data;
3363*d52a580bSJunchao Zhang 
3364*d52a580bSJunchao Zhang   PetscFunctionBegin;
3365*d52a580bSJunchao Zhang   if (A->factortype == MAT_FACTOR_NONE) {
3366*d52a580bSJunchao Zhang     Mat_SeqAIJHIPSPARSE *spptr = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3367*d52a580bSJunchao Zhang     if (spptr->mat) {
3368*d52a580bSJunchao Zhang       CsrMatrix *matrix = (CsrMatrix *)spptr->mat->mat;
3369*d52a580bSJunchao Zhang       if (matrix->values) {
3370*d52a580bSJunchao Zhang         both = PETSC_TRUE;
3371*d52a580bSJunchao Zhang         thrust::fill(thrust::device, matrix->values->begin(), matrix->values->end(), 0.);
3372*d52a580bSJunchao Zhang       }
3373*d52a580bSJunchao Zhang     }
3374*d52a580bSJunchao Zhang     if (spptr->matTranspose) {
3375*d52a580bSJunchao Zhang       CsrMatrix *matrix = (CsrMatrix *)spptr->matTranspose->mat;
3376*d52a580bSJunchao Zhang       if (matrix->values) thrust::fill(thrust::device, matrix->values->begin(), matrix->values->end(), 0.);
3377*d52a580bSJunchao Zhang     }
3378*d52a580bSJunchao Zhang   }
3379*d52a580bSJunchao Zhang   //PetscCall(MatZeroEntries_SeqAIJ(A));
3380*d52a580bSJunchao Zhang   PetscCall(PetscArrayzero(a->a, a->i[A->rmap->n]));
3381*d52a580bSJunchao Zhang   if (both) A->offloadmask = PETSC_OFFLOAD_BOTH;
3382*d52a580bSJunchao Zhang   else A->offloadmask = PETSC_OFFLOAD_CPU;
3383*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3384*d52a580bSJunchao Zhang }
3385*d52a580bSJunchao Zhang 
3386*d52a580bSJunchao Zhang static PetscErrorCode MatGetCurrentMemType_SeqAIJHIPSPARSE(PETSC_UNUSED Mat A, PetscMemType *m)
3387*d52a580bSJunchao Zhang {
3388*d52a580bSJunchao Zhang   PetscFunctionBegin;
3389*d52a580bSJunchao Zhang   *m = PETSC_MEMTYPE_HIP;
3390*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3391*d52a580bSJunchao Zhang }
3392*d52a580bSJunchao Zhang 
3393*d52a580bSJunchao Zhang static PetscErrorCode MatBindToCPU_SeqAIJHIPSPARSE(Mat A, PetscBool flg)
3394*d52a580bSJunchao Zhang {
3395*d52a580bSJunchao Zhang   Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data;
3396*d52a580bSJunchao Zhang 
3397*d52a580bSJunchao Zhang   PetscFunctionBegin;
3398*d52a580bSJunchao Zhang   if (A->factortype != MAT_FACTOR_NONE) {
3399*d52a580bSJunchao Zhang     A->boundtocpu = flg;
3400*d52a580bSJunchao Zhang     PetscFunctionReturn(PETSC_SUCCESS);
3401*d52a580bSJunchao Zhang   }
3402*d52a580bSJunchao Zhang   if (flg) {
3403*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A));
3404*d52a580bSJunchao Zhang 
3405*d52a580bSJunchao Zhang     A->ops->scale                     = MatScale_SeqAIJ;
3406*d52a580bSJunchao Zhang     A->ops->axpy                      = MatAXPY_SeqAIJ;
3407*d52a580bSJunchao Zhang     A->ops->zeroentries               = MatZeroEntries_SeqAIJ;
3408*d52a580bSJunchao Zhang     A->ops->mult                      = MatMult_SeqAIJ;
3409*d52a580bSJunchao Zhang     A->ops->multadd                   = MatMultAdd_SeqAIJ;
3410*d52a580bSJunchao Zhang     A->ops->multtranspose             = MatMultTranspose_SeqAIJ;
3411*d52a580bSJunchao Zhang     A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJ;
3412*d52a580bSJunchao Zhang     A->ops->multhermitiantranspose    = NULL;
3413*d52a580bSJunchao Zhang     A->ops->multhermitiantransposeadd = NULL;
3414*d52a580bSJunchao Zhang     A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJ;
3415*d52a580bSJunchao Zhang     A->ops->getcurrentmemtype         = NULL;
3416*d52a580bSJunchao Zhang     PetscCall(PetscMemzero(a->ops, sizeof(Mat_SeqAIJOps)));
3417*d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSeqAIJCopySubArray_C", NULL));
3418*d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdensehip_C", NULL));
3419*d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdense_C", NULL));
3420*d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
3421*d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
3422*d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqaijhipsparse_C", NULL));
3423*d52a580bSJunchao Zhang   } else {
3424*d52a580bSJunchao Zhang     A->ops->scale                     = MatScale_SeqAIJHIPSPARSE;
3425*d52a580bSJunchao Zhang     A->ops->axpy                      = MatAXPY_SeqAIJHIPSPARSE;
3426*d52a580bSJunchao Zhang     A->ops->zeroentries               = MatZeroEntries_SeqAIJHIPSPARSE;
3427*d52a580bSJunchao Zhang     A->ops->mult                      = MatMult_SeqAIJHIPSPARSE;
3428*d52a580bSJunchao Zhang     A->ops->multadd                   = MatMultAdd_SeqAIJHIPSPARSE;
3429*d52a580bSJunchao Zhang     A->ops->multtranspose             = MatMultTranspose_SeqAIJHIPSPARSE;
3430*d52a580bSJunchao Zhang     A->ops->multtransposeadd          = MatMultTransposeAdd_SeqAIJHIPSPARSE;
3431*d52a580bSJunchao Zhang     A->ops->multhermitiantranspose    = MatMultHermitianTranspose_SeqAIJHIPSPARSE;
3432*d52a580bSJunchao Zhang     A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJHIPSPARSE;
3433*d52a580bSJunchao Zhang     A->ops->productsetfromoptions     = MatProductSetFromOptions_SeqAIJHIPSPARSE;
3434*d52a580bSJunchao Zhang     A->ops->getcurrentmemtype         = MatGetCurrentMemType_SeqAIJHIPSPARSE;
3435*d52a580bSJunchao Zhang     a->ops->getarray                  = MatSeqAIJGetArray_SeqAIJHIPSPARSE;
3436*d52a580bSJunchao Zhang     a->ops->restorearray              = MatSeqAIJRestoreArray_SeqAIJHIPSPARSE;
3437*d52a580bSJunchao Zhang     a->ops->getarrayread              = MatSeqAIJGetArrayRead_SeqAIJHIPSPARSE;
3438*d52a580bSJunchao Zhang     a->ops->restorearrayread          = MatSeqAIJRestoreArrayRead_SeqAIJHIPSPARSE;
3439*d52a580bSJunchao Zhang     a->ops->getarraywrite             = MatSeqAIJGetArrayWrite_SeqAIJHIPSPARSE;
3440*d52a580bSJunchao Zhang     a->ops->restorearraywrite         = MatSeqAIJRestoreArrayWrite_SeqAIJHIPSPARSE;
3441*d52a580bSJunchao Zhang     a->ops->getcsrandmemtype          = MatSeqAIJGetCSRAndMemType_SeqAIJHIPSPARSE;
3442*d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSeqAIJCopySubArray_C", MatSeqAIJCopySubArray_SeqAIJHIPSPARSE));
3443*d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdensehip_C", MatProductSetFromOptions_SeqAIJHIPSPARSE));
3444*d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdense_C", MatProductSetFromOptions_SeqAIJHIPSPARSE));
3445*d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJHIPSPARSE));
3446*d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJHIPSPARSE));
3447*d52a580bSJunchao Zhang     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqaijhipsparse_C", MatProductSetFromOptions_SeqAIJHIPSPARSE));
3448*d52a580bSJunchao Zhang   }
3449*d52a580bSJunchao Zhang   A->boundtocpu = flg;
3450*d52a580bSJunchao Zhang   if (flg && a->inode.size_csr) a->inode.use = PETSC_TRUE;
3451*d52a580bSJunchao Zhang   else a->inode.use = PETSC_FALSE;
3452*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3453*d52a580bSJunchao Zhang }
3454*d52a580bSJunchao Zhang 
3455*d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJHIPSPARSE(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
3456*d52a580bSJunchao Zhang {
3457*d52a580bSJunchao Zhang   Mat B;
3458*d52a580bSJunchao Zhang 
3459*d52a580bSJunchao Zhang   PetscFunctionBegin;
3460*d52a580bSJunchao Zhang   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_HIP)); /* first use of HIPSPARSE may be via MatConvert */
3461*d52a580bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
3462*d52a580bSJunchao Zhang     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
3463*d52a580bSJunchao Zhang   } else if (reuse == MAT_REUSE_MATRIX) {
3464*d52a580bSJunchao Zhang     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
3465*d52a580bSJunchao Zhang   }
3466*d52a580bSJunchao Zhang   B = *newmat;
3467*d52a580bSJunchao Zhang   PetscCall(PetscFree(B->defaultvectype));
3468*d52a580bSJunchao Zhang   PetscCall(PetscStrallocpy(VECHIP, &B->defaultvectype));
3469*d52a580bSJunchao Zhang   if (reuse != MAT_REUSE_MATRIX && !B->spptr) {
3470*d52a580bSJunchao Zhang     if (B->factortype == MAT_FACTOR_NONE) {
3471*d52a580bSJunchao Zhang       Mat_SeqAIJHIPSPARSE *spptr;
3472*d52a580bSJunchao Zhang       PetscCall(PetscNew(&spptr));
3473*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreate(&spptr->handle));
3474*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseSetStream(spptr->handle, PetscDefaultHipStream));
3475*d52a580bSJunchao Zhang       spptr->format = MAT_HIPSPARSE_CSR;
3476*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
3477*d52a580bSJunchao Zhang       spptr->spmvAlg = HIPSPARSE_SPMV_CSR_ALG1;
3478*d52a580bSJunchao Zhang #else
3479*d52a580bSJunchao Zhang       spptr->spmvAlg = HIPSPARSE_CSRMV_ALG1; /* default, since we only support csr */
3480*d52a580bSJunchao Zhang #endif
3481*d52a580bSJunchao Zhang       spptr->spmmAlg = HIPSPARSE_SPMM_CSR_ALG1; /* default, only support column-major dense matrix B */
3482*d52a580bSJunchao Zhang       //spptr->csr2cscAlg = HIPSPARSE_CSR2CSC_ALG1;
3483*d52a580bSJunchao Zhang 
3484*d52a580bSJunchao Zhang       B->spptr = spptr;
3485*d52a580bSJunchao Zhang     } else {
3486*d52a580bSJunchao Zhang       Mat_SeqAIJHIPSPARSETriFactors *spptr;
3487*d52a580bSJunchao Zhang 
3488*d52a580bSJunchao Zhang       PetscCall(PetscNew(&spptr));
3489*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreate(&spptr->handle));
3490*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseSetStream(spptr->handle, PetscDefaultHipStream));
3491*d52a580bSJunchao Zhang       B->spptr = spptr;
3492*d52a580bSJunchao Zhang     }
3493*d52a580bSJunchao Zhang     B->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
3494*d52a580bSJunchao Zhang   }
3495*d52a580bSJunchao Zhang   B->ops->assemblyend       = MatAssemblyEnd_SeqAIJHIPSPARSE;
3496*d52a580bSJunchao Zhang   B->ops->destroy           = MatDestroy_SeqAIJHIPSPARSE;
3497*d52a580bSJunchao Zhang   B->ops->setoption         = MatSetOption_SeqAIJHIPSPARSE;
3498*d52a580bSJunchao Zhang   B->ops->setfromoptions    = MatSetFromOptions_SeqAIJHIPSPARSE;
3499*d52a580bSJunchao Zhang   B->ops->bindtocpu         = MatBindToCPU_SeqAIJHIPSPARSE;
3500*d52a580bSJunchao Zhang   B->ops->duplicate         = MatDuplicate_SeqAIJHIPSPARSE;
3501*d52a580bSJunchao Zhang   B->ops->getcurrentmemtype = MatGetCurrentMemType_SeqAIJHIPSPARSE;
3502*d52a580bSJunchao Zhang 
3503*d52a580bSJunchao Zhang   PetscCall(MatBindToCPU_SeqAIJHIPSPARSE(B, PETSC_FALSE));
3504*d52a580bSJunchao Zhang   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATSEQAIJHIPSPARSE));
3505*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatHIPSPARSESetFormat_C", MatHIPSPARSESetFormat_SeqAIJHIPSPARSE));
3506*d52a580bSJunchao Zhang #if defined(PETSC_HAVE_HYPRE)
3507*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatConvert_seqaijhipsparse_hypre_C", MatConvert_AIJ_HYPRE));
3508*d52a580bSJunchao Zhang #endif
3509*d52a580bSJunchao Zhang   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatHIPSPARSESetUseCPUSolve_C", MatHIPSPARSESetUseCPUSolve_SeqAIJHIPSPARSE));
3510*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3511*d52a580bSJunchao Zhang }
3512*d52a580bSJunchao Zhang 
3513*d52a580bSJunchao Zhang PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJHIPSPARSE(Mat B)
3514*d52a580bSJunchao Zhang {
3515*d52a580bSJunchao Zhang   PetscFunctionBegin;
3516*d52a580bSJunchao Zhang   PetscCall(MatCreate_SeqAIJ(B));
3517*d52a580bSJunchao Zhang   PetscCall(MatConvert_SeqAIJ_SeqAIJHIPSPARSE(B, MATSEQAIJHIPSPARSE, MAT_INPLACE_MATRIX, &B));
3518*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3519*d52a580bSJunchao Zhang }
3520*d52a580bSJunchao Zhang 
3521*d52a580bSJunchao Zhang /*MC
3522*d52a580bSJunchao Zhang    MATSEQAIJHIPSPARSE - MATAIJHIPSPARSE = "(seq)aijhipsparse" - A matrix type to be used for sparse matrices on AMD GPUs
3523*d52a580bSJunchao Zhang 
3524*d52a580bSJunchao Zhang    A matrix type whose data resides on AMD GPUs. These matrices can be in either
3525*d52a580bSJunchao Zhang    CSR, ELL, or Hybrid format.
3526*d52a580bSJunchao Zhang    All matrix calculations are performed on AMD/NVIDIA GPUs using the HIPSPARSE library.
3527*d52a580bSJunchao Zhang 
3528*d52a580bSJunchao Zhang    Options Database Keys:
3529*d52a580bSJunchao Zhang +  -mat_type aijhipsparse - sets the matrix type to `MATSEQAIJHIPSPARSE`
3530*d52a580bSJunchao Zhang .  -mat_hipsparse_storage_format csr - sets the storage format of matrices (for `MatMult()` and factors in `MatSolve()`).
3531*d52a580bSJunchao Zhang                                        Other options include ell (ellpack) or hyb (hybrid).
3532*d52a580bSJunchao Zhang . -mat_hipsparse_mult_storage_format csr - sets the storage format of matrices (for `MatMult()`). Other options include ell (ellpack) or hyb (hybrid).
3533*d52a580bSJunchao Zhang -  -mat_hipsparse_use_cpu_solve - Do `MatSolve()` on the CPU
3534*d52a580bSJunchao Zhang 
3535*d52a580bSJunchao Zhang   Level: beginner
3536*d52a580bSJunchao Zhang 
3537*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatCreateSeqAIJHIPSPARSE()`, `MATAIJHIPSPARSE`, `MatCreateAIJHIPSPARSE()`, `MatHIPSPARSESetFormat()`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation`
3538*d52a580bSJunchao Zhang M*/
3539*d52a580bSJunchao Zhang 
3540*d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatSolverTypeRegister_HIPSPARSE(void)
3541*d52a580bSJunchao Zhang {
3542*d52a580bSJunchao Zhang   PetscFunctionBegin;
3543*d52a580bSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_LU, MatGetFactor_seqaijhipsparse_hipsparse));
3544*d52a580bSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_CHOLESKY, MatGetFactor_seqaijhipsparse_hipsparse));
3545*d52a580bSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_ILU, MatGetFactor_seqaijhipsparse_hipsparse));
3546*d52a580bSJunchao Zhang   PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_ICC, MatGetFactor_seqaijhipsparse_hipsparse));
3547*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3548*d52a580bSJunchao Zhang }
3549*d52a580bSJunchao Zhang 
3550*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSE_Destroy(Mat mat)
3551*d52a580bSJunchao Zhang {
3552*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp = static_cast<Mat_SeqAIJHIPSPARSE *>(mat->spptr);
3553*d52a580bSJunchao Zhang 
3554*d52a580bSJunchao Zhang   PetscFunctionBegin;
3555*d52a580bSJunchao Zhang   if (cusp) {
3556*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&cusp->mat, cusp->format));
3557*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&cusp->matTranspose, cusp->format));
3558*d52a580bSJunchao Zhang     delete cusp->workVector;
3559*d52a580bSJunchao Zhang     delete cusp->rowoffsets_gpu;
3560*d52a580bSJunchao Zhang     delete cusp->csr2csc_i;
3561*d52a580bSJunchao Zhang     delete cusp->coords;
3562*d52a580bSJunchao Zhang     if (cusp->handle) PetscCallHIPSPARSE(hipsparseDestroy(cusp->handle));
3563*d52a580bSJunchao Zhang     PetscCall(PetscFree(mat->spptr));
3564*d52a580bSJunchao Zhang   }
3565*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3566*d52a580bSJunchao Zhang }
3567*d52a580bSJunchao Zhang 
3568*d52a580bSJunchao Zhang static PetscErrorCode CsrMatrix_Destroy(CsrMatrix **mat)
3569*d52a580bSJunchao Zhang {
3570*d52a580bSJunchao Zhang   PetscFunctionBegin;
3571*d52a580bSJunchao Zhang   if (*mat) {
3572*d52a580bSJunchao Zhang     delete (*mat)->values;
3573*d52a580bSJunchao Zhang     delete (*mat)->column_indices;
3574*d52a580bSJunchao Zhang     delete (*mat)->row_offsets;
3575*d52a580bSJunchao Zhang     delete *mat;
3576*d52a580bSJunchao Zhang     *mat = 0;
3577*d52a580bSJunchao Zhang   }
3578*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3579*d52a580bSJunchao Zhang }
3580*d52a580bSJunchao Zhang 
3581*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSETriFactorStruct **trifactor)
3582*d52a580bSJunchao Zhang {
3583*d52a580bSJunchao Zhang   PetscFunctionBegin;
3584*d52a580bSJunchao Zhang   if (*trifactor) {
3585*d52a580bSJunchao Zhang     if ((*trifactor)->descr) PetscCallHIPSPARSE(hipsparseDestroyMatDescr((*trifactor)->descr));
3586*d52a580bSJunchao Zhang     if ((*trifactor)->solveInfo) PetscCallHIPSPARSE(hipsparseDestroyCsrsvInfo((*trifactor)->solveInfo));
3587*d52a580bSJunchao Zhang     PetscCall(CsrMatrix_Destroy(&(*trifactor)->csrMat));
3588*d52a580bSJunchao Zhang     if ((*trifactor)->solveBuffer) PetscCallHIP(hipFree((*trifactor)->solveBuffer));
3589*d52a580bSJunchao Zhang     if ((*trifactor)->AA_h) PetscCallHIP(hipHostFree((*trifactor)->AA_h));
3590*d52a580bSJunchao Zhang     if ((*trifactor)->csr2cscBuffer) PetscCallHIP(hipFree((*trifactor)->csr2cscBuffer));
3591*d52a580bSJunchao Zhang     PetscCall(PetscFree(*trifactor));
3592*d52a580bSJunchao Zhang   }
3593*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3594*d52a580bSJunchao Zhang }
3595*d52a580bSJunchao Zhang 
3596*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSEMultStruct **matstruct, MatHIPSPARSEStorageFormat format)
3597*d52a580bSJunchao Zhang {
3598*d52a580bSJunchao Zhang   CsrMatrix *mat;
3599*d52a580bSJunchao Zhang 
3600*d52a580bSJunchao Zhang   PetscFunctionBegin;
3601*d52a580bSJunchao Zhang   if (*matstruct) {
3602*d52a580bSJunchao Zhang     if ((*matstruct)->mat) {
3603*d52a580bSJunchao Zhang       if (format == MAT_HIPSPARSE_ELL || format == MAT_HIPSPARSE_HYB) {
3604*d52a580bSJunchao Zhang         hipsparseHybMat_t hybMat = (hipsparseHybMat_t)(*matstruct)->mat;
3605*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseDestroyHybMat(hybMat));
3606*d52a580bSJunchao Zhang       } else {
3607*d52a580bSJunchao Zhang         mat = (CsrMatrix *)(*matstruct)->mat;
3608*d52a580bSJunchao Zhang         PetscCall(CsrMatrix_Destroy(&mat));
3609*d52a580bSJunchao Zhang       }
3610*d52a580bSJunchao Zhang     }
3611*d52a580bSJunchao Zhang     if ((*matstruct)->descr) PetscCallHIPSPARSE(hipsparseDestroyMatDescr((*matstruct)->descr));
3612*d52a580bSJunchao Zhang     delete (*matstruct)->cprowIndices;
3613*d52a580bSJunchao Zhang     if ((*matstruct)->alpha_one) PetscCallHIP(hipFree((*matstruct)->alpha_one));
3614*d52a580bSJunchao Zhang     if ((*matstruct)->beta_zero) PetscCallHIP(hipFree((*matstruct)->beta_zero));
3615*d52a580bSJunchao Zhang     if ((*matstruct)->beta_one) PetscCallHIP(hipFree((*matstruct)->beta_one));
3616*d52a580bSJunchao Zhang 
3617*d52a580bSJunchao Zhang     Mat_SeqAIJHIPSPARSEMultStruct *mdata = *matstruct;
3618*d52a580bSJunchao Zhang     if (mdata->matDescr) PetscCallHIPSPARSE(hipsparseDestroySpMat(mdata->matDescr));
3619*d52a580bSJunchao Zhang     for (int i = 0; i < 3; i++) {
3620*d52a580bSJunchao Zhang       if (mdata->hipSpMV[i].initialized) {
3621*d52a580bSJunchao Zhang         PetscCallHIP(hipFree(mdata->hipSpMV[i].spmvBuffer));
3622*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseDestroyDnVec(mdata->hipSpMV[i].vecXDescr));
3623*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseDestroyDnVec(mdata->hipSpMV[i].vecYDescr));
3624*d52a580bSJunchao Zhang       }
3625*d52a580bSJunchao Zhang     }
3626*d52a580bSJunchao Zhang     delete *matstruct;
3627*d52a580bSJunchao Zhang     *matstruct = NULL;
3628*d52a580bSJunchao Zhang   }
3629*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3630*d52a580bSJunchao Zhang }
3631*d52a580bSJunchao Zhang 
3632*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSETriFactors_Reset(Mat_SeqAIJHIPSPARSETriFactors_p *trifactors)
3633*d52a580bSJunchao Zhang {
3634*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSETriFactors *fs = *trifactors;
3635*d52a580bSJunchao Zhang 
3636*d52a580bSJunchao Zhang   PetscFunctionBegin;
3637*d52a580bSJunchao Zhang   if (fs) {
3638*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->loTriFactorPtr));
3639*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->upTriFactorPtr));
3640*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->loTriFactorPtrTranspose));
3641*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->upTriFactorPtrTranspose));
3642*d52a580bSJunchao Zhang     delete fs->rpermIndices;
3643*d52a580bSJunchao Zhang     delete fs->cpermIndices;
3644*d52a580bSJunchao Zhang     delete fs->workVector;
3645*d52a580bSJunchao Zhang     fs->rpermIndices  = NULL;
3646*d52a580bSJunchao Zhang     fs->cpermIndices  = NULL;
3647*d52a580bSJunchao Zhang     fs->workVector    = NULL;
3648*d52a580bSJunchao Zhang     fs->init_dev_prop = PETSC_FALSE;
3649*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0)
3650*d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->csrRowPtr));
3651*d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->csrColIdx));
3652*d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->csrVal));
3653*d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->X));
3654*d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->Y));
3655*d52a580bSJunchao Zhang     // PetscCallHIP(hipFree(fs->factBuffer_M)); /* No needed since factBuffer_M shares with one of spsvBuffer_L/U */
3656*d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->spsvBuffer_L));
3657*d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->spsvBuffer_U));
3658*d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->spsvBuffer_Lt));
3659*d52a580bSJunchao Zhang     PetscCallHIP(hipFree(fs->spsvBuffer_Ut));
3660*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseDestroyMatDescr(fs->matDescr_M));
3661*d52a580bSJunchao Zhang     if (fs->spMatDescr_L) PetscCallHIPSPARSE(hipsparseDestroySpMat(fs->spMatDescr_L));
3662*d52a580bSJunchao Zhang     if (fs->spMatDescr_U) PetscCallHIPSPARSE(hipsparseDestroySpMat(fs->spMatDescr_U));
3663*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_L));
3664*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_Lt));
3665*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_U));
3666*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_Ut));
3667*d52a580bSJunchao Zhang     if (fs->dnVecDescr_X) PetscCallHIPSPARSE(hipsparseDestroyDnVec(fs->dnVecDescr_X));
3668*d52a580bSJunchao Zhang     if (fs->dnVecDescr_Y) PetscCallHIPSPARSE(hipsparseDestroyDnVec(fs->dnVecDescr_Y));
3669*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseDestroyCsrilu02Info(fs->ilu0Info_M));
3670*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseDestroyCsric02Info(fs->ic0Info_M));
3671*d52a580bSJunchao Zhang 
3672*d52a580bSJunchao Zhang     fs->createdTransposeSpSVDescr    = PETSC_FALSE;
3673*d52a580bSJunchao Zhang     fs->updatedTransposeSpSVAnalysis = PETSC_FALSE;
3674*d52a580bSJunchao Zhang #endif
3675*d52a580bSJunchao Zhang   }
3676*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3677*d52a580bSJunchao Zhang }
3678*d52a580bSJunchao Zhang 
3679*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSETriFactors_Destroy(Mat_SeqAIJHIPSPARSETriFactors **trifactors)
3680*d52a580bSJunchao Zhang {
3681*d52a580bSJunchao Zhang   hipsparseHandle_t handle;
3682*d52a580bSJunchao Zhang 
3683*d52a580bSJunchao Zhang   PetscFunctionBegin;
3684*d52a580bSJunchao Zhang   if (*trifactors) {
3685*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(trifactors));
3686*d52a580bSJunchao Zhang     if ((handle = (*trifactors)->handle)) PetscCallHIPSPARSE(hipsparseDestroy(handle));
3687*d52a580bSJunchao Zhang     PetscCall(PetscFree(*trifactors));
3688*d52a580bSJunchao Zhang   }
3689*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3690*d52a580bSJunchao Zhang }
3691*d52a580bSJunchao Zhang 
3692*d52a580bSJunchao Zhang struct IJCompare {
3693*d52a580bSJunchao Zhang   __host__ __device__ inline bool operator()(const thrust::tuple<PetscInt, PetscInt> &t1, const thrust::tuple<PetscInt, PetscInt> &t2)
3694*d52a580bSJunchao Zhang   {
3695*d52a580bSJunchao Zhang     if (t1.get<0>() < t2.get<0>()) return true;
3696*d52a580bSJunchao Zhang     if (t1.get<0>() == t2.get<0>()) return t1.get<1>() < t2.get<1>();
3697*d52a580bSJunchao Zhang     return false;
3698*d52a580bSJunchao Zhang   }
3699*d52a580bSJunchao Zhang };
3700*d52a580bSJunchao Zhang 
3701*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEInvalidateTranspose(Mat A, PetscBool destroy)
3702*d52a580bSJunchao Zhang {
3703*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3704*d52a580bSJunchao Zhang 
3705*d52a580bSJunchao Zhang   PetscFunctionBegin;
3706*d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3707*d52a580bSJunchao Zhang   if (!cusp) PetscFunctionReturn(PETSC_SUCCESS);
3708*d52a580bSJunchao Zhang   if (destroy) {
3709*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&cusp->matTranspose, cusp->format));
3710*d52a580bSJunchao Zhang     delete cusp->csr2csc_i;
3711*d52a580bSJunchao Zhang     cusp->csr2csc_i = NULL;
3712*d52a580bSJunchao Zhang   }
3713*d52a580bSJunchao Zhang   A->transupdated = PETSC_FALSE;
3714*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3715*d52a580bSJunchao Zhang }
3716*d52a580bSJunchao Zhang 
3717*d52a580bSJunchao Zhang static PetscErrorCode MatCOOStructDestroy_SeqAIJHIPSPARSE(PetscCtxRt data)
3718*d52a580bSJunchao Zhang {
3719*d52a580bSJunchao Zhang   MatCOOStruct_SeqAIJ *coo = *(MatCOOStruct_SeqAIJ **)data;
3720*d52a580bSJunchao Zhang 
3721*d52a580bSJunchao Zhang   PetscFunctionBegin;
3722*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->perm));
3723*d52a580bSJunchao Zhang   PetscCallHIP(hipFree(coo->jmap));
3724*d52a580bSJunchao Zhang   PetscCall(PetscFree(coo));
3725*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3726*d52a580bSJunchao Zhang }
3727*d52a580bSJunchao Zhang 
3728*d52a580bSJunchao Zhang static PetscErrorCode MatSetPreallocationCOO_SeqAIJHIPSPARSE(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
3729*d52a580bSJunchao Zhang {
3730*d52a580bSJunchao Zhang   PetscBool            dev_ij = PETSC_FALSE;
3731*d52a580bSJunchao Zhang   PetscMemType         mtype  = PETSC_MEMTYPE_HOST;
3732*d52a580bSJunchao Zhang   PetscInt            *i, *j;
3733*d52a580bSJunchao Zhang   PetscContainer       container_h;
3734*d52a580bSJunchao Zhang   MatCOOStruct_SeqAIJ *coo_h, *coo_d;
3735*d52a580bSJunchao Zhang 
3736*d52a580bSJunchao Zhang   PetscFunctionBegin;
3737*d52a580bSJunchao Zhang   PetscCall(PetscGetMemType(coo_i, &mtype));
3738*d52a580bSJunchao Zhang   if (PetscMemTypeDevice(mtype)) {
3739*d52a580bSJunchao Zhang     dev_ij = PETSC_TRUE;
3740*d52a580bSJunchao Zhang     PetscCall(PetscMalloc2(coo_n, &i, coo_n, &j));
3741*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(i, coo_i, coo_n * sizeof(PetscInt), hipMemcpyDeviceToHost));
3742*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(j, coo_j, coo_n * sizeof(PetscInt), hipMemcpyDeviceToHost));
3743*d52a580bSJunchao Zhang   } else {
3744*d52a580bSJunchao Zhang     i = coo_i;
3745*d52a580bSJunchao Zhang     j = coo_j;
3746*d52a580bSJunchao Zhang   }
3747*d52a580bSJunchao Zhang   PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, i, j));
3748*d52a580bSJunchao Zhang   if (dev_ij) PetscCall(PetscFree2(i, j));
3749*d52a580bSJunchao Zhang   mat->offloadmask = PETSC_OFFLOAD_CPU;
3750*d52a580bSJunchao Zhang   // Create the GPU memory
3751*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(mat));
3752*d52a580bSJunchao Zhang 
3753*d52a580bSJunchao Zhang   // Copy the COO struct to device
3754*d52a580bSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
3755*d52a580bSJunchao Zhang   PetscCall(PetscContainerGetPointer(container_h, &coo_h));
3756*d52a580bSJunchao Zhang   PetscCall(PetscMalloc1(1, &coo_d));
3757*d52a580bSJunchao Zhang   *coo_d = *coo_h; // do a shallow copy and then amend some fields that need to be different
3758*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->jmap, (coo_h->nz + 1) * sizeof(PetscCount)));
3759*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(coo_d->jmap, coo_h->jmap, (coo_h->nz + 1) * sizeof(PetscCount), hipMemcpyHostToDevice));
3760*d52a580bSJunchao Zhang   PetscCallHIP(hipMalloc((void **)&coo_d->perm, coo_h->Atot * sizeof(PetscCount)));
3761*d52a580bSJunchao Zhang   PetscCallHIP(hipMemcpy(coo_d->perm, coo_h->perm, coo_h->Atot * sizeof(PetscCount), hipMemcpyHostToDevice));
3762*d52a580bSJunchao Zhang 
3763*d52a580bSJunchao Zhang   // Put the COO struct in a container and then attach that to the matrix
3764*d52a580bSJunchao Zhang   PetscCall(PetscObjectContainerCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", coo_d, MatCOOStructDestroy_SeqAIJHIPSPARSE));
3765*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3766*d52a580bSJunchao Zhang }
3767*d52a580bSJunchao Zhang 
3768*d52a580bSJunchao Zhang __global__ static void MatAddCOOValues(const PetscScalar kv[], PetscCount nnz, const PetscCount jmap[], const PetscCount perm[], InsertMode imode, PetscScalar a[])
3769*d52a580bSJunchao Zhang {
3770*d52a580bSJunchao Zhang   PetscCount       i         = blockIdx.x * blockDim.x + threadIdx.x;
3771*d52a580bSJunchao Zhang   const PetscCount grid_size = gridDim.x * blockDim.x;
3772*d52a580bSJunchao Zhang   for (; i < nnz; i += grid_size) {
3773*d52a580bSJunchao Zhang     PetscScalar sum = 0.0;
3774*d52a580bSJunchao Zhang     for (PetscCount k = jmap[i]; k < jmap[i + 1]; k++) sum += kv[perm[k]];
3775*d52a580bSJunchao Zhang     a[i] = (imode == INSERT_VALUES ? 0.0 : a[i]) + sum;
3776*d52a580bSJunchao Zhang   }
3777*d52a580bSJunchao Zhang }
3778*d52a580bSJunchao Zhang 
3779*d52a580bSJunchao Zhang static PetscErrorCode MatSetValuesCOO_SeqAIJHIPSPARSE(Mat A, const PetscScalar v[], InsertMode imode)
3780*d52a580bSJunchao Zhang {
3781*d52a580bSJunchao Zhang   Mat_SeqAIJ          *seq  = (Mat_SeqAIJ *)A->data;
3782*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *dev  = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3783*d52a580bSJunchao Zhang   PetscCount           Annz = seq->nz;
3784*d52a580bSJunchao Zhang   PetscMemType         memtype;
3785*d52a580bSJunchao Zhang   const PetscScalar   *v1 = v;
3786*d52a580bSJunchao Zhang   PetscScalar         *Aa;
3787*d52a580bSJunchao Zhang   PetscContainer       container;
3788*d52a580bSJunchao Zhang   MatCOOStruct_SeqAIJ *coo;
3789*d52a580bSJunchao Zhang 
3790*d52a580bSJunchao Zhang   PetscFunctionBegin;
3791*d52a580bSJunchao Zhang   if (!dev->mat) PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
3792*d52a580bSJunchao Zhang 
3793*d52a580bSJunchao Zhang   PetscCall(PetscObjectQuery((PetscObject)A, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
3794*d52a580bSJunchao Zhang   PetscCall(PetscContainerGetPointer(container, &coo));
3795*d52a580bSJunchao Zhang 
3796*d52a580bSJunchao Zhang   PetscCall(PetscGetMemType(v, &memtype));
3797*d52a580bSJunchao Zhang   if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */
3798*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&v1, coo->n * sizeof(PetscScalar)));
3799*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy((void *)v1, v, coo->n * sizeof(PetscScalar), hipMemcpyHostToDevice));
3800*d52a580bSJunchao Zhang   }
3801*d52a580bSJunchao Zhang 
3802*d52a580bSJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJHIPSPARSEGetArrayWrite(A, &Aa));
3803*d52a580bSJunchao Zhang   else PetscCall(MatSeqAIJHIPSPARSEGetArray(A, &Aa));
3804*d52a580bSJunchao Zhang 
3805*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeBegin());
3806*d52a580bSJunchao Zhang   if (Annz) {
3807*d52a580bSJunchao Zhang     hipLaunchKernelGGL(HIP_KERNEL_NAME(MatAddCOOValues), dim3((Annz + 255) / 256), dim3(256), 0, PetscDefaultHipStream, v1, Annz, coo->jmap, coo->perm, imode, Aa);
3808*d52a580bSJunchao Zhang     PetscCallHIP(hipPeekAtLastError());
3809*d52a580bSJunchao Zhang   }
3810*d52a580bSJunchao Zhang   PetscCall(PetscLogGpuTimeEnd());
3811*d52a580bSJunchao Zhang 
3812*d52a580bSJunchao Zhang   if (imode == INSERT_VALUES) PetscCall(MatSeqAIJHIPSPARSERestoreArrayWrite(A, &Aa));
3813*d52a580bSJunchao Zhang   else PetscCall(MatSeqAIJHIPSPARSERestoreArray(A, &Aa));
3814*d52a580bSJunchao Zhang 
3815*d52a580bSJunchao Zhang   if (PetscMemTypeHost(memtype)) PetscCallHIP(hipFree((void *)v1));
3816*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3817*d52a580bSJunchao Zhang }
3818*d52a580bSJunchao Zhang 
3819*d52a580bSJunchao Zhang /*@C
3820*d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSEGetIJ - returns the device row storage `i` and `j` indices for `MATSEQAIJHIPSPARSE` matrices.
3821*d52a580bSJunchao Zhang 
3822*d52a580bSJunchao Zhang   Not Collective
3823*d52a580bSJunchao Zhang 
3824*d52a580bSJunchao Zhang   Input Parameters:
3825*d52a580bSJunchao Zhang + A          - the matrix
3826*d52a580bSJunchao Zhang - compressed - `PETSC_TRUE` or `PETSC_FALSE` indicating the matrix data structure should be always returned in compressed form
3827*d52a580bSJunchao Zhang 
3828*d52a580bSJunchao Zhang   Output Parameters:
3829*d52a580bSJunchao Zhang + i - the CSR row pointers
3830*d52a580bSJunchao Zhang - j - the CSR column indices
3831*d52a580bSJunchao Zhang 
3832*d52a580bSJunchao Zhang   Level: developer
3833*d52a580bSJunchao Zhang 
3834*d52a580bSJunchao Zhang   Note:
3835*d52a580bSJunchao Zhang   When compressed is true, the CSR structure does not contain empty rows
3836*d52a580bSJunchao Zhang 
3837*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSERestoreIJ()`, `MatSeqAIJHIPSPARSEGetArrayRead()`
3838*d52a580bSJunchao Zhang @*/
3839*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetIJ(Mat A, PetscBool compressed, const int *i[], const int *j[])
3840*d52a580bSJunchao Zhang {
3841*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3842*d52a580bSJunchao Zhang   Mat_SeqAIJ          *a    = (Mat_SeqAIJ *)A->data;
3843*d52a580bSJunchao Zhang   CsrMatrix           *csr;
3844*d52a580bSJunchao Zhang 
3845*d52a580bSJunchao Zhang   PetscFunctionBegin;
3846*d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
3847*d52a580bSJunchao Zhang   if (!i || !j) PetscFunctionReturn(PETSC_SUCCESS);
3848*d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3849*d52a580bSJunchao Zhang   PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
3850*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
3851*d52a580bSJunchao Zhang   PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
3852*d52a580bSJunchao Zhang   csr = (CsrMatrix *)cusp->mat->mat;
3853*d52a580bSJunchao Zhang   if (i) {
3854*d52a580bSJunchao Zhang     if (!compressed && a->compressedrow.use) { /* need full row offset */
3855*d52a580bSJunchao Zhang       if (!cusp->rowoffsets_gpu) {
3856*d52a580bSJunchao Zhang         cusp->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1);
3857*d52a580bSJunchao Zhang         cusp->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1);
3858*d52a580bSJunchao Zhang         PetscCall(PetscLogCpuToGpu((A->rmap->n + 1) * sizeof(PetscInt)));
3859*d52a580bSJunchao Zhang       }
3860*d52a580bSJunchao Zhang       *i = cusp->rowoffsets_gpu->data().get();
3861*d52a580bSJunchao Zhang     } else *i = csr->row_offsets->data().get();
3862*d52a580bSJunchao Zhang   }
3863*d52a580bSJunchao Zhang   if (j) *j = csr->column_indices->data().get();
3864*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3865*d52a580bSJunchao Zhang }
3866*d52a580bSJunchao Zhang 
3867*d52a580bSJunchao Zhang /*@C
3868*d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSERestoreIJ - restore the device row storage `i` and `j` indices obtained with `MatSeqAIJHIPSPARSEGetIJ()`
3869*d52a580bSJunchao Zhang 
3870*d52a580bSJunchao Zhang   Not Collective
3871*d52a580bSJunchao Zhang 
3872*d52a580bSJunchao Zhang   Input Parameters:
3873*d52a580bSJunchao Zhang + A          - the matrix
3874*d52a580bSJunchao Zhang . compressed - `PETSC_TRUE` or `PETSC_FALSE` indicating the matrix data structure should be always returned in compressed form
3875*d52a580bSJunchao Zhang . i          - the CSR row pointers
3876*d52a580bSJunchao Zhang - j          - the CSR column indices
3877*d52a580bSJunchao Zhang 
3878*d52a580bSJunchao Zhang   Level: developer
3879*d52a580bSJunchao Zhang 
3880*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetIJ()`
3881*d52a580bSJunchao Zhang @*/
3882*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreIJ(Mat A, PetscBool compressed, const int *i[], const int *j[])
3883*d52a580bSJunchao Zhang {
3884*d52a580bSJunchao Zhang   PetscFunctionBegin;
3885*d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
3886*d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3887*d52a580bSJunchao Zhang   if (i) *i = NULL;
3888*d52a580bSJunchao Zhang   if (j) *j = NULL;
3889*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3890*d52a580bSJunchao Zhang }
3891*d52a580bSJunchao Zhang 
3892*d52a580bSJunchao Zhang /*@C
3893*d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSEGetArrayRead - gives read-only access to the array where the device data for a `MATSEQAIJHIPSPARSE` matrix is stored
3894*d52a580bSJunchao Zhang 
3895*d52a580bSJunchao Zhang   Not Collective
3896*d52a580bSJunchao Zhang 
3897*d52a580bSJunchao Zhang   Input Parameter:
3898*d52a580bSJunchao Zhang . A - a `MATSEQAIJHIPSPARSE` matrix
3899*d52a580bSJunchao Zhang 
3900*d52a580bSJunchao Zhang   Output Parameter:
3901*d52a580bSJunchao Zhang . a - pointer to the device data
3902*d52a580bSJunchao Zhang 
3903*d52a580bSJunchao Zhang   Level: developer
3904*d52a580bSJunchao Zhang 
3905*d52a580bSJunchao Zhang   Note:
3906*d52a580bSJunchao Zhang   May trigger host-device copies if the up-to-date matrix data is on host
3907*d52a580bSJunchao Zhang 
3908*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArray()`, `MatSeqAIJHIPSPARSEGetArrayWrite()`, `MatSeqAIJHIPSPARSERestoreArrayRead()`
3909*d52a580bSJunchao Zhang @*/
3910*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetArrayRead(Mat A, const PetscScalar *a[])
3911*d52a580bSJunchao Zhang {
3912*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3913*d52a580bSJunchao Zhang   CsrMatrix           *csr;
3914*d52a580bSJunchao Zhang 
3915*d52a580bSJunchao Zhang   PetscFunctionBegin;
3916*d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
3917*d52a580bSJunchao Zhang   PetscAssertPointer(a, 2);
3918*d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3919*d52a580bSJunchao Zhang   PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
3920*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
3921*d52a580bSJunchao Zhang   PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
3922*d52a580bSJunchao Zhang   csr = (CsrMatrix *)cusp->mat->mat;
3923*d52a580bSJunchao Zhang   PetscCheck(csr->values, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing HIP memory");
3924*d52a580bSJunchao Zhang   *a = csr->values->data().get();
3925*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3926*d52a580bSJunchao Zhang }
3927*d52a580bSJunchao Zhang 
3928*d52a580bSJunchao Zhang /*@C
3929*d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSERestoreArrayRead - restore the read-only access array obtained from `MatSeqAIJHIPSPARSEGetArrayRead()`
3930*d52a580bSJunchao Zhang 
3931*d52a580bSJunchao Zhang   Not Collective
3932*d52a580bSJunchao Zhang 
3933*d52a580bSJunchao Zhang   Input Parameters:
3934*d52a580bSJunchao Zhang + A - a `MATSEQAIJHIPSPARSE` matrix
3935*d52a580bSJunchao Zhang - a - pointer to the device data
3936*d52a580bSJunchao Zhang 
3937*d52a580bSJunchao Zhang   Level: developer
3938*d52a580bSJunchao Zhang 
3939*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArrayRead()`
3940*d52a580bSJunchao Zhang @*/
3941*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreArrayRead(Mat A, const PetscScalar *a[])
3942*d52a580bSJunchao Zhang {
3943*d52a580bSJunchao Zhang   PetscFunctionBegin;
3944*d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
3945*d52a580bSJunchao Zhang   PetscAssertPointer(a, 2);
3946*d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3947*d52a580bSJunchao Zhang   *a = NULL;
3948*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3949*d52a580bSJunchao Zhang }
3950*d52a580bSJunchao Zhang 
3951*d52a580bSJunchao Zhang /*@C
3952*d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSEGetArray - gives read-write access to the array where the device data for a `MATSEQAIJHIPSPARSE` matrix is stored
3953*d52a580bSJunchao Zhang 
3954*d52a580bSJunchao Zhang   Not Collective
3955*d52a580bSJunchao Zhang 
3956*d52a580bSJunchao Zhang   Input Parameter:
3957*d52a580bSJunchao Zhang . A - a `MATSEQAIJHIPSPARSE` matrix
3958*d52a580bSJunchao Zhang 
3959*d52a580bSJunchao Zhang   Output Parameter:
3960*d52a580bSJunchao Zhang . a - pointer to the device data
3961*d52a580bSJunchao Zhang 
3962*d52a580bSJunchao Zhang   Level: developer
3963*d52a580bSJunchao Zhang 
3964*d52a580bSJunchao Zhang   Note:
3965*d52a580bSJunchao Zhang   May trigger host-device copies if up-to-date matrix data is on host
3966*d52a580bSJunchao Zhang 
3967*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArrayRead()`, `MatSeqAIJHIPSPARSEGetArrayWrite()`, `MatSeqAIJHIPSPARSERestoreArray()`
3968*d52a580bSJunchao Zhang @*/
3969*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetArray(Mat A, PetscScalar *a[])
3970*d52a580bSJunchao Zhang {
3971*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
3972*d52a580bSJunchao Zhang   CsrMatrix           *csr;
3973*d52a580bSJunchao Zhang 
3974*d52a580bSJunchao Zhang   PetscFunctionBegin;
3975*d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
3976*d52a580bSJunchao Zhang   PetscAssertPointer(a, 2);
3977*d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
3978*d52a580bSJunchao Zhang   PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
3979*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
3980*d52a580bSJunchao Zhang   PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
3981*d52a580bSJunchao Zhang   csr = (CsrMatrix *)cusp->mat->mat;
3982*d52a580bSJunchao Zhang   PetscCheck(csr->values, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing HIP memory");
3983*d52a580bSJunchao Zhang   *a             = csr->values->data().get();
3984*d52a580bSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_GPU;
3985*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_FALSE));
3986*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
3987*d52a580bSJunchao Zhang }
3988*d52a580bSJunchao Zhang /*@C
3989*d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSERestoreArray - restore the read-write access array obtained from `MatSeqAIJHIPSPARSEGetArray()`
3990*d52a580bSJunchao Zhang 
3991*d52a580bSJunchao Zhang   Not Collective
3992*d52a580bSJunchao Zhang 
3993*d52a580bSJunchao Zhang   Input Parameters:
3994*d52a580bSJunchao Zhang + A - a `MATSEQAIJHIPSPARSE` matrix
3995*d52a580bSJunchao Zhang - a - pointer to the device data
3996*d52a580bSJunchao Zhang 
3997*d52a580bSJunchao Zhang   Level: developer
3998*d52a580bSJunchao Zhang 
3999*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArray()`
4000*d52a580bSJunchao Zhang @*/
4001*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreArray(Mat A, PetscScalar *a[])
4002*d52a580bSJunchao Zhang {
4003*d52a580bSJunchao Zhang   PetscFunctionBegin;
4004*d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
4005*d52a580bSJunchao Zhang   PetscAssertPointer(a, 2);
4006*d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
4007*d52a580bSJunchao Zhang   PetscCall(PetscObjectStateIncrease((PetscObject)A));
4008*d52a580bSJunchao Zhang   *a = NULL;
4009*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
4010*d52a580bSJunchao Zhang }
4011*d52a580bSJunchao Zhang 
4012*d52a580bSJunchao Zhang /*@C
4013*d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSEGetArrayWrite - gives write access to the array where the device data for a `MATSEQAIJHIPSPARSE` matrix is stored
4014*d52a580bSJunchao Zhang 
4015*d52a580bSJunchao Zhang   Not Collective
4016*d52a580bSJunchao Zhang 
4017*d52a580bSJunchao Zhang   Input Parameter:
4018*d52a580bSJunchao Zhang . A - a `MATSEQAIJHIPSPARSE` matrix
4019*d52a580bSJunchao Zhang 
4020*d52a580bSJunchao Zhang   Output Parameter:
4021*d52a580bSJunchao Zhang . a - pointer to the device data
4022*d52a580bSJunchao Zhang 
4023*d52a580bSJunchao Zhang   Level: developer
4024*d52a580bSJunchao Zhang 
4025*d52a580bSJunchao Zhang   Note:
4026*d52a580bSJunchao Zhang   Does not trigger host-device copies and flags data validity on the GPU
4027*d52a580bSJunchao Zhang 
4028*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArray()`, `MatSeqAIJHIPSPARSEGetArrayRead()`, `MatSeqAIJHIPSPARSERestoreArrayWrite()`
4029*d52a580bSJunchao Zhang @*/
4030*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetArrayWrite(Mat A, PetscScalar *a[])
4031*d52a580bSJunchao Zhang {
4032*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr;
4033*d52a580bSJunchao Zhang   CsrMatrix           *csr;
4034*d52a580bSJunchao Zhang 
4035*d52a580bSJunchao Zhang   PetscFunctionBegin;
4036*d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
4037*d52a580bSJunchao Zhang   PetscAssertPointer(a, 2);
4038*d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
4039*d52a580bSJunchao Zhang   PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
4040*d52a580bSJunchao Zhang   PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
4041*d52a580bSJunchao Zhang   csr = (CsrMatrix *)cusp->mat->mat;
4042*d52a580bSJunchao Zhang   PetscCheck(csr->values, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing HIP memory");
4043*d52a580bSJunchao Zhang   *a             = csr->values->data().get();
4044*d52a580bSJunchao Zhang   A->offloadmask = PETSC_OFFLOAD_GPU;
4045*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_FALSE));
4046*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
4047*d52a580bSJunchao Zhang }
4048*d52a580bSJunchao Zhang 
4049*d52a580bSJunchao Zhang /*@C
4050*d52a580bSJunchao Zhang   MatSeqAIJHIPSPARSERestoreArrayWrite - restore the write-only access array obtained from `MatSeqAIJHIPSPARSEGetArrayWrite()`
4051*d52a580bSJunchao Zhang 
4052*d52a580bSJunchao Zhang   Not Collective
4053*d52a580bSJunchao Zhang 
4054*d52a580bSJunchao Zhang   Input Parameters:
4055*d52a580bSJunchao Zhang + A - a `MATSEQAIJHIPSPARSE` matrix
4056*d52a580bSJunchao Zhang - a - pointer to the device data
4057*d52a580bSJunchao Zhang 
4058*d52a580bSJunchao Zhang   Level: developer
4059*d52a580bSJunchao Zhang 
4060*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArrayWrite()`
4061*d52a580bSJunchao Zhang @*/
4062*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreArrayWrite(Mat A, PetscScalar *a[])
4063*d52a580bSJunchao Zhang {
4064*d52a580bSJunchao Zhang   PetscFunctionBegin;
4065*d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
4066*d52a580bSJunchao Zhang   PetscAssertPointer(a, 2);
4067*d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
4068*d52a580bSJunchao Zhang   PetscCall(PetscObjectStateIncrease((PetscObject)A));
4069*d52a580bSJunchao Zhang   *a = NULL;
4070*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
4071*d52a580bSJunchao Zhang }
4072*d52a580bSJunchao Zhang 
4073*d52a580bSJunchao Zhang struct IJCompare4 {
4074*d52a580bSJunchao Zhang   __host__ __device__ inline bool operator()(const thrust::tuple<int, int, PetscScalar, int> &t1, const thrust::tuple<int, int, PetscScalar, int> &t2)
4075*d52a580bSJunchao Zhang   {
4076*d52a580bSJunchao Zhang     if (t1.get<0>() < t2.get<0>()) return true;
4077*d52a580bSJunchao Zhang     if (t1.get<0>() == t2.get<0>()) return t1.get<1>() < t2.get<1>();
4078*d52a580bSJunchao Zhang     return false;
4079*d52a580bSJunchao Zhang   }
4080*d52a580bSJunchao Zhang };
4081*d52a580bSJunchao Zhang 
4082*d52a580bSJunchao Zhang struct Shift {
4083*d52a580bSJunchao Zhang   int _shift;
4084*d52a580bSJunchao Zhang 
4085*d52a580bSJunchao Zhang   Shift(int shift) : _shift(shift) { }
4086*d52a580bSJunchao Zhang   __host__ __device__ inline int operator()(const int &c) { return c + _shift; }
4087*d52a580bSJunchao Zhang };
4088*d52a580bSJunchao Zhang 
4089*d52a580bSJunchao Zhang /* merges two SeqAIJHIPSPARSE matrices A, B by concatenating their rows. [A';B']' operation in MATLAB notation */
4090*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C)
4091*d52a580bSJunchao Zhang {
4092*d52a580bSJunchao Zhang   Mat_SeqAIJ                    *a = (Mat_SeqAIJ *)A->data, *b = (Mat_SeqAIJ *)B->data, *c;
4093*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSE           *Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr, *Bcusp = (Mat_SeqAIJHIPSPARSE *)B->spptr, *Ccusp;
4094*d52a580bSJunchao Zhang   Mat_SeqAIJHIPSPARSEMultStruct *Cmat;
4095*d52a580bSJunchao Zhang   CsrMatrix                     *Acsr, *Bcsr, *Ccsr;
4096*d52a580bSJunchao Zhang   PetscInt                       Annz, Bnnz;
4097*d52a580bSJunchao Zhang   PetscInt                       i, m, n, zero = 0;
4098*d52a580bSJunchao Zhang 
4099*d52a580bSJunchao Zhang   PetscFunctionBegin;
4100*d52a580bSJunchao Zhang   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
4101*d52a580bSJunchao Zhang   PetscValidHeaderSpecific(B, MAT_CLASSID, 2);
4102*d52a580bSJunchao Zhang   PetscAssertPointer(C, 4);
4103*d52a580bSJunchao Zhang   PetscCheckTypeName(A, MATSEQAIJHIPSPARSE);
4104*d52a580bSJunchao Zhang   PetscCheckTypeName(B, MATSEQAIJHIPSPARSE);
4105*d52a580bSJunchao Zhang   PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n);
4106*d52a580bSJunchao Zhang   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
4107*d52a580bSJunchao Zhang   PetscCheck(Acusp->format != MAT_HIPSPARSE_ELL && Acusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
4108*d52a580bSJunchao Zhang   PetscCheck(Bcusp->format != MAT_HIPSPARSE_ELL && Bcusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
4109*d52a580bSJunchao Zhang   if (reuse == MAT_INITIAL_MATRIX) {
4110*d52a580bSJunchao Zhang     m = A->rmap->n;
4111*d52a580bSJunchao Zhang     n = A->cmap->n + B->cmap->n;
4112*d52a580bSJunchao Zhang     PetscCall(MatCreate(PETSC_COMM_SELF, C));
4113*d52a580bSJunchao Zhang     PetscCall(MatSetSizes(*C, m, n, m, n));
4114*d52a580bSJunchao Zhang     PetscCall(MatSetType(*C, MATSEQAIJHIPSPARSE));
4115*d52a580bSJunchao Zhang     c                       = (Mat_SeqAIJ *)(*C)->data;
4116*d52a580bSJunchao Zhang     Ccusp                   = (Mat_SeqAIJHIPSPARSE *)(*C)->spptr;
4117*d52a580bSJunchao Zhang     Cmat                    = new Mat_SeqAIJHIPSPARSEMultStruct;
4118*d52a580bSJunchao Zhang     Ccsr                    = new CsrMatrix;
4119*d52a580bSJunchao Zhang     Cmat->cprowIndices      = NULL;
4120*d52a580bSJunchao Zhang     c->compressedrow.use    = PETSC_FALSE;
4121*d52a580bSJunchao Zhang     c->compressedrow.nrows  = 0;
4122*d52a580bSJunchao Zhang     c->compressedrow.i      = NULL;
4123*d52a580bSJunchao Zhang     c->compressedrow.rindex = NULL;
4124*d52a580bSJunchao Zhang     Ccusp->workVector       = NULL;
4125*d52a580bSJunchao Zhang     Ccusp->nrows            = m;
4126*d52a580bSJunchao Zhang     Ccusp->mat              = Cmat;
4127*d52a580bSJunchao Zhang     Ccusp->mat->mat         = Ccsr;
4128*d52a580bSJunchao Zhang     Ccsr->num_rows          = m;
4129*d52a580bSJunchao Zhang     Ccsr->num_cols          = n;
4130*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseCreateMatDescr(&Cmat->descr));
4131*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSetMatIndexBase(Cmat->descr, HIPSPARSE_INDEX_BASE_ZERO));
4132*d52a580bSJunchao Zhang     PetscCallHIPSPARSE(hipsparseSetMatType(Cmat->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
4133*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&Cmat->alpha_one, sizeof(PetscScalar)));
4134*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&Cmat->beta_zero, sizeof(PetscScalar)));
4135*d52a580bSJunchao Zhang     PetscCallHIP(hipMalloc((void **)&Cmat->beta_one, sizeof(PetscScalar)));
4136*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(Cmat->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
4137*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(Cmat->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice));
4138*d52a580bSJunchao Zhang     PetscCallHIP(hipMemcpy(Cmat->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
4139*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
4140*d52a580bSJunchao Zhang     PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B));
4141*d52a580bSJunchao Zhang     PetscCheck(Acusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
4142*d52a580bSJunchao Zhang     PetscCheck(Bcusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
4143*d52a580bSJunchao Zhang 
4144*d52a580bSJunchao Zhang     Acsr                 = (CsrMatrix *)Acusp->mat->mat;
4145*d52a580bSJunchao Zhang     Bcsr                 = (CsrMatrix *)Bcusp->mat->mat;
4146*d52a580bSJunchao Zhang     Annz                 = (PetscInt)Acsr->column_indices->size();
4147*d52a580bSJunchao Zhang     Bnnz                 = (PetscInt)Bcsr->column_indices->size();
4148*d52a580bSJunchao Zhang     c->nz                = Annz + Bnnz;
4149*d52a580bSJunchao Zhang     Ccsr->row_offsets    = new THRUSTINTARRAY32(m + 1);
4150*d52a580bSJunchao Zhang     Ccsr->column_indices = new THRUSTINTARRAY32(c->nz);
4151*d52a580bSJunchao Zhang     Ccsr->values         = new THRUSTARRAY(c->nz);
4152*d52a580bSJunchao Zhang     Ccsr->num_entries    = c->nz;
4153*d52a580bSJunchao Zhang     Ccusp->coords        = new THRUSTINTARRAY(c->nz);
4154*d52a580bSJunchao Zhang     if (c->nz) {
4155*d52a580bSJunchao Zhang       auto              Acoo = new THRUSTINTARRAY32(Annz);
4156*d52a580bSJunchao Zhang       auto              Bcoo = new THRUSTINTARRAY32(Bnnz);
4157*d52a580bSJunchao Zhang       auto              Ccoo = new THRUSTINTARRAY32(c->nz);
4158*d52a580bSJunchao Zhang       THRUSTINTARRAY32 *Aroff, *Broff;
4159*d52a580bSJunchao Zhang 
4160*d52a580bSJunchao Zhang       if (a->compressedrow.use) { /* need full row offset */
4161*d52a580bSJunchao Zhang         if (!Acusp->rowoffsets_gpu) {
4162*d52a580bSJunchao Zhang           Acusp->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1);
4163*d52a580bSJunchao Zhang           Acusp->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1);
4164*d52a580bSJunchao Zhang           PetscCall(PetscLogCpuToGpu((A->rmap->n + 1) * sizeof(PetscInt)));
4165*d52a580bSJunchao Zhang         }
4166*d52a580bSJunchao Zhang         Aroff = Acusp->rowoffsets_gpu;
4167*d52a580bSJunchao Zhang       } else Aroff = Acsr->row_offsets;
4168*d52a580bSJunchao Zhang       if (b->compressedrow.use) { /* need full row offset */
4169*d52a580bSJunchao Zhang         if (!Bcusp->rowoffsets_gpu) {
4170*d52a580bSJunchao Zhang           Bcusp->rowoffsets_gpu = new THRUSTINTARRAY32(B->rmap->n + 1);
4171*d52a580bSJunchao Zhang           Bcusp->rowoffsets_gpu->assign(b->i, b->i + B->rmap->n + 1);
4172*d52a580bSJunchao Zhang           PetscCall(PetscLogCpuToGpu((B->rmap->n + 1) * sizeof(PetscInt)));
4173*d52a580bSJunchao Zhang         }
4174*d52a580bSJunchao Zhang         Broff = Bcusp->rowoffsets_gpu;
4175*d52a580bSJunchao Zhang       } else Broff = Bcsr->row_offsets;
4176*d52a580bSJunchao Zhang       PetscCall(PetscLogGpuTimeBegin());
4177*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseXcsr2coo(Acusp->handle, Aroff->data().get(), Annz, m, Acoo->data().get(), HIPSPARSE_INDEX_BASE_ZERO));
4178*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseXcsr2coo(Bcusp->handle, Broff->data().get(), Bnnz, m, Bcoo->data().get(), HIPSPARSE_INDEX_BASE_ZERO));
4179*d52a580bSJunchao Zhang       /* Issues when using bool with large matrices on SUMMIT 10.2.89 */
4180*d52a580bSJunchao Zhang       auto Aperm = thrust::make_constant_iterator(1);
4181*d52a580bSJunchao Zhang       auto Bperm = thrust::make_constant_iterator(0);
4182*d52a580bSJunchao Zhang       auto Bcib  = thrust::make_transform_iterator(Bcsr->column_indices->begin(), Shift(A->cmap->n));
4183*d52a580bSJunchao Zhang       auto Bcie  = thrust::make_transform_iterator(Bcsr->column_indices->end(), Shift(A->cmap->n));
4184*d52a580bSJunchao Zhang       auto wPerm = new THRUSTINTARRAY32(Annz + Bnnz);
4185*d52a580bSJunchao Zhang       auto Azb   = thrust::make_zip_iterator(thrust::make_tuple(Acoo->begin(), Acsr->column_indices->begin(), Acsr->values->begin(), Aperm));
4186*d52a580bSJunchao Zhang       auto Aze   = thrust::make_zip_iterator(thrust::make_tuple(Acoo->end(), Acsr->column_indices->end(), Acsr->values->end(), Aperm));
4187*d52a580bSJunchao Zhang       auto Bzb   = thrust::make_zip_iterator(thrust::make_tuple(Bcoo->begin(), Bcib, Bcsr->values->begin(), Bperm));
4188*d52a580bSJunchao Zhang       auto Bze   = thrust::make_zip_iterator(thrust::make_tuple(Bcoo->end(), Bcie, Bcsr->values->end(), Bperm));
4189*d52a580bSJunchao Zhang       auto Czb   = thrust::make_zip_iterator(thrust::make_tuple(Ccoo->begin(), Ccsr->column_indices->begin(), Ccsr->values->begin(), wPerm->begin()));
4190*d52a580bSJunchao Zhang       auto p1    = Ccusp->coords->begin();
4191*d52a580bSJunchao Zhang       auto p2    = Ccusp->coords->begin();
4192*d52a580bSJunchao Zhang       thrust::advance(p2, Annz);
4193*d52a580bSJunchao Zhang       PetscCallThrust(thrust::merge(thrust::device, Azb, Aze, Bzb, Bze, Czb, IJCompare4()));
4194*d52a580bSJunchao Zhang       auto cci = thrust::make_counting_iterator(zero);
4195*d52a580bSJunchao Zhang       auto cce = thrust::make_counting_iterator(c->nz);
4196*d52a580bSJunchao Zhang #if 0 //Errors on SUMMIT cuda 11.1.0
4197*d52a580bSJunchao Zhang       PetscCallThrust(thrust::partition_copy(thrust::device, cci, cce, wPerm->begin(), p1, p2, thrust::identity<int>()));
4198*d52a580bSJunchao Zhang #else
4199*d52a580bSJunchao Zhang       auto pred = thrust::identity<int>();
4200*d52a580bSJunchao Zhang       PetscCallThrust(thrust::copy_if(thrust::device, cci, cce, wPerm->begin(), p1, pred));
4201*d52a580bSJunchao Zhang       PetscCallThrust(thrust::remove_copy_if(thrust::device, cci, cce, wPerm->begin(), p2, pred));
4202*d52a580bSJunchao Zhang #endif
4203*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseXcoo2csr(Ccusp->handle, Ccoo->data().get(), c->nz, m, Ccsr->row_offsets->data().get(), HIPSPARSE_INDEX_BASE_ZERO));
4204*d52a580bSJunchao Zhang       PetscCall(PetscLogGpuTimeEnd());
4205*d52a580bSJunchao Zhang       delete wPerm;
4206*d52a580bSJunchao Zhang       delete Acoo;
4207*d52a580bSJunchao Zhang       delete Bcoo;
4208*d52a580bSJunchao Zhang       delete Ccoo;
4209*d52a580bSJunchao Zhang       PetscCallHIPSPARSE(hipsparseCreateCsr(&Cmat->matDescr, Ccsr->num_rows, Ccsr->num_cols, Ccsr->num_entries, Ccsr->row_offsets->data().get(), Ccsr->column_indices->data().get(), Ccsr->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
4210*d52a580bSJunchao Zhang 
4211*d52a580bSJunchao Zhang       if (A->form_explicit_transpose && B->form_explicit_transpose) { /* if A and B have the transpose, generate C transpose too */
4212*d52a580bSJunchao Zhang         PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A));
4213*d52a580bSJunchao Zhang         PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(B));
4214*d52a580bSJunchao Zhang         PetscBool                      AT = Acusp->matTranspose ? PETSC_TRUE : PETSC_FALSE, BT = Bcusp->matTranspose ? PETSC_TRUE : PETSC_FALSE;
4215*d52a580bSJunchao Zhang         Mat_SeqAIJHIPSPARSEMultStruct *CmatT = new Mat_SeqAIJHIPSPARSEMultStruct;
4216*d52a580bSJunchao Zhang         CsrMatrix                     *CcsrT = new CsrMatrix;
4217*d52a580bSJunchao Zhang         CsrMatrix                     *AcsrT = AT ? (CsrMatrix *)Acusp->matTranspose->mat : NULL;
4218*d52a580bSJunchao Zhang         CsrMatrix                     *BcsrT = BT ? (CsrMatrix *)Bcusp->matTranspose->mat : NULL;
4219*d52a580bSJunchao Zhang 
4220*d52a580bSJunchao Zhang         (*C)->form_explicit_transpose = PETSC_TRUE;
4221*d52a580bSJunchao Zhang         (*C)->transupdated            = PETSC_TRUE;
4222*d52a580bSJunchao Zhang         Ccusp->rowoffsets_gpu         = NULL;
4223*d52a580bSJunchao Zhang         CmatT->cprowIndices           = NULL;
4224*d52a580bSJunchao Zhang         CmatT->mat                    = CcsrT;
4225*d52a580bSJunchao Zhang         CcsrT->num_rows               = n;
4226*d52a580bSJunchao Zhang         CcsrT->num_cols               = m;
4227*d52a580bSJunchao Zhang         CcsrT->num_entries            = c->nz;
4228*d52a580bSJunchao Zhang         CcsrT->row_offsets            = new THRUSTINTARRAY32(n + 1);
4229*d52a580bSJunchao Zhang         CcsrT->column_indices         = new THRUSTINTARRAY32(c->nz);
4230*d52a580bSJunchao Zhang         CcsrT->values                 = new THRUSTARRAY(c->nz);
4231*d52a580bSJunchao Zhang 
4232*d52a580bSJunchao Zhang         PetscCall(PetscLogGpuTimeBegin());
4233*d52a580bSJunchao Zhang         auto rT = CcsrT->row_offsets->begin();
4234*d52a580bSJunchao Zhang         if (AT) {
4235*d52a580bSJunchao Zhang           rT = thrust::copy(AcsrT->row_offsets->begin(), AcsrT->row_offsets->end(), rT);
4236*d52a580bSJunchao Zhang           thrust::advance(rT, -1);
4237*d52a580bSJunchao Zhang         }
4238*d52a580bSJunchao Zhang         if (BT) {
4239*d52a580bSJunchao Zhang           auto titb = thrust::make_transform_iterator(BcsrT->row_offsets->begin(), Shift(a->nz));
4240*d52a580bSJunchao Zhang           auto tite = thrust::make_transform_iterator(BcsrT->row_offsets->end(), Shift(a->nz));
4241*d52a580bSJunchao Zhang           thrust::copy(titb, tite, rT);
4242*d52a580bSJunchao Zhang         }
4243*d52a580bSJunchao Zhang         auto cT = CcsrT->column_indices->begin();
4244*d52a580bSJunchao Zhang         if (AT) cT = thrust::copy(AcsrT->column_indices->begin(), AcsrT->column_indices->end(), cT);
4245*d52a580bSJunchao Zhang         if (BT) thrust::copy(BcsrT->column_indices->begin(), BcsrT->column_indices->end(), cT);
4246*d52a580bSJunchao Zhang         auto vT = CcsrT->values->begin();
4247*d52a580bSJunchao Zhang         if (AT) vT = thrust::copy(AcsrT->values->begin(), AcsrT->values->end(), vT);
4248*d52a580bSJunchao Zhang         if (BT) thrust::copy(BcsrT->values->begin(), BcsrT->values->end(), vT);
4249*d52a580bSJunchao Zhang         PetscCall(PetscLogGpuTimeEnd());
4250*d52a580bSJunchao Zhang 
4251*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateMatDescr(&CmatT->descr));
4252*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatIndexBase(CmatT->descr, HIPSPARSE_INDEX_BASE_ZERO));
4253*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseSetMatType(CmatT->descr, HIPSPARSE_MATRIX_TYPE_GENERAL));
4254*d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc((void **)&CmatT->alpha_one, sizeof(PetscScalar)));
4255*d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc((void **)&CmatT->beta_zero, sizeof(PetscScalar)));
4256*d52a580bSJunchao Zhang         PetscCallHIP(hipMalloc((void **)&CmatT->beta_one, sizeof(PetscScalar)));
4257*d52a580bSJunchao Zhang         PetscCallHIP(hipMemcpy(CmatT->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
4258*d52a580bSJunchao Zhang         PetscCallHIP(hipMemcpy(CmatT->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice));
4259*d52a580bSJunchao Zhang         PetscCallHIP(hipMemcpy(CmatT->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice));
4260*d52a580bSJunchao Zhang 
4261*d52a580bSJunchao Zhang         PetscCallHIPSPARSE(hipsparseCreateCsr(&CmatT->matDescr, CcsrT->num_rows, CcsrT->num_cols, CcsrT->num_entries, CcsrT->row_offsets->data().get(), CcsrT->column_indices->data().get(), CcsrT->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype));
4262*d52a580bSJunchao Zhang         Ccusp->matTranspose = CmatT;
4263*d52a580bSJunchao Zhang       }
4264*d52a580bSJunchao Zhang     }
4265*d52a580bSJunchao Zhang 
4266*d52a580bSJunchao Zhang     c->free_a = PETSC_TRUE;
4267*d52a580bSJunchao Zhang     PetscCall(PetscShmgetAllocateArray(c->nz, sizeof(PetscInt), (void **)&c->j));
4268*d52a580bSJunchao Zhang     PetscCall(PetscShmgetAllocateArray(m + 1, sizeof(PetscInt), (void **)&c->i));
4269*d52a580bSJunchao Zhang     c->free_ij = PETSC_TRUE;
4270*d52a580bSJunchao Zhang     if (PetscDefined(USE_64BIT_INDICES)) { /* 32 to 64-bit conversion on the GPU and then copy to host (lazy) */
4271*d52a580bSJunchao Zhang       THRUSTINTARRAY ii(Ccsr->row_offsets->size());
4272*d52a580bSJunchao Zhang       THRUSTINTARRAY jj(Ccsr->column_indices->size());
4273*d52a580bSJunchao Zhang       ii = *Ccsr->row_offsets;
4274*d52a580bSJunchao Zhang       jj = *Ccsr->column_indices;
4275*d52a580bSJunchao Zhang       PetscCallHIP(hipMemcpy(c->i, ii.data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
4276*d52a580bSJunchao Zhang       PetscCallHIP(hipMemcpy(c->j, jj.data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
4277*d52a580bSJunchao Zhang     } else {
4278*d52a580bSJunchao Zhang       PetscCallHIP(hipMemcpy(c->i, Ccsr->row_offsets->data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
4279*d52a580bSJunchao Zhang       PetscCallHIP(hipMemcpy(c->j, Ccsr->column_indices->data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost));
4280*d52a580bSJunchao Zhang     }
4281*d52a580bSJunchao Zhang     PetscCall(PetscLogGpuToCpu((Ccsr->column_indices->size() + Ccsr->row_offsets->size()) * sizeof(PetscInt)));
4282*d52a580bSJunchao Zhang     PetscCall(PetscMalloc1(m, &c->ilen));
4283*d52a580bSJunchao Zhang     PetscCall(PetscMalloc1(m, &c->imax));
4284*d52a580bSJunchao Zhang     c->maxnz         = c->nz;
4285*d52a580bSJunchao Zhang     c->nonzerorowcnt = 0;
4286*d52a580bSJunchao Zhang     c->rmax          = 0;
4287*d52a580bSJunchao Zhang     for (i = 0; i < m; i++) {
4288*d52a580bSJunchao Zhang       const PetscInt nn = c->i[i + 1] - c->i[i];
4289*d52a580bSJunchao Zhang       c->ilen[i] = c->imax[i] = nn;
4290*d52a580bSJunchao Zhang       c->nonzerorowcnt += (PetscInt)!!nn;
4291*d52a580bSJunchao Zhang       c->rmax = PetscMax(c->rmax, nn);
4292*d52a580bSJunchao Zhang     }
4293*d52a580bSJunchao Zhang     PetscCall(PetscMalloc1(c->nz, &c->a));
4294*d52a580bSJunchao Zhang     (*C)->nonzerostate++;
4295*d52a580bSJunchao Zhang     PetscCall(PetscLayoutSetUp((*C)->rmap));
4296*d52a580bSJunchao Zhang     PetscCall(PetscLayoutSetUp((*C)->cmap));
4297*d52a580bSJunchao Zhang     Ccusp->nonzerostate = (*C)->nonzerostate;
4298*d52a580bSJunchao Zhang     (*C)->preallocated  = PETSC_TRUE;
4299*d52a580bSJunchao Zhang   } else {
4300*d52a580bSJunchao Zhang     PetscCheck((*C)->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, (*C)->rmap->n, B->rmap->n);
4301*d52a580bSJunchao Zhang     c = (Mat_SeqAIJ *)(*C)->data;
4302*d52a580bSJunchao Zhang     if (c->nz) {
4303*d52a580bSJunchao Zhang       Ccusp = (Mat_SeqAIJHIPSPARSE *)(*C)->spptr;
4304*d52a580bSJunchao Zhang       PetscCheck(Ccusp->coords, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing coords");
4305*d52a580bSJunchao Zhang       PetscCheck(Ccusp->format != MAT_HIPSPARSE_ELL && Ccusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
4306*d52a580bSJunchao Zhang       PetscCheck(Ccusp->nonzerostate == (*C)->nonzerostate, PETSC_COMM_SELF, PETSC_ERR_COR, "Wrong nonzerostate");
4307*d52a580bSJunchao Zhang       PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A));
4308*d52a580bSJunchao Zhang       PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B));
4309*d52a580bSJunchao Zhang       PetscCheck(Acusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
4310*d52a580bSJunchao Zhang       PetscCheck(Bcusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct");
4311*d52a580bSJunchao Zhang       Acsr = (CsrMatrix *)Acusp->mat->mat;
4312*d52a580bSJunchao Zhang       Bcsr = (CsrMatrix *)Bcusp->mat->mat;
4313*d52a580bSJunchao Zhang       Ccsr = (CsrMatrix *)Ccusp->mat->mat;
4314*d52a580bSJunchao Zhang       PetscCheck(Acsr->num_entries == (PetscInt)Acsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "A nnz %" PetscInt_FMT " != %" PetscInt_FMT, Acsr->num_entries, (PetscInt)Acsr->values->size());
4315*d52a580bSJunchao Zhang       PetscCheck(Bcsr->num_entries == (PetscInt)Bcsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "B nnz %" PetscInt_FMT " != %" PetscInt_FMT, Bcsr->num_entries, (PetscInt)Bcsr->values->size());
4316*d52a580bSJunchao Zhang       PetscCheck(Ccsr->num_entries == (PetscInt)Ccsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "C nnz %" PetscInt_FMT " != %" PetscInt_FMT, Ccsr->num_entries, (PetscInt)Ccsr->values->size());
4317*d52a580bSJunchao Zhang       PetscCheck(Ccsr->num_entries == Acsr->num_entries + Bcsr->num_entries, PETSC_COMM_SELF, PETSC_ERR_COR, "C nnz %" PetscInt_FMT " != %" PetscInt_FMT " + %" PetscInt_FMT, Ccsr->num_entries, Acsr->num_entries, Bcsr->num_entries);
4318*d52a580bSJunchao Zhang       PetscCheck(Ccusp->coords->size() == Ccsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "permSize %" PetscInt_FMT " != %" PetscInt_FMT, (PetscInt)Ccusp->coords->size(), (PetscInt)Ccsr->values->size());
4319*d52a580bSJunchao Zhang       auto pmid = Ccusp->coords->begin();
4320*d52a580bSJunchao Zhang       thrust::advance(pmid, Acsr->num_entries);
4321*d52a580bSJunchao Zhang       PetscCall(PetscLogGpuTimeBegin());
4322*d52a580bSJunchao Zhang       auto zibait = thrust::make_zip_iterator(thrust::make_tuple(Acsr->values->begin(), thrust::make_permutation_iterator(Ccsr->values->begin(), Ccusp->coords->begin())));
4323*d52a580bSJunchao Zhang       auto zieait = thrust::make_zip_iterator(thrust::make_tuple(Acsr->values->end(), thrust::make_permutation_iterator(Ccsr->values->begin(), pmid)));
4324*d52a580bSJunchao Zhang       thrust::for_each(zibait, zieait, VecHIPEquals());
4325*d52a580bSJunchao Zhang       auto zibbit = thrust::make_zip_iterator(thrust::make_tuple(Bcsr->values->begin(), thrust::make_permutation_iterator(Ccsr->values->begin(), pmid)));
4326*d52a580bSJunchao Zhang       auto ziebit = thrust::make_zip_iterator(thrust::make_tuple(Bcsr->values->end(), thrust::make_permutation_iterator(Ccsr->values->begin(), Ccusp->coords->end())));
4327*d52a580bSJunchao Zhang       thrust::for_each(zibbit, ziebit, VecHIPEquals());
4328*d52a580bSJunchao Zhang       PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(*C, PETSC_FALSE));
4329*d52a580bSJunchao Zhang       if (A->form_explicit_transpose && B->form_explicit_transpose && (*C)->form_explicit_transpose) {
4330*d52a580bSJunchao Zhang         PetscCheck(Ccusp->matTranspose, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing transpose Mat_SeqAIJHIPSPARSEMultStruct");
4331*d52a580bSJunchao Zhang         PetscBool  AT = Acusp->matTranspose ? PETSC_TRUE : PETSC_FALSE, BT = Bcusp->matTranspose ? PETSC_TRUE : PETSC_FALSE;
4332*d52a580bSJunchao Zhang         CsrMatrix *AcsrT = AT ? (CsrMatrix *)Acusp->matTranspose->mat : NULL;
4333*d52a580bSJunchao Zhang         CsrMatrix *BcsrT = BT ? (CsrMatrix *)Bcusp->matTranspose->mat : NULL;
4334*d52a580bSJunchao Zhang         CsrMatrix *CcsrT = (CsrMatrix *)Ccusp->matTranspose->mat;
4335*d52a580bSJunchao Zhang         auto       vT    = CcsrT->values->begin();
4336*d52a580bSJunchao Zhang         if (AT) vT = thrust::copy(AcsrT->values->begin(), AcsrT->values->end(), vT);
4337*d52a580bSJunchao Zhang         if (BT) thrust::copy(BcsrT->values->begin(), BcsrT->values->end(), vT);
4338*d52a580bSJunchao Zhang         (*C)->transupdated = PETSC_TRUE;
4339*d52a580bSJunchao Zhang       }
4340*d52a580bSJunchao Zhang       PetscCall(PetscLogGpuTimeEnd());
4341*d52a580bSJunchao Zhang     }
4342*d52a580bSJunchao Zhang   }
4343*d52a580bSJunchao Zhang   PetscCall(PetscObjectStateIncrease((PetscObject)*C));
4344*d52a580bSJunchao Zhang   (*C)->assembled     = PETSC_TRUE;
4345*d52a580bSJunchao Zhang   (*C)->was_assembled = PETSC_FALSE;
4346*d52a580bSJunchao Zhang   (*C)->offloadmask   = PETSC_OFFLOAD_GPU;
4347*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
4348*d52a580bSJunchao Zhang }
4349*d52a580bSJunchao Zhang 
4350*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJCopySubArray_SeqAIJHIPSPARSE(Mat A, PetscInt n, const PetscInt idx[], PetscScalar v[])
4351*d52a580bSJunchao Zhang {
4352*d52a580bSJunchao Zhang   bool               dmem;
4353*d52a580bSJunchao Zhang   const PetscScalar *av;
4354*d52a580bSJunchao Zhang 
4355*d52a580bSJunchao Zhang   PetscFunctionBegin;
4356*d52a580bSJunchao Zhang   dmem = isHipMem(v);
4357*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSEGetArrayRead(A, &av));
4358*d52a580bSJunchao Zhang   if (n && idx) {
4359*d52a580bSJunchao Zhang     THRUSTINTARRAY widx(n);
4360*d52a580bSJunchao Zhang     widx.assign(idx, idx + n);
4361*d52a580bSJunchao Zhang     PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt)));
4362*d52a580bSJunchao Zhang 
4363*d52a580bSJunchao Zhang     THRUSTARRAY                    *w = NULL;
4364*d52a580bSJunchao Zhang     thrust::device_ptr<PetscScalar> dv;
4365*d52a580bSJunchao Zhang     if (dmem) dv = thrust::device_pointer_cast(v);
4366*d52a580bSJunchao Zhang     else {
4367*d52a580bSJunchao Zhang       w  = new THRUSTARRAY(n);
4368*d52a580bSJunchao Zhang       dv = w->data();
4369*d52a580bSJunchao Zhang     }
4370*d52a580bSJunchao Zhang     thrust::device_ptr<const PetscScalar> dav = thrust::device_pointer_cast(av);
4371*d52a580bSJunchao Zhang 
4372*d52a580bSJunchao Zhang     auto zibit = thrust::make_zip_iterator(thrust::make_tuple(thrust::make_permutation_iterator(dav, widx.begin()), dv));
4373*d52a580bSJunchao Zhang     auto zieit = thrust::make_zip_iterator(thrust::make_tuple(thrust::make_permutation_iterator(dav, widx.end()), dv + n));
4374*d52a580bSJunchao Zhang     thrust::for_each(zibit, zieit, VecHIPEquals());
4375*d52a580bSJunchao Zhang     if (w) PetscCallHIP(hipMemcpy(v, w->data().get(), n * sizeof(PetscScalar), hipMemcpyDeviceToHost));
4376*d52a580bSJunchao Zhang     delete w;
4377*d52a580bSJunchao Zhang   } else PetscCallHIP(hipMemcpy(v, av, n * sizeof(PetscScalar), dmem ? hipMemcpyDeviceToDevice : hipMemcpyDeviceToHost));
4378*d52a580bSJunchao Zhang 
4379*d52a580bSJunchao Zhang   if (!dmem) PetscCall(PetscLogCpuToGpu(n * sizeof(PetscScalar)));
4380*d52a580bSJunchao Zhang   PetscCall(MatSeqAIJHIPSPARSERestoreArrayRead(A, &av));
4381*d52a580bSJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
4382*d52a580bSJunchao Zhang }
4383