1*d52a580bSJunchao Zhang /* 2*d52a580bSJunchao Zhang Defines the basic matrix operations for the AIJ (compressed row) 3*d52a580bSJunchao Zhang matrix storage format using the HIPSPARSE library, 4*d52a580bSJunchao Zhang Portions of this code are under: 5*d52a580bSJunchao Zhang Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. 6*d52a580bSJunchao Zhang */ 7*d52a580bSJunchao Zhang #include <petscconf.h> 8*d52a580bSJunchao Zhang #include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/ 9*d52a580bSJunchao Zhang #include <../src/mat/impls/sbaij/seq/sbaij.h> 10*d52a580bSJunchao Zhang #include <../src/mat/impls/dense/seq/dense.h> // MatMatMultNumeric_SeqDenseHIP_SeqDenseHIP_Internal() 11*d52a580bSJunchao Zhang #include <../src/vec/vec/impls/dvecimpl.h> 12*d52a580bSJunchao Zhang #include <petsc/private/vecimpl.h> 13*d52a580bSJunchao Zhang #undef VecType 14*d52a580bSJunchao Zhang #include <../src/mat/impls/aij/seq/seqhipsparse/hipsparsematimpl.h> 15*d52a580bSJunchao Zhang #include <thrust/adjacent_difference.h> 16*d52a580bSJunchao Zhang #include <thrust/iterator/transform_iterator.h> 17*d52a580bSJunchao Zhang #if PETSC_CPP_VERSION >= 14 18*d52a580bSJunchao Zhang #define PETSC_HAVE_THRUST_ASYNC 1 19*d52a580bSJunchao Zhang #include <thrust/async/for_each.h> 20*d52a580bSJunchao Zhang #endif 21*d52a580bSJunchao Zhang #include <thrust/iterator/constant_iterator.h> 22*d52a580bSJunchao Zhang #include <thrust/iterator/discard_iterator.h> 23*d52a580bSJunchao Zhang #include <thrust/binary_search.h> 24*d52a580bSJunchao Zhang #include <thrust/remove.h> 25*d52a580bSJunchao Zhang #include <thrust/sort.h> 26*d52a580bSJunchao Zhang #include <thrust/unique.h> 27*d52a580bSJunchao Zhang 28*d52a580bSJunchao Zhang const char *const MatHIPSPARSEStorageFormats[] = {"CSR", "ELL", "HYB", "MatHIPSPARSEStorageFormat", "MAT_HIPSPARSE_", 0}; 29*d52a580bSJunchao Zhang const char *const MatHIPSPARSESpMVAlgorithms[] = {"MV_ALG_DEFAULT", "COOMV_ALG", "CSRMV_ALG1", "CSRMV_ALG2", "SPMV_ALG_DEFAULT", "SPMV_COO_ALG1", "SPMV_COO_ALG2", "SPMV_CSR_ALG1", "SPMV_CSR_ALG2", "hipsparseSpMVAlg_t", "HIPSPARSE_", 0}; 30*d52a580bSJunchao Zhang const char *const MatHIPSPARSESpMMAlgorithms[] = {"ALG_DEFAULT", "COO_ALG1", "COO_ALG2", "COO_ALG3", "CSR_ALG1", "COO_ALG4", "CSR_ALG2", "hipsparseSpMMAlg_t", "HIPSPARSE_SPMM_", 0}; 31*d52a580bSJunchao Zhang //const char *const MatHIPSPARSECsr2CscAlgorithms[] = {"INVALID"/*HIPSPARSE does not have enum 0! We created one*/, "ALG1", "ALG2", "hipsparseCsr2CscAlg_t", "HIPSPARSE_CSR2CSC_", 0}; 32*d52a580bSJunchao Zhang 33*d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, const MatFactorInfo *); 34*d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, const MatFactorInfo *); 35*d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorNumeric_SeqAIJHIPSPARSE(Mat, Mat, const MatFactorInfo *); 36*d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, IS, const MatFactorInfo *); 37*d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorSymbolic_SeqAIJHIPSPARSE(Mat, Mat, IS, IS, const MatFactorInfo *); 38*d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorNumeric_SeqAIJHIPSPARSE(Mat, Mat, const MatFactorInfo *); 39*d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE(Mat, Vec, Vec); 40*d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_NaturalOrdering(Mat, Vec, Vec); 41*d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE(Mat, Vec, Vec); 42*d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering(Mat, Vec, Vec); 43*d52a580bSJunchao Zhang static PetscErrorCode MatSetFromOptions_SeqAIJHIPSPARSE(Mat, PetscOptionItems PetscOptionsObject); 44*d52a580bSJunchao Zhang static PetscErrorCode MatAXPY_SeqAIJHIPSPARSE(Mat, PetscScalar, Mat, MatStructure); 45*d52a580bSJunchao Zhang static PetscErrorCode MatScale_SeqAIJHIPSPARSE(Mat, PetscScalar); 46*d52a580bSJunchao Zhang static PetscErrorCode MatMult_SeqAIJHIPSPARSE(Mat, Vec, Vec); 47*d52a580bSJunchao Zhang static PetscErrorCode MatMultAdd_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec); 48*d52a580bSJunchao Zhang static PetscErrorCode MatMultTranspose_SeqAIJHIPSPARSE(Mat, Vec, Vec); 49*d52a580bSJunchao Zhang static PetscErrorCode MatMultTransposeAdd_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec); 50*d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTranspose_SeqAIJHIPSPARSE(Mat, Vec, Vec); 51*d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec); 52*d52a580bSJunchao Zhang static PetscErrorCode MatMultAddKernel_SeqAIJHIPSPARSE(Mat, Vec, Vec, Vec, PetscBool, PetscBool); 53*d52a580bSJunchao Zhang static PetscErrorCode CsrMatrix_Destroy(CsrMatrix **); 54*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSETriFactorStruct **); 55*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSEMultStruct **, MatHIPSPARSEStorageFormat); 56*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSETriFactors_Destroy(Mat_SeqAIJHIPSPARSETriFactors **); 57*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSE_Destroy(Mat); 58*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSECopyFromGPU(Mat); 59*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEILUAnalysisAndCopyToGPU(Mat); 60*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEInvalidateTranspose(Mat, PetscBool); 61*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJCopySubArray_SeqAIJHIPSPARSE(Mat, PetscInt, const PetscInt[], PetscScalar[]); 62*d52a580bSJunchao Zhang static PetscErrorCode MatBindToCPU_SeqAIJHIPSPARSE(Mat, PetscBool); 63*d52a580bSJunchao Zhang static PetscErrorCode MatSetPreallocationCOO_SeqAIJHIPSPARSE(Mat, PetscCount, PetscInt[], PetscInt[]); 64*d52a580bSJunchao Zhang static PetscErrorCode MatSetValuesCOO_SeqAIJHIPSPARSE(Mat, const PetscScalar[], InsertMode); 65*d52a580bSJunchao Zhang 66*d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatProductSetFromOptions_SeqAIJ_SeqDense(Mat); 67*d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJHIPSPARSE(Mat, MatType, MatReuse, Mat *); 68*d52a580bSJunchao Zhang 69*d52a580bSJunchao Zhang /* 70*d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetStream(Mat A, const hipStream_t stream) 71*d52a580bSJunchao Zhang { 72*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE*)A->spptr; 73*d52a580bSJunchao Zhang 74*d52a580bSJunchao Zhang PetscFunctionBegin; 75*d52a580bSJunchao Zhang PetscCheck(hipsparsestruct, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing spptr"); 76*d52a580bSJunchao Zhang hipsparsestruct->stream = stream; 77*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetStream(hipsparsestruct->handle, hipsparsestruct->stream)); 78*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 79*d52a580bSJunchao Zhang } 80*d52a580bSJunchao Zhang 81*d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetHandle(Mat A, const hipsparseHandle_t handle) 82*d52a580bSJunchao Zhang { 83*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE*)A->spptr; 84*d52a580bSJunchao Zhang 85*d52a580bSJunchao Zhang PetscFunctionBegin; 86*d52a580bSJunchao Zhang PetscCheck(hipsparsestruct, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing spptr"); 87*d52a580bSJunchao Zhang if (hipsparsestruct->handle != handle) { 88*d52a580bSJunchao Zhang if (hipsparsestruct->handle) PetscCallHIPSPARSE(hipsparseDestroy(hipsparsestruct->handle)); 89*d52a580bSJunchao Zhang hipsparsestruct->handle = handle; 90*d52a580bSJunchao Zhang } 91*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(hipsparsestruct->handle, HIPSPARSE_POINTER_MODE_DEVICE)); 92*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 93*d52a580bSJunchao Zhang } 94*d52a580bSJunchao Zhang 95*d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSEClearHandle(Mat A) 96*d52a580bSJunchao Zhang { 97*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE*)A->spptr; 98*d52a580bSJunchao Zhang PetscBool flg; 99*d52a580bSJunchao Zhang 100*d52a580bSJunchao Zhang PetscFunctionBegin; 101*d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg)); 102*d52a580bSJunchao Zhang if (!flg || !hipsparsestruct) PetscFunctionReturn(PETSC_SUCCESS); 103*d52a580bSJunchao Zhang if (hipsparsestruct->handle) hipsparsestruct->handle = 0; 104*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 105*d52a580bSJunchao Zhang } 106*d52a580bSJunchao Zhang */ 107*d52a580bSJunchao Zhang 108*d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatHIPSPARSESetFormat_SeqAIJHIPSPARSE(Mat A, MatHIPSPARSEFormatOperation op, MatHIPSPARSEStorageFormat format) 109*d52a580bSJunchao Zhang { 110*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr; 111*d52a580bSJunchao Zhang 112*d52a580bSJunchao Zhang PetscFunctionBegin; 113*d52a580bSJunchao Zhang switch (op) { 114*d52a580bSJunchao Zhang case MAT_HIPSPARSE_MULT: 115*d52a580bSJunchao Zhang hipsparsestruct->format = format; 116*d52a580bSJunchao Zhang break; 117*d52a580bSJunchao Zhang case MAT_HIPSPARSE_ALL: 118*d52a580bSJunchao Zhang hipsparsestruct->format = format; 119*d52a580bSJunchao Zhang break; 120*d52a580bSJunchao Zhang default: 121*d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "unsupported operation %d for MatHIPSPARSEFormatOperation. MAT_HIPSPARSE_MULT and MAT_HIPSPARSE_ALL are currently supported.", op); 122*d52a580bSJunchao Zhang } 123*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 124*d52a580bSJunchao Zhang } 125*d52a580bSJunchao Zhang 126*d52a580bSJunchao Zhang /*@ 127*d52a580bSJunchao Zhang MatHIPSPARSESetFormat - Sets the storage format of `MATSEQHIPSPARSE` matrices for a particular 128*d52a580bSJunchao Zhang operation. Only the `MatMult()` operation can use different GPU storage formats 129*d52a580bSJunchao Zhang 130*d52a580bSJunchao Zhang Not Collective 131*d52a580bSJunchao Zhang 132*d52a580bSJunchao Zhang Input Parameters: 133*d52a580bSJunchao Zhang + A - Matrix of type `MATSEQAIJHIPSPARSE` 134*d52a580bSJunchao Zhang . op - `MatHIPSPARSEFormatOperation`. `MATSEQAIJHIPSPARSE` matrices support `MAT_HIPSPARSE_MULT` and `MAT_HIPSPARSE_ALL`. 135*d52a580bSJunchao Zhang `MATMPIAIJHIPSPARSE` matrices support `MAT_HIPSPARSE_MULT_DIAG`, `MAT_HIPSPARSE_MULT_OFFDIAG`, and `MAT_HIPSPARSE_ALL`. 136*d52a580bSJunchao Zhang - format - `MatHIPSPARSEStorageFormat` (one of `MAT_HIPSPARSE_CSR`, `MAT_HIPSPARSE_ELL`, `MAT_HIPSPARSE_HYB`.) 137*d52a580bSJunchao Zhang 138*d52a580bSJunchao Zhang Level: intermediate 139*d52a580bSJunchao Zhang 140*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MATSEQAIJHIPSPARSE`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation` 141*d52a580bSJunchao Zhang @*/ 142*d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetFormat(Mat A, MatHIPSPARSEFormatOperation op, MatHIPSPARSEStorageFormat format) 143*d52a580bSJunchao Zhang { 144*d52a580bSJunchao Zhang PetscFunctionBegin; 145*d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 146*d52a580bSJunchao Zhang PetscTryMethod(A, "MatHIPSPARSESetFormat_C", (Mat, MatHIPSPARSEFormatOperation, MatHIPSPARSEStorageFormat), (A, op, format)); 147*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 148*d52a580bSJunchao Zhang } 149*d52a580bSJunchao Zhang 150*d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatHIPSPARSESetUseCPUSolve_SeqAIJHIPSPARSE(Mat A, PetscBool use_cpu) 151*d52a580bSJunchao Zhang { 152*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr; 153*d52a580bSJunchao Zhang 154*d52a580bSJunchao Zhang PetscFunctionBegin; 155*d52a580bSJunchao Zhang hipsparsestruct->use_cpu_solve = use_cpu; 156*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 157*d52a580bSJunchao Zhang } 158*d52a580bSJunchao Zhang 159*d52a580bSJunchao Zhang /*@ 160*d52a580bSJunchao Zhang MatHIPSPARSESetUseCPUSolve - Sets use CPU `MatSolve()`. 161*d52a580bSJunchao Zhang 162*d52a580bSJunchao Zhang Input Parameters: 163*d52a580bSJunchao Zhang + A - Matrix of type `MATSEQAIJHIPSPARSE` 164*d52a580bSJunchao Zhang - use_cpu - set flag for using the built-in CPU `MatSolve()` 165*d52a580bSJunchao Zhang 166*d52a580bSJunchao Zhang Level: intermediate 167*d52a580bSJunchao Zhang 168*d52a580bSJunchao Zhang Notes: 169*d52a580bSJunchao Zhang The hipSparse LU solver currently computes the factors with the built-in CPU method 170*d52a580bSJunchao Zhang and moves the factors to the GPU for the solve. We have observed better performance keeping the data on the CPU and computing the solve there. 171*d52a580bSJunchao Zhang This method to specifies if the solve is done on the CPU or GPU (GPU is the default). 172*d52a580bSJunchao Zhang 173*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSolve()`, `MATSEQAIJHIPSPARSE`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation` 174*d52a580bSJunchao Zhang @*/ 175*d52a580bSJunchao Zhang PetscErrorCode MatHIPSPARSESetUseCPUSolve(Mat A, PetscBool use_cpu) 176*d52a580bSJunchao Zhang { 177*d52a580bSJunchao Zhang PetscFunctionBegin; 178*d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 179*d52a580bSJunchao Zhang PetscTryMethod(A, "MatHIPSPARSESetUseCPUSolve_C", (Mat, PetscBool), (A, use_cpu)); 180*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 181*d52a580bSJunchao Zhang } 182*d52a580bSJunchao Zhang 183*d52a580bSJunchao Zhang static PetscErrorCode MatSetOption_SeqAIJHIPSPARSE(Mat A, MatOption op, PetscBool flg) 184*d52a580bSJunchao Zhang { 185*d52a580bSJunchao Zhang PetscFunctionBegin; 186*d52a580bSJunchao Zhang switch (op) { 187*d52a580bSJunchao Zhang case MAT_FORM_EXPLICIT_TRANSPOSE: 188*d52a580bSJunchao Zhang /* need to destroy the transpose matrix if present to prevent from logic errors if flg is set to true later */ 189*d52a580bSJunchao Zhang if (A->form_explicit_transpose && !flg) PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_TRUE)); 190*d52a580bSJunchao Zhang A->form_explicit_transpose = flg; 191*d52a580bSJunchao Zhang break; 192*d52a580bSJunchao Zhang default: 193*d52a580bSJunchao Zhang PetscCall(MatSetOption_SeqAIJ(A, op, flg)); 194*d52a580bSJunchao Zhang break; 195*d52a580bSJunchao Zhang } 196*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 197*d52a580bSJunchao Zhang } 198*d52a580bSJunchao Zhang 199*d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorNumeric_SeqAIJHIPSPARSE(Mat B, Mat A, const MatFactorInfo *info) 200*d52a580bSJunchao Zhang { 201*d52a580bSJunchao Zhang PetscBool row_identity, col_identity; 202*d52a580bSJunchao Zhang Mat_SeqAIJ *b = (Mat_SeqAIJ *)B->data; 203*d52a580bSJunchao Zhang IS isrow = b->row, iscol = b->col; 204*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)B->spptr; 205*d52a580bSJunchao Zhang 206*d52a580bSJunchao Zhang PetscFunctionBegin; 207*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A)); 208*d52a580bSJunchao Zhang PetscCall(MatLUFactorNumeric_SeqAIJ(B, A, info)); 209*d52a580bSJunchao Zhang B->offloadmask = PETSC_OFFLOAD_CPU; 210*d52a580bSJunchao Zhang /* determine which version of MatSolve needs to be used. */ 211*d52a580bSJunchao Zhang PetscCall(ISIdentity(isrow, &row_identity)); 212*d52a580bSJunchao Zhang PetscCall(ISIdentity(iscol, &col_identity)); 213*d52a580bSJunchao Zhang if (!hipsparsestruct->use_cpu_solve) { 214*d52a580bSJunchao Zhang if (row_identity && col_identity) { 215*d52a580bSJunchao Zhang B->ops->solve = MatSolve_SeqAIJHIPSPARSE_NaturalOrdering; 216*d52a580bSJunchao Zhang B->ops->solvetranspose = MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering; 217*d52a580bSJunchao Zhang } else { 218*d52a580bSJunchao Zhang B->ops->solve = MatSolve_SeqAIJHIPSPARSE; 219*d52a580bSJunchao Zhang B->ops->solvetranspose = MatSolveTranspose_SeqAIJHIPSPARSE; 220*d52a580bSJunchao Zhang } 221*d52a580bSJunchao Zhang } 222*d52a580bSJunchao Zhang B->ops->matsolve = NULL; 223*d52a580bSJunchao Zhang B->ops->matsolvetranspose = NULL; 224*d52a580bSJunchao Zhang 225*d52a580bSJunchao Zhang /* get the triangular factors */ 226*d52a580bSJunchao Zhang if (!hipsparsestruct->use_cpu_solve) PetscCall(MatSeqAIJHIPSPARSEILUAnalysisAndCopyToGPU(B)); 227*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 228*d52a580bSJunchao Zhang } 229*d52a580bSJunchao Zhang 230*d52a580bSJunchao Zhang static PetscErrorCode MatSetFromOptions_SeqAIJHIPSPARSE(Mat A, PetscOptionItems PetscOptionsObject) 231*d52a580bSJunchao Zhang { 232*d52a580bSJunchao Zhang MatHIPSPARSEStorageFormat format; 233*d52a580bSJunchao Zhang PetscBool flg; 234*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr; 235*d52a580bSJunchao Zhang 236*d52a580bSJunchao Zhang PetscFunctionBegin; 237*d52a580bSJunchao Zhang PetscOptionsHeadBegin(PetscOptionsObject, "SeqAIJHIPSPARSE options"); 238*d52a580bSJunchao Zhang if (A->factortype == MAT_FACTOR_NONE) { 239*d52a580bSJunchao Zhang PetscCall(PetscOptionsEnum("-mat_hipsparse_mult_storage_format", "sets storage format of (seq)aijhipsparse gpu matrices for SpMV", "MatHIPSPARSESetFormat", MatHIPSPARSEStorageFormats, (PetscEnum)hipsparsestruct->format, (PetscEnum *)&format, &flg)); 240*d52a580bSJunchao Zhang if (flg) PetscCall(MatHIPSPARSESetFormat(A, MAT_HIPSPARSE_MULT, format)); 241*d52a580bSJunchao Zhang PetscCall(PetscOptionsEnum("-mat_hipsparse_storage_format", "sets storage format of (seq)aijhipsparse gpu matrices for SpMV and TriSolve", "MatHIPSPARSESetFormat", MatHIPSPARSEStorageFormats, (PetscEnum)hipsparsestruct->format, (PetscEnum *)&format, &flg)); 242*d52a580bSJunchao Zhang if (flg) PetscCall(MatHIPSPARSESetFormat(A, MAT_HIPSPARSE_ALL, format)); 243*d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-mat_hipsparse_use_cpu_solve", "Use CPU (I)LU solve", "MatHIPSPARSESetUseCPUSolve", hipsparsestruct->use_cpu_solve, &hipsparsestruct->use_cpu_solve, &flg)); 244*d52a580bSJunchao Zhang if (flg) PetscCall(MatHIPSPARSESetUseCPUSolve(A, hipsparsestruct->use_cpu_solve)); 245*d52a580bSJunchao Zhang PetscCall( 246*d52a580bSJunchao Zhang PetscOptionsEnum("-mat_hipsparse_spmv_alg", "sets hipSPARSE algorithm used in sparse-mat dense-vector multiplication (SpMV)", "hipsparseSpMVAlg_t", MatHIPSPARSESpMVAlgorithms, (PetscEnum)hipsparsestruct->spmvAlg, (PetscEnum *)&hipsparsestruct->spmvAlg, &flg)); 247*d52a580bSJunchao Zhang /* If user did use this option, check its consistency with hipSPARSE, since PetscOptionsEnum() sets enum values based on their position in MatHIPSPARSESpMVAlgorithms[] */ 248*d52a580bSJunchao Zhang PetscCheck(!flg || HIPSPARSE_CSRMV_ALG1 == 2, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE enum hipsparseSpMVAlg_t has been changed but PETSc has not been updated accordingly"); 249*d52a580bSJunchao Zhang PetscCall( 250*d52a580bSJunchao Zhang PetscOptionsEnum("-mat_hipsparse_spmm_alg", "sets hipSPARSE algorithm used in sparse-mat dense-mat multiplication (SpMM)", "hipsparseSpMMAlg_t", MatHIPSPARSESpMMAlgorithms, (PetscEnum)hipsparsestruct->spmmAlg, (PetscEnum *)&hipsparsestruct->spmmAlg, &flg)); 251*d52a580bSJunchao Zhang PetscCheck(!flg || HIPSPARSE_SPMM_CSR_ALG1 == 4, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE enum hipsparseSpMMAlg_t has been changed but PETSc has not been updated accordingly"); 252*d52a580bSJunchao Zhang /* 253*d52a580bSJunchao Zhang PetscCall(PetscOptionsEnum("-mat_hipsparse_csr2csc_alg", "sets hipSPARSE algorithm used in converting CSR matrices to CSC matrices", "hipsparseCsr2CscAlg_t", MatHIPSPARSECsr2CscAlgorithms, (PetscEnum)hipsparsestruct->csr2cscAlg, (PetscEnum*)&hipsparsestruct->csr2cscAlg, &flg)); 254*d52a580bSJunchao Zhang PetscCheck(!flg || HIPSPARSE_CSR2CSC_ALG1 == 1, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE enum hipsparseCsr2CscAlg_t has been changed but PETSc has not been updated accordingly"); 255*d52a580bSJunchao Zhang */ 256*d52a580bSJunchao Zhang } 257*d52a580bSJunchao Zhang PetscOptionsHeadEnd(); 258*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 259*d52a580bSJunchao Zhang } 260*d52a580bSJunchao Zhang 261*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEBuildILULowerTriMatrix(Mat A) 262*d52a580bSJunchao Zhang { 263*d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data; 264*d52a580bSJunchao Zhang PetscInt n = A->rmap->n; 265*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr; 266*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr; 267*d52a580bSJunchao Zhang const PetscInt *ai = a->i, *aj = a->j, *vi; 268*d52a580bSJunchao Zhang const MatScalar *aa = a->a, *v; 269*d52a580bSJunchao Zhang PetscInt *AiLo, *AjLo; 270*d52a580bSJunchao Zhang PetscInt i, nz, nzLower, offset, rowOffset; 271*d52a580bSJunchao Zhang 272*d52a580bSJunchao Zhang PetscFunctionBegin; 273*d52a580bSJunchao Zhang if (!n) PetscFunctionReturn(PETSC_SUCCESS); 274*d52a580bSJunchao Zhang if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) { 275*d52a580bSJunchao Zhang try { 276*d52a580bSJunchao Zhang /* first figure out the number of nonzeros in the lower triangular matrix including 1's on the diagonal. */ 277*d52a580bSJunchao Zhang nzLower = n + ai[n] - ai[1]; 278*d52a580bSJunchao Zhang if (!loTriFactor) { 279*d52a580bSJunchao Zhang PetscScalar *AALo; 280*d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AALo, nzLower * sizeof(PetscScalar))); 281*d52a580bSJunchao Zhang 282*d52a580bSJunchao Zhang /* Allocate Space for the lower triangular matrix */ 283*d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AiLo, (n + 1) * sizeof(PetscInt))); 284*d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AjLo, nzLower * sizeof(PetscInt))); 285*d52a580bSJunchao Zhang 286*d52a580bSJunchao Zhang /* Fill the lower triangular matrix */ 287*d52a580bSJunchao Zhang AiLo[0] = (PetscInt)0; 288*d52a580bSJunchao Zhang AiLo[n] = nzLower; 289*d52a580bSJunchao Zhang AjLo[0] = (PetscInt)0; 290*d52a580bSJunchao Zhang AALo[0] = (MatScalar)1.0; 291*d52a580bSJunchao Zhang v = aa; 292*d52a580bSJunchao Zhang vi = aj; 293*d52a580bSJunchao Zhang offset = 1; 294*d52a580bSJunchao Zhang rowOffset = 1; 295*d52a580bSJunchao Zhang for (i = 1; i < n; i++) { 296*d52a580bSJunchao Zhang nz = ai[i + 1] - ai[i]; 297*d52a580bSJunchao Zhang /* additional 1 for the term on the diagonal */ 298*d52a580bSJunchao Zhang AiLo[i] = rowOffset; 299*d52a580bSJunchao Zhang rowOffset += nz + 1; 300*d52a580bSJunchao Zhang 301*d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&AjLo[offset], vi, nz)); 302*d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&AALo[offset], v, nz)); 303*d52a580bSJunchao Zhang offset += nz; 304*d52a580bSJunchao Zhang AjLo[offset] = (PetscInt)i; 305*d52a580bSJunchao Zhang AALo[offset] = (MatScalar)1.0; 306*d52a580bSJunchao Zhang offset += 1; 307*d52a580bSJunchao Zhang v += nz; 308*d52a580bSJunchao Zhang vi += nz; 309*d52a580bSJunchao Zhang } 310*d52a580bSJunchao Zhang 311*d52a580bSJunchao Zhang /* allocate space for the triangular factor information */ 312*d52a580bSJunchao Zhang PetscCall(PetscNew(&loTriFactor)); 313*d52a580bSJunchao Zhang loTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL; 314*d52a580bSJunchao Zhang /* Create the matrix description */ 315*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&loTriFactor->descr)); 316*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(loTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO)); 317*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(loTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL)); 318*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatFillMode(loTriFactor->descr, HIPSPARSE_FILL_MODE_LOWER)); 319*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatDiagType(loTriFactor->descr, HIPSPARSE_DIAG_TYPE_UNIT)); 320*d52a580bSJunchao Zhang 321*d52a580bSJunchao Zhang /* set the operation */ 322*d52a580bSJunchao Zhang loTriFactor->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE; 323*d52a580bSJunchao Zhang 324*d52a580bSJunchao Zhang /* set the matrix */ 325*d52a580bSJunchao Zhang loTriFactor->csrMat = new CsrMatrix; 326*d52a580bSJunchao Zhang loTriFactor->csrMat->num_rows = n; 327*d52a580bSJunchao Zhang loTriFactor->csrMat->num_cols = n; 328*d52a580bSJunchao Zhang loTriFactor->csrMat->num_entries = nzLower; 329*d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets = new THRUSTINTARRAY32(n + 1); 330*d52a580bSJunchao Zhang loTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(nzLower); 331*d52a580bSJunchao Zhang loTriFactor->csrMat->values = new THRUSTARRAY(nzLower); 332*d52a580bSJunchao Zhang 333*d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->assign(AiLo, AiLo + n + 1); 334*d52a580bSJunchao Zhang loTriFactor->csrMat->column_indices->assign(AjLo, AjLo + nzLower); 335*d52a580bSJunchao Zhang loTriFactor->csrMat->values->assign(AALo, AALo + nzLower); 336*d52a580bSJunchao Zhang 337*d52a580bSJunchao Zhang /* Create the solve analysis information */ 338*d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0)); 339*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&loTriFactor->solveInfo)); 340*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(), 341*d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, &loTriFactor->solveBufferSize)); 342*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&loTriFactor->solveBuffer, loTriFactor->solveBufferSize)); 343*d52a580bSJunchao Zhang 344*d52a580bSJunchao Zhang /* perform the solve analysis */ 345*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(), 346*d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, loTriFactor->solvePolicy, loTriFactor->solveBuffer)); 347*d52a580bSJunchao Zhang 348*d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP()); 349*d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0)); 350*d52a580bSJunchao Zhang 351*d52a580bSJunchao Zhang /* assign the pointer */ 352*d52a580bSJunchao Zhang ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->loTriFactorPtr = loTriFactor; 353*d52a580bSJunchao Zhang loTriFactor->AA_h = AALo; 354*d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AiLo)); 355*d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AjLo)); 356*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu((n + 1 + nzLower) * sizeof(int) + nzLower * sizeof(PetscScalar))); 357*d52a580bSJunchao Zhang } else { /* update values only */ 358*d52a580bSJunchao Zhang if (!loTriFactor->AA_h) PetscCallHIP(hipHostMalloc((void **)&loTriFactor->AA_h, nzLower * sizeof(PetscScalar))); 359*d52a580bSJunchao Zhang /* Fill the lower triangular matrix */ 360*d52a580bSJunchao Zhang loTriFactor->AA_h[0] = 1.0; 361*d52a580bSJunchao Zhang v = aa; 362*d52a580bSJunchao Zhang vi = aj; 363*d52a580bSJunchao Zhang offset = 1; 364*d52a580bSJunchao Zhang for (i = 1; i < n; i++) { 365*d52a580bSJunchao Zhang nz = ai[i + 1] - ai[i]; 366*d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&loTriFactor->AA_h[offset], v, nz)); 367*d52a580bSJunchao Zhang offset += nz; 368*d52a580bSJunchao Zhang loTriFactor->AA_h[offset] = 1.0; 369*d52a580bSJunchao Zhang offset += 1; 370*d52a580bSJunchao Zhang v += nz; 371*d52a580bSJunchao Zhang } 372*d52a580bSJunchao Zhang loTriFactor->csrMat->values->assign(loTriFactor->AA_h, loTriFactor->AA_h + nzLower); 373*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(nzLower * sizeof(PetscScalar))); 374*d52a580bSJunchao Zhang } 375*d52a580bSJunchao Zhang } catch (char *ex) { 376*d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex); 377*d52a580bSJunchao Zhang } 378*d52a580bSJunchao Zhang } 379*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 380*d52a580bSJunchao Zhang } 381*d52a580bSJunchao Zhang 382*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEBuildILUUpperTriMatrix(Mat A) 383*d52a580bSJunchao Zhang { 384*d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data; 385*d52a580bSJunchao Zhang PetscInt n = A->rmap->n; 386*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr; 387*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr; 388*d52a580bSJunchao Zhang const PetscInt *aj = a->j, *adiag, *vi; 389*d52a580bSJunchao Zhang const MatScalar *aa = a->a, *v; 390*d52a580bSJunchao Zhang PetscInt *AiUp, *AjUp; 391*d52a580bSJunchao Zhang PetscInt i, nz, nzUpper, offset; 392*d52a580bSJunchao Zhang 393*d52a580bSJunchao Zhang PetscFunctionBegin; 394*d52a580bSJunchao Zhang if (!n) PetscFunctionReturn(PETSC_SUCCESS); 395*d52a580bSJunchao Zhang PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, &adiag, NULL)); 396*d52a580bSJunchao Zhang if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) { 397*d52a580bSJunchao Zhang try { 398*d52a580bSJunchao Zhang /* next, figure out the number of nonzeros in the upper triangular matrix. */ 399*d52a580bSJunchao Zhang nzUpper = adiag[0] - adiag[n]; 400*d52a580bSJunchao Zhang if (!upTriFactor) { 401*d52a580bSJunchao Zhang PetscScalar *AAUp; 402*d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AAUp, nzUpper * sizeof(PetscScalar))); 403*d52a580bSJunchao Zhang 404*d52a580bSJunchao Zhang /* Allocate Space for the upper triangular matrix */ 405*d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AiUp, (n + 1) * sizeof(PetscInt))); 406*d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AjUp, nzUpper * sizeof(PetscInt))); 407*d52a580bSJunchao Zhang 408*d52a580bSJunchao Zhang /* Fill the upper triangular matrix */ 409*d52a580bSJunchao Zhang AiUp[0] = (PetscInt)0; 410*d52a580bSJunchao Zhang AiUp[n] = nzUpper; 411*d52a580bSJunchao Zhang offset = nzUpper; 412*d52a580bSJunchao Zhang for (i = n - 1; i >= 0; i--) { 413*d52a580bSJunchao Zhang v = aa + adiag[i + 1] + 1; 414*d52a580bSJunchao Zhang vi = aj + adiag[i + 1] + 1; 415*d52a580bSJunchao Zhang nz = adiag[i] - adiag[i + 1] - 1; /* number of elements NOT on the diagonal */ 416*d52a580bSJunchao Zhang offset -= (nz + 1); /* decrement the offset */ 417*d52a580bSJunchao Zhang 418*d52a580bSJunchao Zhang /* first, set the diagonal elements */ 419*d52a580bSJunchao Zhang AjUp[offset] = (PetscInt)i; 420*d52a580bSJunchao Zhang AAUp[offset] = (MatScalar)1. / v[nz]; 421*d52a580bSJunchao Zhang AiUp[i] = AiUp[i + 1] - (nz + 1); 422*d52a580bSJunchao Zhang 423*d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&AjUp[offset + 1], vi, nz)); 424*d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&AAUp[offset + 1], v, nz)); 425*d52a580bSJunchao Zhang } 426*d52a580bSJunchao Zhang 427*d52a580bSJunchao Zhang /* allocate space for the triangular factor information */ 428*d52a580bSJunchao Zhang PetscCall(PetscNew(&upTriFactor)); 429*d52a580bSJunchao Zhang upTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL; 430*d52a580bSJunchao Zhang 431*d52a580bSJunchao Zhang /* Create the matrix description */ 432*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&upTriFactor->descr)); 433*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(upTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO)); 434*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(upTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL)); 435*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatFillMode(upTriFactor->descr, HIPSPARSE_FILL_MODE_UPPER)); 436*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatDiagType(upTriFactor->descr, HIPSPARSE_DIAG_TYPE_NON_UNIT)); 437*d52a580bSJunchao Zhang 438*d52a580bSJunchao Zhang /* set the operation */ 439*d52a580bSJunchao Zhang upTriFactor->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE; 440*d52a580bSJunchao Zhang 441*d52a580bSJunchao Zhang /* set the matrix */ 442*d52a580bSJunchao Zhang upTriFactor->csrMat = new CsrMatrix; 443*d52a580bSJunchao Zhang upTriFactor->csrMat->num_rows = n; 444*d52a580bSJunchao Zhang upTriFactor->csrMat->num_cols = n; 445*d52a580bSJunchao Zhang upTriFactor->csrMat->num_entries = nzUpper; 446*d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets = new THRUSTINTARRAY32(n + 1); 447*d52a580bSJunchao Zhang upTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(nzUpper); 448*d52a580bSJunchao Zhang upTriFactor->csrMat->values = new THRUSTARRAY(nzUpper); 449*d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->assign(AiUp, AiUp + n + 1); 450*d52a580bSJunchao Zhang upTriFactor->csrMat->column_indices->assign(AjUp, AjUp + nzUpper); 451*d52a580bSJunchao Zhang upTriFactor->csrMat->values->assign(AAUp, AAUp + nzUpper); 452*d52a580bSJunchao Zhang 453*d52a580bSJunchao Zhang /* Create the solve analysis information */ 454*d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0)); 455*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&upTriFactor->solveInfo)); 456*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(), 457*d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, &upTriFactor->solveBufferSize)); 458*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&upTriFactor->solveBuffer, upTriFactor->solveBufferSize)); 459*d52a580bSJunchao Zhang 460*d52a580bSJunchao Zhang /* perform the solve analysis */ 461*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(), 462*d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, upTriFactor->solvePolicy, upTriFactor->solveBuffer)); 463*d52a580bSJunchao Zhang 464*d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP()); 465*d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0)); 466*d52a580bSJunchao Zhang 467*d52a580bSJunchao Zhang /* assign the pointer */ 468*d52a580bSJunchao Zhang ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->upTriFactorPtr = upTriFactor; 469*d52a580bSJunchao Zhang upTriFactor->AA_h = AAUp; 470*d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AiUp)); 471*d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AjUp)); 472*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu((n + 1 + nzUpper) * sizeof(int) + nzUpper * sizeof(PetscScalar))); 473*d52a580bSJunchao Zhang } else { 474*d52a580bSJunchao Zhang if (!upTriFactor->AA_h) PetscCallHIP(hipHostMalloc((void **)&upTriFactor->AA_h, nzUpper * sizeof(PetscScalar))); 475*d52a580bSJunchao Zhang /* Fill the upper triangular matrix */ 476*d52a580bSJunchao Zhang offset = nzUpper; 477*d52a580bSJunchao Zhang for (i = n - 1; i >= 0; i--) { 478*d52a580bSJunchao Zhang v = aa + adiag[i + 1] + 1; 479*d52a580bSJunchao Zhang nz = adiag[i] - adiag[i + 1] - 1; /* number of elements NOT on the diagonal */ 480*d52a580bSJunchao Zhang offset -= (nz + 1); /* decrement the offset */ 481*d52a580bSJunchao Zhang 482*d52a580bSJunchao Zhang /* first, set the diagonal elements */ 483*d52a580bSJunchao Zhang upTriFactor->AA_h[offset] = 1. / v[nz]; 484*d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&upTriFactor->AA_h[offset + 1], v, nz)); 485*d52a580bSJunchao Zhang } 486*d52a580bSJunchao Zhang upTriFactor->csrMat->values->assign(upTriFactor->AA_h, upTriFactor->AA_h + nzUpper); 487*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(nzUpper * sizeof(PetscScalar))); 488*d52a580bSJunchao Zhang } 489*d52a580bSJunchao Zhang } catch (char *ex) { 490*d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex); 491*d52a580bSJunchao Zhang } 492*d52a580bSJunchao Zhang } 493*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 494*d52a580bSJunchao Zhang } 495*d52a580bSJunchao Zhang 496*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEILUAnalysisAndCopyToGPU(Mat A) 497*d52a580bSJunchao Zhang { 498*d52a580bSJunchao Zhang PetscBool row_identity, col_identity; 499*d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data; 500*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr; 501*d52a580bSJunchao Zhang IS isrow = a->row, iscol = a->icol; 502*d52a580bSJunchao Zhang PetscInt n = A->rmap->n; 503*d52a580bSJunchao Zhang 504*d52a580bSJunchao Zhang PetscFunctionBegin; 505*d52a580bSJunchao Zhang PetscCheck(hipsparseTriFactors, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors"); 506*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEBuildILULowerTriMatrix(A)); 507*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEBuildILUUpperTriMatrix(A)); 508*d52a580bSJunchao Zhang 509*d52a580bSJunchao Zhang if (!hipsparseTriFactors->workVector) hipsparseTriFactors->workVector = new THRUSTARRAY(n); 510*d52a580bSJunchao Zhang hipsparseTriFactors->nnz = a->nz; 511*d52a580bSJunchao Zhang 512*d52a580bSJunchao Zhang A->offloadmask = PETSC_OFFLOAD_BOTH; 513*d52a580bSJunchao Zhang /* lower triangular indices */ 514*d52a580bSJunchao Zhang PetscCall(ISIdentity(isrow, &row_identity)); 515*d52a580bSJunchao Zhang if (!row_identity && !hipsparseTriFactors->rpermIndices) { 516*d52a580bSJunchao Zhang const PetscInt *r; 517*d52a580bSJunchao Zhang 518*d52a580bSJunchao Zhang PetscCall(ISGetIndices(isrow, &r)); 519*d52a580bSJunchao Zhang hipsparseTriFactors->rpermIndices = new THRUSTINTARRAY(n); 520*d52a580bSJunchao Zhang hipsparseTriFactors->rpermIndices->assign(r, r + n); 521*d52a580bSJunchao Zhang PetscCall(ISRestoreIndices(isrow, &r)); 522*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt))); 523*d52a580bSJunchao Zhang } 524*d52a580bSJunchao Zhang /* upper triangular indices */ 525*d52a580bSJunchao Zhang PetscCall(ISIdentity(iscol, &col_identity)); 526*d52a580bSJunchao Zhang if (!col_identity && !hipsparseTriFactors->cpermIndices) { 527*d52a580bSJunchao Zhang const PetscInt *c; 528*d52a580bSJunchao Zhang 529*d52a580bSJunchao Zhang PetscCall(ISGetIndices(iscol, &c)); 530*d52a580bSJunchao Zhang hipsparseTriFactors->cpermIndices = new THRUSTINTARRAY(n); 531*d52a580bSJunchao Zhang hipsparseTriFactors->cpermIndices->assign(c, c + n); 532*d52a580bSJunchao Zhang PetscCall(ISRestoreIndices(iscol, &c)); 533*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt))); 534*d52a580bSJunchao Zhang } 535*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 536*d52a580bSJunchao Zhang } 537*d52a580bSJunchao Zhang 538*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEBuildICCTriMatrices(Mat A) 539*d52a580bSJunchao Zhang { 540*d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data; 541*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr; 542*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr; 543*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr; 544*d52a580bSJunchao Zhang PetscInt *AiUp, *AjUp; 545*d52a580bSJunchao Zhang PetscScalar *AAUp; 546*d52a580bSJunchao Zhang PetscScalar *AALo; 547*d52a580bSJunchao Zhang PetscInt nzUpper = a->nz, n = A->rmap->n, i, offset, nz, j; 548*d52a580bSJunchao Zhang Mat_SeqSBAIJ *b = (Mat_SeqSBAIJ *)A->data; 549*d52a580bSJunchao Zhang const PetscInt *ai = b->i, *aj = b->j, *vj; 550*d52a580bSJunchao Zhang const MatScalar *aa = b->a, *v; 551*d52a580bSJunchao Zhang 552*d52a580bSJunchao Zhang PetscFunctionBegin; 553*d52a580bSJunchao Zhang if (!n) PetscFunctionReturn(PETSC_SUCCESS); 554*d52a580bSJunchao Zhang if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) { 555*d52a580bSJunchao Zhang try { 556*d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AAUp, nzUpper * sizeof(PetscScalar))); 557*d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AALo, nzUpper * sizeof(PetscScalar))); 558*d52a580bSJunchao Zhang if (!upTriFactor && !loTriFactor) { 559*d52a580bSJunchao Zhang /* Allocate Space for the upper triangular matrix */ 560*d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AiUp, (n + 1) * sizeof(PetscInt))); 561*d52a580bSJunchao Zhang PetscCallHIP(hipHostMalloc((void **)&AjUp, nzUpper * sizeof(PetscInt))); 562*d52a580bSJunchao Zhang 563*d52a580bSJunchao Zhang /* Fill the upper triangular matrix */ 564*d52a580bSJunchao Zhang AiUp[0] = (PetscInt)0; 565*d52a580bSJunchao Zhang AiUp[n] = nzUpper; 566*d52a580bSJunchao Zhang offset = 0; 567*d52a580bSJunchao Zhang for (i = 0; i < n; i++) { 568*d52a580bSJunchao Zhang /* set the pointers */ 569*d52a580bSJunchao Zhang v = aa + ai[i]; 570*d52a580bSJunchao Zhang vj = aj + ai[i]; 571*d52a580bSJunchao Zhang nz = ai[i + 1] - ai[i] - 1; /* exclude diag[i] */ 572*d52a580bSJunchao Zhang 573*d52a580bSJunchao Zhang /* first, set the diagonal elements */ 574*d52a580bSJunchao Zhang AjUp[offset] = (PetscInt)i; 575*d52a580bSJunchao Zhang AAUp[offset] = (MatScalar)1.0 / v[nz]; 576*d52a580bSJunchao Zhang AiUp[i] = offset; 577*d52a580bSJunchao Zhang AALo[offset] = (MatScalar)1.0 / v[nz]; 578*d52a580bSJunchao Zhang 579*d52a580bSJunchao Zhang offset += 1; 580*d52a580bSJunchao Zhang if (nz > 0) { 581*d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&AjUp[offset], vj, nz)); 582*d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&AAUp[offset], v, nz)); 583*d52a580bSJunchao Zhang for (j = offset; j < offset + nz; j++) { 584*d52a580bSJunchao Zhang AAUp[j] = -AAUp[j]; 585*d52a580bSJunchao Zhang AALo[j] = AAUp[j] / v[nz]; 586*d52a580bSJunchao Zhang } 587*d52a580bSJunchao Zhang offset += nz; 588*d52a580bSJunchao Zhang } 589*d52a580bSJunchao Zhang } 590*d52a580bSJunchao Zhang 591*d52a580bSJunchao Zhang /* allocate space for the triangular factor information */ 592*d52a580bSJunchao Zhang PetscCall(PetscNew(&upTriFactor)); 593*d52a580bSJunchao Zhang upTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL; 594*d52a580bSJunchao Zhang 595*d52a580bSJunchao Zhang /* Create the matrix description */ 596*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&upTriFactor->descr)); 597*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(upTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO)); 598*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(upTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL)); 599*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatFillMode(upTriFactor->descr, HIPSPARSE_FILL_MODE_UPPER)); 600*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatDiagType(upTriFactor->descr, HIPSPARSE_DIAG_TYPE_UNIT)); 601*d52a580bSJunchao Zhang 602*d52a580bSJunchao Zhang /* set the matrix */ 603*d52a580bSJunchao Zhang upTriFactor->csrMat = new CsrMatrix; 604*d52a580bSJunchao Zhang upTriFactor->csrMat->num_rows = A->rmap->n; 605*d52a580bSJunchao Zhang upTriFactor->csrMat->num_cols = A->cmap->n; 606*d52a580bSJunchao Zhang upTriFactor->csrMat->num_entries = a->nz; 607*d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets = new THRUSTINTARRAY32(A->rmap->n + 1); 608*d52a580bSJunchao Zhang upTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(a->nz); 609*d52a580bSJunchao Zhang upTriFactor->csrMat->values = new THRUSTARRAY(a->nz); 610*d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->assign(AiUp, AiUp + A->rmap->n + 1); 611*d52a580bSJunchao Zhang upTriFactor->csrMat->column_indices->assign(AjUp, AjUp + a->nz); 612*d52a580bSJunchao Zhang upTriFactor->csrMat->values->assign(AAUp, AAUp + a->nz); 613*d52a580bSJunchao Zhang 614*d52a580bSJunchao Zhang /* set the operation */ 615*d52a580bSJunchao Zhang upTriFactor->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE; 616*d52a580bSJunchao Zhang 617*d52a580bSJunchao Zhang /* Create the solve analysis information */ 618*d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0)); 619*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&upTriFactor->solveInfo)); 620*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(), 621*d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, &upTriFactor->solveBufferSize)); 622*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&upTriFactor->solveBuffer, upTriFactor->solveBufferSize)); 623*d52a580bSJunchao Zhang 624*d52a580bSJunchao Zhang /* perform the solve analysis */ 625*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, upTriFactor->descr, upTriFactor->csrMat->values->data().get(), 626*d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, upTriFactor->solvePolicy, upTriFactor->solveBuffer)); 627*d52a580bSJunchao Zhang 628*d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP()); 629*d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0)); 630*d52a580bSJunchao Zhang 631*d52a580bSJunchao Zhang /* assign the pointer */ 632*d52a580bSJunchao Zhang ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->upTriFactorPtr = upTriFactor; 633*d52a580bSJunchao Zhang 634*d52a580bSJunchao Zhang /* allocate space for the triangular factor information */ 635*d52a580bSJunchao Zhang PetscCall(PetscNew(&loTriFactor)); 636*d52a580bSJunchao Zhang loTriFactor->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL; 637*d52a580bSJunchao Zhang 638*d52a580bSJunchao Zhang /* Create the matrix description */ 639*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&loTriFactor->descr)); 640*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(loTriFactor->descr, HIPSPARSE_INDEX_BASE_ZERO)); 641*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(loTriFactor->descr, HIPSPARSE_MATRIX_TYPE_GENERAL)); 642*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatFillMode(loTriFactor->descr, HIPSPARSE_FILL_MODE_UPPER)); 643*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatDiagType(loTriFactor->descr, HIPSPARSE_DIAG_TYPE_NON_UNIT)); 644*d52a580bSJunchao Zhang 645*d52a580bSJunchao Zhang /* set the operation */ 646*d52a580bSJunchao Zhang loTriFactor->solveOp = HIPSPARSE_OPERATION_TRANSPOSE; 647*d52a580bSJunchao Zhang 648*d52a580bSJunchao Zhang /* set the matrix */ 649*d52a580bSJunchao Zhang loTriFactor->csrMat = new CsrMatrix; 650*d52a580bSJunchao Zhang loTriFactor->csrMat->num_rows = A->rmap->n; 651*d52a580bSJunchao Zhang loTriFactor->csrMat->num_cols = A->cmap->n; 652*d52a580bSJunchao Zhang loTriFactor->csrMat->num_entries = a->nz; 653*d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets = new THRUSTINTARRAY32(A->rmap->n + 1); 654*d52a580bSJunchao Zhang loTriFactor->csrMat->column_indices = new THRUSTINTARRAY32(a->nz); 655*d52a580bSJunchao Zhang loTriFactor->csrMat->values = new THRUSTARRAY(a->nz); 656*d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->assign(AiUp, AiUp + A->rmap->n + 1); 657*d52a580bSJunchao Zhang loTriFactor->csrMat->column_indices->assign(AjUp, AjUp + a->nz); 658*d52a580bSJunchao Zhang loTriFactor->csrMat->values->assign(AALo, AALo + a->nz); 659*d52a580bSJunchao Zhang 660*d52a580bSJunchao Zhang /* Create the solve analysis information */ 661*d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0)); 662*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&loTriFactor->solveInfo)); 663*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(), 664*d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, &loTriFactor->solveBufferSize)); 665*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&loTriFactor->solveBuffer, loTriFactor->solveBufferSize)); 666*d52a580bSJunchao Zhang 667*d52a580bSJunchao Zhang /* perform the solve analysis */ 668*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, loTriFactor->descr, loTriFactor->csrMat->values->data().get(), 669*d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, loTriFactor->solvePolicy, loTriFactor->solveBuffer)); 670*d52a580bSJunchao Zhang 671*d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP()); 672*d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0)); 673*d52a580bSJunchao Zhang 674*d52a580bSJunchao Zhang /* assign the pointer */ 675*d52a580bSJunchao Zhang ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->loTriFactorPtr = loTriFactor; 676*d52a580bSJunchao Zhang 677*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(2 * (((A->rmap->n + 1) + (a->nz)) * sizeof(int) + (a->nz) * sizeof(PetscScalar)))); 678*d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AiUp)); 679*d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AjUp)); 680*d52a580bSJunchao Zhang } else { 681*d52a580bSJunchao Zhang /* Fill the upper triangular matrix */ 682*d52a580bSJunchao Zhang offset = 0; 683*d52a580bSJunchao Zhang for (i = 0; i < n; i++) { 684*d52a580bSJunchao Zhang /* set the pointers */ 685*d52a580bSJunchao Zhang v = aa + ai[i]; 686*d52a580bSJunchao Zhang nz = ai[i + 1] - ai[i] - 1; /* exclude diag[i] */ 687*d52a580bSJunchao Zhang 688*d52a580bSJunchao Zhang /* first, set the diagonal elements */ 689*d52a580bSJunchao Zhang AAUp[offset] = 1.0 / v[nz]; 690*d52a580bSJunchao Zhang AALo[offset] = 1.0 / v[nz]; 691*d52a580bSJunchao Zhang 692*d52a580bSJunchao Zhang offset += 1; 693*d52a580bSJunchao Zhang if (nz > 0) { 694*d52a580bSJunchao Zhang PetscCall(PetscArraycpy(&AAUp[offset], v, nz)); 695*d52a580bSJunchao Zhang for (j = offset; j < offset + nz; j++) { 696*d52a580bSJunchao Zhang AAUp[j] = -AAUp[j]; 697*d52a580bSJunchao Zhang AALo[j] = AAUp[j] / v[nz]; 698*d52a580bSJunchao Zhang } 699*d52a580bSJunchao Zhang offset += nz; 700*d52a580bSJunchao Zhang } 701*d52a580bSJunchao Zhang } 702*d52a580bSJunchao Zhang PetscCheck(upTriFactor, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors"); 703*d52a580bSJunchao Zhang PetscCheck(loTriFactor, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors"); 704*d52a580bSJunchao Zhang upTriFactor->csrMat->values->assign(AAUp, AAUp + a->nz); 705*d52a580bSJunchao Zhang loTriFactor->csrMat->values->assign(AALo, AALo + a->nz); 706*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(2 * (a->nz) * sizeof(PetscScalar))); 707*d52a580bSJunchao Zhang } 708*d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AAUp)); 709*d52a580bSJunchao Zhang PetscCallHIP(hipHostFree(AALo)); 710*d52a580bSJunchao Zhang } catch (char *ex) { 711*d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex); 712*d52a580bSJunchao Zhang } 713*d52a580bSJunchao Zhang } 714*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 715*d52a580bSJunchao Zhang } 716*d52a580bSJunchao Zhang 717*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEICCAnalysisAndCopyToGPU(Mat A) 718*d52a580bSJunchao Zhang { 719*d52a580bSJunchao Zhang PetscBool perm_identity; 720*d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data; 721*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr; 722*d52a580bSJunchao Zhang IS ip = a->row; 723*d52a580bSJunchao Zhang PetscInt n = A->rmap->n; 724*d52a580bSJunchao Zhang 725*d52a580bSJunchao Zhang PetscFunctionBegin; 726*d52a580bSJunchao Zhang PetscCheck(hipsparseTriFactors, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing hipsparseTriFactors"); 727*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEBuildICCTriMatrices(A)); 728*d52a580bSJunchao Zhang if (!hipsparseTriFactors->workVector) hipsparseTriFactors->workVector = new THRUSTARRAY(n); 729*d52a580bSJunchao Zhang hipsparseTriFactors->nnz = (a->nz - n) * 2 + n; 730*d52a580bSJunchao Zhang 731*d52a580bSJunchao Zhang A->offloadmask = PETSC_OFFLOAD_BOTH; 732*d52a580bSJunchao Zhang /* lower triangular indices */ 733*d52a580bSJunchao Zhang PetscCall(ISIdentity(ip, &perm_identity)); 734*d52a580bSJunchao Zhang if (!perm_identity) { 735*d52a580bSJunchao Zhang IS iip; 736*d52a580bSJunchao Zhang const PetscInt *irip, *rip; 737*d52a580bSJunchao Zhang 738*d52a580bSJunchao Zhang PetscCall(ISInvertPermutation(ip, PETSC_DECIDE, &iip)); 739*d52a580bSJunchao Zhang PetscCall(ISGetIndices(iip, &irip)); 740*d52a580bSJunchao Zhang PetscCall(ISGetIndices(ip, &rip)); 741*d52a580bSJunchao Zhang hipsparseTriFactors->rpermIndices = new THRUSTINTARRAY(n); 742*d52a580bSJunchao Zhang hipsparseTriFactors->cpermIndices = new THRUSTINTARRAY(n); 743*d52a580bSJunchao Zhang hipsparseTriFactors->rpermIndices->assign(rip, rip + n); 744*d52a580bSJunchao Zhang hipsparseTriFactors->cpermIndices->assign(irip, irip + n); 745*d52a580bSJunchao Zhang PetscCall(ISRestoreIndices(iip, &irip)); 746*d52a580bSJunchao Zhang PetscCall(ISDestroy(&iip)); 747*d52a580bSJunchao Zhang PetscCall(ISRestoreIndices(ip, &rip)); 748*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(2. * n * sizeof(PetscInt))); 749*d52a580bSJunchao Zhang } 750*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 751*d52a580bSJunchao Zhang } 752*d52a580bSJunchao Zhang 753*d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorNumeric_SeqAIJHIPSPARSE(Mat B, Mat A, const MatFactorInfo *info) 754*d52a580bSJunchao Zhang { 755*d52a580bSJunchao Zhang PetscBool perm_identity; 756*d52a580bSJunchao Zhang Mat_SeqAIJ *b = (Mat_SeqAIJ *)B->data; 757*d52a580bSJunchao Zhang IS ip = b->row; 758*d52a580bSJunchao Zhang 759*d52a580bSJunchao Zhang PetscFunctionBegin; 760*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A)); 761*d52a580bSJunchao Zhang PetscCall(MatCholeskyFactorNumeric_SeqAIJ(B, A, info)); 762*d52a580bSJunchao Zhang B->offloadmask = PETSC_OFFLOAD_CPU; 763*d52a580bSJunchao Zhang /* determine which version of MatSolve needs to be used. */ 764*d52a580bSJunchao Zhang PetscCall(ISIdentity(ip, &perm_identity)); 765*d52a580bSJunchao Zhang if (perm_identity) { 766*d52a580bSJunchao Zhang B->ops->solve = MatSolve_SeqAIJHIPSPARSE_NaturalOrdering; 767*d52a580bSJunchao Zhang B->ops->solvetranspose = MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering; 768*d52a580bSJunchao Zhang B->ops->matsolve = NULL; 769*d52a580bSJunchao Zhang B->ops->matsolvetranspose = NULL; 770*d52a580bSJunchao Zhang } else { 771*d52a580bSJunchao Zhang B->ops->solve = MatSolve_SeqAIJHIPSPARSE; 772*d52a580bSJunchao Zhang B->ops->solvetranspose = MatSolveTranspose_SeqAIJHIPSPARSE; 773*d52a580bSJunchao Zhang B->ops->matsolve = NULL; 774*d52a580bSJunchao Zhang B->ops->matsolvetranspose = NULL; 775*d52a580bSJunchao Zhang } 776*d52a580bSJunchao Zhang 777*d52a580bSJunchao Zhang /* get the triangular factors */ 778*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEICCAnalysisAndCopyToGPU(B)); 779*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 780*d52a580bSJunchao Zhang } 781*d52a580bSJunchao Zhang 782*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEAnalyzeTransposeForSolve(Mat A) 783*d52a580bSJunchao Zhang { 784*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr; 785*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr; 786*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr; 787*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorT; 788*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorT; 789*d52a580bSJunchao Zhang hipsparseIndexBase_t indexBase; 790*d52a580bSJunchao Zhang hipsparseMatrixType_t matrixType; 791*d52a580bSJunchao Zhang hipsparseFillMode_t fillMode; 792*d52a580bSJunchao Zhang hipsparseDiagType_t diagType; 793*d52a580bSJunchao Zhang 794*d52a580bSJunchao Zhang PetscFunctionBegin; 795*d52a580bSJunchao Zhang /* allocate space for the transpose of the lower triangular factor */ 796*d52a580bSJunchao Zhang PetscCall(PetscNew(&loTriFactorT)); 797*d52a580bSJunchao Zhang loTriFactorT->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL; 798*d52a580bSJunchao Zhang 799*d52a580bSJunchao Zhang /* set the matrix descriptors of the lower triangular factor */ 800*d52a580bSJunchao Zhang matrixType = hipsparseGetMatType(loTriFactor->descr); 801*d52a580bSJunchao Zhang indexBase = hipsparseGetMatIndexBase(loTriFactor->descr); 802*d52a580bSJunchao Zhang fillMode = hipsparseGetMatFillMode(loTriFactor->descr) == HIPSPARSE_FILL_MODE_UPPER ? HIPSPARSE_FILL_MODE_LOWER : HIPSPARSE_FILL_MODE_UPPER; 803*d52a580bSJunchao Zhang diagType = hipsparseGetMatDiagType(loTriFactor->descr); 804*d52a580bSJunchao Zhang 805*d52a580bSJunchao Zhang /* Create the matrix description */ 806*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&loTriFactorT->descr)); 807*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(loTriFactorT->descr, indexBase)); 808*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(loTriFactorT->descr, matrixType)); 809*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatFillMode(loTriFactorT->descr, fillMode)); 810*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatDiagType(loTriFactorT->descr, diagType)); 811*d52a580bSJunchao Zhang 812*d52a580bSJunchao Zhang /* set the operation */ 813*d52a580bSJunchao Zhang loTriFactorT->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE; 814*d52a580bSJunchao Zhang 815*d52a580bSJunchao Zhang /* allocate GPU space for the CSC of the lower triangular factor*/ 816*d52a580bSJunchao Zhang loTriFactorT->csrMat = new CsrMatrix; 817*d52a580bSJunchao Zhang loTriFactorT->csrMat->num_rows = loTriFactor->csrMat->num_cols; 818*d52a580bSJunchao Zhang loTriFactorT->csrMat->num_cols = loTriFactor->csrMat->num_rows; 819*d52a580bSJunchao Zhang loTriFactorT->csrMat->num_entries = loTriFactor->csrMat->num_entries; 820*d52a580bSJunchao Zhang loTriFactorT->csrMat->row_offsets = new THRUSTINTARRAY32(loTriFactorT->csrMat->num_rows + 1); 821*d52a580bSJunchao Zhang loTriFactorT->csrMat->column_indices = new THRUSTINTARRAY32(loTriFactorT->csrMat->num_entries); 822*d52a580bSJunchao Zhang loTriFactorT->csrMat->values = new THRUSTARRAY(loTriFactorT->csrMat->num_entries); 823*d52a580bSJunchao Zhang 824*d52a580bSJunchao Zhang /* compute the transpose of the lower triangular factor, i.e. the CSC */ 825*d52a580bSJunchao Zhang /* Csr2cscEx2 is not implemented in ROCm-5.2.0 and is planned for implementation in hipsparse with future releases of ROCm 826*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 2, 0) 827*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCsr2cscEx2_bufferSize(hipsparseTriFactors->handle, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_cols, loTriFactor->csrMat->num_entries, loTriFactor->csrMat->values->data().get(), 828*d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactorT->csrMat->values->data().get(), loTriFactorT->csrMat->row_offsets->data().get(), 829*d52a580bSJunchao Zhang loTriFactorT->csrMat->column_indices->data().get(), hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, &loTriFactor->csr2cscBufferSize)); 830*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&loTriFactor->csr2cscBuffer, loTriFactor->csr2cscBufferSize)); 831*d52a580bSJunchao Zhang #endif 832*d52a580bSJunchao Zhang */ 833*d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0)); 834*d52a580bSJunchao Zhang 835*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparseTriFactors->handle, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_cols, loTriFactor->csrMat->num_entries, loTriFactor->csrMat->values->data().get(), loTriFactor->csrMat->row_offsets->data().get(), 836*d52a580bSJunchao Zhang loTriFactor->csrMat->column_indices->data().get(), loTriFactorT->csrMat->values->data().get(), 837*d52a580bSJunchao Zhang #if 0 /* when Csr2cscEx2 is implemented in hipSparse PETSC_PKG_HIP_VERSION_GE(5, 2, 0)*/ 838*d52a580bSJunchao Zhang loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), 839*d52a580bSJunchao Zhang hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, loTriFactor->csr2cscBuffer)); 840*d52a580bSJunchao Zhang #else 841*d52a580bSJunchao Zhang loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->csrMat->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase)); 842*d52a580bSJunchao Zhang #endif 843*d52a580bSJunchao Zhang 844*d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP()); 845*d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0)); 846*d52a580bSJunchao Zhang 847*d52a580bSJunchao Zhang /* Create the solve analysis information */ 848*d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0)); 849*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&loTriFactorT->solveInfo)); 850*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(), 851*d52a580bSJunchao Zhang loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, &loTriFactorT->solveBufferSize)); 852*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&loTriFactorT->solveBuffer, loTriFactorT->solveBufferSize)); 853*d52a580bSJunchao Zhang 854*d52a580bSJunchao Zhang /* perform the solve analysis */ 855*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(), 856*d52a580bSJunchao Zhang loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, loTriFactorT->solvePolicy, loTriFactorT->solveBuffer)); 857*d52a580bSJunchao Zhang 858*d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP()); 859*d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0)); 860*d52a580bSJunchao Zhang 861*d52a580bSJunchao Zhang /* assign the pointer */ 862*d52a580bSJunchao Zhang ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->loTriFactorPtrTranspose = loTriFactorT; 863*d52a580bSJunchao Zhang 864*d52a580bSJunchao Zhang /*********************************************/ 865*d52a580bSJunchao Zhang /* Now the Transpose of the Upper Tri Factor */ 866*d52a580bSJunchao Zhang /*********************************************/ 867*d52a580bSJunchao Zhang 868*d52a580bSJunchao Zhang /* allocate space for the transpose of the upper triangular factor */ 869*d52a580bSJunchao Zhang PetscCall(PetscNew(&upTriFactorT)); 870*d52a580bSJunchao Zhang upTriFactorT->solvePolicy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL; 871*d52a580bSJunchao Zhang 872*d52a580bSJunchao Zhang /* set the matrix descriptors of the upper triangular factor */ 873*d52a580bSJunchao Zhang matrixType = hipsparseGetMatType(upTriFactor->descr); 874*d52a580bSJunchao Zhang indexBase = hipsparseGetMatIndexBase(upTriFactor->descr); 875*d52a580bSJunchao Zhang fillMode = hipsparseGetMatFillMode(upTriFactor->descr) == HIPSPARSE_FILL_MODE_UPPER ? HIPSPARSE_FILL_MODE_LOWER : HIPSPARSE_FILL_MODE_UPPER; 876*d52a580bSJunchao Zhang diagType = hipsparseGetMatDiagType(upTriFactor->descr); 877*d52a580bSJunchao Zhang 878*d52a580bSJunchao Zhang /* Create the matrix description */ 879*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&upTriFactorT->descr)); 880*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(upTriFactorT->descr, indexBase)); 881*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(upTriFactorT->descr, matrixType)); 882*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatFillMode(upTriFactorT->descr, fillMode)); 883*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatDiagType(upTriFactorT->descr, diagType)); 884*d52a580bSJunchao Zhang 885*d52a580bSJunchao Zhang /* set the operation */ 886*d52a580bSJunchao Zhang upTriFactorT->solveOp = HIPSPARSE_OPERATION_NON_TRANSPOSE; 887*d52a580bSJunchao Zhang 888*d52a580bSJunchao Zhang /* allocate GPU space for the CSC of the upper triangular factor*/ 889*d52a580bSJunchao Zhang upTriFactorT->csrMat = new CsrMatrix; 890*d52a580bSJunchao Zhang upTriFactorT->csrMat->num_rows = upTriFactor->csrMat->num_cols; 891*d52a580bSJunchao Zhang upTriFactorT->csrMat->num_cols = upTriFactor->csrMat->num_rows; 892*d52a580bSJunchao Zhang upTriFactorT->csrMat->num_entries = upTriFactor->csrMat->num_entries; 893*d52a580bSJunchao Zhang upTriFactorT->csrMat->row_offsets = new THRUSTINTARRAY32(upTriFactorT->csrMat->num_rows + 1); 894*d52a580bSJunchao Zhang upTriFactorT->csrMat->column_indices = new THRUSTINTARRAY32(upTriFactorT->csrMat->num_entries); 895*d52a580bSJunchao Zhang upTriFactorT->csrMat->values = new THRUSTARRAY(upTriFactorT->csrMat->num_entries); 896*d52a580bSJunchao Zhang 897*d52a580bSJunchao Zhang /* compute the transpose of the upper triangular factor, i.e. the CSC */ 898*d52a580bSJunchao Zhang /* Csr2cscEx2 is not implemented in ROCm-5.2.0 and is planned for implementation in hipsparse with future releases of ROCm 899*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 2, 0) 900*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCsr2cscEx2_bufferSize(hipsparseTriFactors->handle, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_cols, upTriFactor->csrMat->num_entries, upTriFactor->csrMat->values->data().get(), 901*d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactorT->csrMat->values->data().get(), upTriFactorT->csrMat->row_offsets->data().get(), 902*d52a580bSJunchao Zhang upTriFactorT->csrMat->column_indices->data().get(), hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, &upTriFactor->csr2cscBufferSize)); 903*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&upTriFactor->csr2cscBuffer, upTriFactor->csr2cscBufferSize)); 904*d52a580bSJunchao Zhang #endif 905*d52a580bSJunchao Zhang */ 906*d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0)); 907*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparseTriFactors->handle, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_cols, upTriFactor->csrMat->num_entries, upTriFactor->csrMat->values->data().get(), upTriFactor->csrMat->row_offsets->data().get(), 908*d52a580bSJunchao Zhang upTriFactor->csrMat->column_indices->data().get(), upTriFactorT->csrMat->values->data().get(), 909*d52a580bSJunchao Zhang #if 0 /* when Csr2cscEx2 is implemented in hipSparse PETSC_PKG_HIP_VERSION_GE(5, 2, 0)*/ 910*d52a580bSJunchao Zhang upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), 911*d52a580bSJunchao Zhang hipsparse_scalartype, HIPSPARSE_ACTION_NUMERIC, indexBase, HIPSPARSE_CSR2CSC_ALG1, upTriFactor->csr2cscBuffer)); 912*d52a580bSJunchao Zhang #else 913*d52a580bSJunchao Zhang upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->csrMat->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase)); 914*d52a580bSJunchao Zhang #endif 915*d52a580bSJunchao Zhang 916*d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP()); 917*d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0)); 918*d52a580bSJunchao Zhang 919*d52a580bSJunchao Zhang /* Create the solve analysis information */ 920*d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0)); 921*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsrsvInfo(&upTriFactorT->solveInfo)); 922*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_buffsize(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(), 923*d52a580bSJunchao Zhang upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, &upTriFactorT->solveBufferSize)); 924*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&upTriFactorT->solveBuffer, upTriFactorT->solveBufferSize)); 925*d52a580bSJunchao Zhang 926*d52a580bSJunchao Zhang /* perform the solve analysis */ 927*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_analysis(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(), 928*d52a580bSJunchao Zhang upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, upTriFactorT->solvePolicy, upTriFactorT->solveBuffer)); 929*d52a580bSJunchao Zhang 930*d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP()); 931*d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSESolveAnalysis, A, 0, 0, 0)); 932*d52a580bSJunchao Zhang 933*d52a580bSJunchao Zhang /* assign the pointer */ 934*d52a580bSJunchao Zhang ((Mat_SeqAIJHIPSPARSETriFactors *)A->spptr)->upTriFactorPtrTranspose = upTriFactorT; 935*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 936*d52a580bSJunchao Zhang } 937*d52a580bSJunchao Zhang 938*d52a580bSJunchao Zhang struct PetscScalarToPetscInt { 939*d52a580bSJunchao Zhang __host__ __device__ PetscInt operator()(PetscScalar s) { return (PetscInt)PetscRealPart(s); } 940*d52a580bSJunchao Zhang }; 941*d52a580bSJunchao Zhang 942*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEFormExplicitTranspose(Mat A) 943*d52a580bSJunchao Zhang { 944*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr; 945*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *matstruct, *matstructT; 946*d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data; 947*d52a580bSJunchao Zhang hipsparseIndexBase_t indexBase; 948*d52a580bSJunchao Zhang 949*d52a580bSJunchao Zhang PetscFunctionBegin; 950*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A)); 951*d52a580bSJunchao Zhang matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->mat; 952*d52a580bSJunchao Zhang PetscCheck(matstruct, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing mat struct"); 953*d52a580bSJunchao Zhang matstructT = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->matTranspose; 954*d52a580bSJunchao Zhang PetscCheck(!A->transupdated || matstructT, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing matTranspose struct"); 955*d52a580bSJunchao Zhang if (A->transupdated) PetscFunctionReturn(PETSC_SUCCESS); 956*d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0)); 957*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 958*d52a580bSJunchao Zhang if (hipsparsestruct->format != MAT_HIPSPARSE_CSR) PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_TRUE)); 959*d52a580bSJunchao Zhang if (!hipsparsestruct->matTranspose) { /* create hipsparse matrix */ 960*d52a580bSJunchao Zhang matstructT = new Mat_SeqAIJHIPSPARSEMultStruct; 961*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&matstructT->descr)); 962*d52a580bSJunchao Zhang indexBase = hipsparseGetMatIndexBase(matstruct->descr); 963*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(matstructT->descr, indexBase)); 964*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(matstructT->descr, HIPSPARSE_MATRIX_TYPE_GENERAL)); 965*d52a580bSJunchao Zhang 966*d52a580bSJunchao Zhang /* set alpha and beta */ 967*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&matstructT->alpha_one, sizeof(PetscScalar))); 968*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&matstructT->beta_zero, sizeof(PetscScalar))); 969*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&matstructT->beta_one, sizeof(PetscScalar))); 970*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(matstructT->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice)); 971*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(matstructT->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice)); 972*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(matstructT->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice)); 973*d52a580bSJunchao Zhang 974*d52a580bSJunchao Zhang if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) { 975*d52a580bSJunchao Zhang CsrMatrix *matrixT = new CsrMatrix; 976*d52a580bSJunchao Zhang matstructT->mat = matrixT; 977*d52a580bSJunchao Zhang matrixT->num_rows = A->cmap->n; 978*d52a580bSJunchao Zhang matrixT->num_cols = A->rmap->n; 979*d52a580bSJunchao Zhang matrixT->num_entries = a->nz; 980*d52a580bSJunchao Zhang matrixT->row_offsets = new THRUSTINTARRAY32(matrixT->num_rows + 1); 981*d52a580bSJunchao Zhang matrixT->column_indices = new THRUSTINTARRAY32(a->nz); 982*d52a580bSJunchao Zhang matrixT->values = new THRUSTARRAY(a->nz); 983*d52a580bSJunchao Zhang 984*d52a580bSJunchao Zhang if (!hipsparsestruct->rowoffsets_gpu) hipsparsestruct->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1); 985*d52a580bSJunchao Zhang hipsparsestruct->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1); 986*d52a580bSJunchao Zhang 987*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&matstructT->matDescr, matrixT->num_rows, matrixT->num_cols, matrixT->num_entries, matrixT->row_offsets->data().get(), matrixT->column_indices->data().get(), matrixT->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, /* row offset, col idx type due to THRUSTINTARRAY32 */ 988*d52a580bSJunchao Zhang indexBase, hipsparse_scalartype)); 989*d52a580bSJunchao Zhang } else if (hipsparsestruct->format == MAT_HIPSPARSE_ELL || hipsparsestruct->format == MAT_HIPSPARSE_HYB) { 990*d52a580bSJunchao Zhang CsrMatrix *temp = new CsrMatrix; 991*d52a580bSJunchao Zhang CsrMatrix *tempT = new CsrMatrix; 992*d52a580bSJunchao Zhang /* First convert HYB to CSR */ 993*d52a580bSJunchao Zhang temp->num_rows = A->rmap->n; 994*d52a580bSJunchao Zhang temp->num_cols = A->cmap->n; 995*d52a580bSJunchao Zhang temp->num_entries = a->nz; 996*d52a580bSJunchao Zhang temp->row_offsets = new THRUSTINTARRAY32(A->rmap->n + 1); 997*d52a580bSJunchao Zhang temp->column_indices = new THRUSTINTARRAY32(a->nz); 998*d52a580bSJunchao Zhang temp->values = new THRUSTARRAY(a->nz); 999*d52a580bSJunchao Zhang 1000*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_hyb2csr(hipsparsestruct->handle, matstruct->descr, (hipsparseHybMat_t)matstruct->mat, temp->values->data().get(), temp->row_offsets->data().get(), temp->column_indices->data().get())); 1001*d52a580bSJunchao Zhang 1002*d52a580bSJunchao Zhang /* Next, convert CSR to CSC (i.e. the matrix transpose) */ 1003*d52a580bSJunchao Zhang tempT->num_rows = A->rmap->n; 1004*d52a580bSJunchao Zhang tempT->num_cols = A->cmap->n; 1005*d52a580bSJunchao Zhang tempT->num_entries = a->nz; 1006*d52a580bSJunchao Zhang tempT->row_offsets = new THRUSTINTARRAY32(A->rmap->n + 1); 1007*d52a580bSJunchao Zhang tempT->column_indices = new THRUSTINTARRAY32(a->nz); 1008*d52a580bSJunchao Zhang tempT->values = new THRUSTARRAY(a->nz); 1009*d52a580bSJunchao Zhang 1010*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparsestruct->handle, temp->num_rows, temp->num_cols, temp->num_entries, temp->values->data().get(), temp->row_offsets->data().get(), temp->column_indices->data().get(), tempT->values->data().get(), 1011*d52a580bSJunchao Zhang tempT->column_indices->data().get(), tempT->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase)); 1012*d52a580bSJunchao Zhang 1013*d52a580bSJunchao Zhang /* Last, convert CSC to HYB */ 1014*d52a580bSJunchao Zhang hipsparseHybMat_t hybMat; 1015*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateHybMat(&hybMat)); 1016*d52a580bSJunchao Zhang hipsparseHybPartition_t partition = hipsparsestruct->format == MAT_HIPSPARSE_ELL ? HIPSPARSE_HYB_PARTITION_MAX : HIPSPARSE_HYB_PARTITION_AUTO; 1017*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr2hyb(hipsparsestruct->handle, A->rmap->n, A->cmap->n, matstructT->descr, tempT->values->data().get(), tempT->row_offsets->data().get(), tempT->column_indices->data().get(), hybMat, 0, partition)); 1018*d52a580bSJunchao Zhang 1019*d52a580bSJunchao Zhang /* assign the pointer */ 1020*d52a580bSJunchao Zhang matstructT->mat = hybMat; 1021*d52a580bSJunchao Zhang A->transupdated = PETSC_TRUE; 1022*d52a580bSJunchao Zhang /* delete temporaries */ 1023*d52a580bSJunchao Zhang if (tempT) { 1024*d52a580bSJunchao Zhang if (tempT->values) delete (THRUSTARRAY *)tempT->values; 1025*d52a580bSJunchao Zhang if (tempT->column_indices) delete (THRUSTINTARRAY32 *)tempT->column_indices; 1026*d52a580bSJunchao Zhang if (tempT->row_offsets) delete (THRUSTINTARRAY32 *)tempT->row_offsets; 1027*d52a580bSJunchao Zhang delete (CsrMatrix *)tempT; 1028*d52a580bSJunchao Zhang } 1029*d52a580bSJunchao Zhang if (temp) { 1030*d52a580bSJunchao Zhang if (temp->values) delete (THRUSTARRAY *)temp->values; 1031*d52a580bSJunchao Zhang if (temp->column_indices) delete (THRUSTINTARRAY32 *)temp->column_indices; 1032*d52a580bSJunchao Zhang if (temp->row_offsets) delete (THRUSTINTARRAY32 *)temp->row_offsets; 1033*d52a580bSJunchao Zhang delete (CsrMatrix *)temp; 1034*d52a580bSJunchao Zhang } 1035*d52a580bSJunchao Zhang } 1036*d52a580bSJunchao Zhang } 1037*d52a580bSJunchao Zhang if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) { /* transpose mat struct may be already present, update data */ 1038*d52a580bSJunchao Zhang CsrMatrix *matrix = (CsrMatrix *)matstruct->mat; 1039*d52a580bSJunchao Zhang CsrMatrix *matrixT = (CsrMatrix *)matstructT->mat; 1040*d52a580bSJunchao Zhang PetscCheck(matrix, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix"); 1041*d52a580bSJunchao Zhang PetscCheck(matrix->row_offsets, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix rows"); 1042*d52a580bSJunchao Zhang PetscCheck(matrix->column_indices, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix cols"); 1043*d52a580bSJunchao Zhang PetscCheck(matrix->values, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrix values"); 1044*d52a580bSJunchao Zhang PetscCheck(matrixT, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT"); 1045*d52a580bSJunchao Zhang PetscCheck(matrixT->row_offsets, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT rows"); 1046*d52a580bSJunchao Zhang PetscCheck(matrixT->column_indices, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT cols"); 1047*d52a580bSJunchao Zhang PetscCheck(matrixT->values, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CsrMatrixT values"); 1048*d52a580bSJunchao Zhang if (!hipsparsestruct->rowoffsets_gpu) { /* this may be absent when we did not construct the transpose with csr2csc */ 1049*d52a580bSJunchao Zhang hipsparsestruct->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1); 1050*d52a580bSJunchao Zhang hipsparsestruct->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1); 1051*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu((A->rmap->n + 1) * sizeof(PetscInt))); 1052*d52a580bSJunchao Zhang } 1053*d52a580bSJunchao Zhang if (!hipsparsestruct->csr2csc_i) { 1054*d52a580bSJunchao Zhang THRUSTARRAY csr2csc_a(matrix->num_entries); 1055*d52a580bSJunchao Zhang PetscCallThrust(thrust::sequence(thrust::device, csr2csc_a.begin(), csr2csc_a.end(), 0.0)); 1056*d52a580bSJunchao Zhang 1057*d52a580bSJunchao Zhang indexBase = hipsparseGetMatIndexBase(matstruct->descr); 1058*d52a580bSJunchao Zhang if (matrix->num_entries) { 1059*d52a580bSJunchao Zhang /* This routine is known to give errors with CUDA-11, but works fine with CUDA-10 1060*d52a580bSJunchao Zhang Need to verify this for ROCm. 1061*d52a580bSJunchao Zhang */ 1062*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr2csc(hipsparsestruct->handle, A->rmap->n, A->cmap->n, matrix->num_entries, csr2csc_a.data().get(), hipsparsestruct->rowoffsets_gpu->data().get(), matrix->column_indices->data().get(), matrixT->values->data().get(), 1063*d52a580bSJunchao Zhang matrixT->column_indices->data().get(), matrixT->row_offsets->data().get(), HIPSPARSE_ACTION_NUMERIC, indexBase)); 1064*d52a580bSJunchao Zhang } else { 1065*d52a580bSJunchao Zhang matrixT->row_offsets->assign(matrixT->row_offsets->size(), indexBase); 1066*d52a580bSJunchao Zhang } 1067*d52a580bSJunchao Zhang 1068*d52a580bSJunchao Zhang hipsparsestruct->csr2csc_i = new THRUSTINTARRAY(matrix->num_entries); 1069*d52a580bSJunchao Zhang PetscCallThrust(thrust::transform(thrust::device, matrixT->values->begin(), matrixT->values->end(), hipsparsestruct->csr2csc_i->begin(), PetscScalarToPetscInt())); 1070*d52a580bSJunchao Zhang } 1071*d52a580bSJunchao Zhang PetscCallThrust( 1072*d52a580bSJunchao Zhang thrust::copy(thrust::device, thrust::make_permutation_iterator(matrix->values->begin(), hipsparsestruct->csr2csc_i->begin()), thrust::make_permutation_iterator(matrix->values->begin(), hipsparsestruct->csr2csc_i->end()), matrixT->values->begin())); 1073*d52a580bSJunchao Zhang } 1074*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 1075*d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSEGenerateTranspose, A, 0, 0, 0)); 1076*d52a580bSJunchao Zhang /* the compressed row indices is not used for matTranspose */ 1077*d52a580bSJunchao Zhang matstructT->cprowIndices = NULL; 1078*d52a580bSJunchao Zhang /* assign the pointer */ 1079*d52a580bSJunchao Zhang ((Mat_SeqAIJHIPSPARSE *)A->spptr)->matTranspose = matstructT; 1080*d52a580bSJunchao Zhang A->transupdated = PETSC_TRUE; 1081*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1082*d52a580bSJunchao Zhang } 1083*d52a580bSJunchao Zhang 1084*d52a580bSJunchao Zhang /* Why do we need to analyze the transposed matrix again? Can't we just use op(A) = HIPSPARSE_OPERATION_TRANSPOSE in MatSolve_SeqAIJHIPSPARSE? */ 1085*d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE(Mat A, Vec bb, Vec xx) 1086*d52a580bSJunchao Zhang { 1087*d52a580bSJunchao Zhang PetscInt n = xx->map->n; 1088*d52a580bSJunchao Zhang const PetscScalar *barray; 1089*d52a580bSJunchao Zhang PetscScalar *xarray; 1090*d52a580bSJunchao Zhang thrust::device_ptr<const PetscScalar> bGPU; 1091*d52a580bSJunchao Zhang thrust::device_ptr<PetscScalar> xGPU; 1092*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr; 1093*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose; 1094*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose; 1095*d52a580bSJunchao Zhang THRUSTARRAY *tempGPU = (THRUSTARRAY *)hipsparseTriFactors->workVector; 1096*d52a580bSJunchao Zhang 1097*d52a580bSJunchao Zhang PetscFunctionBegin; 1098*d52a580bSJunchao Zhang /* Analyze the matrix and create the transpose ... on the fly */ 1099*d52a580bSJunchao Zhang if (!loTriFactorT && !upTriFactorT) { 1100*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEAnalyzeTransposeForSolve(A)); 1101*d52a580bSJunchao Zhang loTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose; 1102*d52a580bSJunchao Zhang upTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose; 1103*d52a580bSJunchao Zhang } 1104*d52a580bSJunchao Zhang 1105*d52a580bSJunchao Zhang /* Get the GPU pointers */ 1106*d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayWrite(xx, &xarray)); 1107*d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(bb, &barray)); 1108*d52a580bSJunchao Zhang xGPU = thrust::device_pointer_cast(xarray); 1109*d52a580bSJunchao Zhang bGPU = thrust::device_pointer_cast(barray); 1110*d52a580bSJunchao Zhang 1111*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 1112*d52a580bSJunchao Zhang /* First, reorder with the row permutation */ 1113*d52a580bSJunchao Zhang thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(bGPU, hipsparseTriFactors->rpermIndices->begin()), thrust::make_permutation_iterator(bGPU + n, hipsparseTriFactors->rpermIndices->end()), xGPU); 1114*d52a580bSJunchao Zhang 1115*d52a580bSJunchao Zhang /* First, solve U */ 1116*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(), 1117*d52a580bSJunchao Zhang upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, xarray, tempGPU->data().get(), upTriFactorT->solvePolicy, upTriFactorT->solveBuffer)); 1118*d52a580bSJunchao Zhang 1119*d52a580bSJunchao Zhang /* Then, solve L */ 1120*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(), 1121*d52a580bSJunchao Zhang loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, tempGPU->data().get(), xarray, loTriFactorT->solvePolicy, loTriFactorT->solveBuffer)); 1122*d52a580bSJunchao Zhang 1123*d52a580bSJunchao Zhang /* Last, copy the solution, xGPU, into a temporary with the column permutation ... can't be done in place. */ 1124*d52a580bSJunchao Zhang thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(xGPU, hipsparseTriFactors->cpermIndices->begin()), thrust::make_permutation_iterator(xGPU + n, hipsparseTriFactors->cpermIndices->end()), tempGPU->begin()); 1125*d52a580bSJunchao Zhang 1126*d52a580bSJunchao Zhang /* Copy the temporary to the full solution. */ 1127*d52a580bSJunchao Zhang thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), tempGPU->begin(), tempGPU->end(), xGPU); 1128*d52a580bSJunchao Zhang 1129*d52a580bSJunchao Zhang /* restore */ 1130*d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(bb, &barray)); 1131*d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayWrite(xx, &xarray)); 1132*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 1133*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n)); 1134*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1135*d52a580bSJunchao Zhang } 1136*d52a580bSJunchao Zhang 1137*d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE_NaturalOrdering(Mat A, Vec bb, Vec xx) 1138*d52a580bSJunchao Zhang { 1139*d52a580bSJunchao Zhang const PetscScalar *barray; 1140*d52a580bSJunchao Zhang PetscScalar *xarray; 1141*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr; 1142*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose; 1143*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose; 1144*d52a580bSJunchao Zhang THRUSTARRAY *tempGPU = (THRUSTARRAY *)hipsparseTriFactors->workVector; 1145*d52a580bSJunchao Zhang 1146*d52a580bSJunchao Zhang PetscFunctionBegin; 1147*d52a580bSJunchao Zhang /* Analyze the matrix and create the transpose ... on the fly */ 1148*d52a580bSJunchao Zhang if (!loTriFactorT && !upTriFactorT) { 1149*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEAnalyzeTransposeForSolve(A)); 1150*d52a580bSJunchao Zhang loTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtrTranspose; 1151*d52a580bSJunchao Zhang upTriFactorT = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtrTranspose; 1152*d52a580bSJunchao Zhang } 1153*d52a580bSJunchao Zhang 1154*d52a580bSJunchao Zhang /* Get the GPU pointers */ 1155*d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayWrite(xx, &xarray)); 1156*d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(bb, &barray)); 1157*d52a580bSJunchao Zhang 1158*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 1159*d52a580bSJunchao Zhang /* First, solve U */ 1160*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactorT->solveOp, upTriFactorT->csrMat->num_rows, upTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactorT->descr, upTriFactorT->csrMat->values->data().get(), 1161*d52a580bSJunchao Zhang upTriFactorT->csrMat->row_offsets->data().get(), upTriFactorT->csrMat->column_indices->data().get(), upTriFactorT->solveInfo, barray, tempGPU->data().get(), upTriFactorT->solvePolicy, upTriFactorT->solveBuffer)); 1162*d52a580bSJunchao Zhang 1163*d52a580bSJunchao Zhang /* Then, solve L */ 1164*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactorT->solveOp, loTriFactorT->csrMat->num_rows, loTriFactorT->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactorT->descr, loTriFactorT->csrMat->values->data().get(), 1165*d52a580bSJunchao Zhang loTriFactorT->csrMat->row_offsets->data().get(), loTriFactorT->csrMat->column_indices->data().get(), loTriFactorT->solveInfo, tempGPU->data().get(), xarray, loTriFactorT->solvePolicy, loTriFactorT->solveBuffer)); 1166*d52a580bSJunchao Zhang 1167*d52a580bSJunchao Zhang /* restore */ 1168*d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(bb, &barray)); 1169*d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayWrite(xx, &xarray)); 1170*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 1171*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n)); 1172*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1173*d52a580bSJunchao Zhang } 1174*d52a580bSJunchao Zhang 1175*d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE(Mat A, Vec bb, Vec xx) 1176*d52a580bSJunchao Zhang { 1177*d52a580bSJunchao Zhang const PetscScalar *barray; 1178*d52a580bSJunchao Zhang PetscScalar *xarray; 1179*d52a580bSJunchao Zhang thrust::device_ptr<const PetscScalar> bGPU; 1180*d52a580bSJunchao Zhang thrust::device_ptr<PetscScalar> xGPU; 1181*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr; 1182*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr; 1183*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr; 1184*d52a580bSJunchao Zhang THRUSTARRAY *tempGPU = (THRUSTARRAY *)hipsparseTriFactors->workVector; 1185*d52a580bSJunchao Zhang 1186*d52a580bSJunchao Zhang PetscFunctionBegin; 1187*d52a580bSJunchao Zhang /* Get the GPU pointers */ 1188*d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayWrite(xx, &xarray)); 1189*d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(bb, &barray)); 1190*d52a580bSJunchao Zhang xGPU = thrust::device_pointer_cast(xarray); 1191*d52a580bSJunchao Zhang bGPU = thrust::device_pointer_cast(barray); 1192*d52a580bSJunchao Zhang 1193*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 1194*d52a580bSJunchao Zhang /* First, reorder with the row permutation */ 1195*d52a580bSJunchao Zhang thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(bGPU, hipsparseTriFactors->rpermIndices->begin()), thrust::make_permutation_iterator(bGPU, hipsparseTriFactors->rpermIndices->end()), tempGPU->begin()); 1196*d52a580bSJunchao Zhang 1197*d52a580bSJunchao Zhang /* Next, solve L */ 1198*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactor->descr, loTriFactor->csrMat->values->data().get(), 1199*d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, tempGPU->data().get(), xarray, loTriFactor->solvePolicy, loTriFactor->solveBuffer)); 1200*d52a580bSJunchao Zhang 1201*d52a580bSJunchao Zhang /* Then, solve U */ 1202*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactor->descr, upTriFactor->csrMat->values->data().get(), 1203*d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, xarray, tempGPU->data().get(), upTriFactor->solvePolicy, upTriFactor->solveBuffer)); 1204*d52a580bSJunchao Zhang 1205*d52a580bSJunchao Zhang /* Last, reorder with the column permutation */ 1206*d52a580bSJunchao Zhang thrust::copy(thrust::hip::par.on(PetscDefaultHipStream), thrust::make_permutation_iterator(tempGPU->begin(), hipsparseTriFactors->cpermIndices->begin()), thrust::make_permutation_iterator(tempGPU->begin(), hipsparseTriFactors->cpermIndices->end()), xGPU); 1207*d52a580bSJunchao Zhang 1208*d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(bb, &barray)); 1209*d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayWrite(xx, &xarray)); 1210*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 1211*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n)); 1212*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1213*d52a580bSJunchao Zhang } 1214*d52a580bSJunchao Zhang 1215*d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_NaturalOrdering(Mat A, Vec bb, Vec xx) 1216*d52a580bSJunchao Zhang { 1217*d52a580bSJunchao Zhang const PetscScalar *barray; 1218*d52a580bSJunchao Zhang PetscScalar *xarray; 1219*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr; 1220*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->loTriFactorPtr; 1221*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactor = (Mat_SeqAIJHIPSPARSETriFactorStruct *)hipsparseTriFactors->upTriFactorPtr; 1222*d52a580bSJunchao Zhang THRUSTARRAY *tempGPU = (THRUSTARRAY *)hipsparseTriFactors->workVector; 1223*d52a580bSJunchao Zhang 1224*d52a580bSJunchao Zhang PetscFunctionBegin; 1225*d52a580bSJunchao Zhang /* Get the GPU pointers */ 1226*d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayWrite(xx, &xarray)); 1227*d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(bb, &barray)); 1228*d52a580bSJunchao Zhang 1229*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 1230*d52a580bSJunchao Zhang /* First, solve L */ 1231*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, loTriFactor->solveOp, loTriFactor->csrMat->num_rows, loTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, loTriFactor->descr, loTriFactor->csrMat->values->data().get(), 1232*d52a580bSJunchao Zhang loTriFactor->csrMat->row_offsets->data().get(), loTriFactor->csrMat->column_indices->data().get(), loTriFactor->solveInfo, barray, tempGPU->data().get(), loTriFactor->solvePolicy, loTriFactor->solveBuffer)); 1233*d52a580bSJunchao Zhang 1234*d52a580bSJunchao Zhang /* Next, solve U */ 1235*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrsv_solve(hipsparseTriFactors->handle, upTriFactor->solveOp, upTriFactor->csrMat->num_rows, upTriFactor->csrMat->num_entries, &PETSC_HIPSPARSE_ONE, upTriFactor->descr, upTriFactor->csrMat->values->data().get(), 1236*d52a580bSJunchao Zhang upTriFactor->csrMat->row_offsets->data().get(), upTriFactor->csrMat->column_indices->data().get(), upTriFactor->solveInfo, tempGPU->data().get(), xarray, upTriFactor->solvePolicy, upTriFactor->solveBuffer)); 1237*d52a580bSJunchao Zhang 1238*d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(bb, &barray)); 1239*d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayWrite(xx, &xarray)); 1240*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 1241*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * hipsparseTriFactors->nnz - A->cmap->n)); 1242*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1243*d52a580bSJunchao Zhang } 1244*d52a580bSJunchao Zhang 1245*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0) 1246*d52a580bSJunchao Zhang /* hipsparseSpSV_solve() and related functions first appeared in ROCm-4.5.0*/ 1247*d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_ILU0(Mat fact, Vec b, Vec x) 1248*d52a580bSJunchao Zhang { 1249*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr; 1250*d52a580bSJunchao Zhang Mat_SeqAIJ *aij = (Mat_SeqAIJ *)fact->data; 1251*d52a580bSJunchao Zhang const PetscScalar *barray; 1252*d52a580bSJunchao Zhang PetscScalar *xarray; 1253*d52a580bSJunchao Zhang 1254*d52a580bSJunchao Zhang PetscFunctionBegin; 1255*d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayWrite(x, &xarray)); 1256*d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(b, &barray)); 1257*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 1258*d52a580bSJunchao Zhang 1259*d52a580bSJunchao Zhang /* Solve L*y = b */ 1260*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray)); 1261*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y)); 1262*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0) 1263*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */ 1264*d52a580bSJunchao Zhang fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L)); // hipsparseSpSV_solve() secretely uses the external buffer used in hipsparseSpSV_analysis()! 1265*d52a580bSJunchao Zhang #else 1266*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */ 1267*d52a580bSJunchao Zhang fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L)); // hipsparseSpSV_solve() secretely uses the external buffer used in hipsparseSpSV_analysis()! 1268*d52a580bSJunchao Zhang #endif 1269*d52a580bSJunchao Zhang /* Solve U*x = y */ 1270*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray)); 1271*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0) 1272*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* U X = Y */ 1273*d52a580bSJunchao Zhang fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U)); 1274*d52a580bSJunchao Zhang #else 1275*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* U X = Y */ 1276*d52a580bSJunchao Zhang fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U, fs->spsvBuffer_U)); 1277*d52a580bSJunchao Zhang #endif 1278*d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(b, &barray)); 1279*d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayWrite(x, &xarray)); 1280*d52a580bSJunchao Zhang 1281*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 1282*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * aij->nz - fact->rmap->n)); 1283*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1284*d52a580bSJunchao Zhang } 1285*d52a580bSJunchao Zhang 1286*d52a580bSJunchao Zhang static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE_ILU0(Mat fact, Vec b, Vec x) 1287*d52a580bSJunchao Zhang { 1288*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr; 1289*d52a580bSJunchao Zhang Mat_SeqAIJ *aij = (Mat_SeqAIJ *)fact->data; 1290*d52a580bSJunchao Zhang const PetscScalar *barray; 1291*d52a580bSJunchao Zhang PetscScalar *xarray; 1292*d52a580bSJunchao Zhang 1293*d52a580bSJunchao Zhang PetscFunctionBegin; 1294*d52a580bSJunchao Zhang if (!fs->createdTransposeSpSVDescr) { /* Call MatSolveTranspose() for the first time */ 1295*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_Lt)); 1296*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* The matrix is still L. We only do transpose solve with it */ 1297*d52a580bSJunchao Zhang fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, &fs->spsvBufferSize_Lt)); 1298*d52a580bSJunchao Zhang 1299*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_Ut)); 1300*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut, &fs->spsvBufferSize_Ut)); 1301*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_Lt, fs->spsvBufferSize_Lt)); 1302*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_Ut, fs->spsvBufferSize_Ut)); 1303*d52a580bSJunchao Zhang fs->createdTransposeSpSVDescr = PETSC_TRUE; 1304*d52a580bSJunchao Zhang } 1305*d52a580bSJunchao Zhang 1306*d52a580bSJunchao Zhang if (!fs->updatedTransposeSpSVAnalysis) { 1307*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt)); 1308*d52a580bSJunchao Zhang 1309*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut, fs->spsvBuffer_Ut)); 1310*d52a580bSJunchao Zhang fs->updatedTransposeSpSVAnalysis = PETSC_TRUE; 1311*d52a580bSJunchao Zhang } 1312*d52a580bSJunchao Zhang 1313*d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayWrite(x, &xarray)); 1314*d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(b, &barray)); 1315*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 1316*d52a580bSJunchao Zhang 1317*d52a580bSJunchao Zhang /* Solve Ut*y = b */ 1318*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray)); 1319*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y)); 1320*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0) 1321*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* Ut Y = X */ 1322*d52a580bSJunchao Zhang fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut)); 1323*d52a580bSJunchao Zhang #else 1324*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* Ut Y = X */ 1325*d52a580bSJunchao Zhang fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut, fs->spsvBuffer_Ut)); 1326*d52a580bSJunchao Zhang #endif 1327*d52a580bSJunchao Zhang /* Solve Lt*x = y */ 1328*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray)); 1329*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0) 1330*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */ 1331*d52a580bSJunchao Zhang fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt)); 1332*d52a580bSJunchao Zhang #else 1333*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */ 1334*d52a580bSJunchao Zhang fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt)); 1335*d52a580bSJunchao Zhang #endif 1336*d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(b, &barray)); 1337*d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayWrite(x, &xarray)); 1338*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 1339*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * aij->nz - fact->rmap->n)); 1340*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1341*d52a580bSJunchao Zhang } 1342*d52a580bSJunchao Zhang 1343*d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorNumeric_SeqAIJHIPSPARSE_ILU0(Mat fact, Mat A, const MatFactorInfo *info) 1344*d52a580bSJunchao Zhang { 1345*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr; 1346*d52a580bSJunchao Zhang Mat_SeqAIJ *aij = (Mat_SeqAIJ *)fact->data; 1347*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr; 1348*d52a580bSJunchao Zhang CsrMatrix *Acsr; 1349*d52a580bSJunchao Zhang PetscInt m, nz; 1350*d52a580bSJunchao Zhang PetscBool flg; 1351*d52a580bSJunchao Zhang 1352*d52a580bSJunchao Zhang PetscFunctionBegin; 1353*d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) { 1354*d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg)); 1355*d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name); 1356*d52a580bSJunchao Zhang } 1357*d52a580bSJunchao Zhang 1358*d52a580bSJunchao Zhang /* Copy A's value to fact */ 1359*d52a580bSJunchao Zhang m = fact->rmap->n; 1360*d52a580bSJunchao Zhang nz = aij->nz; 1361*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A)); 1362*d52a580bSJunchao Zhang Acsr = (CsrMatrix *)Acusp->mat->mat; 1363*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpyAsync(fs->csrVal, Acsr->values->data().get(), sizeof(PetscScalar) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream)); 1364*d52a580bSJunchao Zhang 1365*d52a580bSJunchao Zhang /* Factorize fact inplace */ 1366*d52a580bSJunchao Zhang if (m) 1367*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrilu02(fs->handle, m, nz, /* hipsparseXcsrilu02 errors out with empty matrices (m=0) */ 1368*d52a580bSJunchao Zhang fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ilu0Info_M, fs->policy_M, fs->factBuffer_M)); 1369*d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) { 1370*d52a580bSJunchao Zhang int numerical_zero; 1371*d52a580bSJunchao Zhang hipsparseStatus_t status; 1372*d52a580bSJunchao Zhang status = hipsparseXcsrilu02_zeroPivot(fs->handle, fs->ilu0Info_M, &numerical_zero); 1373*d52a580bSJunchao Zhang PetscAssert(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Numerical zero pivot detected in csrilu02: A(%d,%d) is zero", numerical_zero, numerical_zero); 1374*d52a580bSJunchao Zhang } 1375*d52a580bSJunchao Zhang 1376*d52a580bSJunchao Zhang /* hipsparseSpSV_analysis() is numeric, i.e., it requires valid matrix values, therefore, we do it after hipsparseXcsrilu02() */ 1377*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L)); 1378*d52a580bSJunchao Zhang 1379*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U, fs->spsvBuffer_U)); 1380*d52a580bSJunchao Zhang 1381*d52a580bSJunchao Zhang /* L, U values have changed, reset the flag to indicate we need to redo hipsparseSpSV_analysis() for transpose solve */ 1382*d52a580bSJunchao Zhang fs->updatedTransposeSpSVAnalysis = PETSC_FALSE; 1383*d52a580bSJunchao Zhang 1384*d52a580bSJunchao Zhang fact->offloadmask = PETSC_OFFLOAD_GPU; 1385*d52a580bSJunchao Zhang fact->ops->solve = MatSolve_SeqAIJHIPSPARSE_ILU0; 1386*d52a580bSJunchao Zhang fact->ops->solvetranspose = MatSolveTranspose_SeqAIJHIPSPARSE_ILU0; 1387*d52a580bSJunchao Zhang fact->ops->matsolve = NULL; 1388*d52a580bSJunchao Zhang fact->ops->matsolvetranspose = NULL; 1389*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(fs->numericFactFlops)); 1390*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1391*d52a580bSJunchao Zhang } 1392*d52a580bSJunchao Zhang 1393*d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJHIPSPARSE_ILU0(Mat fact, Mat A, IS isrow, IS iscol, const MatFactorInfo *info) 1394*d52a580bSJunchao Zhang { 1395*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr; 1396*d52a580bSJunchao Zhang Mat_SeqAIJ *aij = (Mat_SeqAIJ *)fact->data; 1397*d52a580bSJunchao Zhang PetscInt m, nz; 1398*d52a580bSJunchao Zhang 1399*d52a580bSJunchao Zhang PetscFunctionBegin; 1400*d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) { 1401*d52a580bSJunchao Zhang PetscBool flg, diagDense; 1402*d52a580bSJunchao Zhang 1403*d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg)); 1404*d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name); 1405*d52a580bSJunchao Zhang PetscCheck(A->rmap->n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must be square matrix, rows %" PetscInt_FMT " columns %" PetscInt_FMT, A->rmap->n, A->cmap->n); 1406*d52a580bSJunchao Zhang PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, NULL, &diagDense)); 1407*d52a580bSJunchao Zhang PetscCheck(diagDense, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Matrix is missing diagonal entries"); 1408*d52a580bSJunchao Zhang } 1409*d52a580bSJunchao Zhang 1410*d52a580bSJunchao Zhang /* Free the old stale stuff */ 1411*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&fs)); 1412*d52a580bSJunchao Zhang 1413*d52a580bSJunchao Zhang /* Copy over A's meta data to fact. Note that we also allocated fact's i,j,a on host, 1414*d52a580bSJunchao Zhang but they will not be used. Allocate them just for easy debugging. 1415*d52a580bSJunchao Zhang */ 1416*d52a580bSJunchao Zhang PetscCall(MatDuplicateNoCreate_SeqAIJ(fact, A, MAT_DO_NOT_COPY_VALUES, PETSC_TRUE /*malloc*/)); 1417*d52a580bSJunchao Zhang 1418*d52a580bSJunchao Zhang fact->offloadmask = PETSC_OFFLOAD_BOTH; 1419*d52a580bSJunchao Zhang fact->factortype = MAT_FACTOR_ILU; 1420*d52a580bSJunchao Zhang fact->info.factor_mallocs = 0; 1421*d52a580bSJunchao Zhang fact->info.fill_ratio_given = info->fill; 1422*d52a580bSJunchao Zhang fact->info.fill_ratio_needed = 1.0; 1423*d52a580bSJunchao Zhang 1424*d52a580bSJunchao Zhang aij->row = NULL; 1425*d52a580bSJunchao Zhang aij->col = NULL; 1426*d52a580bSJunchao Zhang 1427*d52a580bSJunchao Zhang /* ====================================================================== */ 1428*d52a580bSJunchao Zhang /* Copy A's i, j to fact and also allocate the value array of fact. */ 1429*d52a580bSJunchao Zhang /* We'll do in-place factorization on fact */ 1430*d52a580bSJunchao Zhang /* ====================================================================== */ 1431*d52a580bSJunchao Zhang const int *Ai, *Aj; 1432*d52a580bSJunchao Zhang 1433*d52a580bSJunchao Zhang m = fact->rmap->n; 1434*d52a580bSJunchao Zhang nz = aij->nz; 1435*d52a580bSJunchao Zhang 1436*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->csrRowPtr, sizeof(int) * (m + 1))); 1437*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->csrColIdx, sizeof(int) * nz)); 1438*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->csrVal, sizeof(PetscScalar) * nz)); 1439*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetIJ(A, PETSC_FALSE, &Ai, &Aj)); /* Do not use compressed Ai */ 1440*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpyAsync(fs->csrRowPtr, Ai, sizeof(int) * (m + 1), hipMemcpyDeviceToDevice, PetscDefaultHipStream)); 1441*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpyAsync(fs->csrColIdx, Aj, sizeof(int) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream)); 1442*d52a580bSJunchao Zhang 1443*d52a580bSJunchao Zhang /* ====================================================================== */ 1444*d52a580bSJunchao Zhang /* Create descriptors for M, L, U */ 1445*d52a580bSJunchao Zhang /* ====================================================================== */ 1446*d52a580bSJunchao Zhang hipsparseFillMode_t fillMode; 1447*d52a580bSJunchao Zhang hipsparseDiagType_t diagType; 1448*d52a580bSJunchao Zhang 1449*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&fs->matDescr_M)); 1450*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(fs->matDescr_M, HIPSPARSE_INDEX_BASE_ZERO)); 1451*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(fs->matDescr_M, HIPSPARSE_MATRIX_TYPE_GENERAL)); 1452*d52a580bSJunchao Zhang 1453*d52a580bSJunchao Zhang /* https://docs.amd.com/bundle/hipSPARSE-Documentation---hipSPARSE-documentation/page/usermanual.html/#hipsparse_8h_1a79e036b6c0680cb37e2aa53d3542a054 1454*d52a580bSJunchao Zhang hipsparseDiagType_t: This type indicates if the matrix diagonal entries are unity. The diagonal elements are always 1455*d52a580bSJunchao Zhang assumed to be present, but if HIPSPARSE_DIAG_TYPE_UNIT is passed to an API routine, then the routine assumes that 1456*d52a580bSJunchao Zhang all diagonal entries are unity and will not read or modify those entries. Note that in this case the routine 1457*d52a580bSJunchao Zhang assumes the diagonal entries are equal to one, regardless of what those entries are actually set to in memory. 1458*d52a580bSJunchao Zhang */ 1459*d52a580bSJunchao Zhang fillMode = HIPSPARSE_FILL_MODE_LOWER; 1460*d52a580bSJunchao Zhang diagType = HIPSPARSE_DIAG_TYPE_UNIT; 1461*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&fs->spMatDescr_L, m, m, nz, fs->csrRowPtr, fs->csrColIdx, fs->csrVal, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype)); 1462*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_FILL_MODE, &fillMode, sizeof(fillMode))); 1463*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_DIAG_TYPE, &diagType, sizeof(diagType))); 1464*d52a580bSJunchao Zhang 1465*d52a580bSJunchao Zhang fillMode = HIPSPARSE_FILL_MODE_UPPER; 1466*d52a580bSJunchao Zhang diagType = HIPSPARSE_DIAG_TYPE_NON_UNIT; 1467*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&fs->spMatDescr_U, m, m, nz, fs->csrRowPtr, fs->csrColIdx, fs->csrVal, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype)); 1468*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_U, HIPSPARSE_SPMAT_FILL_MODE, &fillMode, sizeof(fillMode))); 1469*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_U, HIPSPARSE_SPMAT_DIAG_TYPE, &diagType, sizeof(diagType))); 1470*d52a580bSJunchao Zhang 1471*d52a580bSJunchao Zhang /* ========================================================================= */ 1472*d52a580bSJunchao Zhang /* Query buffer sizes for csrilu0, SpSV and allocate buffers */ 1473*d52a580bSJunchao Zhang /* ========================================================================= */ 1474*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsrilu02Info(&fs->ilu0Info_M)); 1475*d52a580bSJunchao Zhang if (m) 1476*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrilu02_bufferSize(fs->handle, m, nz, /* hipsparseXcsrilu02 errors out with empty matrices (m=0) */ 1477*d52a580bSJunchao Zhang fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ilu0Info_M, &fs->factBufferSize_M)); 1478*d52a580bSJunchao Zhang 1479*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->X, sizeof(PetscScalar) * m)); 1480*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->Y, sizeof(PetscScalar) * m)); 1481*d52a580bSJunchao Zhang 1482*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_X, m, fs->X, hipsparse_scalartype)); 1483*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_Y, m, fs->Y, hipsparse_scalartype)); 1484*d52a580bSJunchao Zhang 1485*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_L)); 1486*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, &fs->spsvBufferSize_L)); 1487*d52a580bSJunchao Zhang 1488*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_U)); 1489*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U, &fs->spsvBufferSize_U)); 1490*d52a580bSJunchao Zhang 1491*d52a580bSJunchao Zhang /* It appears spsvBuffer_L/U can not be shared (i.e., the same) for our case, but factBuffer_M can share with either of spsvBuffer_L/U. 1492*d52a580bSJunchao Zhang To save memory, we make factBuffer_M share with the bigger of spsvBuffer_L/U. 1493*d52a580bSJunchao Zhang */ 1494*d52a580bSJunchao Zhang if (fs->spsvBufferSize_L > fs->spsvBufferSize_U) { 1495*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_L, (size_t)fs->factBufferSize_M))); 1496*d52a580bSJunchao Zhang fs->spsvBuffer_L = fs->factBuffer_M; 1497*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_U, fs->spsvBufferSize_U)); 1498*d52a580bSJunchao Zhang } else { 1499*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_U, (size_t)fs->factBufferSize_M))); 1500*d52a580bSJunchao Zhang fs->spsvBuffer_U = fs->factBuffer_M; 1501*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_L, fs->spsvBufferSize_L)); 1502*d52a580bSJunchao Zhang } 1503*d52a580bSJunchao Zhang 1504*d52a580bSJunchao Zhang /* ========================================================================== */ 1505*d52a580bSJunchao Zhang /* Perform analysis of ilu0 on M, SpSv on L and U */ 1506*d52a580bSJunchao Zhang /* The lower(upper) triangular part of M has the same sparsity pattern as L(U)*/ 1507*d52a580bSJunchao Zhang /* ========================================================================== */ 1508*d52a580bSJunchao Zhang int structural_zero; 1509*d52a580bSJunchao Zhang 1510*d52a580bSJunchao Zhang fs->policy_M = HIPSPARSE_SOLVE_POLICY_USE_LEVEL; 1511*d52a580bSJunchao Zhang if (m) 1512*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrilu02_analysis(fs->handle, m, nz, /* hipsparseXcsrilu02 errors out with empty matrices (m=0) */ 1513*d52a580bSJunchao Zhang fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ilu0Info_M, fs->policy_M, fs->factBuffer_M)); 1514*d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) { 1515*d52a580bSJunchao Zhang /* Function hipsparseXcsrilu02_zeroPivot() is a blocking call. It calls hipDeviceSynchronize() to make sure all previous kernels are done. */ 1516*d52a580bSJunchao Zhang hipsparseStatus_t status; 1517*d52a580bSJunchao Zhang status = hipsparseXcsrilu02_zeroPivot(fs->handle, fs->ilu0Info_M, &structural_zero); 1518*d52a580bSJunchao Zhang PetscCheck(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Structural zero pivot detected in csrilu02: A(%d,%d) is missing", structural_zero, structural_zero); 1519*d52a580bSJunchao Zhang } 1520*d52a580bSJunchao Zhang 1521*d52a580bSJunchao Zhang /* Estimate FLOPs of the numeric factorization */ 1522*d52a580bSJunchao Zhang { 1523*d52a580bSJunchao Zhang Mat_SeqAIJ *Aseq = (Mat_SeqAIJ *)A->data; 1524*d52a580bSJunchao Zhang PetscInt *Ai, nzRow, nzLeft; 1525*d52a580bSJunchao Zhang PetscLogDouble flops = 0.0; 1526*d52a580bSJunchao Zhang const PetscInt *Adiag; 1527*d52a580bSJunchao Zhang 1528*d52a580bSJunchao Zhang PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, &Adiag, NULL)); 1529*d52a580bSJunchao Zhang Ai = Aseq->i; 1530*d52a580bSJunchao Zhang for (PetscInt i = 0; i < m; i++) { 1531*d52a580bSJunchao Zhang if (Ai[i] < Adiag[i] && Adiag[i] < Ai[i + 1]) { /* There are nonzeros left to the diagonal of row i */ 1532*d52a580bSJunchao Zhang nzRow = Ai[i + 1] - Ai[i]; 1533*d52a580bSJunchao Zhang nzLeft = Adiag[i] - Ai[i]; 1534*d52a580bSJunchao Zhang /* We want to eliminate nonzeros left to the diagonal one by one. Assume each time, nonzeros right 1535*d52a580bSJunchao Zhang and include the eliminated one will be updated, which incurs a multiplication and an addition. 1536*d52a580bSJunchao Zhang */ 1537*d52a580bSJunchao Zhang nzLeft = (nzRow - 1) / 2; 1538*d52a580bSJunchao Zhang flops += nzLeft * (2.0 * nzRow - nzLeft + 1); 1539*d52a580bSJunchao Zhang } 1540*d52a580bSJunchao Zhang } 1541*d52a580bSJunchao Zhang fs->numericFactFlops = flops; 1542*d52a580bSJunchao Zhang } 1543*d52a580bSJunchao Zhang fact->ops->lufactornumeric = MatILUFactorNumeric_SeqAIJHIPSPARSE_ILU0; 1544*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1545*d52a580bSJunchao Zhang } 1546*d52a580bSJunchao Zhang 1547*d52a580bSJunchao Zhang static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_ICC0(Mat fact, Vec b, Vec x) 1548*d52a580bSJunchao Zhang { 1549*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr; 1550*d52a580bSJunchao Zhang Mat_SeqAIJ *aij = (Mat_SeqAIJ *)fact->data; 1551*d52a580bSJunchao Zhang const PetscScalar *barray; 1552*d52a580bSJunchao Zhang PetscScalar *xarray; 1553*d52a580bSJunchao Zhang 1554*d52a580bSJunchao Zhang PetscFunctionBegin; 1555*d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayWrite(x, &xarray)); 1556*d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(b, &barray)); 1557*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 1558*d52a580bSJunchao Zhang 1559*d52a580bSJunchao Zhang /* Solve L*y = b */ 1560*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray)); 1561*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y)); 1562*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0) 1563*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */ 1564*d52a580bSJunchao Zhang fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L)); 1565*d52a580bSJunchao Zhang #else 1566*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */ 1567*d52a580bSJunchao Zhang fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L)); 1568*d52a580bSJunchao Zhang #endif 1569*d52a580bSJunchao Zhang /* Solve Lt*x = y */ 1570*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray)); 1571*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0) || PETSC_PKG_HIP_VERSION_GE(6, 0, 0) 1572*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */ 1573*d52a580bSJunchao Zhang fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt)); 1574*d52a580bSJunchao Zhang #else 1575*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */ 1576*d52a580bSJunchao Zhang fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt)); 1577*d52a580bSJunchao Zhang #endif 1578*d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(b, &barray)); 1579*d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayWrite(x, &xarray)); 1580*d52a580bSJunchao Zhang 1581*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 1582*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * aij->nz - fact->rmap->n)); 1583*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1584*d52a580bSJunchao Zhang } 1585*d52a580bSJunchao Zhang 1586*d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorNumeric_SeqAIJHIPSPARSE_ICC0(Mat fact, Mat A, const MatFactorInfo *info) 1587*d52a580bSJunchao Zhang { 1588*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr; 1589*d52a580bSJunchao Zhang Mat_SeqAIJ *aij = (Mat_SeqAIJ *)fact->data; 1590*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr; 1591*d52a580bSJunchao Zhang CsrMatrix *Acsr; 1592*d52a580bSJunchao Zhang PetscInt m, nz; 1593*d52a580bSJunchao Zhang PetscBool flg; 1594*d52a580bSJunchao Zhang 1595*d52a580bSJunchao Zhang PetscFunctionBegin; 1596*d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) { 1597*d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg)); 1598*d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name); 1599*d52a580bSJunchao Zhang } 1600*d52a580bSJunchao Zhang 1601*d52a580bSJunchao Zhang /* Copy A's value to fact */ 1602*d52a580bSJunchao Zhang m = fact->rmap->n; 1603*d52a580bSJunchao Zhang nz = aij->nz; 1604*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A)); 1605*d52a580bSJunchao Zhang Acsr = (CsrMatrix *)Acusp->mat->mat; 1606*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpyAsync(fs->csrVal, Acsr->values->data().get(), sizeof(PetscScalar) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream)); 1607*d52a580bSJunchao Zhang 1608*d52a580bSJunchao Zhang /* Factorize fact inplace */ 1609*d52a580bSJunchao Zhang /* Function csric02() only takes the lower triangular part of matrix A to perform factorization. 1610*d52a580bSJunchao Zhang The matrix type must be HIPSPARSE_MATRIX_TYPE_GENERAL, the fill mode and diagonal type are ignored, 1611*d52a580bSJunchao Zhang and the strictly upper triangular part is ignored and never touched. It does not matter if A is Hermitian or not. 1612*d52a580bSJunchao Zhang In other words, from the point of view of csric02() A is Hermitian and only the lower triangular part is provided. 1613*d52a580bSJunchao Zhang */ 1614*d52a580bSJunchao Zhang if (m) PetscCallHIPSPARSE(hipsparseXcsric02(fs->handle, m, nz, fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ic0Info_M, fs->policy_M, fs->factBuffer_M)); 1615*d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) { 1616*d52a580bSJunchao Zhang int numerical_zero; 1617*d52a580bSJunchao Zhang hipsparseStatus_t status; 1618*d52a580bSJunchao Zhang status = hipsparseXcsric02_zeroPivot(fs->handle, fs->ic0Info_M, &numerical_zero); 1619*d52a580bSJunchao Zhang PetscAssert(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Numerical zero pivot detected in csric02: A(%d,%d) is zero", numerical_zero, numerical_zero); 1620*d52a580bSJunchao Zhang } 1621*d52a580bSJunchao Zhang 1622*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L)); 1623*d52a580bSJunchao Zhang 1624*d52a580bSJunchao Zhang /* Note that hipsparse reports this error if we use double and HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE 1625*d52a580bSJunchao Zhang ** On entry to hipsparseSpSV_analysis(): conjugate transpose (opA) is not supported for matA data type, current -> CUDA_R_64F 1626*d52a580bSJunchao Zhang */ 1627*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_analysis(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt)); 1628*d52a580bSJunchao Zhang 1629*d52a580bSJunchao Zhang fact->offloadmask = PETSC_OFFLOAD_GPU; 1630*d52a580bSJunchao Zhang fact->ops->solve = MatSolve_SeqAIJHIPSPARSE_ICC0; 1631*d52a580bSJunchao Zhang fact->ops->solvetranspose = MatSolve_SeqAIJHIPSPARSE_ICC0; 1632*d52a580bSJunchao Zhang fact->ops->matsolve = NULL; 1633*d52a580bSJunchao Zhang fact->ops->matsolvetranspose = NULL; 1634*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(fs->numericFactFlops)); 1635*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1636*d52a580bSJunchao Zhang } 1637*d52a580bSJunchao Zhang 1638*d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJHIPSPARSE_ICC0(Mat fact, Mat A, IS perm, const MatFactorInfo *info) 1639*d52a580bSJunchao Zhang { 1640*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)fact->spptr; 1641*d52a580bSJunchao Zhang Mat_SeqAIJ *aij = (Mat_SeqAIJ *)fact->data; 1642*d52a580bSJunchao Zhang PetscInt m, nz; 1643*d52a580bSJunchao Zhang 1644*d52a580bSJunchao Zhang PetscFunctionBegin; 1645*d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) { 1646*d52a580bSJunchao Zhang PetscBool flg, diagDense; 1647*d52a580bSJunchao Zhang 1648*d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg)); 1649*d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Expected MATSEQAIJHIPSPARSE, but input is %s", ((PetscObject)A)->type_name); 1650*d52a580bSJunchao Zhang PetscCheck(A->rmap->n == A->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must be square matrix, rows %" PetscInt_FMT " columns %" PetscInt_FMT, A->rmap->n, A->cmap->n); 1651*d52a580bSJunchao Zhang PetscCall(MatGetDiagonalMarkers_SeqAIJ(A, NULL, &diagDense)); 1652*d52a580bSJunchao Zhang PetscCheck(diagDense, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Matrix is missing diagonal entries"); 1653*d52a580bSJunchao Zhang } 1654*d52a580bSJunchao Zhang 1655*d52a580bSJunchao Zhang /* Free the old stale stuff */ 1656*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&fs)); 1657*d52a580bSJunchao Zhang 1658*d52a580bSJunchao Zhang /* Copy over A's meta data to fact. Note that we also allocated fact's i,j,a on host, 1659*d52a580bSJunchao Zhang but they will not be used. Allocate them just for easy debugging. 1660*d52a580bSJunchao Zhang */ 1661*d52a580bSJunchao Zhang PetscCall(MatDuplicateNoCreate_SeqAIJ(fact, A, MAT_DO_NOT_COPY_VALUES, PETSC_TRUE /*malloc*/)); 1662*d52a580bSJunchao Zhang 1663*d52a580bSJunchao Zhang fact->offloadmask = PETSC_OFFLOAD_BOTH; 1664*d52a580bSJunchao Zhang fact->factortype = MAT_FACTOR_ICC; 1665*d52a580bSJunchao Zhang fact->info.factor_mallocs = 0; 1666*d52a580bSJunchao Zhang fact->info.fill_ratio_given = info->fill; 1667*d52a580bSJunchao Zhang fact->info.fill_ratio_needed = 1.0; 1668*d52a580bSJunchao Zhang 1669*d52a580bSJunchao Zhang aij->row = NULL; 1670*d52a580bSJunchao Zhang aij->col = NULL; 1671*d52a580bSJunchao Zhang 1672*d52a580bSJunchao Zhang /* ====================================================================== */ 1673*d52a580bSJunchao Zhang /* Copy A's i, j to fact and also allocate the value array of fact. */ 1674*d52a580bSJunchao Zhang /* We'll do in-place factorization on fact */ 1675*d52a580bSJunchao Zhang /* ====================================================================== */ 1676*d52a580bSJunchao Zhang const int *Ai, *Aj; 1677*d52a580bSJunchao Zhang 1678*d52a580bSJunchao Zhang m = fact->rmap->n; 1679*d52a580bSJunchao Zhang nz = aij->nz; 1680*d52a580bSJunchao Zhang 1681*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->csrRowPtr, sizeof(int) * (m + 1))); 1682*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->csrColIdx, sizeof(int) * nz)); 1683*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->csrVal, sizeof(PetscScalar) * nz)); 1684*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetIJ(A, PETSC_FALSE, &Ai, &Aj)); /* Do not use compressed Ai */ 1685*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpyAsync(fs->csrRowPtr, Ai, sizeof(int) * (m + 1), hipMemcpyDeviceToDevice, PetscDefaultHipStream)); 1686*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpyAsync(fs->csrColIdx, Aj, sizeof(int) * nz, hipMemcpyDeviceToDevice, PetscDefaultHipStream)); 1687*d52a580bSJunchao Zhang 1688*d52a580bSJunchao Zhang /* ====================================================================== */ 1689*d52a580bSJunchao Zhang /* Create mat descriptors for M, L */ 1690*d52a580bSJunchao Zhang /* ====================================================================== */ 1691*d52a580bSJunchao Zhang hipsparseFillMode_t fillMode; 1692*d52a580bSJunchao Zhang hipsparseDiagType_t diagType; 1693*d52a580bSJunchao Zhang 1694*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&fs->matDescr_M)); 1695*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(fs->matDescr_M, HIPSPARSE_INDEX_BASE_ZERO)); 1696*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(fs->matDescr_M, HIPSPARSE_MATRIX_TYPE_GENERAL)); 1697*d52a580bSJunchao Zhang 1698*d52a580bSJunchao Zhang /* https://docs.amd.com/bundle/hipSPARSE-Documentation---hipSPARSE-documentation/page/usermanual.html/#hipsparse_8h_1a79e036b6c0680cb37e2aa53d3542a054 1699*d52a580bSJunchao Zhang hipsparseDiagType_t: This type indicates if the matrix diagonal entries are unity. The diagonal elements are always 1700*d52a580bSJunchao Zhang assumed to be present, but if HIPSPARSE_DIAG_TYPE_UNIT is passed to an API routine, then the routine assumes that 1701*d52a580bSJunchao Zhang all diagonal entries are unity and will not read or modify those entries. Note that in this case the routine 1702*d52a580bSJunchao Zhang assumes the diagonal entries are equal to one, regardless of what those entries are actually set to in memory. 1703*d52a580bSJunchao Zhang */ 1704*d52a580bSJunchao Zhang fillMode = HIPSPARSE_FILL_MODE_LOWER; 1705*d52a580bSJunchao Zhang diagType = HIPSPARSE_DIAG_TYPE_NON_UNIT; 1706*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&fs->spMatDescr_L, m, m, nz, fs->csrRowPtr, fs->csrColIdx, fs->csrVal, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype)); 1707*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_FILL_MODE, &fillMode, sizeof(fillMode))); 1708*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatSetAttribute(fs->spMatDescr_L, HIPSPARSE_SPMAT_DIAG_TYPE, &diagType, sizeof(diagType))); 1709*d52a580bSJunchao Zhang 1710*d52a580bSJunchao Zhang /* ========================================================================= */ 1711*d52a580bSJunchao Zhang /* Query buffer sizes for csric0, SpSV of L and Lt, and allocate buffers */ 1712*d52a580bSJunchao Zhang /* ========================================================================= */ 1713*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsric02Info(&fs->ic0Info_M)); 1714*d52a580bSJunchao Zhang if (m) PetscCallHIPSPARSE(hipsparseXcsric02_bufferSize(fs->handle, m, nz, fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ic0Info_M, &fs->factBufferSize_M)); 1715*d52a580bSJunchao Zhang 1716*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->X, sizeof(PetscScalar) * m)); 1717*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->Y, sizeof(PetscScalar) * m)); 1718*d52a580bSJunchao Zhang 1719*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_X, m, fs->X, hipsparse_scalartype)); 1720*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnVec(&fs->dnVecDescr_Y, m, fs->Y, hipsparse_scalartype)); 1721*d52a580bSJunchao Zhang 1722*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_L)); 1723*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, &fs->spsvBufferSize_L)); 1724*d52a580bSJunchao Zhang 1725*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_createDescr(&fs->spsvDescr_Lt)); 1726*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_bufferSize(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, &fs->spsvBufferSize_Lt)); 1727*d52a580bSJunchao Zhang 1728*d52a580bSJunchao Zhang /* To save device memory, we make the factorization buffer share with one of the solver buffer. 1729*d52a580bSJunchao Zhang See also comments in `MatILUFactorSymbolic_SeqAIJHIPSPARSE_ILU0()`. 1730*d52a580bSJunchao Zhang */ 1731*d52a580bSJunchao Zhang if (fs->spsvBufferSize_L > fs->spsvBufferSize_Lt) { 1732*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_L, (size_t)fs->factBufferSize_M))); 1733*d52a580bSJunchao Zhang fs->spsvBuffer_L = fs->factBuffer_M; 1734*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_Lt, fs->spsvBufferSize_Lt)); 1735*d52a580bSJunchao Zhang } else { 1736*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->factBuffer_M, PetscMax(fs->spsvBufferSize_Lt, (size_t)fs->factBufferSize_M))); 1737*d52a580bSJunchao Zhang fs->spsvBuffer_Lt = fs->factBuffer_M; 1738*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&fs->spsvBuffer_L, fs->spsvBufferSize_L)); 1739*d52a580bSJunchao Zhang } 1740*d52a580bSJunchao Zhang 1741*d52a580bSJunchao Zhang /* ========================================================================== */ 1742*d52a580bSJunchao Zhang /* Perform analysis of ic0 on M */ 1743*d52a580bSJunchao Zhang /* The lower triangular part of M has the same sparsity pattern as L */ 1744*d52a580bSJunchao Zhang /* ========================================================================== */ 1745*d52a580bSJunchao Zhang int structural_zero; 1746*d52a580bSJunchao Zhang 1747*d52a580bSJunchao Zhang fs->policy_M = HIPSPARSE_SOLVE_POLICY_USE_LEVEL; 1748*d52a580bSJunchao Zhang if (m) PetscCallHIPSPARSE(hipsparseXcsric02_analysis(fs->handle, m, nz, fs->matDescr_M, fs->csrVal, fs->csrRowPtr, fs->csrColIdx, fs->ic0Info_M, fs->policy_M, fs->factBuffer_M)); 1749*d52a580bSJunchao Zhang if (PetscDefined(USE_DEBUG)) { 1750*d52a580bSJunchao Zhang hipsparseStatus_t status; 1751*d52a580bSJunchao Zhang /* Function hipsparseXcsric02_zeroPivot() is a blocking call. It calls hipDeviceSynchronize() to make sure all previous kernels are done. */ 1752*d52a580bSJunchao Zhang status = hipsparseXcsric02_zeroPivot(fs->handle, fs->ic0Info_M, &structural_zero); 1753*d52a580bSJunchao Zhang PetscCheck(HIPSPARSE_STATUS_ZERO_PIVOT != status, PETSC_COMM_SELF, PETSC_ERR_USER_INPUT, "Structural zero pivot detected in csric02: A(%d,%d) is missing", structural_zero, structural_zero); 1754*d52a580bSJunchao Zhang } 1755*d52a580bSJunchao Zhang 1756*d52a580bSJunchao Zhang /* Estimate FLOPs of the numeric factorization */ 1757*d52a580bSJunchao Zhang { 1758*d52a580bSJunchao Zhang Mat_SeqAIJ *Aseq = (Mat_SeqAIJ *)A->data; 1759*d52a580bSJunchao Zhang PetscInt *Ai, nzRow, nzLeft; 1760*d52a580bSJunchao Zhang PetscLogDouble flops = 0.0; 1761*d52a580bSJunchao Zhang 1762*d52a580bSJunchao Zhang Ai = Aseq->i; 1763*d52a580bSJunchao Zhang for (PetscInt i = 0; i < m; i++) { 1764*d52a580bSJunchao Zhang nzRow = Ai[i + 1] - Ai[i]; 1765*d52a580bSJunchao Zhang if (nzRow > 1) { 1766*d52a580bSJunchao Zhang /* We want to eliminate nonzeros left to the diagonal one by one. Assume each time, nonzeros right 1767*d52a580bSJunchao Zhang and include the eliminated one will be updated, which incurs a multiplication and an addition. 1768*d52a580bSJunchao Zhang */ 1769*d52a580bSJunchao Zhang nzLeft = (nzRow - 1) / 2; 1770*d52a580bSJunchao Zhang flops += nzLeft * (2.0 * nzRow - nzLeft + 1); 1771*d52a580bSJunchao Zhang } 1772*d52a580bSJunchao Zhang } 1773*d52a580bSJunchao Zhang fs->numericFactFlops = flops; 1774*d52a580bSJunchao Zhang } 1775*d52a580bSJunchao Zhang fact->ops->choleskyfactornumeric = MatICCFactorNumeric_SeqAIJHIPSPARSE_ICC0; 1776*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1777*d52a580bSJunchao Zhang } 1778*d52a580bSJunchao Zhang #endif 1779*d52a580bSJunchao Zhang 1780*d52a580bSJunchao Zhang static PetscErrorCode MatILUFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info) 1781*d52a580bSJunchao Zhang { 1782*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr; 1783*d52a580bSJunchao Zhang 1784*d52a580bSJunchao Zhang PetscFunctionBegin; 1785*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0) 1786*d52a580bSJunchao Zhang PetscBool row_identity = PETSC_FALSE, col_identity = PETSC_FALSE; 1787*d52a580bSJunchao Zhang if (!info->factoronhost) { 1788*d52a580bSJunchao Zhang PetscCall(ISIdentity(isrow, &row_identity)); 1789*d52a580bSJunchao Zhang PetscCall(ISIdentity(iscol, &col_identity)); 1790*d52a580bSJunchao Zhang } 1791*d52a580bSJunchao Zhang if (!info->levels && row_identity && col_identity) PetscCall(MatILUFactorSymbolic_SeqAIJHIPSPARSE_ILU0(B, A, isrow, iscol, info)); 1792*d52a580bSJunchao Zhang else 1793*d52a580bSJunchao Zhang #endif 1794*d52a580bSJunchao Zhang { 1795*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors)); 1796*d52a580bSJunchao Zhang PetscCall(MatILUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info)); 1797*d52a580bSJunchao Zhang B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJHIPSPARSE; 1798*d52a580bSJunchao Zhang } 1799*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1800*d52a580bSJunchao Zhang } 1801*d52a580bSJunchao Zhang 1802*d52a580bSJunchao Zhang static PetscErrorCode MatLUFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS isrow, IS iscol, const MatFactorInfo *info) 1803*d52a580bSJunchao Zhang { 1804*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr; 1805*d52a580bSJunchao Zhang 1806*d52a580bSJunchao Zhang PetscFunctionBegin; 1807*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors)); 1808*d52a580bSJunchao Zhang PetscCall(MatLUFactorSymbolic_SeqAIJ(B, A, isrow, iscol, info)); 1809*d52a580bSJunchao Zhang B->ops->lufactornumeric = MatLUFactorNumeric_SeqAIJHIPSPARSE; 1810*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1811*d52a580bSJunchao Zhang } 1812*d52a580bSJunchao Zhang 1813*d52a580bSJunchao Zhang static PetscErrorCode MatICCFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS perm, const MatFactorInfo *info) 1814*d52a580bSJunchao Zhang { 1815*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr; 1816*d52a580bSJunchao Zhang 1817*d52a580bSJunchao Zhang PetscFunctionBegin; 1818*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0) 1819*d52a580bSJunchao Zhang PetscBool perm_identity = PETSC_FALSE; 1820*d52a580bSJunchao Zhang if (!info->factoronhost) PetscCall(ISIdentity(perm, &perm_identity)); 1821*d52a580bSJunchao Zhang if (!info->levels && perm_identity) PetscCall(MatICCFactorSymbolic_SeqAIJHIPSPARSE_ICC0(B, A, perm, info)); 1822*d52a580bSJunchao Zhang else 1823*d52a580bSJunchao Zhang #endif 1824*d52a580bSJunchao Zhang { 1825*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors)); 1826*d52a580bSJunchao Zhang PetscCall(MatICCFactorSymbolic_SeqAIJ(B, A, perm, info)); 1827*d52a580bSJunchao Zhang B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJHIPSPARSE; 1828*d52a580bSJunchao Zhang } 1829*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1830*d52a580bSJunchao Zhang } 1831*d52a580bSJunchao Zhang 1832*d52a580bSJunchao Zhang static PetscErrorCode MatCholeskyFactorSymbolic_SeqAIJHIPSPARSE(Mat B, Mat A, IS perm, const MatFactorInfo *info) 1833*d52a580bSJunchao Zhang { 1834*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *hipsparseTriFactors = (Mat_SeqAIJHIPSPARSETriFactors *)B->spptr; 1835*d52a580bSJunchao Zhang 1836*d52a580bSJunchao Zhang PetscFunctionBegin; 1837*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(&hipsparseTriFactors)); 1838*d52a580bSJunchao Zhang PetscCall(MatCholeskyFactorSymbolic_SeqAIJ(B, A, perm, info)); 1839*d52a580bSJunchao Zhang B->ops->choleskyfactornumeric = MatCholeskyFactorNumeric_SeqAIJHIPSPARSE; 1840*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1841*d52a580bSJunchao Zhang } 1842*d52a580bSJunchao Zhang 1843*d52a580bSJunchao Zhang static PetscErrorCode MatFactorGetSolverType_seqaij_hipsparse(Mat A, MatSolverType *type) 1844*d52a580bSJunchao Zhang { 1845*d52a580bSJunchao Zhang PetscFunctionBegin; 1846*d52a580bSJunchao Zhang *type = MATSOLVERHIPSPARSE; 1847*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1848*d52a580bSJunchao Zhang } 1849*d52a580bSJunchao Zhang 1850*d52a580bSJunchao Zhang /*MC 1851*d52a580bSJunchao Zhang MATSOLVERHIPSPARSE = "hipsparse" - A matrix type providing triangular solvers for sequential matrices 1852*d52a580bSJunchao Zhang on a single GPU of type, `MATSEQAIJHIPSPARSE`. Currently supported 1853*d52a580bSJunchao Zhang algorithms are ILU(k) and ICC(k). Typically, deeper factorizations (larger k) results in poorer 1854*d52a580bSJunchao Zhang performance in the triangular solves. Full LU, and Cholesky decompositions can be solved through the 1855*d52a580bSJunchao Zhang HipSPARSE triangular solve algorithm. However, the performance can be quite poor and thus these 1856*d52a580bSJunchao Zhang algorithms are not recommended. This class does NOT support direct solver operations. 1857*d52a580bSJunchao Zhang 1858*d52a580bSJunchao Zhang Level: beginner 1859*d52a580bSJunchao Zhang 1860*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MATSEQAIJHIPSPARSE`, `PCFactorSetMatSolverType()`, `MatSolverType`, `MatCreateSeqAIJHIPSPARSE()`, `MATAIJHIPSPARSE`, `MatCreateAIJHIPSPARSE()`, `MatHIPSPARSESetFormat()`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation` 1861*d52a580bSJunchao Zhang M*/ 1862*d52a580bSJunchao Zhang 1863*d52a580bSJunchao Zhang PETSC_EXTERN PetscErrorCode MatGetFactor_seqaijhipsparse_hipsparse(Mat A, MatFactorType ftype, Mat *B) 1864*d52a580bSJunchao Zhang { 1865*d52a580bSJunchao Zhang PetscInt n = A->rmap->n; 1866*d52a580bSJunchao Zhang 1867*d52a580bSJunchao Zhang PetscFunctionBegin; 1868*d52a580bSJunchao Zhang PetscCall(MatCreate(PetscObjectComm((PetscObject)A), B)); 1869*d52a580bSJunchao Zhang PetscCall(MatSetSizes(*B, n, n, n, n)); 1870*d52a580bSJunchao Zhang (*B)->factortype = ftype; 1871*d52a580bSJunchao Zhang PetscCall(MatSetType(*B, MATSEQAIJHIPSPARSE)); 1872*d52a580bSJunchao Zhang 1873*d52a580bSJunchao Zhang if (A->boundtocpu && A->bindingpropagates) PetscCall(MatBindToCPU(*B, PETSC_TRUE)); 1874*d52a580bSJunchao Zhang if (ftype == MAT_FACTOR_LU || ftype == MAT_FACTOR_ILU || ftype == MAT_FACTOR_ILUDT) { 1875*d52a580bSJunchao Zhang PetscCall(MatSetBlockSizesFromMats(*B, A, A)); 1876*d52a580bSJunchao Zhang if (!A->boundtocpu) { 1877*d52a580bSJunchao Zhang (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJHIPSPARSE; 1878*d52a580bSJunchao Zhang (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJHIPSPARSE; 1879*d52a580bSJunchao Zhang } else { 1880*d52a580bSJunchao Zhang (*B)->ops->ilufactorsymbolic = MatILUFactorSymbolic_SeqAIJ; 1881*d52a580bSJunchao Zhang (*B)->ops->lufactorsymbolic = MatLUFactorSymbolic_SeqAIJ; 1882*d52a580bSJunchao Zhang } 1883*d52a580bSJunchao Zhang PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_LU])); 1884*d52a580bSJunchao Zhang PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILU])); 1885*d52a580bSJunchao Zhang PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ILUDT])); 1886*d52a580bSJunchao Zhang } else if (ftype == MAT_FACTOR_CHOLESKY || ftype == MAT_FACTOR_ICC) { 1887*d52a580bSJunchao Zhang if (!A->boundtocpu) { 1888*d52a580bSJunchao Zhang (*B)->ops->iccfactorsymbolic = MatICCFactorSymbolic_SeqAIJHIPSPARSE; 1889*d52a580bSJunchao Zhang (*B)->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqAIJHIPSPARSE; 1890*d52a580bSJunchao Zhang } else { 1891*d52a580bSJunchao Zhang (*B)->ops->iccfactorsymbolic = MatICCFactorSymbolic_SeqAIJ; 1892*d52a580bSJunchao Zhang (*B)->ops->choleskyfactorsymbolic = MatCholeskyFactorSymbolic_SeqAIJ; 1893*d52a580bSJunchao Zhang } 1894*d52a580bSJunchao Zhang PetscCall(PetscStrallocpy(MATORDERINGND, (char **)&(*B)->preferredordering[MAT_FACTOR_CHOLESKY])); 1895*d52a580bSJunchao Zhang PetscCall(PetscStrallocpy(MATORDERINGNATURAL, (char **)&(*B)->preferredordering[MAT_FACTOR_ICC])); 1896*d52a580bSJunchao Zhang } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Factor type not supported for HIPSPARSE Matrix Types"); 1897*d52a580bSJunchao Zhang 1898*d52a580bSJunchao Zhang PetscCall(MatSeqAIJSetPreallocation(*B, MAT_SKIP_ALLOCATION, NULL)); 1899*d52a580bSJunchao Zhang (*B)->canuseordering = PETSC_TRUE; 1900*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)*B, "MatFactorGetSolverType_C", MatFactorGetSolverType_seqaij_hipsparse)); 1901*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1902*d52a580bSJunchao Zhang } 1903*d52a580bSJunchao Zhang 1904*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSECopyFromGPU(Mat A) 1905*d52a580bSJunchao Zhang { 1906*d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data; 1907*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr; 1908*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0) 1909*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = (Mat_SeqAIJHIPSPARSETriFactors *)A->spptr; 1910*d52a580bSJunchao Zhang #endif 1911*d52a580bSJunchao Zhang 1912*d52a580bSJunchao Zhang PetscFunctionBegin; 1913*d52a580bSJunchao Zhang if (A->offloadmask == PETSC_OFFLOAD_GPU) { 1914*d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSECopyFromGPU, A, 0, 0, 0)); 1915*d52a580bSJunchao Zhang if (A->factortype == MAT_FACTOR_NONE) { 1916*d52a580bSJunchao Zhang CsrMatrix *matrix = (CsrMatrix *)cusp->mat->mat; 1917*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(a->a, matrix->values->data().get(), a->nz * sizeof(PetscScalar), hipMemcpyDeviceToHost)); 1918*d52a580bSJunchao Zhang } 1919*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0) 1920*d52a580bSJunchao Zhang else if (fs->csrVal) { 1921*d52a580bSJunchao Zhang /* We have a factorized matrix on device and are able to copy it to host */ 1922*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(a->a, fs->csrVal, a->nz * sizeof(PetscScalar), hipMemcpyDeviceToHost)); 1923*d52a580bSJunchao Zhang } 1924*d52a580bSJunchao Zhang #endif 1925*d52a580bSJunchao Zhang else 1926*d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "No support for copying this type of factorized matrix from device to host"); 1927*d52a580bSJunchao Zhang PetscCall(PetscLogGpuToCpu(a->nz * sizeof(PetscScalar))); 1928*d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSECopyFromGPU, A, 0, 0, 0)); 1929*d52a580bSJunchao Zhang A->offloadmask = PETSC_OFFLOAD_BOTH; 1930*d52a580bSJunchao Zhang } 1931*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1932*d52a580bSJunchao Zhang } 1933*d52a580bSJunchao Zhang 1934*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetArray_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[]) 1935*d52a580bSJunchao Zhang { 1936*d52a580bSJunchao Zhang PetscFunctionBegin; 1937*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A)); 1938*d52a580bSJunchao Zhang *array = ((Mat_SeqAIJ *)A->data)->a; 1939*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1940*d52a580bSJunchao Zhang } 1941*d52a580bSJunchao Zhang 1942*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJRestoreArray_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[]) 1943*d52a580bSJunchao Zhang { 1944*d52a580bSJunchao Zhang PetscFunctionBegin; 1945*d52a580bSJunchao Zhang A->offloadmask = PETSC_OFFLOAD_CPU; 1946*d52a580bSJunchao Zhang *array = NULL; 1947*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1948*d52a580bSJunchao Zhang } 1949*d52a580bSJunchao Zhang 1950*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetArrayRead_SeqAIJHIPSPARSE(Mat A, const PetscScalar *array[]) 1951*d52a580bSJunchao Zhang { 1952*d52a580bSJunchao Zhang PetscFunctionBegin; 1953*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A)); 1954*d52a580bSJunchao Zhang *array = ((Mat_SeqAIJ *)A->data)->a; 1955*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1956*d52a580bSJunchao Zhang } 1957*d52a580bSJunchao Zhang 1958*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJRestoreArrayRead_SeqAIJHIPSPARSE(Mat A, const PetscScalar *array[]) 1959*d52a580bSJunchao Zhang { 1960*d52a580bSJunchao Zhang PetscFunctionBegin; 1961*d52a580bSJunchao Zhang *array = NULL; 1962*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1963*d52a580bSJunchao Zhang } 1964*d52a580bSJunchao Zhang 1965*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetArrayWrite_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[]) 1966*d52a580bSJunchao Zhang { 1967*d52a580bSJunchao Zhang PetscFunctionBegin; 1968*d52a580bSJunchao Zhang *array = ((Mat_SeqAIJ *)A->data)->a; 1969*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1970*d52a580bSJunchao Zhang } 1971*d52a580bSJunchao Zhang 1972*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJRestoreArrayWrite_SeqAIJHIPSPARSE(Mat A, PetscScalar *array[]) 1973*d52a580bSJunchao Zhang { 1974*d52a580bSJunchao Zhang PetscFunctionBegin; 1975*d52a580bSJunchao Zhang A->offloadmask = PETSC_OFFLOAD_CPU; 1976*d52a580bSJunchao Zhang *array = NULL; 1977*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 1978*d52a580bSJunchao Zhang } 1979*d52a580bSJunchao Zhang 1980*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJGetCSRAndMemType_SeqAIJHIPSPARSE(Mat A, const PetscInt **i, const PetscInt **j, PetscScalar **a, PetscMemType *mtype) 1981*d52a580bSJunchao Zhang { 1982*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp; 1983*d52a580bSJunchao Zhang CsrMatrix *matrix; 1984*d52a580bSJunchao Zhang 1985*d52a580bSJunchao Zhang PetscFunctionBegin; 1986*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A)); 1987*d52a580bSJunchao Zhang PetscCheck(A->factortype == MAT_FACTOR_NONE, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "Not for factored matrix"); 1988*d52a580bSJunchao Zhang cusp = static_cast<Mat_SeqAIJHIPSPARSE *>(A->spptr); 1989*d52a580bSJunchao Zhang PetscCheck(cusp != NULL, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONGSTATE, "cusp is NULL"); 1990*d52a580bSJunchao Zhang matrix = (CsrMatrix *)cusp->mat->mat; 1991*d52a580bSJunchao Zhang 1992*d52a580bSJunchao Zhang if (i) { 1993*d52a580bSJunchao Zhang #if !defined(PETSC_USE_64BIT_INDICES) 1994*d52a580bSJunchao Zhang *i = matrix->row_offsets->data().get(); 1995*d52a580bSJunchao Zhang #else 1996*d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSparse does not supported 64-bit indices"); 1997*d52a580bSJunchao Zhang #endif 1998*d52a580bSJunchao Zhang } 1999*d52a580bSJunchao Zhang if (j) { 2000*d52a580bSJunchao Zhang #if !defined(PETSC_USE_64BIT_INDICES) 2001*d52a580bSJunchao Zhang *j = matrix->column_indices->data().get(); 2002*d52a580bSJunchao Zhang #else 2003*d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSparse does not supported 64-bit indices"); 2004*d52a580bSJunchao Zhang #endif 2005*d52a580bSJunchao Zhang } 2006*d52a580bSJunchao Zhang if (a) *a = matrix->values->data().get(); 2007*d52a580bSJunchao Zhang if (mtype) *mtype = PETSC_MEMTYPE_HIP; 2008*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 2009*d52a580bSJunchao Zhang } 2010*d52a580bSJunchao Zhang 2011*d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatSeqAIJHIPSPARSECopyToGPU(Mat A) 2012*d52a580bSJunchao Zhang { 2013*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr; 2014*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *matstruct = hipsparsestruct->mat; 2015*d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data; 2016*d52a580bSJunchao Zhang PetscBool both = PETSC_TRUE; 2017*d52a580bSJunchao Zhang PetscInt m = A->rmap->n, *ii, *ridx, tmp; 2018*d52a580bSJunchao Zhang 2019*d52a580bSJunchao Zhang PetscFunctionBegin; 2020*d52a580bSJunchao Zhang PetscCheck(!A->boundtocpu, PETSC_COMM_SELF, PETSC_ERR_GPU, "Cannot copy to GPU"); 2021*d52a580bSJunchao Zhang if (A->offloadmask == PETSC_OFFLOAD_UNALLOCATED || A->offloadmask == PETSC_OFFLOAD_CPU) { 2022*d52a580bSJunchao Zhang if (A->nonzerostate == hipsparsestruct->nonzerostate && hipsparsestruct->format == MAT_HIPSPARSE_CSR) { /* Copy values only */ 2023*d52a580bSJunchao Zhang CsrMatrix *matrix; 2024*d52a580bSJunchao Zhang matrix = (CsrMatrix *)hipsparsestruct->mat->mat; 2025*d52a580bSJunchao Zhang 2026*d52a580bSJunchao Zhang PetscCheck(!a->nz || a->a, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CSR values"); 2027*d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0)); 2028*d52a580bSJunchao Zhang matrix->values->assign(a->a, a->a + a->nz); 2029*d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP()); 2030*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(a->nz * sizeof(PetscScalar))); 2031*d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0)); 2032*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_FALSE)); 2033*d52a580bSJunchao Zhang } else { 2034*d52a580bSJunchao Zhang PetscInt nnz; 2035*d52a580bSJunchao Zhang PetscCall(PetscLogEventBegin(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0)); 2036*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&hipsparsestruct->mat, hipsparsestruct->format)); 2037*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_TRUE)); 2038*d52a580bSJunchao Zhang delete hipsparsestruct->workVector; 2039*d52a580bSJunchao Zhang delete hipsparsestruct->rowoffsets_gpu; 2040*d52a580bSJunchao Zhang hipsparsestruct->workVector = NULL; 2041*d52a580bSJunchao Zhang hipsparsestruct->rowoffsets_gpu = NULL; 2042*d52a580bSJunchao Zhang try { 2043*d52a580bSJunchao Zhang if (a->compressedrow.use) { 2044*d52a580bSJunchao Zhang m = a->compressedrow.nrows; 2045*d52a580bSJunchao Zhang ii = a->compressedrow.i; 2046*d52a580bSJunchao Zhang ridx = a->compressedrow.rindex; 2047*d52a580bSJunchao Zhang } else { 2048*d52a580bSJunchao Zhang m = A->rmap->n; 2049*d52a580bSJunchao Zhang ii = a->i; 2050*d52a580bSJunchao Zhang ridx = NULL; 2051*d52a580bSJunchao Zhang } 2052*d52a580bSJunchao Zhang PetscCheck(ii, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CSR row data"); 2053*d52a580bSJunchao Zhang if (!a->a) { 2054*d52a580bSJunchao Zhang nnz = ii[m]; 2055*d52a580bSJunchao Zhang both = PETSC_FALSE; 2056*d52a580bSJunchao Zhang } else nnz = a->nz; 2057*d52a580bSJunchao Zhang PetscCheck(!nnz || a->j, PETSC_COMM_SELF, PETSC_ERR_GPU, "Missing CSR column data"); 2058*d52a580bSJunchao Zhang 2059*d52a580bSJunchao Zhang /* create hipsparse matrix */ 2060*d52a580bSJunchao Zhang hipsparsestruct->nrows = m; 2061*d52a580bSJunchao Zhang matstruct = new Mat_SeqAIJHIPSPARSEMultStruct; 2062*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&matstruct->descr)); 2063*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(matstruct->descr, HIPSPARSE_INDEX_BASE_ZERO)); 2064*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(matstruct->descr, HIPSPARSE_MATRIX_TYPE_GENERAL)); 2065*d52a580bSJunchao Zhang 2066*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&matstruct->alpha_one, sizeof(PetscScalar))); 2067*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&matstruct->beta_zero, sizeof(PetscScalar))); 2068*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&matstruct->beta_one, sizeof(PetscScalar))); 2069*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(matstruct->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice)); 2070*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(matstruct->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice)); 2071*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(matstruct->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice)); 2072*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(hipsparsestruct->handle, HIPSPARSE_POINTER_MODE_DEVICE)); 2073*d52a580bSJunchao Zhang 2074*d52a580bSJunchao Zhang /* Build a hybrid/ellpack matrix if this option is chosen for the storage */ 2075*d52a580bSJunchao Zhang if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) { 2076*d52a580bSJunchao Zhang /* set the matrix */ 2077*d52a580bSJunchao Zhang CsrMatrix *mat = new CsrMatrix; 2078*d52a580bSJunchao Zhang mat->num_rows = m; 2079*d52a580bSJunchao Zhang mat->num_cols = A->cmap->n; 2080*d52a580bSJunchao Zhang mat->num_entries = nnz; 2081*d52a580bSJunchao Zhang mat->row_offsets = new THRUSTINTARRAY32(m + 1); 2082*d52a580bSJunchao Zhang mat->column_indices = new THRUSTINTARRAY32(nnz); 2083*d52a580bSJunchao Zhang mat->values = new THRUSTARRAY(nnz); 2084*d52a580bSJunchao Zhang mat->row_offsets->assign(ii, ii + m + 1); 2085*d52a580bSJunchao Zhang mat->column_indices->assign(a->j, a->j + nnz); 2086*d52a580bSJunchao Zhang if (a->a) mat->values->assign(a->a, a->a + nnz); 2087*d52a580bSJunchao Zhang 2088*d52a580bSJunchao Zhang /* assign the pointer */ 2089*d52a580bSJunchao Zhang matstruct->mat = mat; 2090*d52a580bSJunchao Zhang if (mat->num_rows) { /* hipsparse errors on empty matrices! */ 2091*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&matstruct->matDescr, mat->num_rows, mat->num_cols, mat->num_entries, mat->row_offsets->data().get(), mat->column_indices->data().get(), mat->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, /* row offset, col idx types due to THRUSTINTARRAY32 */ 2092*d52a580bSJunchao Zhang HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype)); 2093*d52a580bSJunchao Zhang } 2094*d52a580bSJunchao Zhang } else if (hipsparsestruct->format == MAT_HIPSPARSE_ELL || hipsparsestruct->format == MAT_HIPSPARSE_HYB) { 2095*d52a580bSJunchao Zhang CsrMatrix *mat = new CsrMatrix; 2096*d52a580bSJunchao Zhang mat->num_rows = m; 2097*d52a580bSJunchao Zhang mat->num_cols = A->cmap->n; 2098*d52a580bSJunchao Zhang mat->num_entries = nnz; 2099*d52a580bSJunchao Zhang mat->row_offsets = new THRUSTINTARRAY32(m + 1); 2100*d52a580bSJunchao Zhang mat->column_indices = new THRUSTINTARRAY32(nnz); 2101*d52a580bSJunchao Zhang mat->values = new THRUSTARRAY(nnz); 2102*d52a580bSJunchao Zhang mat->row_offsets->assign(ii, ii + m + 1); 2103*d52a580bSJunchao Zhang mat->column_indices->assign(a->j, a->j + nnz); 2104*d52a580bSJunchao Zhang if (a->a) mat->values->assign(a->a, a->a + nnz); 2105*d52a580bSJunchao Zhang 2106*d52a580bSJunchao Zhang hipsparseHybMat_t hybMat; 2107*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateHybMat(&hybMat)); 2108*d52a580bSJunchao Zhang hipsparseHybPartition_t partition = hipsparsestruct->format == MAT_HIPSPARSE_ELL ? HIPSPARSE_HYB_PARTITION_MAX : HIPSPARSE_HYB_PARTITION_AUTO; 2109*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr2hyb(hipsparsestruct->handle, mat->num_rows, mat->num_cols, matstruct->descr, mat->values->data().get(), mat->row_offsets->data().get(), mat->column_indices->data().get(), hybMat, 0, partition)); 2110*d52a580bSJunchao Zhang /* assign the pointer */ 2111*d52a580bSJunchao Zhang matstruct->mat = hybMat; 2112*d52a580bSJunchao Zhang 2113*d52a580bSJunchao Zhang if (mat) { 2114*d52a580bSJunchao Zhang if (mat->values) delete (THRUSTARRAY *)mat->values; 2115*d52a580bSJunchao Zhang if (mat->column_indices) delete (THRUSTINTARRAY32 *)mat->column_indices; 2116*d52a580bSJunchao Zhang if (mat->row_offsets) delete (THRUSTINTARRAY32 *)mat->row_offsets; 2117*d52a580bSJunchao Zhang delete (CsrMatrix *)mat; 2118*d52a580bSJunchao Zhang } 2119*d52a580bSJunchao Zhang } 2120*d52a580bSJunchao Zhang 2121*d52a580bSJunchao Zhang /* assign the compressed row indices */ 2122*d52a580bSJunchao Zhang if (a->compressedrow.use) { 2123*d52a580bSJunchao Zhang hipsparsestruct->workVector = new THRUSTARRAY(m); 2124*d52a580bSJunchao Zhang matstruct->cprowIndices = new THRUSTINTARRAY(m); 2125*d52a580bSJunchao Zhang matstruct->cprowIndices->assign(ridx, ridx + m); 2126*d52a580bSJunchao Zhang tmp = m; 2127*d52a580bSJunchao Zhang } else { 2128*d52a580bSJunchao Zhang hipsparsestruct->workVector = NULL; 2129*d52a580bSJunchao Zhang matstruct->cprowIndices = NULL; 2130*d52a580bSJunchao Zhang tmp = 0; 2131*d52a580bSJunchao Zhang } 2132*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(((m + 1) + (a->nz)) * sizeof(int) + tmp * sizeof(PetscInt) + (3 + (a->nz)) * sizeof(PetscScalar))); 2133*d52a580bSJunchao Zhang 2134*d52a580bSJunchao Zhang /* assign the pointer */ 2135*d52a580bSJunchao Zhang hipsparsestruct->mat = matstruct; 2136*d52a580bSJunchao Zhang } catch (char *ex) { 2137*d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex); 2138*d52a580bSJunchao Zhang } 2139*d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP()); 2140*d52a580bSJunchao Zhang PetscCall(PetscLogEventEnd(MAT_HIPSPARSECopyToGPU, A, 0, 0, 0)); 2141*d52a580bSJunchao Zhang hipsparsestruct->nonzerostate = A->nonzerostate; 2142*d52a580bSJunchao Zhang } 2143*d52a580bSJunchao Zhang if (both) A->offloadmask = PETSC_OFFLOAD_BOTH; 2144*d52a580bSJunchao Zhang } 2145*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 2146*d52a580bSJunchao Zhang } 2147*d52a580bSJunchao Zhang 2148*d52a580bSJunchao Zhang struct VecHIPPlusEquals { 2149*d52a580bSJunchao Zhang template <typename Tuple> 2150*d52a580bSJunchao Zhang __host__ __device__ void operator()(Tuple t) 2151*d52a580bSJunchao Zhang { 2152*d52a580bSJunchao Zhang thrust::get<1>(t) = thrust::get<1>(t) + thrust::get<0>(t); 2153*d52a580bSJunchao Zhang } 2154*d52a580bSJunchao Zhang }; 2155*d52a580bSJunchao Zhang 2156*d52a580bSJunchao Zhang struct VecHIPEquals { 2157*d52a580bSJunchao Zhang template <typename Tuple> 2158*d52a580bSJunchao Zhang __host__ __device__ void operator()(Tuple t) 2159*d52a580bSJunchao Zhang { 2160*d52a580bSJunchao Zhang thrust::get<1>(t) = thrust::get<0>(t); 2161*d52a580bSJunchao Zhang } 2162*d52a580bSJunchao Zhang }; 2163*d52a580bSJunchao Zhang 2164*d52a580bSJunchao Zhang struct VecHIPEqualsReverse { 2165*d52a580bSJunchao Zhang template <typename Tuple> 2166*d52a580bSJunchao Zhang __host__ __device__ void operator()(Tuple t) 2167*d52a580bSJunchao Zhang { 2168*d52a580bSJunchao Zhang thrust::get<0>(t) = thrust::get<1>(t); 2169*d52a580bSJunchao Zhang } 2170*d52a580bSJunchao Zhang }; 2171*d52a580bSJunchao Zhang 2172*d52a580bSJunchao Zhang struct MatProductCtx_MatMatHipsparse { 2173*d52a580bSJunchao Zhang PetscBool cisdense; 2174*d52a580bSJunchao Zhang PetscScalar *Bt; 2175*d52a580bSJunchao Zhang Mat X; 2176*d52a580bSJunchao Zhang PetscBool reusesym; /* Hipsparse does not have split symbolic and numeric phases for sparse matmat operations */ 2177*d52a580bSJunchao Zhang PetscLogDouble flops; 2178*d52a580bSJunchao Zhang CsrMatrix *Bcsr; 2179*d52a580bSJunchao Zhang hipsparseSpMatDescr_t matSpBDescr; 2180*d52a580bSJunchao Zhang PetscBool initialized; /* C = alpha op(A) op(B) + beta C */ 2181*d52a580bSJunchao Zhang hipsparseDnMatDescr_t matBDescr; 2182*d52a580bSJunchao Zhang hipsparseDnMatDescr_t matCDescr; 2183*d52a580bSJunchao Zhang PetscInt Blda, Clda; /* Record leading dimensions of B and C here to detect changes*/ 2184*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0) 2185*d52a580bSJunchao Zhang void *dBuffer4, *dBuffer5; 2186*d52a580bSJunchao Zhang #endif 2187*d52a580bSJunchao Zhang size_t mmBufferSize; 2188*d52a580bSJunchao Zhang void *mmBuffer, *mmBuffer2; /* SpGEMM WorkEstimation buffer */ 2189*d52a580bSJunchao Zhang hipsparseSpGEMMDescr_t spgemmDesc; 2190*d52a580bSJunchao Zhang }; 2191*d52a580bSJunchao Zhang 2192*d52a580bSJunchao Zhang static PetscErrorCode MatProductCtxDestroy_MatMatHipsparse(PetscCtxRt data) 2193*d52a580bSJunchao Zhang { 2194*d52a580bSJunchao Zhang MatProductCtx_MatMatHipsparse *mmdata = *(MatProductCtx_MatMatHipsparse **)data; 2195*d52a580bSJunchao Zhang 2196*d52a580bSJunchao Zhang PetscFunctionBegin; 2197*d52a580bSJunchao Zhang PetscCallHIP(hipFree(mmdata->Bt)); 2198*d52a580bSJunchao Zhang delete mmdata->Bcsr; 2199*d52a580bSJunchao Zhang if (mmdata->matSpBDescr) PetscCallHIPSPARSE(hipsparseDestroySpMat(mmdata->matSpBDescr)); 2200*d52a580bSJunchao Zhang if (mmdata->matBDescr) PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matBDescr)); 2201*d52a580bSJunchao Zhang if (mmdata->matCDescr) PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matCDescr)); 2202*d52a580bSJunchao Zhang if (mmdata->spgemmDesc) PetscCallHIPSPARSE(hipsparseSpGEMM_destroyDescr(mmdata->spgemmDesc)); 2203*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0) 2204*d52a580bSJunchao Zhang if (mmdata->dBuffer4) PetscCallHIP(hipFree(mmdata->dBuffer4)); 2205*d52a580bSJunchao Zhang if (mmdata->dBuffer5) PetscCallHIP(hipFree(mmdata->dBuffer5)); 2206*d52a580bSJunchao Zhang #endif 2207*d52a580bSJunchao Zhang if (mmdata->mmBuffer) PetscCallHIP(hipFree(mmdata->mmBuffer)); 2208*d52a580bSJunchao Zhang if (mmdata->mmBuffer2) PetscCallHIP(hipFree(mmdata->mmBuffer2)); 2209*d52a580bSJunchao Zhang PetscCall(MatDestroy(&mmdata->X)); 2210*d52a580bSJunchao Zhang PetscCall(PetscFree(*(void **)data)); 2211*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 2212*d52a580bSJunchao Zhang } 2213*d52a580bSJunchao Zhang 2214*d52a580bSJunchao Zhang static PetscErrorCode MatProductNumeric_SeqAIJHIPSPARSE_SeqDENSEHIP(Mat C) 2215*d52a580bSJunchao Zhang { 2216*d52a580bSJunchao Zhang Mat_Product *product = C->product; 2217*d52a580bSJunchao Zhang Mat A, B; 2218*d52a580bSJunchao Zhang PetscInt m, n, blda, clda; 2219*d52a580bSJunchao Zhang PetscBool flg, biship; 2220*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp; 2221*d52a580bSJunchao Zhang hipsparseOperation_t opA; 2222*d52a580bSJunchao Zhang const PetscScalar *barray; 2223*d52a580bSJunchao Zhang PetscScalar *carray; 2224*d52a580bSJunchao Zhang MatProductCtx_MatMatHipsparse *mmdata; 2225*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *mat; 2226*d52a580bSJunchao Zhang CsrMatrix *csrmat; 2227*d52a580bSJunchao Zhang 2228*d52a580bSJunchao Zhang PetscFunctionBegin; 2229*d52a580bSJunchao Zhang MatCheckProduct(C, 1); 2230*d52a580bSJunchao Zhang PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data empty"); 2231*d52a580bSJunchao Zhang mmdata = (MatProductCtx_MatMatHipsparse *)product->data; 2232*d52a580bSJunchao Zhang A = product->A; 2233*d52a580bSJunchao Zhang B = product->B; 2234*d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg)); 2235*d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name); 2236*d52a580bSJunchao Zhang /* currently CopyToGpu does not copy if the matrix is bound to CPU 2237*d52a580bSJunchao Zhang Instead of silently accepting the wrong answer, I prefer to raise the error */ 2238*d52a580bSJunchao Zhang PetscCheck(!A->boundtocpu, PetscObjectComm((PetscObject)A), PETSC_ERR_ARG_WRONG, "Cannot bind to CPU a HIPSPARSE matrix between MatProductSymbolic and MatProductNumeric phases"); 2239*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A)); 2240*d52a580bSJunchao Zhang cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr; 2241*d52a580bSJunchao Zhang switch (product->type) { 2242*d52a580bSJunchao Zhang case MATPRODUCT_AB: 2243*d52a580bSJunchao Zhang case MATPRODUCT_PtAP: 2244*d52a580bSJunchao Zhang mat = cusp->mat; 2245*d52a580bSJunchao Zhang opA = HIPSPARSE_OPERATION_NON_TRANSPOSE; 2246*d52a580bSJunchao Zhang m = A->rmap->n; 2247*d52a580bSJunchao Zhang n = B->cmap->n; 2248*d52a580bSJunchao Zhang break; 2249*d52a580bSJunchao Zhang case MATPRODUCT_AtB: 2250*d52a580bSJunchao Zhang if (!A->form_explicit_transpose) { 2251*d52a580bSJunchao Zhang mat = cusp->mat; 2252*d52a580bSJunchao Zhang opA = HIPSPARSE_OPERATION_TRANSPOSE; 2253*d52a580bSJunchao Zhang } else { 2254*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A)); 2255*d52a580bSJunchao Zhang mat = cusp->matTranspose; 2256*d52a580bSJunchao Zhang opA = HIPSPARSE_OPERATION_NON_TRANSPOSE; 2257*d52a580bSJunchao Zhang } 2258*d52a580bSJunchao Zhang m = A->cmap->n; 2259*d52a580bSJunchao Zhang n = B->cmap->n; 2260*d52a580bSJunchao Zhang break; 2261*d52a580bSJunchao Zhang case MATPRODUCT_ABt: 2262*d52a580bSJunchao Zhang case MATPRODUCT_RARt: 2263*d52a580bSJunchao Zhang mat = cusp->mat; 2264*d52a580bSJunchao Zhang opA = HIPSPARSE_OPERATION_NON_TRANSPOSE; 2265*d52a580bSJunchao Zhang m = A->rmap->n; 2266*d52a580bSJunchao Zhang n = B->rmap->n; 2267*d52a580bSJunchao Zhang break; 2268*d52a580bSJunchao Zhang default: 2269*d52a580bSJunchao Zhang SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]); 2270*d52a580bSJunchao Zhang } 2271*d52a580bSJunchao Zhang PetscCheck(mat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing Mat_SeqAIJHIPSPARSEMultStruct"); 2272*d52a580bSJunchao Zhang csrmat = (CsrMatrix *)mat->mat; 2273*d52a580bSJunchao Zhang /* if the user passed a CPU matrix, copy the data to the GPU */ 2274*d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)B, MATSEQDENSEHIP, &biship)); 2275*d52a580bSJunchao Zhang if (!biship) PetscCall(MatConvert(B, MATSEQDENSEHIP, MAT_INPLACE_MATRIX, &B)); 2276*d52a580bSJunchao Zhang PetscCall(MatDenseGetArrayReadAndMemType(B, &barray, nullptr)); 2277*d52a580bSJunchao Zhang PetscCall(MatDenseGetLDA(B, &blda)); 2278*d52a580bSJunchao Zhang if (product->type == MATPRODUCT_RARt || product->type == MATPRODUCT_PtAP) { 2279*d52a580bSJunchao Zhang PetscCall(MatDenseGetArrayWriteAndMemType(mmdata->X, &carray, nullptr)); 2280*d52a580bSJunchao Zhang PetscCall(MatDenseGetLDA(mmdata->X, &clda)); 2281*d52a580bSJunchao Zhang } else { 2282*d52a580bSJunchao Zhang PetscCall(MatDenseGetArrayWriteAndMemType(C, &carray, nullptr)); 2283*d52a580bSJunchao Zhang PetscCall(MatDenseGetLDA(C, &clda)); 2284*d52a580bSJunchao Zhang } 2285*d52a580bSJunchao Zhang 2286*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 2287*d52a580bSJunchao Zhang hipsparseOperation_t opB = (product->type == MATPRODUCT_ABt || product->type == MATPRODUCT_RARt) ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE; 2288*d52a580bSJunchao Zhang /* (re)allocate mmBuffer if not initialized or LDAs are different */ 2289*d52a580bSJunchao Zhang if (!mmdata->initialized || mmdata->Blda != blda || mmdata->Clda != clda) { 2290*d52a580bSJunchao Zhang size_t mmBufferSize; 2291*d52a580bSJunchao Zhang if (mmdata->initialized && mmdata->Blda != blda) { 2292*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matBDescr)); 2293*d52a580bSJunchao Zhang mmdata->matBDescr = NULL; 2294*d52a580bSJunchao Zhang } 2295*d52a580bSJunchao Zhang if (!mmdata->matBDescr) { 2296*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnMat(&mmdata->matBDescr, B->rmap->n, B->cmap->n, blda, (void *)barray, hipsparse_scalartype, HIPSPARSE_ORDER_COL)); 2297*d52a580bSJunchao Zhang mmdata->Blda = blda; 2298*d52a580bSJunchao Zhang } 2299*d52a580bSJunchao Zhang if (mmdata->initialized && mmdata->Clda != clda) { 2300*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyDnMat(mmdata->matCDescr)); 2301*d52a580bSJunchao Zhang mmdata->matCDescr = NULL; 2302*d52a580bSJunchao Zhang } 2303*d52a580bSJunchao Zhang if (!mmdata->matCDescr) { /* matCDescr is for C or mmdata->X */ 2304*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnMat(&mmdata->matCDescr, m, n, clda, (void *)carray, hipsparse_scalartype, HIPSPARSE_ORDER_COL)); 2305*d52a580bSJunchao Zhang mmdata->Clda = clda; 2306*d52a580bSJunchao Zhang } 2307*d52a580bSJunchao Zhang if (!mat->matDescr) { 2308*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&mat->matDescr, csrmat->num_rows, csrmat->num_cols, csrmat->num_entries, csrmat->row_offsets->data().get(), csrmat->column_indices->data().get(), csrmat->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, /* row offset, col idx types due to THRUSTINTARRAY32 */ 2309*d52a580bSJunchao Zhang HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype)); 2310*d52a580bSJunchao Zhang } 2311*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMM_bufferSize(cusp->handle, opA, opB, mat->alpha_one, mat->matDescr, mmdata->matBDescr, mat->beta_zero, mmdata->matCDescr, hipsparse_scalartype, cusp->spmmAlg, &mmBufferSize)); 2312*d52a580bSJunchao Zhang if ((mmdata->mmBuffer && mmdata->mmBufferSize < mmBufferSize) || !mmdata->mmBuffer) { 2313*d52a580bSJunchao Zhang PetscCallHIP(hipFree(mmdata->mmBuffer)); 2314*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&mmdata->mmBuffer, mmBufferSize)); 2315*d52a580bSJunchao Zhang mmdata->mmBufferSize = mmBufferSize; 2316*d52a580bSJunchao Zhang } 2317*d52a580bSJunchao Zhang mmdata->initialized = PETSC_TRUE; 2318*d52a580bSJunchao Zhang } else { 2319*d52a580bSJunchao Zhang /* to be safe, always update pointers of the mats */ 2320*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatSetValues(mat->matDescr, csrmat->values->data().get())); 2321*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnMatSetValues(mmdata->matBDescr, (void *)barray)); 2322*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnMatSetValues(mmdata->matCDescr, (void *)carray)); 2323*d52a580bSJunchao Zhang } 2324*d52a580bSJunchao Zhang 2325*d52a580bSJunchao Zhang /* do hipsparseSpMM, which supports transpose on B */ 2326*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMM(cusp->handle, opA, opB, mat->alpha_one, mat->matDescr, mmdata->matBDescr, mat->beta_zero, mmdata->matCDescr, hipsparse_scalartype, cusp->spmmAlg, mmdata->mmBuffer)); 2327*d52a580bSJunchao Zhang 2328*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 2329*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(n * 2.0 * csrmat->num_entries)); 2330*d52a580bSJunchao Zhang PetscCall(MatDenseRestoreArrayReadAndMemType(B, &barray)); 2331*d52a580bSJunchao Zhang if (product->type == MATPRODUCT_RARt) { 2332*d52a580bSJunchao Zhang PetscCall(MatDenseRestoreArrayWriteAndMemType(mmdata->X, &carray)); 2333*d52a580bSJunchao Zhang PetscCall(MatMatMultNumeric_SeqDenseHIP_SeqDenseHIP_Internal(B, mmdata->X, C, PETSC_FALSE, PETSC_FALSE)); 2334*d52a580bSJunchao Zhang } else if (product->type == MATPRODUCT_PtAP) { 2335*d52a580bSJunchao Zhang PetscCall(MatDenseRestoreArrayWriteAndMemType(mmdata->X, &carray)); 2336*d52a580bSJunchao Zhang PetscCall(MatMatMultNumeric_SeqDenseHIP_SeqDenseHIP_Internal(B, mmdata->X, C, PETSC_TRUE, PETSC_FALSE)); 2337*d52a580bSJunchao Zhang } else PetscCall(MatDenseRestoreArrayWriteAndMemType(C, &carray)); 2338*d52a580bSJunchao Zhang if (mmdata->cisdense) PetscCall(MatConvert(C, MATSEQDENSE, MAT_INPLACE_MATRIX, &C)); 2339*d52a580bSJunchao Zhang if (!biship) PetscCall(MatConvert(B, MATSEQDENSE, MAT_INPLACE_MATRIX, &B)); 2340*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 2341*d52a580bSJunchao Zhang } 2342*d52a580bSJunchao Zhang 2343*d52a580bSJunchao Zhang static PetscErrorCode MatProductSymbolic_SeqAIJHIPSPARSE_SeqDENSEHIP(Mat C) 2344*d52a580bSJunchao Zhang { 2345*d52a580bSJunchao Zhang Mat_Product *product = C->product; 2346*d52a580bSJunchao Zhang Mat A, B; 2347*d52a580bSJunchao Zhang PetscInt m, n; 2348*d52a580bSJunchao Zhang PetscBool cisdense, flg; 2349*d52a580bSJunchao Zhang MatProductCtx_MatMatHipsparse *mmdata; 2350*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp; 2351*d52a580bSJunchao Zhang 2352*d52a580bSJunchao Zhang PetscFunctionBegin; 2353*d52a580bSJunchao Zhang MatCheckProduct(C, 1); 2354*d52a580bSJunchao Zhang PetscCheck(!C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data not empty"); 2355*d52a580bSJunchao Zhang A = product->A; 2356*d52a580bSJunchao Zhang B = product->B; 2357*d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg)); 2358*d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name); 2359*d52a580bSJunchao Zhang cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr; 2360*d52a580bSJunchao Zhang PetscCheck(cusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format"); 2361*d52a580bSJunchao Zhang switch (product->type) { 2362*d52a580bSJunchao Zhang case MATPRODUCT_AB: 2363*d52a580bSJunchao Zhang m = A->rmap->n; 2364*d52a580bSJunchao Zhang n = B->cmap->n; 2365*d52a580bSJunchao Zhang break; 2366*d52a580bSJunchao Zhang case MATPRODUCT_AtB: 2367*d52a580bSJunchao Zhang m = A->cmap->n; 2368*d52a580bSJunchao Zhang n = B->cmap->n; 2369*d52a580bSJunchao Zhang break; 2370*d52a580bSJunchao Zhang case MATPRODUCT_ABt: 2371*d52a580bSJunchao Zhang m = A->rmap->n; 2372*d52a580bSJunchao Zhang n = B->rmap->n; 2373*d52a580bSJunchao Zhang break; 2374*d52a580bSJunchao Zhang case MATPRODUCT_PtAP: 2375*d52a580bSJunchao Zhang m = B->cmap->n; 2376*d52a580bSJunchao Zhang n = B->cmap->n; 2377*d52a580bSJunchao Zhang break; 2378*d52a580bSJunchao Zhang case MATPRODUCT_RARt: 2379*d52a580bSJunchao Zhang m = B->rmap->n; 2380*d52a580bSJunchao Zhang n = B->rmap->n; 2381*d52a580bSJunchao Zhang break; 2382*d52a580bSJunchao Zhang default: 2383*d52a580bSJunchao Zhang SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]); 2384*d52a580bSJunchao Zhang } 2385*d52a580bSJunchao Zhang PetscCall(MatSetSizes(C, m, n, m, n)); 2386*d52a580bSJunchao Zhang /* if C is of type MATSEQDENSE (CPU), perform the operation on the GPU and then copy on the CPU */ 2387*d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)C, MATSEQDENSE, &cisdense)); 2388*d52a580bSJunchao Zhang PetscCall(MatSetType(C, MATSEQDENSEHIP)); 2389*d52a580bSJunchao Zhang 2390*d52a580bSJunchao Zhang /* product data */ 2391*d52a580bSJunchao Zhang PetscCall(PetscNew(&mmdata)); 2392*d52a580bSJunchao Zhang mmdata->cisdense = cisdense; 2393*d52a580bSJunchao Zhang /* for these products we need intermediate storage */ 2394*d52a580bSJunchao Zhang if (product->type == MATPRODUCT_RARt || product->type == MATPRODUCT_PtAP) { 2395*d52a580bSJunchao Zhang PetscCall(MatCreate(PetscObjectComm((PetscObject)C), &mmdata->X)); 2396*d52a580bSJunchao Zhang PetscCall(MatSetType(mmdata->X, MATSEQDENSEHIP)); 2397*d52a580bSJunchao Zhang /* do not preallocate, since the first call to MatDenseHIPGetArray will preallocate on the GPU for us */ 2398*d52a580bSJunchao Zhang if (product->type == MATPRODUCT_RARt) PetscCall(MatSetSizes(mmdata->X, A->rmap->n, B->rmap->n, A->rmap->n, B->rmap->n)); 2399*d52a580bSJunchao Zhang else PetscCall(MatSetSizes(mmdata->X, A->rmap->n, B->cmap->n, A->rmap->n, B->cmap->n)); 2400*d52a580bSJunchao Zhang } 2401*d52a580bSJunchao Zhang C->product->data = mmdata; 2402*d52a580bSJunchao Zhang C->product->destroy = MatProductCtxDestroy_MatMatHipsparse; 2403*d52a580bSJunchao Zhang C->ops->productnumeric = MatProductNumeric_SeqAIJHIPSPARSE_SeqDENSEHIP; 2404*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 2405*d52a580bSJunchao Zhang } 2406*d52a580bSJunchao Zhang 2407*d52a580bSJunchao Zhang static PetscErrorCode MatProductNumeric_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE(Mat C) 2408*d52a580bSJunchao Zhang { 2409*d52a580bSJunchao Zhang Mat_Product *product = C->product; 2410*d52a580bSJunchao Zhang Mat A, B; 2411*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *Acusp, *Bcusp, *Ccusp; 2412*d52a580bSJunchao Zhang Mat_SeqAIJ *c = (Mat_SeqAIJ *)C->data; 2413*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *Amat, *Bmat, *Cmat; 2414*d52a580bSJunchao Zhang CsrMatrix *Acsr, *Bcsr, *Ccsr; 2415*d52a580bSJunchao Zhang PetscBool flg; 2416*d52a580bSJunchao Zhang MatProductType ptype; 2417*d52a580bSJunchao Zhang MatProductCtx_MatMatHipsparse *mmdata; 2418*d52a580bSJunchao Zhang hipsparseSpMatDescr_t BmatSpDescr; 2419*d52a580bSJunchao Zhang hipsparseOperation_t opA = HIPSPARSE_OPERATION_NON_TRANSPOSE, opB = HIPSPARSE_OPERATION_NON_TRANSPOSE; /* hipSPARSE spgemm doesn't support transpose yet */ 2420*d52a580bSJunchao Zhang 2421*d52a580bSJunchao Zhang PetscFunctionBegin; 2422*d52a580bSJunchao Zhang MatCheckProduct(C, 1); 2423*d52a580bSJunchao Zhang PetscCheck(C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data empty"); 2424*d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)C, MATSEQAIJHIPSPARSE, &flg)); 2425*d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for C of type %s", ((PetscObject)C)->type_name); 2426*d52a580bSJunchao Zhang mmdata = (MatProductCtx_MatMatHipsparse *)C->product->data; 2427*d52a580bSJunchao Zhang A = product->A; 2428*d52a580bSJunchao Zhang B = product->B; 2429*d52a580bSJunchao Zhang if (mmdata->reusesym) { /* this happens when api_user is true, meaning that the matrix values have been already computed in the MatProductSymbolic phase */ 2430*d52a580bSJunchao Zhang mmdata->reusesym = PETSC_FALSE; 2431*d52a580bSJunchao Zhang Ccusp = (Mat_SeqAIJHIPSPARSE *)C->spptr; 2432*d52a580bSJunchao Zhang PetscCheck(Ccusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format"); 2433*d52a580bSJunchao Zhang Cmat = Ccusp->mat; 2434*d52a580bSJunchao Zhang PetscCheck(Cmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C mult struct for product type %s", MatProductTypes[C->product->type]); 2435*d52a580bSJunchao Zhang Ccsr = (CsrMatrix *)Cmat->mat; 2436*d52a580bSJunchao Zhang PetscCheck(Ccsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C CSR struct"); 2437*d52a580bSJunchao Zhang goto finalize; 2438*d52a580bSJunchao Zhang } 2439*d52a580bSJunchao Zhang if (!c->nz) goto finalize; 2440*d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg)); 2441*d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name); 2442*d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)B, MATSEQAIJHIPSPARSE, &flg)); 2443*d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for B of type %s", ((PetscObject)B)->type_name); 2444*d52a580bSJunchao Zhang PetscCheck(!A->boundtocpu, PetscObjectComm((PetscObject)C), PETSC_ERR_ARG_WRONG, "Cannot bind to CPU a HIPSPARSE matrix between MatProductSymbolic and MatProductNumeric phases"); 2445*d52a580bSJunchao Zhang PetscCheck(!B->boundtocpu, PetscObjectComm((PetscObject)C), PETSC_ERR_ARG_WRONG, "Cannot bind to CPU a HIPSPARSE matrix between MatProductSymbolic and MatProductNumeric phases"); 2446*d52a580bSJunchao Zhang Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr; 2447*d52a580bSJunchao Zhang Bcusp = (Mat_SeqAIJHIPSPARSE *)B->spptr; 2448*d52a580bSJunchao Zhang Ccusp = (Mat_SeqAIJHIPSPARSE *)C->spptr; 2449*d52a580bSJunchao Zhang PetscCheck(Acusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format"); 2450*d52a580bSJunchao Zhang PetscCheck(Bcusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format"); 2451*d52a580bSJunchao Zhang PetscCheck(Ccusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format"); 2452*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A)); 2453*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B)); 2454*d52a580bSJunchao Zhang 2455*d52a580bSJunchao Zhang ptype = product->type; 2456*d52a580bSJunchao Zhang if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) { 2457*d52a580bSJunchao Zhang ptype = MATPRODUCT_AB; 2458*d52a580bSJunchao Zhang PetscCheck(product->symbolic_used_the_fact_A_is_symmetric, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Symbolic should have been built using the fact that A is symmetric"); 2459*d52a580bSJunchao Zhang } 2460*d52a580bSJunchao Zhang if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) { 2461*d52a580bSJunchao Zhang ptype = MATPRODUCT_AB; 2462*d52a580bSJunchao Zhang PetscCheck(product->symbolic_used_the_fact_B_is_symmetric, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Symbolic should have been built using the fact that B is symmetric"); 2463*d52a580bSJunchao Zhang } 2464*d52a580bSJunchao Zhang switch (ptype) { 2465*d52a580bSJunchao Zhang case MATPRODUCT_AB: 2466*d52a580bSJunchao Zhang Amat = Acusp->mat; 2467*d52a580bSJunchao Zhang Bmat = Bcusp->mat; 2468*d52a580bSJunchao Zhang break; 2469*d52a580bSJunchao Zhang case MATPRODUCT_AtB: 2470*d52a580bSJunchao Zhang Amat = Acusp->matTranspose; 2471*d52a580bSJunchao Zhang Bmat = Bcusp->mat; 2472*d52a580bSJunchao Zhang break; 2473*d52a580bSJunchao Zhang case MATPRODUCT_ABt: 2474*d52a580bSJunchao Zhang Amat = Acusp->mat; 2475*d52a580bSJunchao Zhang Bmat = Bcusp->matTranspose; 2476*d52a580bSJunchao Zhang break; 2477*d52a580bSJunchao Zhang default: 2478*d52a580bSJunchao Zhang SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]); 2479*d52a580bSJunchao Zhang } 2480*d52a580bSJunchao Zhang Cmat = Ccusp->mat; 2481*d52a580bSJunchao Zhang PetscCheck(Amat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A mult struct for product type %s", MatProductTypes[ptype]); 2482*d52a580bSJunchao Zhang PetscCheck(Bmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B mult struct for product type %s", MatProductTypes[ptype]); 2483*d52a580bSJunchao Zhang PetscCheck(Cmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C mult struct for product type %s", MatProductTypes[ptype]); 2484*d52a580bSJunchao Zhang Acsr = (CsrMatrix *)Amat->mat; 2485*d52a580bSJunchao Zhang Bcsr = mmdata->Bcsr ? mmdata->Bcsr : (CsrMatrix *)Bmat->mat; /* B may be in compressed row storage */ 2486*d52a580bSJunchao Zhang Ccsr = (CsrMatrix *)Cmat->mat; 2487*d52a580bSJunchao Zhang PetscCheck(Acsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A CSR struct"); 2488*d52a580bSJunchao Zhang PetscCheck(Bcsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B CSR struct"); 2489*d52a580bSJunchao Zhang PetscCheck(Ccsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing C CSR struct"); 2490*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 2491*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 0, 0) 2492*d52a580bSJunchao Zhang BmatSpDescr = mmdata->Bcsr ? mmdata->matSpBDescr : Bmat->matDescr; /* B may be in compressed row storage */ 2493*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_DEVICE)); 2494*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0) 2495*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc)); 2496*d52a580bSJunchao Zhang #else 2497*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &mmdata->mmBufferSize, mmdata->mmBuffer)); 2498*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_copy(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc)); 2499*d52a580bSJunchao Zhang #endif 2500*d52a580bSJunchao Zhang #else 2501*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr_spgemm(Ccusp->handle, opA, opB, Acsr->num_rows, Bcsr->num_cols, Acsr->num_cols, Amat->descr, Acsr->num_entries, Acsr->values->data().get(), Acsr->row_offsets->data().get(), Acsr->column_indices->data().get(), Bmat->descr, 2502*d52a580bSJunchao Zhang Bcsr->num_entries, Bcsr->values->data().get(), Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Cmat->descr, Ccsr->values->data().get(), Ccsr->row_offsets->data().get(), 2503*d52a580bSJunchao Zhang Ccsr->column_indices->data().get())); 2504*d52a580bSJunchao Zhang #endif 2505*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(mmdata->flops)); 2506*d52a580bSJunchao Zhang PetscCallHIP(WaitForHIP()); 2507*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 2508*d52a580bSJunchao Zhang C->offloadmask = PETSC_OFFLOAD_GPU; 2509*d52a580bSJunchao Zhang finalize: 2510*d52a580bSJunchao Zhang /* shorter version of MatAssemblyEnd_SeqAIJ */ 2511*d52a580bSJunchao Zhang PetscCall(PetscInfo(C, "Matrix size: %" PetscInt_FMT " X %" PetscInt_FMT "; storage space: 0 unneeded, %" PetscInt_FMT " used\n", C->rmap->n, C->cmap->n, c->nz)); 2512*d52a580bSJunchao Zhang PetscCall(PetscInfo(C, "Number of mallocs during MatSetValues() is 0\n")); 2513*d52a580bSJunchao Zhang PetscCall(PetscInfo(C, "Maximum nonzeros in any row is %" PetscInt_FMT "\n", c->rmax)); 2514*d52a580bSJunchao Zhang c->reallocs = 0; 2515*d52a580bSJunchao Zhang C->info.mallocs += 0; 2516*d52a580bSJunchao Zhang C->info.nz_unneeded = 0; 2517*d52a580bSJunchao Zhang C->assembled = C->was_assembled = PETSC_TRUE; 2518*d52a580bSJunchao Zhang C->num_ass++; 2519*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 2520*d52a580bSJunchao Zhang } 2521*d52a580bSJunchao Zhang 2522*d52a580bSJunchao Zhang static PetscErrorCode MatProductSymbolic_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE(Mat C) 2523*d52a580bSJunchao Zhang { 2524*d52a580bSJunchao Zhang Mat_Product *product = C->product; 2525*d52a580bSJunchao Zhang Mat A, B; 2526*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *Acusp, *Bcusp, *Ccusp; 2527*d52a580bSJunchao Zhang Mat_SeqAIJ *a, *b, *c; 2528*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *Amat, *Bmat, *Cmat; 2529*d52a580bSJunchao Zhang CsrMatrix *Acsr, *Bcsr, *Ccsr; 2530*d52a580bSJunchao Zhang PetscInt i, j, m, n, k; 2531*d52a580bSJunchao Zhang PetscBool flg; 2532*d52a580bSJunchao Zhang MatProductType ptype; 2533*d52a580bSJunchao Zhang MatProductCtx_MatMatHipsparse *mmdata; 2534*d52a580bSJunchao Zhang PetscLogDouble flops; 2535*d52a580bSJunchao Zhang PetscBool biscompressed, ciscompressed; 2536*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 0, 0) 2537*d52a580bSJunchao Zhang int64_t C_num_rows1, C_num_cols1, C_nnz1; 2538*d52a580bSJunchao Zhang hipsparseSpMatDescr_t BmatSpDescr; 2539*d52a580bSJunchao Zhang #else 2540*d52a580bSJunchao Zhang int cnz; 2541*d52a580bSJunchao Zhang #endif 2542*d52a580bSJunchao Zhang hipsparseOperation_t opA = HIPSPARSE_OPERATION_NON_TRANSPOSE, opB = HIPSPARSE_OPERATION_NON_TRANSPOSE; /* hipSPARSE spgemm doesn't support transpose yet */ 2543*d52a580bSJunchao Zhang 2544*d52a580bSJunchao Zhang PetscFunctionBegin; 2545*d52a580bSJunchao Zhang MatCheckProduct(C, 1); 2546*d52a580bSJunchao Zhang PetscCheck(!C->product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Product data not empty"); 2547*d52a580bSJunchao Zhang A = product->A; 2548*d52a580bSJunchao Zhang B = product->B; 2549*d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)A, MATSEQAIJHIPSPARSE, &flg)); 2550*d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for type %s", ((PetscObject)A)->type_name); 2551*d52a580bSJunchao Zhang PetscCall(PetscObjectTypeCompare((PetscObject)B, MATSEQAIJHIPSPARSE, &flg)); 2552*d52a580bSJunchao Zhang PetscCheck(flg, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Not for B of type %s", ((PetscObject)B)->type_name); 2553*d52a580bSJunchao Zhang a = (Mat_SeqAIJ *)A->data; 2554*d52a580bSJunchao Zhang b = (Mat_SeqAIJ *)B->data; 2555*d52a580bSJunchao Zhang /* product data */ 2556*d52a580bSJunchao Zhang PetscCall(PetscNew(&mmdata)); 2557*d52a580bSJunchao Zhang C->product->data = mmdata; 2558*d52a580bSJunchao Zhang C->product->destroy = MatProductCtxDestroy_MatMatHipsparse; 2559*d52a580bSJunchao Zhang 2560*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A)); 2561*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B)); 2562*d52a580bSJunchao Zhang Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr; /* Access spptr after MatSeqAIJHIPSPARSECopyToGPU, not before */ 2563*d52a580bSJunchao Zhang Bcusp = (Mat_SeqAIJHIPSPARSE *)B->spptr; 2564*d52a580bSJunchao Zhang PetscCheck(Acusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format"); 2565*d52a580bSJunchao Zhang PetscCheck(Bcusp->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Only for MAT_HIPSPARSE_CSR format"); 2566*d52a580bSJunchao Zhang 2567*d52a580bSJunchao Zhang ptype = product->type; 2568*d52a580bSJunchao Zhang if (A->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_AtB) { 2569*d52a580bSJunchao Zhang ptype = MATPRODUCT_AB; 2570*d52a580bSJunchao Zhang product->symbolic_used_the_fact_A_is_symmetric = PETSC_TRUE; 2571*d52a580bSJunchao Zhang } 2572*d52a580bSJunchao Zhang if (B->symmetric == PETSC_BOOL3_TRUE && ptype == MATPRODUCT_ABt) { 2573*d52a580bSJunchao Zhang ptype = MATPRODUCT_AB; 2574*d52a580bSJunchao Zhang product->symbolic_used_the_fact_B_is_symmetric = PETSC_TRUE; 2575*d52a580bSJunchao Zhang } 2576*d52a580bSJunchao Zhang biscompressed = PETSC_FALSE; 2577*d52a580bSJunchao Zhang ciscompressed = PETSC_FALSE; 2578*d52a580bSJunchao Zhang switch (ptype) { 2579*d52a580bSJunchao Zhang case MATPRODUCT_AB: 2580*d52a580bSJunchao Zhang m = A->rmap->n; 2581*d52a580bSJunchao Zhang n = B->cmap->n; 2582*d52a580bSJunchao Zhang k = A->cmap->n; 2583*d52a580bSJunchao Zhang Amat = Acusp->mat; 2584*d52a580bSJunchao Zhang Bmat = Bcusp->mat; 2585*d52a580bSJunchao Zhang if (a->compressedrow.use) ciscompressed = PETSC_TRUE; 2586*d52a580bSJunchao Zhang if (b->compressedrow.use) biscompressed = PETSC_TRUE; 2587*d52a580bSJunchao Zhang break; 2588*d52a580bSJunchao Zhang case MATPRODUCT_AtB: 2589*d52a580bSJunchao Zhang m = A->cmap->n; 2590*d52a580bSJunchao Zhang n = B->cmap->n; 2591*d52a580bSJunchao Zhang k = A->rmap->n; 2592*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A)); 2593*d52a580bSJunchao Zhang Amat = Acusp->matTranspose; 2594*d52a580bSJunchao Zhang Bmat = Bcusp->mat; 2595*d52a580bSJunchao Zhang if (b->compressedrow.use) biscompressed = PETSC_TRUE; 2596*d52a580bSJunchao Zhang break; 2597*d52a580bSJunchao Zhang case MATPRODUCT_ABt: 2598*d52a580bSJunchao Zhang m = A->rmap->n; 2599*d52a580bSJunchao Zhang n = B->rmap->n; 2600*d52a580bSJunchao Zhang k = A->cmap->n; 2601*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(B)); 2602*d52a580bSJunchao Zhang Amat = Acusp->mat; 2603*d52a580bSJunchao Zhang Bmat = Bcusp->matTranspose; 2604*d52a580bSJunchao Zhang if (a->compressedrow.use) ciscompressed = PETSC_TRUE; 2605*d52a580bSJunchao Zhang break; 2606*d52a580bSJunchao Zhang default: 2607*d52a580bSJunchao Zhang SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Unsupported product type %s", MatProductTypes[product->type]); 2608*d52a580bSJunchao Zhang } 2609*d52a580bSJunchao Zhang 2610*d52a580bSJunchao Zhang /* create hipsparse matrix */ 2611*d52a580bSJunchao Zhang PetscCall(MatSetSizes(C, m, n, m, n)); 2612*d52a580bSJunchao Zhang PetscCall(MatSetType(C, MATSEQAIJHIPSPARSE)); 2613*d52a580bSJunchao Zhang c = (Mat_SeqAIJ *)C->data; 2614*d52a580bSJunchao Zhang Ccusp = (Mat_SeqAIJHIPSPARSE *)C->spptr; 2615*d52a580bSJunchao Zhang Cmat = new Mat_SeqAIJHIPSPARSEMultStruct; 2616*d52a580bSJunchao Zhang Ccsr = new CsrMatrix; 2617*d52a580bSJunchao Zhang 2618*d52a580bSJunchao Zhang c->compressedrow.use = ciscompressed; 2619*d52a580bSJunchao Zhang if (c->compressedrow.use) { /* if a is in compressed row, than c will be in compressed row format */ 2620*d52a580bSJunchao Zhang c->compressedrow.nrows = a->compressedrow.nrows; 2621*d52a580bSJunchao Zhang PetscCall(PetscMalloc2(c->compressedrow.nrows + 1, &c->compressedrow.i, c->compressedrow.nrows, &c->compressedrow.rindex)); 2622*d52a580bSJunchao Zhang PetscCall(PetscArraycpy(c->compressedrow.rindex, a->compressedrow.rindex, c->compressedrow.nrows)); 2623*d52a580bSJunchao Zhang Ccusp->workVector = new THRUSTARRAY(c->compressedrow.nrows); 2624*d52a580bSJunchao Zhang Cmat->cprowIndices = new THRUSTINTARRAY(c->compressedrow.nrows); 2625*d52a580bSJunchao Zhang Cmat->cprowIndices->assign(c->compressedrow.rindex, c->compressedrow.rindex + c->compressedrow.nrows); 2626*d52a580bSJunchao Zhang } else { 2627*d52a580bSJunchao Zhang c->compressedrow.nrows = 0; 2628*d52a580bSJunchao Zhang c->compressedrow.i = NULL; 2629*d52a580bSJunchao Zhang c->compressedrow.rindex = NULL; 2630*d52a580bSJunchao Zhang Ccusp->workVector = NULL; 2631*d52a580bSJunchao Zhang Cmat->cprowIndices = NULL; 2632*d52a580bSJunchao Zhang } 2633*d52a580bSJunchao Zhang Ccusp->nrows = ciscompressed ? c->compressedrow.nrows : m; 2634*d52a580bSJunchao Zhang Ccusp->mat = Cmat; 2635*d52a580bSJunchao Zhang Ccusp->mat->mat = Ccsr; 2636*d52a580bSJunchao Zhang Ccsr->num_rows = Ccusp->nrows; 2637*d52a580bSJunchao Zhang Ccsr->num_cols = n; 2638*d52a580bSJunchao Zhang Ccsr->row_offsets = new THRUSTINTARRAY32(Ccusp->nrows + 1); 2639*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&Cmat->descr)); 2640*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(Cmat->descr, HIPSPARSE_INDEX_BASE_ZERO)); 2641*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(Cmat->descr, HIPSPARSE_MATRIX_TYPE_GENERAL)); 2642*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&Cmat->alpha_one, sizeof(PetscScalar))); 2643*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&Cmat->beta_zero, sizeof(PetscScalar))); 2644*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&Cmat->beta_one, sizeof(PetscScalar))); 2645*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(Cmat->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice)); 2646*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(Cmat->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice)); 2647*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(Cmat->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice)); 2648*d52a580bSJunchao Zhang if (!Ccsr->num_rows || !Ccsr->num_cols || !a->nz || !b->nz) { /* hipsparse raise errors in different calls when matrices have zero rows/columns! */ 2649*d52a580bSJunchao Zhang thrust::fill(thrust::device, Ccsr->row_offsets->begin(), Ccsr->row_offsets->end(), 0); 2650*d52a580bSJunchao Zhang c->nz = 0; 2651*d52a580bSJunchao Zhang Ccsr->column_indices = new THRUSTINTARRAY32(c->nz); 2652*d52a580bSJunchao Zhang Ccsr->values = new THRUSTARRAY(c->nz); 2653*d52a580bSJunchao Zhang goto finalizesym; 2654*d52a580bSJunchao Zhang } 2655*d52a580bSJunchao Zhang 2656*d52a580bSJunchao Zhang PetscCheck(Amat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A mult struct for product type %s", MatProductTypes[ptype]); 2657*d52a580bSJunchao Zhang PetscCheck(Bmat, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B mult struct for product type %s", MatProductTypes[ptype]); 2658*d52a580bSJunchao Zhang Acsr = (CsrMatrix *)Amat->mat; 2659*d52a580bSJunchao Zhang if (!biscompressed) { 2660*d52a580bSJunchao Zhang Bcsr = (CsrMatrix *)Bmat->mat; 2661*d52a580bSJunchao Zhang BmatSpDescr = Bmat->matDescr; 2662*d52a580bSJunchao Zhang } else { /* we need to use row offsets for the full matrix */ 2663*d52a580bSJunchao Zhang CsrMatrix *cBcsr = (CsrMatrix *)Bmat->mat; 2664*d52a580bSJunchao Zhang Bcsr = new CsrMatrix; 2665*d52a580bSJunchao Zhang Bcsr->num_rows = B->rmap->n; 2666*d52a580bSJunchao Zhang Bcsr->num_cols = cBcsr->num_cols; 2667*d52a580bSJunchao Zhang Bcsr->num_entries = cBcsr->num_entries; 2668*d52a580bSJunchao Zhang Bcsr->column_indices = cBcsr->column_indices; 2669*d52a580bSJunchao Zhang Bcsr->values = cBcsr->values; 2670*d52a580bSJunchao Zhang if (!Bcusp->rowoffsets_gpu) { 2671*d52a580bSJunchao Zhang Bcusp->rowoffsets_gpu = new THRUSTINTARRAY32(B->rmap->n + 1); 2672*d52a580bSJunchao Zhang Bcusp->rowoffsets_gpu->assign(b->i, b->i + B->rmap->n + 1); 2673*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu((B->rmap->n + 1) * sizeof(PetscInt))); 2674*d52a580bSJunchao Zhang } 2675*d52a580bSJunchao Zhang Bcsr->row_offsets = Bcusp->rowoffsets_gpu; 2676*d52a580bSJunchao Zhang mmdata->Bcsr = Bcsr; 2677*d52a580bSJunchao Zhang if (Bcsr->num_rows && Bcsr->num_cols) { 2678*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&mmdata->matSpBDescr, Bcsr->num_rows, Bcsr->num_cols, Bcsr->num_entries, Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Bcsr->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype)); 2679*d52a580bSJunchao Zhang } 2680*d52a580bSJunchao Zhang BmatSpDescr = mmdata->matSpBDescr; 2681*d52a580bSJunchao Zhang } 2682*d52a580bSJunchao Zhang PetscCheck(Acsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing A CSR struct"); 2683*d52a580bSJunchao Zhang PetscCheck(Bcsr, PetscObjectComm((PetscObject)C), PETSC_ERR_GPU, "Missing B CSR struct"); 2684*d52a580bSJunchao Zhang /* precompute flops count */ 2685*d52a580bSJunchao Zhang if (ptype == MATPRODUCT_AB) { 2686*d52a580bSJunchao Zhang for (i = 0, flops = 0; i < A->rmap->n; i++) { 2687*d52a580bSJunchao Zhang const PetscInt st = a->i[i]; 2688*d52a580bSJunchao Zhang const PetscInt en = a->i[i + 1]; 2689*d52a580bSJunchao Zhang for (j = st; j < en; j++) { 2690*d52a580bSJunchao Zhang const PetscInt brow = a->j[j]; 2691*d52a580bSJunchao Zhang flops += 2. * (b->i[brow + 1] - b->i[brow]); 2692*d52a580bSJunchao Zhang } 2693*d52a580bSJunchao Zhang } 2694*d52a580bSJunchao Zhang } else if (ptype == MATPRODUCT_AtB) { 2695*d52a580bSJunchao Zhang for (i = 0, flops = 0; i < A->rmap->n; i++) { 2696*d52a580bSJunchao Zhang const PetscInt anzi = a->i[i + 1] - a->i[i]; 2697*d52a580bSJunchao Zhang const PetscInt bnzi = b->i[i + 1] - b->i[i]; 2698*d52a580bSJunchao Zhang flops += (2. * anzi) * bnzi; 2699*d52a580bSJunchao Zhang } 2700*d52a580bSJunchao Zhang } else flops = 0.; /* TODO */ 2701*d52a580bSJunchao Zhang 2702*d52a580bSJunchao Zhang mmdata->flops = flops; 2703*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 2704*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 0, 0) 2705*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_DEVICE)); 2706*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&Cmat->matDescr, Ccsr->num_rows, Ccsr->num_cols, 0, Ccsr->row_offsets->data().get(), NULL, NULL, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype)); 2707*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_createDescr(&mmdata->spgemmDesc)); 2708*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0) 2709*d52a580bSJunchao Zhang { 2710*d52a580bSJunchao Zhang /* hipsparseSpGEMMreuse has more reasonable APIs than hipsparseSpGEMM, so we prefer to use it. 2711*d52a580bSJunchao Zhang We follow the sample code at https://github.com/ROCmSoftwarePlatform/hipSPARSE/blob/develop/clients/include/testing_spgemmreuse_csr.hpp 2712*d52a580bSJunchao Zhang */ 2713*d52a580bSJunchao Zhang void *dBuffer1 = NULL; 2714*d52a580bSJunchao Zhang void *dBuffer2 = NULL; 2715*d52a580bSJunchao Zhang void *dBuffer3 = NULL; 2716*d52a580bSJunchao Zhang /* dBuffer4, dBuffer5 are needed by hipsparseSpGEMMreuse_compute, and therefore are stored in mmdata */ 2717*d52a580bSJunchao Zhang size_t bufferSize1 = 0; 2718*d52a580bSJunchao Zhang size_t bufferSize2 = 0; 2719*d52a580bSJunchao Zhang size_t bufferSize3 = 0; 2720*d52a580bSJunchao Zhang size_t bufferSize4 = 0; 2721*d52a580bSJunchao Zhang size_t bufferSize5 = 0; 2722*d52a580bSJunchao Zhang 2723*d52a580bSJunchao Zhang /* ask bufferSize1 bytes for external memory */ 2724*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_workEstimation(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize1, NULL)); 2725*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&dBuffer1, bufferSize1)); 2726*d52a580bSJunchao Zhang /* inspect the matrices A and B to understand the memory requirement for the next step */ 2727*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_workEstimation(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize1, dBuffer1)); 2728*d52a580bSJunchao Zhang 2729*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_nnz(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize2, NULL, &bufferSize3, NULL, &bufferSize4, NULL)); 2730*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&dBuffer2, bufferSize2)); 2731*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&dBuffer3, bufferSize3)); 2732*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&mmdata->dBuffer4, bufferSize4)); 2733*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_nnz(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize2, dBuffer2, &bufferSize3, dBuffer3, &bufferSize4, mmdata->dBuffer4)); 2734*d52a580bSJunchao Zhang PetscCallHIP(hipFree(dBuffer1)); 2735*d52a580bSJunchao Zhang PetscCallHIP(hipFree(dBuffer2)); 2736*d52a580bSJunchao Zhang 2737*d52a580bSJunchao Zhang /* get matrix C non-zero entries C_nnz1 */ 2738*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatGetSize(Cmat->matDescr, &C_num_rows1, &C_num_cols1, &C_nnz1)); 2739*d52a580bSJunchao Zhang c->nz = (PetscInt)C_nnz1; 2740*d52a580bSJunchao Zhang /* allocate matrix C */ 2741*d52a580bSJunchao Zhang Ccsr->column_indices = new THRUSTINTARRAY32(c->nz); 2742*d52a580bSJunchao Zhang PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */ 2743*d52a580bSJunchao Zhang Ccsr->values = new THRUSTARRAY(c->nz); 2744*d52a580bSJunchao Zhang PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */ 2745*d52a580bSJunchao Zhang /* update matC with the new pointers */ 2746*d52a580bSJunchao Zhang if (c->nz) { /* 5.5.1 has a bug with nz = 0, exposed by mat_tests_ex123_2_hypre */ 2747*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCsrSetPointers(Cmat->matDescr, Ccsr->row_offsets->data().get(), Ccsr->column_indices->data().get(), Ccsr->values->data().get())); 2748*d52a580bSJunchao Zhang 2749*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_copy(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize5, NULL)); 2750*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&mmdata->dBuffer5, bufferSize5)); 2751*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_copy(Ccusp->handle, opA, opB, Amat->matDescr, BmatSpDescr, Cmat->matDescr, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufferSize5, mmdata->dBuffer5)); 2752*d52a580bSJunchao Zhang PetscCallHIP(hipFree(dBuffer3)); 2753*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMMreuse_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc)); 2754*d52a580bSJunchao Zhang } 2755*d52a580bSJunchao Zhang PetscCall(PetscInfo(C, "Buffer sizes for type %s, result %" PetscInt_FMT " x %" PetscInt_FMT " (k %" PetscInt_FMT ", nzA %" PetscInt_FMT ", nzB %" PetscInt_FMT ", nzC %" PetscInt_FMT ") are: %ldKB %ldKB\n", MatProductTypes[ptype], m, n, k, a->nz, b->nz, c->nz, bufferSize4 / 1024, bufferSize5 / 1024)); 2756*d52a580bSJunchao Zhang } 2757*d52a580bSJunchao Zhang #else 2758*d52a580bSJunchao Zhang size_t bufSize2; 2759*d52a580bSJunchao Zhang /* ask bufferSize bytes for external memory */ 2760*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_workEstimation(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufSize2, NULL)); 2761*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&mmdata->mmBuffer2, bufSize2)); 2762*d52a580bSJunchao Zhang /* inspect the matrices A and B to understand the memory requirement for the next step */ 2763*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_workEstimation(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &bufSize2, mmdata->mmBuffer2)); 2764*d52a580bSJunchao Zhang /* ask bufferSize again bytes for external memory */ 2765*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &mmdata->mmBufferSize, NULL)); 2766*d52a580bSJunchao Zhang /* Similar to CUSPARSE, we need both buffers to perform the operations properly! 2767*d52a580bSJunchao Zhang mmdata->mmBuffer2 does not appear anywhere in the compute/copy API 2768*d52a580bSJunchao Zhang it only appears for the workEstimation stuff, but it seems it is needed in compute, so probably the address 2769*d52a580bSJunchao Zhang is stored in the descriptor! What a messy API... */ 2770*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&mmdata->mmBuffer, mmdata->mmBufferSize)); 2771*d52a580bSJunchao Zhang /* compute the intermediate product of A * B */ 2772*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_compute(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc, &mmdata->mmBufferSize, mmdata->mmBuffer)); 2773*d52a580bSJunchao Zhang /* get matrix C non-zero entries C_nnz1 */ 2774*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMatGetSize(Cmat->matDescr, &C_num_rows1, &C_num_cols1, &C_nnz1)); 2775*d52a580bSJunchao Zhang c->nz = (PetscInt)C_nnz1; 2776*d52a580bSJunchao Zhang PetscCall(PetscInfo(C, "Buffer sizes for type %s, result %" PetscInt_FMT " x %" PetscInt_FMT " (k %" PetscInt_FMT ", nzA %" PetscInt_FMT ", nzB %" PetscInt_FMT ", nzC %" PetscInt_FMT ") are: %ldKB %ldKB\n", MatProductTypes[ptype], m, n, k, a->nz, b->nz, c->nz, bufSize2 / 1024, 2777*d52a580bSJunchao Zhang mmdata->mmBufferSize / 1024)); 2778*d52a580bSJunchao Zhang Ccsr->column_indices = new THRUSTINTARRAY32(c->nz); 2779*d52a580bSJunchao Zhang PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */ 2780*d52a580bSJunchao Zhang Ccsr->values = new THRUSTARRAY(c->nz); 2781*d52a580bSJunchao Zhang PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */ 2782*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCsrSetPointers(Cmat->matDescr, Ccsr->row_offsets->data().get(), Ccsr->column_indices->data().get(), Ccsr->values->data().get())); 2783*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpGEMM_copy(Ccusp->handle, opA, opB, Cmat->alpha_one, Amat->matDescr, BmatSpDescr, Cmat->beta_zero, Cmat->matDescr, hipsparse_scalartype, HIPSPARSE_SPGEMM_DEFAULT, mmdata->spgemmDesc)); 2784*d52a580bSJunchao Zhang #endif 2785*d52a580bSJunchao Zhang #else 2786*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_HOST)); 2787*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsrgemmNnz(Ccusp->handle, opA, opB, Acsr->num_rows, Bcsr->num_cols, Acsr->num_cols, Amat->descr, Acsr->num_entries, Acsr->row_offsets->data().get(), Acsr->column_indices->data().get(), Bmat->descr, Bcsr->num_entries, 2788*d52a580bSJunchao Zhang Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Cmat->descr, Ccsr->row_offsets->data().get(), &cnz)); 2789*d52a580bSJunchao Zhang c->nz = cnz; 2790*d52a580bSJunchao Zhang Ccsr->column_indices = new THRUSTINTARRAY32(c->nz); 2791*d52a580bSJunchao Zhang PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */ 2792*d52a580bSJunchao Zhang Ccsr->values = new THRUSTARRAY(c->nz); 2793*d52a580bSJunchao Zhang PetscCallHIP(hipPeekAtLastError()); /* catch out of memory errors */ 2794*d52a580bSJunchao Zhang 2795*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(Ccusp->handle, HIPSPARSE_POINTER_MODE_DEVICE)); 2796*d52a580bSJunchao Zhang /* with the old gemm interface (removed from 11.0 on) we cannot compute the symbolic factorization only. 2797*d52a580bSJunchao Zhang I have tried using the gemm2 interface (alpha * A * B + beta * D), which allows to do symbolic by passing NULL for values, but it seems quite buggy when 2798*d52a580bSJunchao Zhang D is NULL, despite the fact that CUSPARSE documentation claims it is supported! */ 2799*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr_spgemm(Ccusp->handle, opA, opB, Acsr->num_rows, Bcsr->num_cols, Acsr->num_cols, Amat->descr, Acsr->num_entries, Acsr->values->data().get(), Acsr->row_offsets->data().get(), Acsr->column_indices->data().get(), Bmat->descr, 2800*d52a580bSJunchao Zhang Bcsr->num_entries, Bcsr->values->data().get(), Bcsr->row_offsets->data().get(), Bcsr->column_indices->data().get(), Cmat->descr, Ccsr->values->data().get(), Ccsr->row_offsets->data().get(), 2801*d52a580bSJunchao Zhang Ccsr->column_indices->data().get())); 2802*d52a580bSJunchao Zhang #endif 2803*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(mmdata->flops)); 2804*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 2805*d52a580bSJunchao Zhang finalizesym: 2806*d52a580bSJunchao Zhang c->free_a = PETSC_TRUE; 2807*d52a580bSJunchao Zhang PetscCall(PetscShmgetAllocateArray(c->nz, sizeof(PetscInt), (void **)&c->j)); 2808*d52a580bSJunchao Zhang PetscCall(PetscShmgetAllocateArray(m + 1, sizeof(PetscInt), (void **)&c->i)); 2809*d52a580bSJunchao Zhang c->free_ij = PETSC_TRUE; 2810*d52a580bSJunchao Zhang if (PetscDefined(USE_64BIT_INDICES)) { /* 32 to 64-bit conversion on the GPU and then copy to host (lazy) */ 2811*d52a580bSJunchao Zhang PetscInt *d_i = c->i; 2812*d52a580bSJunchao Zhang THRUSTINTARRAY ii(Ccsr->row_offsets->size()); 2813*d52a580bSJunchao Zhang THRUSTINTARRAY jj(Ccsr->column_indices->size()); 2814*d52a580bSJunchao Zhang ii = *Ccsr->row_offsets; 2815*d52a580bSJunchao Zhang jj = *Ccsr->column_indices; 2816*d52a580bSJunchao Zhang if (ciscompressed) d_i = c->compressedrow.i; 2817*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(d_i, ii.data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost)); 2818*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(c->j, jj.data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost)); 2819*d52a580bSJunchao Zhang } else { 2820*d52a580bSJunchao Zhang PetscInt *d_i = c->i; 2821*d52a580bSJunchao Zhang if (ciscompressed) d_i = c->compressedrow.i; 2822*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(d_i, Ccsr->row_offsets->data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost)); 2823*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(c->j, Ccsr->column_indices->data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost)); 2824*d52a580bSJunchao Zhang } 2825*d52a580bSJunchao Zhang if (ciscompressed) { /* need to expand host row offsets */ 2826*d52a580bSJunchao Zhang PetscInt r = 0; 2827*d52a580bSJunchao Zhang c->i[0] = 0; 2828*d52a580bSJunchao Zhang for (k = 0; k < c->compressedrow.nrows; k++) { 2829*d52a580bSJunchao Zhang const PetscInt next = c->compressedrow.rindex[k]; 2830*d52a580bSJunchao Zhang const PetscInt old = c->compressedrow.i[k]; 2831*d52a580bSJunchao Zhang for (; r < next; r++) c->i[r + 1] = old; 2832*d52a580bSJunchao Zhang } 2833*d52a580bSJunchao Zhang for (; r < m; r++) c->i[r + 1] = c->compressedrow.i[c->compressedrow.nrows]; 2834*d52a580bSJunchao Zhang } 2835*d52a580bSJunchao Zhang PetscCall(PetscLogGpuToCpu((Ccsr->column_indices->size() + Ccsr->row_offsets->size()) * sizeof(PetscInt))); 2836*d52a580bSJunchao Zhang PetscCall(PetscMalloc1(m, &c->ilen)); 2837*d52a580bSJunchao Zhang PetscCall(PetscMalloc1(m, &c->imax)); 2838*d52a580bSJunchao Zhang c->maxnz = c->nz; 2839*d52a580bSJunchao Zhang c->nonzerorowcnt = 0; 2840*d52a580bSJunchao Zhang c->rmax = 0; 2841*d52a580bSJunchao Zhang for (k = 0; k < m; k++) { 2842*d52a580bSJunchao Zhang const PetscInt nn = c->i[k + 1] - c->i[k]; 2843*d52a580bSJunchao Zhang c->ilen[k] = c->imax[k] = nn; 2844*d52a580bSJunchao Zhang c->nonzerorowcnt += (PetscInt)!!nn; 2845*d52a580bSJunchao Zhang c->rmax = PetscMax(c->rmax, nn); 2846*d52a580bSJunchao Zhang } 2847*d52a580bSJunchao Zhang PetscCall(PetscMalloc1(c->nz, &c->a)); 2848*d52a580bSJunchao Zhang Ccsr->num_entries = c->nz; 2849*d52a580bSJunchao Zhang 2850*d52a580bSJunchao Zhang C->nonzerostate++; 2851*d52a580bSJunchao Zhang PetscCall(PetscLayoutSetUp(C->rmap)); 2852*d52a580bSJunchao Zhang PetscCall(PetscLayoutSetUp(C->cmap)); 2853*d52a580bSJunchao Zhang Ccusp->nonzerostate = C->nonzerostate; 2854*d52a580bSJunchao Zhang C->offloadmask = PETSC_OFFLOAD_UNALLOCATED; 2855*d52a580bSJunchao Zhang C->preallocated = PETSC_TRUE; 2856*d52a580bSJunchao Zhang C->assembled = PETSC_FALSE; 2857*d52a580bSJunchao Zhang C->was_assembled = PETSC_FALSE; 2858*d52a580bSJunchao Zhang if (product->api_user && A->offloadmask == PETSC_OFFLOAD_BOTH && B->offloadmask == PETSC_OFFLOAD_BOTH) { /* flag the matrix C values as computed, so that the numeric phase will only call MatAssembly */ 2859*d52a580bSJunchao Zhang mmdata->reusesym = PETSC_TRUE; 2860*d52a580bSJunchao Zhang C->offloadmask = PETSC_OFFLOAD_GPU; 2861*d52a580bSJunchao Zhang } 2862*d52a580bSJunchao Zhang C->ops->productnumeric = MatProductNumeric_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE; 2863*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 2864*d52a580bSJunchao Zhang } 2865*d52a580bSJunchao Zhang 2866*d52a580bSJunchao Zhang /* handles sparse or dense B */ 2867*d52a580bSJunchao Zhang static PetscErrorCode MatProductSetFromOptions_SeqAIJHIPSPARSE(Mat mat) 2868*d52a580bSJunchao Zhang { 2869*d52a580bSJunchao Zhang Mat_Product *product = mat->product; 2870*d52a580bSJunchao Zhang PetscBool isdense = PETSC_FALSE, Biscusp = PETSC_FALSE, Ciscusp = PETSC_TRUE; 2871*d52a580bSJunchao Zhang 2872*d52a580bSJunchao Zhang PetscFunctionBegin; 2873*d52a580bSJunchao Zhang MatCheckProduct(mat, 1); 2874*d52a580bSJunchao Zhang PetscCall(PetscObjectBaseTypeCompare((PetscObject)product->B, MATSEQDENSE, &isdense)); 2875*d52a580bSJunchao Zhang if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, MATSEQAIJHIPSPARSE, &Biscusp)); 2876*d52a580bSJunchao Zhang if (product->type == MATPRODUCT_ABC) { 2877*d52a580bSJunchao Zhang Ciscusp = PETSC_FALSE; 2878*d52a580bSJunchao Zhang if (!product->C->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->C, MATSEQAIJHIPSPARSE, &Ciscusp)); 2879*d52a580bSJunchao Zhang } 2880*d52a580bSJunchao Zhang if (Biscusp && Ciscusp) { /* we can always select the CPU backend */ 2881*d52a580bSJunchao Zhang PetscBool usecpu = PETSC_FALSE; 2882*d52a580bSJunchao Zhang switch (product->type) { 2883*d52a580bSJunchao Zhang case MATPRODUCT_AB: 2884*d52a580bSJunchao Zhang if (product->api_user) { 2885*d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat"); 2886*d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 2887*d52a580bSJunchao Zhang PetscOptionsEnd(); 2888*d52a580bSJunchao Zhang } else { 2889*d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat"); 2890*d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL)); 2891*d52a580bSJunchao Zhang PetscOptionsEnd(); 2892*d52a580bSJunchao Zhang } 2893*d52a580bSJunchao Zhang break; 2894*d52a580bSJunchao Zhang case MATPRODUCT_AtB: 2895*d52a580bSJunchao Zhang if (product->api_user) { 2896*d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat"); 2897*d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 2898*d52a580bSJunchao Zhang PetscOptionsEnd(); 2899*d52a580bSJunchao Zhang } else { 2900*d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat"); 2901*d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL)); 2902*d52a580bSJunchao Zhang PetscOptionsEnd(); 2903*d52a580bSJunchao Zhang } 2904*d52a580bSJunchao Zhang break; 2905*d52a580bSJunchao Zhang case MATPRODUCT_PtAP: 2906*d52a580bSJunchao Zhang if (product->api_user) { 2907*d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat"); 2908*d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 2909*d52a580bSJunchao Zhang PetscOptionsEnd(); 2910*d52a580bSJunchao Zhang } else { 2911*d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat"); 2912*d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL)); 2913*d52a580bSJunchao Zhang PetscOptionsEnd(); 2914*d52a580bSJunchao Zhang } 2915*d52a580bSJunchao Zhang break; 2916*d52a580bSJunchao Zhang case MATPRODUCT_RARt: 2917*d52a580bSJunchao Zhang if (product->api_user) { 2918*d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatRARt", "Mat"); 2919*d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-matrart_backend_cpu", "Use CPU code", "MatRARt", usecpu, &usecpu, NULL)); 2920*d52a580bSJunchao Zhang PetscOptionsEnd(); 2921*d52a580bSJunchao Zhang } else { 2922*d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_RARt", "Mat"); 2923*d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatRARt", usecpu, &usecpu, NULL)); 2924*d52a580bSJunchao Zhang PetscOptionsEnd(); 2925*d52a580bSJunchao Zhang } 2926*d52a580bSJunchao Zhang break; 2927*d52a580bSJunchao Zhang case MATPRODUCT_ABC: 2928*d52a580bSJunchao Zhang if (product->api_user) { 2929*d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMatMult", "Mat"); 2930*d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-matmatmatmult_backend_cpu", "Use CPU code", "MatMatMatMult", usecpu, &usecpu, NULL)); 2931*d52a580bSJunchao Zhang PetscOptionsEnd(); 2932*d52a580bSJunchao Zhang } else { 2933*d52a580bSJunchao Zhang PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_ABC", "Mat"); 2934*d52a580bSJunchao Zhang PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMatMult", usecpu, &usecpu, NULL)); 2935*d52a580bSJunchao Zhang PetscOptionsEnd(); 2936*d52a580bSJunchao Zhang } 2937*d52a580bSJunchao Zhang break; 2938*d52a580bSJunchao Zhang default: 2939*d52a580bSJunchao Zhang break; 2940*d52a580bSJunchao Zhang } 2941*d52a580bSJunchao Zhang if (usecpu) Biscusp = Ciscusp = PETSC_FALSE; 2942*d52a580bSJunchao Zhang } 2943*d52a580bSJunchao Zhang /* dispatch */ 2944*d52a580bSJunchao Zhang if (isdense) { 2945*d52a580bSJunchao Zhang switch (product->type) { 2946*d52a580bSJunchao Zhang case MATPRODUCT_AB: 2947*d52a580bSJunchao Zhang case MATPRODUCT_AtB: 2948*d52a580bSJunchao Zhang case MATPRODUCT_ABt: 2949*d52a580bSJunchao Zhang case MATPRODUCT_PtAP: 2950*d52a580bSJunchao Zhang case MATPRODUCT_RARt: 2951*d52a580bSJunchao Zhang if (product->A->boundtocpu) PetscCall(MatProductSetFromOptions_SeqAIJ_SeqDense(mat)); 2952*d52a580bSJunchao Zhang else mat->ops->productsymbolic = MatProductSymbolic_SeqAIJHIPSPARSE_SeqDENSEHIP; 2953*d52a580bSJunchao Zhang break; 2954*d52a580bSJunchao Zhang case MATPRODUCT_ABC: 2955*d52a580bSJunchao Zhang mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic; 2956*d52a580bSJunchao Zhang break; 2957*d52a580bSJunchao Zhang default: 2958*d52a580bSJunchao Zhang break; 2959*d52a580bSJunchao Zhang } 2960*d52a580bSJunchao Zhang } else if (Biscusp && Ciscusp) { 2961*d52a580bSJunchao Zhang switch (product->type) { 2962*d52a580bSJunchao Zhang case MATPRODUCT_AB: 2963*d52a580bSJunchao Zhang case MATPRODUCT_AtB: 2964*d52a580bSJunchao Zhang case MATPRODUCT_ABt: 2965*d52a580bSJunchao Zhang mat->ops->productsymbolic = MatProductSymbolic_SeqAIJHIPSPARSE_SeqAIJHIPSPARSE; 2966*d52a580bSJunchao Zhang break; 2967*d52a580bSJunchao Zhang case MATPRODUCT_PtAP: 2968*d52a580bSJunchao Zhang case MATPRODUCT_RARt: 2969*d52a580bSJunchao Zhang case MATPRODUCT_ABC: 2970*d52a580bSJunchao Zhang mat->ops->productsymbolic = MatProductSymbolic_ABC_Basic; 2971*d52a580bSJunchao Zhang break; 2972*d52a580bSJunchao Zhang default: 2973*d52a580bSJunchao Zhang break; 2974*d52a580bSJunchao Zhang } 2975*d52a580bSJunchao Zhang } else PetscCall(MatProductSetFromOptions_SeqAIJ(mat)); /* fallback for AIJ */ 2976*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 2977*d52a580bSJunchao Zhang } 2978*d52a580bSJunchao Zhang 2979*d52a580bSJunchao Zhang static PetscErrorCode MatMult_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy) 2980*d52a580bSJunchao Zhang { 2981*d52a580bSJunchao Zhang PetscFunctionBegin; 2982*d52a580bSJunchao Zhang PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, NULL, yy, PETSC_FALSE, PETSC_FALSE)); 2983*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 2984*d52a580bSJunchao Zhang } 2985*d52a580bSJunchao Zhang 2986*d52a580bSJunchao Zhang static PetscErrorCode MatMultAdd_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz) 2987*d52a580bSJunchao Zhang { 2988*d52a580bSJunchao Zhang PetscFunctionBegin; 2989*d52a580bSJunchao Zhang PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, yy, zz, PETSC_FALSE, PETSC_FALSE)); 2990*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 2991*d52a580bSJunchao Zhang } 2992*d52a580bSJunchao Zhang 2993*d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTranspose_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy) 2994*d52a580bSJunchao Zhang { 2995*d52a580bSJunchao Zhang PetscFunctionBegin; 2996*d52a580bSJunchao Zhang PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, NULL, yy, PETSC_TRUE, PETSC_TRUE)); 2997*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 2998*d52a580bSJunchao Zhang } 2999*d52a580bSJunchao Zhang 3000*d52a580bSJunchao Zhang static PetscErrorCode MatMultHermitianTransposeAdd_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz) 3001*d52a580bSJunchao Zhang { 3002*d52a580bSJunchao Zhang PetscFunctionBegin; 3003*d52a580bSJunchao Zhang PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, yy, zz, PETSC_TRUE, PETSC_TRUE)); 3004*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3005*d52a580bSJunchao Zhang } 3006*d52a580bSJunchao Zhang 3007*d52a580bSJunchao Zhang static PetscErrorCode MatMultTranspose_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy) 3008*d52a580bSJunchao Zhang { 3009*d52a580bSJunchao Zhang PetscFunctionBegin; 3010*d52a580bSJunchao Zhang PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, NULL, yy, PETSC_TRUE, PETSC_FALSE)); 3011*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3012*d52a580bSJunchao Zhang } 3013*d52a580bSJunchao Zhang 3014*d52a580bSJunchao Zhang __global__ static void ScatterAdd(PetscInt n, PetscInt *idx, const PetscScalar *x, PetscScalar *y) 3015*d52a580bSJunchao Zhang { 3016*d52a580bSJunchao Zhang int i = blockIdx.x * blockDim.x + threadIdx.x; 3017*d52a580bSJunchao Zhang if (i < n) y[idx[i]] += x[i]; 3018*d52a580bSJunchao Zhang } 3019*d52a580bSJunchao Zhang 3020*d52a580bSJunchao Zhang /* z = op(A) x + y. If trans & !herm, op = ^T; if trans & herm, op = ^H; if !trans, op = no-op */ 3021*d52a580bSJunchao Zhang static PetscErrorCode MatMultAddKernel_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz, PetscBool trans, PetscBool herm) 3022*d52a580bSJunchao Zhang { 3023*d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data; 3024*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *hipsparsestruct = (Mat_SeqAIJHIPSPARSE *)A->spptr; 3025*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *matstruct; 3026*d52a580bSJunchao Zhang PetscScalar *xarray, *zarray, *dptr, *beta, *xptr; 3027*d52a580bSJunchao Zhang hipsparseOperation_t opA = HIPSPARSE_OPERATION_NON_TRANSPOSE; 3028*d52a580bSJunchao Zhang PetscBool compressed; 3029*d52a580bSJunchao Zhang PetscInt nx, ny; 3030*d52a580bSJunchao Zhang 3031*d52a580bSJunchao Zhang PetscFunctionBegin; 3032*d52a580bSJunchao Zhang PetscCheck(!herm || trans, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "Hermitian and not transpose not supported"); 3033*d52a580bSJunchao Zhang if (!a->nz) { 3034*d52a580bSJunchao Zhang if (yy) PetscCall(VecSeq_HIP::Copy(yy, zz)); 3035*d52a580bSJunchao Zhang else PetscCall(VecSeq_HIP::Set(zz, 0)); 3036*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3037*d52a580bSJunchao Zhang } 3038*d52a580bSJunchao Zhang /* The line below is necessary due to the operations that modify the matrix on the CPU (axpy, scale, etc) */ 3039*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A)); 3040*d52a580bSJunchao Zhang if (!trans) { 3041*d52a580bSJunchao Zhang matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->mat; 3042*d52a580bSJunchao Zhang PetscCheck(matstruct, PetscObjectComm((PetscObject)A), PETSC_ERR_GPU, "SeqAIJHIPSPARSE does not have a 'mat' (need to fix)"); 3043*d52a580bSJunchao Zhang } else { 3044*d52a580bSJunchao Zhang if (herm || !A->form_explicit_transpose) { 3045*d52a580bSJunchao Zhang opA = herm ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE : HIPSPARSE_OPERATION_TRANSPOSE; 3046*d52a580bSJunchao Zhang matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->mat; 3047*d52a580bSJunchao Zhang } else { 3048*d52a580bSJunchao Zhang if (!hipsparsestruct->matTranspose) PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A)); 3049*d52a580bSJunchao Zhang matstruct = (Mat_SeqAIJHIPSPARSEMultStruct *)hipsparsestruct->matTranspose; 3050*d52a580bSJunchao Zhang } 3051*d52a580bSJunchao Zhang } 3052*d52a580bSJunchao Zhang /* Does the matrix use compressed rows (i.e., drop zero rows)? */ 3053*d52a580bSJunchao Zhang compressed = matstruct->cprowIndices ? PETSC_TRUE : PETSC_FALSE; 3054*d52a580bSJunchao Zhang try { 3055*d52a580bSJunchao Zhang PetscCall(VecHIPGetArrayRead(xx, (const PetscScalar **)&xarray)); 3056*d52a580bSJunchao Zhang if (yy == zz) PetscCall(VecHIPGetArray(zz, &zarray)); /* read & write zz, so need to get up-to-date zarray on GPU */ 3057*d52a580bSJunchao Zhang else PetscCall(VecHIPGetArrayWrite(zz, &zarray)); /* write zz, so no need to init zarray on GPU */ 3058*d52a580bSJunchao Zhang 3059*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 3060*d52a580bSJunchao Zhang if (opA == HIPSPARSE_OPERATION_NON_TRANSPOSE) { 3061*d52a580bSJunchao Zhang /* z = A x + beta y. 3062*d52a580bSJunchao Zhang If A is compressed (with less rows), then Ax is shorter than the full z, so we need a work vector to store Ax. 3063*d52a580bSJunchao Zhang When A is non-compressed, and z = y, we can set beta=1 to compute y = Ax + y in one call. 3064*d52a580bSJunchao Zhang */ 3065*d52a580bSJunchao Zhang xptr = xarray; 3066*d52a580bSJunchao Zhang dptr = compressed ? hipsparsestruct->workVector->data().get() : zarray; 3067*d52a580bSJunchao Zhang beta = (yy == zz && !compressed) ? matstruct->beta_one : matstruct->beta_zero; 3068*d52a580bSJunchao Zhang /* Get length of x, y for y=Ax. ny might be shorter than the work vector's allocated length, since the work vector is 3069*d52a580bSJunchao Zhang allocated to accommodate different uses. So we get the length info directly from mat. 3070*d52a580bSJunchao Zhang */ 3071*d52a580bSJunchao Zhang if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) { 3072*d52a580bSJunchao Zhang CsrMatrix *mat = (CsrMatrix *)matstruct->mat; 3073*d52a580bSJunchao Zhang nx = mat->num_cols; 3074*d52a580bSJunchao Zhang ny = mat->num_rows; 3075*d52a580bSJunchao Zhang } 3076*d52a580bSJunchao Zhang } else { 3077*d52a580bSJunchao Zhang /* z = A^T x + beta y 3078*d52a580bSJunchao Zhang If A is compressed, then we need a work vector as the shorter version of x to compute A^T x. 3079*d52a580bSJunchao Zhang Note A^Tx is of full length, so we set beta to 1.0 if y exists. 3080*d52a580bSJunchao Zhang */ 3081*d52a580bSJunchao Zhang xptr = compressed ? hipsparsestruct->workVector->data().get() : xarray; 3082*d52a580bSJunchao Zhang dptr = zarray; 3083*d52a580bSJunchao Zhang beta = yy ? matstruct->beta_one : matstruct->beta_zero; 3084*d52a580bSJunchao Zhang if (compressed) { /* Scatter x to work vector */ 3085*d52a580bSJunchao Zhang thrust::device_ptr<PetscScalar> xarr = thrust::device_pointer_cast(xarray); 3086*d52a580bSJunchao Zhang thrust::for_each( 3087*d52a580bSJunchao Zhang #if PetscDefined(HAVE_THRUST_ASYNC) 3088*d52a580bSJunchao Zhang thrust::hip::par.on(PetscDefaultHipStream), 3089*d52a580bSJunchao Zhang #endif 3090*d52a580bSJunchao Zhang thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(xarr, matstruct->cprowIndices->begin()))), 3091*d52a580bSJunchao Zhang thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(xarr, matstruct->cprowIndices->begin()))) + matstruct->cprowIndices->size(), VecHIPEqualsReverse()); 3092*d52a580bSJunchao Zhang } 3093*d52a580bSJunchao Zhang if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) { 3094*d52a580bSJunchao Zhang CsrMatrix *mat = (CsrMatrix *)matstruct->mat; 3095*d52a580bSJunchao Zhang nx = mat->num_rows; 3096*d52a580bSJunchao Zhang ny = mat->num_cols; 3097*d52a580bSJunchao Zhang } 3098*d52a580bSJunchao Zhang } 3099*d52a580bSJunchao Zhang /* csr_spmv does y = alpha op(A) x + beta y */ 3100*d52a580bSJunchao Zhang if (hipsparsestruct->format == MAT_HIPSPARSE_CSR) { 3101*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(5, 1, 0) 3102*d52a580bSJunchao Zhang PetscCheck(opA >= 0 && opA <= 2, PETSC_COMM_SELF, PETSC_ERR_SUP, "hipSPARSE API on hipsparseOperation_t has changed and PETSc has not been updated accordingly"); 3103*d52a580bSJunchao Zhang if (!matstruct->hipSpMV[opA].initialized) { /* built on demand */ 3104*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnVec(&matstruct->hipSpMV[opA].vecXDescr, nx, xptr, hipsparse_scalartype)); 3105*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateDnVec(&matstruct->hipSpMV[opA].vecYDescr, ny, dptr, hipsparse_scalartype)); 3106*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMV_bufferSize(hipsparsestruct->handle, opA, matstruct->alpha_one, matstruct->matDescr, matstruct->hipSpMV[opA].vecXDescr, beta, matstruct->hipSpMV[opA].vecYDescr, hipsparse_scalartype, hipsparsestruct->spmvAlg, 3107*d52a580bSJunchao Zhang &matstruct->hipSpMV[opA].spmvBufferSize)); 3108*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&matstruct->hipSpMV[opA].spmvBuffer, matstruct->hipSpMV[opA].spmvBufferSize)); 3109*d52a580bSJunchao Zhang matstruct->hipSpMV[opA].initialized = PETSC_TRUE; 3110*d52a580bSJunchao Zhang } else { 3111*d52a580bSJunchao Zhang /* x, y's value pointers might change between calls, but their shape is kept, so we just update pointers */ 3112*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(matstruct->hipSpMV[opA].vecXDescr, xptr)); 3113*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDnVecSetValues(matstruct->hipSpMV[opA].vecYDescr, dptr)); 3114*d52a580bSJunchao Zhang } 3115*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpMV(hipsparsestruct->handle, opA, matstruct->alpha_one, matstruct->matDescr, /* built in MatSeqAIJHIPSPARSECopyToGPU() or MatSeqAIJHIPSPARSEFormExplicitTranspose() */ 3116*d52a580bSJunchao Zhang matstruct->hipSpMV[opA].vecXDescr, beta, matstruct->hipSpMV[opA].vecYDescr, hipsparse_scalartype, hipsparsestruct->spmvAlg, matstruct->hipSpMV[opA].spmvBuffer)); 3117*d52a580bSJunchao Zhang #else 3118*d52a580bSJunchao Zhang CsrMatrix *mat = (CsrMatrix *)matstruct->mat; 3119*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr_spmv(hipsparsestruct->handle, opA, mat->num_rows, mat->num_cols, mat->num_entries, matstruct->alpha_one, matstruct->descr, mat->values->data().get(), mat->row_offsets->data().get(), mat->column_indices->data().get(), xptr, beta, dptr)); 3120*d52a580bSJunchao Zhang #endif 3121*d52a580bSJunchao Zhang } else { 3122*d52a580bSJunchao Zhang if (hipsparsestruct->nrows) { 3123*d52a580bSJunchao Zhang hipsparseHybMat_t hybMat = (hipsparseHybMat_t)matstruct->mat; 3124*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_hyb_spmv(hipsparsestruct->handle, opA, matstruct->alpha_one, matstruct->descr, hybMat, xptr, beta, dptr)); 3125*d52a580bSJunchao Zhang } 3126*d52a580bSJunchao Zhang } 3127*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 3128*d52a580bSJunchao Zhang 3129*d52a580bSJunchao Zhang if (opA == HIPSPARSE_OPERATION_NON_TRANSPOSE) { 3130*d52a580bSJunchao Zhang if (yy) { /* MatMultAdd: zz = A*xx + yy */ 3131*d52a580bSJunchao Zhang if (compressed) { /* A is compressed. We first copy yy to zz, then ScatterAdd the work vector to zz */ 3132*d52a580bSJunchao Zhang PetscCall(VecSeq_HIP::Copy(yy, zz)); /* zz = yy */ 3133*d52a580bSJunchao Zhang } else if (zz != yy) { /* A is not compressed. zz already contains A*xx, and we just need to add yy */ 3134*d52a580bSJunchao Zhang PetscCall(VecSeq_HIP::AXPY(zz, 1.0, yy)); /* zz += yy */ 3135*d52a580bSJunchao Zhang } 3136*d52a580bSJunchao Zhang } else if (compressed) { /* MatMult: zz = A*xx. A is compressed, so we zero zz first, then ScatterAdd the work vector to zz */ 3137*d52a580bSJunchao Zhang PetscCall(VecSeq_HIP::Set(zz, 0)); 3138*d52a580bSJunchao Zhang } 3139*d52a580bSJunchao Zhang 3140*d52a580bSJunchao Zhang /* ScatterAdd the result from work vector into the full vector when A is compressed */ 3141*d52a580bSJunchao Zhang if (compressed) { 3142*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 3143*d52a580bSJunchao Zhang /* I wanted to make this for_each asynchronous but failed. thrust::async::for_each() returns an event (internally registered) 3144*d52a580bSJunchao Zhang and in the destructor of the scope, it will call hipStreamSynchronize() on this stream. One has to store all events to 3145*d52a580bSJunchao Zhang prevent that. So I just add a ScatterAdd kernel. 3146*d52a580bSJunchao Zhang */ 3147*d52a580bSJunchao Zhang #if 0 3148*d52a580bSJunchao Zhang thrust::device_ptr<PetscScalar> zptr = thrust::device_pointer_cast(zarray); 3149*d52a580bSJunchao Zhang thrust::async::for_each(thrust::hip::par.on(hipsparsestruct->stream), 3150*d52a580bSJunchao Zhang thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(zptr, matstruct->cprowIndices->begin()))), 3151*d52a580bSJunchao Zhang thrust::make_zip_iterator(thrust::make_tuple(hipsparsestruct->workVector->begin(), thrust::make_permutation_iterator(zptr, matstruct->cprowIndices->begin()))) + matstruct->cprowIndices->size(), 3152*d52a580bSJunchao Zhang VecHIPPlusEquals()); 3153*d52a580bSJunchao Zhang #else 3154*d52a580bSJunchao Zhang PetscInt n = matstruct->cprowIndices->size(); 3155*d52a580bSJunchao Zhang hipLaunchKernelGGL(ScatterAdd, dim3((n + 255) / 256), dim3(256), 0, PetscDefaultHipStream, n, matstruct->cprowIndices->data().get(), hipsparsestruct->workVector->data().get(), zarray); 3156*d52a580bSJunchao Zhang #endif 3157*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 3158*d52a580bSJunchao Zhang } 3159*d52a580bSJunchao Zhang } else { 3160*d52a580bSJunchao Zhang if (yy && yy != zz) PetscCall(VecSeq_HIP::AXPY(zz, 1.0, yy)); /* zz += yy */ 3161*d52a580bSJunchao Zhang } 3162*d52a580bSJunchao Zhang PetscCall(VecHIPRestoreArrayRead(xx, (const PetscScalar **)&xarray)); 3163*d52a580bSJunchao Zhang if (yy == zz) PetscCall(VecHIPRestoreArray(zz, &zarray)); 3164*d52a580bSJunchao Zhang else PetscCall(VecHIPRestoreArrayWrite(zz, &zarray)); 3165*d52a580bSJunchao Zhang } catch (char *ex) { 3166*d52a580bSJunchao Zhang SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "HIPSPARSE error: %s", ex); 3167*d52a580bSJunchao Zhang } 3168*d52a580bSJunchao Zhang if (yy) PetscCall(PetscLogGpuFlops(2.0 * a->nz)); 3169*d52a580bSJunchao Zhang else PetscCall(PetscLogGpuFlops(2.0 * a->nz - a->nonzerorowcnt)); 3170*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3171*d52a580bSJunchao Zhang } 3172*d52a580bSJunchao Zhang 3173*d52a580bSJunchao Zhang static PetscErrorCode MatMultTransposeAdd_SeqAIJHIPSPARSE(Mat A, Vec xx, Vec yy, Vec zz) 3174*d52a580bSJunchao Zhang { 3175*d52a580bSJunchao Zhang PetscFunctionBegin; 3176*d52a580bSJunchao Zhang PetscCall(MatMultAddKernel_SeqAIJHIPSPARSE(A, xx, yy, zz, PETSC_TRUE, PETSC_FALSE)); 3177*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3178*d52a580bSJunchao Zhang } 3179*d52a580bSJunchao Zhang 3180*d52a580bSJunchao Zhang static PetscErrorCode MatAssemblyEnd_SeqAIJHIPSPARSE(Mat A, MatAssemblyType mode) 3181*d52a580bSJunchao Zhang { 3182*d52a580bSJunchao Zhang PetscFunctionBegin; 3183*d52a580bSJunchao Zhang PetscCall(MatAssemblyEnd_SeqAIJ(A, mode)); 3184*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3185*d52a580bSJunchao Zhang } 3186*d52a580bSJunchao Zhang 3187*d52a580bSJunchao Zhang /*@ 3188*d52a580bSJunchao Zhang MatCreateSeqAIJHIPSPARSE - Creates a sparse matrix in `MATAIJHIPSPARSE` (compressed row) format. 3189*d52a580bSJunchao Zhang This matrix will ultimately pushed down to AMD GPUs and use the HIPSPARSE library for calculations. 3190*d52a580bSJunchao Zhang 3191*d52a580bSJunchao Zhang Collective 3192*d52a580bSJunchao Zhang 3193*d52a580bSJunchao Zhang Input Parameters: 3194*d52a580bSJunchao Zhang + comm - MPI communicator, set to `PETSC_COMM_SELF` 3195*d52a580bSJunchao Zhang . m - number of rows 3196*d52a580bSJunchao Zhang . n - number of columns 3197*d52a580bSJunchao Zhang . nz - number of nonzeros per row (same for all rows), ignored if `nnz` is set 3198*d52a580bSJunchao Zhang - nnz - array containing the number of nonzeros in the various rows (possibly different for each row) or `NULL` 3199*d52a580bSJunchao Zhang 3200*d52a580bSJunchao Zhang Output Parameter: 3201*d52a580bSJunchao Zhang . A - the matrix 3202*d52a580bSJunchao Zhang 3203*d52a580bSJunchao Zhang Level: intermediate 3204*d52a580bSJunchao Zhang 3205*d52a580bSJunchao Zhang Notes: 3206*d52a580bSJunchao Zhang It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`, 3207*d52a580bSJunchao Zhang `MatXXXXSetPreallocation()` paradgm instead of this routine directly. 3208*d52a580bSJunchao Zhang [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation`] 3209*d52a580bSJunchao Zhang 3210*d52a580bSJunchao Zhang The AIJ format (compressed row storage), is fully compatible with standard Fortran 3211*d52a580bSJunchao Zhang storage. That is, the stored row and column indices can begin at 3212*d52a580bSJunchao Zhang either one (as in Fortran) or zero. 3213*d52a580bSJunchao Zhang 3214*d52a580bSJunchao Zhang Specify the preallocated storage with either `nz` or `nnz` (not both). 3215*d52a580bSJunchao Zhang Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory 3216*d52a580bSJunchao Zhang allocation. 3217*d52a580bSJunchao Zhang 3218*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`, `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MATSEQAIJHIPSPARSE`, `MATAIJHIPSPARSE` 3219*d52a580bSJunchao Zhang @*/ 3220*d52a580bSJunchao Zhang PetscErrorCode MatCreateSeqAIJHIPSPARSE(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt nz, const PetscInt nnz[], Mat *A) 3221*d52a580bSJunchao Zhang { 3222*d52a580bSJunchao Zhang PetscFunctionBegin; 3223*d52a580bSJunchao Zhang PetscCall(MatCreate(comm, A)); 3224*d52a580bSJunchao Zhang PetscCall(MatSetSizes(*A, m, n, m, n)); 3225*d52a580bSJunchao Zhang PetscCall(MatSetType(*A, MATSEQAIJHIPSPARSE)); 3226*d52a580bSJunchao Zhang PetscCall(MatSeqAIJSetPreallocation_SeqAIJ(*A, nz, (PetscInt *)nnz)); 3227*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3228*d52a580bSJunchao Zhang } 3229*d52a580bSJunchao Zhang 3230*d52a580bSJunchao Zhang static PetscErrorCode MatDestroy_SeqAIJHIPSPARSE(Mat A) 3231*d52a580bSJunchao Zhang { 3232*d52a580bSJunchao Zhang PetscFunctionBegin; 3233*d52a580bSJunchao Zhang if (A->factortype == MAT_FACTOR_NONE) PetscCall(MatSeqAIJHIPSPARSE_Destroy(A)); 3234*d52a580bSJunchao Zhang else PetscCall(MatSeqAIJHIPSPARSETriFactors_Destroy((Mat_SeqAIJHIPSPARSETriFactors **)&A->spptr)); 3235*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSeqAIJCopySubArray_C", NULL)); 3236*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatHIPSPARSESetFormat_C", NULL)); 3237*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatHIPSPARSESetUseCPUSolve_C", NULL)); 3238*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdensehip_C", NULL)); 3239*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdense_C", NULL)); 3240*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqaijhipsparse_C", NULL)); 3241*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatFactorGetSolverType_C", NULL)); 3242*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL)); 3243*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL)); 3244*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatConvert_seqaijhipsparse_hypre_C", NULL)); 3245*d52a580bSJunchao Zhang PetscCall(MatDestroy_SeqAIJ(A)); 3246*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3247*d52a580bSJunchao Zhang } 3248*d52a580bSJunchao Zhang 3249*d52a580bSJunchao Zhang static PetscErrorCode MatDuplicate_SeqAIJHIPSPARSE(Mat A, MatDuplicateOption cpvalues, Mat *B) 3250*d52a580bSJunchao Zhang { 3251*d52a580bSJunchao Zhang PetscFunctionBegin; 3252*d52a580bSJunchao Zhang PetscCall(MatDuplicate_SeqAIJ(A, cpvalues, B)); 3253*d52a580bSJunchao Zhang PetscCall(MatConvert_SeqAIJ_SeqAIJHIPSPARSE(*B, MATSEQAIJHIPSPARSE, MAT_INPLACE_MATRIX, B)); 3254*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3255*d52a580bSJunchao Zhang } 3256*d52a580bSJunchao Zhang 3257*d52a580bSJunchao Zhang static PetscErrorCode MatAXPY_SeqAIJHIPSPARSE(Mat Y, PetscScalar a, Mat X, MatStructure str) 3258*d52a580bSJunchao Zhang { 3259*d52a580bSJunchao Zhang Mat_SeqAIJ *x = (Mat_SeqAIJ *)X->data, *y = (Mat_SeqAIJ *)Y->data; 3260*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cy; 3261*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cx; 3262*d52a580bSJunchao Zhang PetscScalar *ay; 3263*d52a580bSJunchao Zhang const PetscScalar *ax; 3264*d52a580bSJunchao Zhang CsrMatrix *csry, *csrx; 3265*d52a580bSJunchao Zhang 3266*d52a580bSJunchao Zhang PetscFunctionBegin; 3267*d52a580bSJunchao Zhang cy = (Mat_SeqAIJHIPSPARSE *)Y->spptr; 3268*d52a580bSJunchao Zhang cx = (Mat_SeqAIJHIPSPARSE *)X->spptr; 3269*d52a580bSJunchao Zhang if (X->ops->axpy != Y->ops->axpy) { 3270*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(Y, PETSC_FALSE)); 3271*d52a580bSJunchao Zhang PetscCall(MatAXPY_SeqAIJ(Y, a, X, str)); 3272*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3273*d52a580bSJunchao Zhang } 3274*d52a580bSJunchao Zhang /* if we are here, it means both matrices are bound to GPU */ 3275*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(Y)); 3276*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(X)); 3277*d52a580bSJunchao Zhang PetscCheck(cy->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)Y), PETSC_ERR_GPU, "only MAT_HIPSPARSE_CSR supported"); 3278*d52a580bSJunchao Zhang PetscCheck(cx->format == MAT_HIPSPARSE_CSR, PetscObjectComm((PetscObject)X), PETSC_ERR_GPU, "only MAT_HIPSPARSE_CSR supported"); 3279*d52a580bSJunchao Zhang csry = (CsrMatrix *)cy->mat->mat; 3280*d52a580bSJunchao Zhang csrx = (CsrMatrix *)cx->mat->mat; 3281*d52a580bSJunchao Zhang /* see if we can turn this into a hipblas axpy */ 3282*d52a580bSJunchao Zhang if (str != SAME_NONZERO_PATTERN && x->nz == y->nz && !x->compressedrow.use && !y->compressedrow.use) { 3283*d52a580bSJunchao Zhang bool eq = thrust::equal(thrust::device, csry->row_offsets->begin(), csry->row_offsets->end(), csrx->row_offsets->begin()); 3284*d52a580bSJunchao Zhang if (eq) eq = thrust::equal(thrust::device, csry->column_indices->begin(), csry->column_indices->end(), csrx->column_indices->begin()); 3285*d52a580bSJunchao Zhang if (eq) str = SAME_NONZERO_PATTERN; 3286*d52a580bSJunchao Zhang } 3287*d52a580bSJunchao Zhang /* spgeam is buggy with one column */ 3288*d52a580bSJunchao Zhang if (Y->cmap->n == 1 && str != SAME_NONZERO_PATTERN) str = DIFFERENT_NONZERO_PATTERN; 3289*d52a580bSJunchao Zhang if (str == SUBSET_NONZERO_PATTERN) { 3290*d52a580bSJunchao Zhang PetscScalar b = 1.0; 3291*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0) 3292*d52a580bSJunchao Zhang size_t bufferSize; 3293*d52a580bSJunchao Zhang void *buffer; 3294*d52a580bSJunchao Zhang #endif 3295*d52a580bSJunchao Zhang 3296*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetArrayRead(X, &ax)); 3297*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetArray(Y, &ay)); 3298*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(cy->handle, HIPSPARSE_POINTER_MODE_HOST)); 3299*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0) 3300*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr_spgeam_bufferSize(cy->handle, Y->rmap->n, Y->cmap->n, &a, cx->mat->descr, x->nz, ax, csrx->row_offsets->data().get(), csrx->column_indices->data().get(), &b, cy->mat->descr, y->nz, ay, csry->row_offsets->data().get(), 3301*d52a580bSJunchao Zhang csry->column_indices->data().get(), cy->mat->descr, ay, csry->row_offsets->data().get(), csry->column_indices->data().get(), &bufferSize)); 3302*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc(&buffer, bufferSize)); 3303*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 3304*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr_spgeam(cy->handle, Y->rmap->n, Y->cmap->n, &a, cx->mat->descr, x->nz, ax, csrx->row_offsets->data().get(), csrx->column_indices->data().get(), &b, cy->mat->descr, y->nz, ay, csry->row_offsets->data().get(), 3305*d52a580bSJunchao Zhang csry->column_indices->data().get(), cy->mat->descr, ay, csry->row_offsets->data().get(), csry->column_indices->data().get(), buffer)); 3306*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(x->nz + y->nz)); 3307*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 3308*d52a580bSJunchao Zhang PetscCallHIP(hipFree(buffer)); 3309*d52a580bSJunchao Zhang #else 3310*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 3311*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparse_csr_spgeam(cy->handle, Y->rmap->n, Y->cmap->n, &a, cx->mat->descr, x->nz, ax, csrx->row_offsets->data().get(), csrx->column_indices->data().get(), &b, cy->mat->descr, y->nz, ay, csry->row_offsets->data().get(), 3312*d52a580bSJunchao Zhang csry->column_indices->data().get(), cy->mat->descr, ay, csry->row_offsets->data().get(), csry->column_indices->data().get())); 3313*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(x->nz + y->nz)); 3314*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 3315*d52a580bSJunchao Zhang #endif 3316*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetPointerMode(cy->handle, HIPSPARSE_POINTER_MODE_DEVICE)); 3317*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSERestoreArrayRead(X, &ax)); 3318*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSERestoreArray(Y, &ay)); 3319*d52a580bSJunchao Zhang } else if (str == SAME_NONZERO_PATTERN) { 3320*d52a580bSJunchao Zhang hipblasHandle_t hipblasv2handle; 3321*d52a580bSJunchao Zhang PetscBLASInt one = 1, bnz = 1; 3322*d52a580bSJunchao Zhang 3323*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetArrayRead(X, &ax)); 3324*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetArray(Y, &ay)); 3325*d52a580bSJunchao Zhang PetscCall(PetscHIPBLASGetHandle(&hipblasv2handle)); 3326*d52a580bSJunchao Zhang PetscCall(PetscBLASIntCast(x->nz, &bnz)); 3327*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 3328*d52a580bSJunchao Zhang PetscCallHIPBLAS(hipblasXaxpy(hipblasv2handle, bnz, &a, ax, one, ay, one)); 3329*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(2.0 * bnz)); 3330*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 3331*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSERestoreArrayRead(X, &ax)); 3332*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSERestoreArray(Y, &ay)); 3333*d52a580bSJunchao Zhang } else { 3334*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(Y, PETSC_FALSE)); 3335*d52a580bSJunchao Zhang PetscCall(MatAXPY_SeqAIJ(Y, a, X, str)); 3336*d52a580bSJunchao Zhang } 3337*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3338*d52a580bSJunchao Zhang } 3339*d52a580bSJunchao Zhang 3340*d52a580bSJunchao Zhang static PetscErrorCode MatScale_SeqAIJHIPSPARSE(Mat Y, PetscScalar a) 3341*d52a580bSJunchao Zhang { 3342*d52a580bSJunchao Zhang Mat_SeqAIJ *y = (Mat_SeqAIJ *)Y->data; 3343*d52a580bSJunchao Zhang PetscScalar *ay; 3344*d52a580bSJunchao Zhang hipblasHandle_t hipblasv2handle; 3345*d52a580bSJunchao Zhang PetscBLASInt one = 1, bnz = 1; 3346*d52a580bSJunchao Zhang 3347*d52a580bSJunchao Zhang PetscFunctionBegin; 3348*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetArray(Y, &ay)); 3349*d52a580bSJunchao Zhang PetscCall(PetscHIPBLASGetHandle(&hipblasv2handle)); 3350*d52a580bSJunchao Zhang PetscCall(PetscBLASIntCast(y->nz, &bnz)); 3351*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 3352*d52a580bSJunchao Zhang PetscCallHIPBLAS(hipblasXscal(hipblasv2handle, bnz, &a, ay, one)); 3353*d52a580bSJunchao Zhang PetscCall(PetscLogGpuFlops(bnz)); 3354*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 3355*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSERestoreArray(Y, &ay)); 3356*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3357*d52a580bSJunchao Zhang } 3358*d52a580bSJunchao Zhang 3359*d52a580bSJunchao Zhang static PetscErrorCode MatZeroEntries_SeqAIJHIPSPARSE(Mat A) 3360*d52a580bSJunchao Zhang { 3361*d52a580bSJunchao Zhang PetscBool both = PETSC_FALSE; 3362*d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data; 3363*d52a580bSJunchao Zhang 3364*d52a580bSJunchao Zhang PetscFunctionBegin; 3365*d52a580bSJunchao Zhang if (A->factortype == MAT_FACTOR_NONE) { 3366*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *spptr = (Mat_SeqAIJHIPSPARSE *)A->spptr; 3367*d52a580bSJunchao Zhang if (spptr->mat) { 3368*d52a580bSJunchao Zhang CsrMatrix *matrix = (CsrMatrix *)spptr->mat->mat; 3369*d52a580bSJunchao Zhang if (matrix->values) { 3370*d52a580bSJunchao Zhang both = PETSC_TRUE; 3371*d52a580bSJunchao Zhang thrust::fill(thrust::device, matrix->values->begin(), matrix->values->end(), 0.); 3372*d52a580bSJunchao Zhang } 3373*d52a580bSJunchao Zhang } 3374*d52a580bSJunchao Zhang if (spptr->matTranspose) { 3375*d52a580bSJunchao Zhang CsrMatrix *matrix = (CsrMatrix *)spptr->matTranspose->mat; 3376*d52a580bSJunchao Zhang if (matrix->values) thrust::fill(thrust::device, matrix->values->begin(), matrix->values->end(), 0.); 3377*d52a580bSJunchao Zhang } 3378*d52a580bSJunchao Zhang } 3379*d52a580bSJunchao Zhang //PetscCall(MatZeroEntries_SeqAIJ(A)); 3380*d52a580bSJunchao Zhang PetscCall(PetscArrayzero(a->a, a->i[A->rmap->n])); 3381*d52a580bSJunchao Zhang if (both) A->offloadmask = PETSC_OFFLOAD_BOTH; 3382*d52a580bSJunchao Zhang else A->offloadmask = PETSC_OFFLOAD_CPU; 3383*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3384*d52a580bSJunchao Zhang } 3385*d52a580bSJunchao Zhang 3386*d52a580bSJunchao Zhang static PetscErrorCode MatGetCurrentMemType_SeqAIJHIPSPARSE(PETSC_UNUSED Mat A, PetscMemType *m) 3387*d52a580bSJunchao Zhang { 3388*d52a580bSJunchao Zhang PetscFunctionBegin; 3389*d52a580bSJunchao Zhang *m = PETSC_MEMTYPE_HIP; 3390*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3391*d52a580bSJunchao Zhang } 3392*d52a580bSJunchao Zhang 3393*d52a580bSJunchao Zhang static PetscErrorCode MatBindToCPU_SeqAIJHIPSPARSE(Mat A, PetscBool flg) 3394*d52a580bSJunchao Zhang { 3395*d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data; 3396*d52a580bSJunchao Zhang 3397*d52a580bSJunchao Zhang PetscFunctionBegin; 3398*d52a580bSJunchao Zhang if (A->factortype != MAT_FACTOR_NONE) { 3399*d52a580bSJunchao Zhang A->boundtocpu = flg; 3400*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3401*d52a580bSJunchao Zhang } 3402*d52a580bSJunchao Zhang if (flg) { 3403*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyFromGPU(A)); 3404*d52a580bSJunchao Zhang 3405*d52a580bSJunchao Zhang A->ops->scale = MatScale_SeqAIJ; 3406*d52a580bSJunchao Zhang A->ops->axpy = MatAXPY_SeqAIJ; 3407*d52a580bSJunchao Zhang A->ops->zeroentries = MatZeroEntries_SeqAIJ; 3408*d52a580bSJunchao Zhang A->ops->mult = MatMult_SeqAIJ; 3409*d52a580bSJunchao Zhang A->ops->multadd = MatMultAdd_SeqAIJ; 3410*d52a580bSJunchao Zhang A->ops->multtranspose = MatMultTranspose_SeqAIJ; 3411*d52a580bSJunchao Zhang A->ops->multtransposeadd = MatMultTransposeAdd_SeqAIJ; 3412*d52a580bSJunchao Zhang A->ops->multhermitiantranspose = NULL; 3413*d52a580bSJunchao Zhang A->ops->multhermitiantransposeadd = NULL; 3414*d52a580bSJunchao Zhang A->ops->productsetfromoptions = MatProductSetFromOptions_SeqAIJ; 3415*d52a580bSJunchao Zhang A->ops->getcurrentmemtype = NULL; 3416*d52a580bSJunchao Zhang PetscCall(PetscMemzero(a->ops, sizeof(Mat_SeqAIJOps))); 3417*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSeqAIJCopySubArray_C", NULL)); 3418*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdensehip_C", NULL)); 3419*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdense_C", NULL)); 3420*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL)); 3421*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL)); 3422*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqaijhipsparse_C", NULL)); 3423*d52a580bSJunchao Zhang } else { 3424*d52a580bSJunchao Zhang A->ops->scale = MatScale_SeqAIJHIPSPARSE; 3425*d52a580bSJunchao Zhang A->ops->axpy = MatAXPY_SeqAIJHIPSPARSE; 3426*d52a580bSJunchao Zhang A->ops->zeroentries = MatZeroEntries_SeqAIJHIPSPARSE; 3427*d52a580bSJunchao Zhang A->ops->mult = MatMult_SeqAIJHIPSPARSE; 3428*d52a580bSJunchao Zhang A->ops->multadd = MatMultAdd_SeqAIJHIPSPARSE; 3429*d52a580bSJunchao Zhang A->ops->multtranspose = MatMultTranspose_SeqAIJHIPSPARSE; 3430*d52a580bSJunchao Zhang A->ops->multtransposeadd = MatMultTransposeAdd_SeqAIJHIPSPARSE; 3431*d52a580bSJunchao Zhang A->ops->multhermitiantranspose = MatMultHermitianTranspose_SeqAIJHIPSPARSE; 3432*d52a580bSJunchao Zhang A->ops->multhermitiantransposeadd = MatMultHermitianTransposeAdd_SeqAIJHIPSPARSE; 3433*d52a580bSJunchao Zhang A->ops->productsetfromoptions = MatProductSetFromOptions_SeqAIJHIPSPARSE; 3434*d52a580bSJunchao Zhang A->ops->getcurrentmemtype = MatGetCurrentMemType_SeqAIJHIPSPARSE; 3435*d52a580bSJunchao Zhang a->ops->getarray = MatSeqAIJGetArray_SeqAIJHIPSPARSE; 3436*d52a580bSJunchao Zhang a->ops->restorearray = MatSeqAIJRestoreArray_SeqAIJHIPSPARSE; 3437*d52a580bSJunchao Zhang a->ops->getarrayread = MatSeqAIJGetArrayRead_SeqAIJHIPSPARSE; 3438*d52a580bSJunchao Zhang a->ops->restorearrayread = MatSeqAIJRestoreArrayRead_SeqAIJHIPSPARSE; 3439*d52a580bSJunchao Zhang a->ops->getarraywrite = MatSeqAIJGetArrayWrite_SeqAIJHIPSPARSE; 3440*d52a580bSJunchao Zhang a->ops->restorearraywrite = MatSeqAIJRestoreArrayWrite_SeqAIJHIPSPARSE; 3441*d52a580bSJunchao Zhang a->ops->getcsrandmemtype = MatSeqAIJGetCSRAndMemType_SeqAIJHIPSPARSE; 3442*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSeqAIJCopySubArray_C", MatSeqAIJCopySubArray_SeqAIJHIPSPARSE)); 3443*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdensehip_C", MatProductSetFromOptions_SeqAIJHIPSPARSE)); 3444*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqdense_C", MatProductSetFromOptions_SeqAIJHIPSPARSE)); 3445*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_SeqAIJHIPSPARSE)); 3446*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", MatSetValuesCOO_SeqAIJHIPSPARSE)); 3447*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatProductSetFromOptions_seqaijhipsparse_seqaijhipsparse_C", MatProductSetFromOptions_SeqAIJHIPSPARSE)); 3448*d52a580bSJunchao Zhang } 3449*d52a580bSJunchao Zhang A->boundtocpu = flg; 3450*d52a580bSJunchao Zhang if (flg && a->inode.size_csr) a->inode.use = PETSC_TRUE; 3451*d52a580bSJunchao Zhang else a->inode.use = PETSC_FALSE; 3452*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3453*d52a580bSJunchao Zhang } 3454*d52a580bSJunchao Zhang 3455*d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatConvert_SeqAIJ_SeqAIJHIPSPARSE(Mat A, MatType mtype, MatReuse reuse, Mat *newmat) 3456*d52a580bSJunchao Zhang { 3457*d52a580bSJunchao Zhang Mat B; 3458*d52a580bSJunchao Zhang 3459*d52a580bSJunchao Zhang PetscFunctionBegin; 3460*d52a580bSJunchao Zhang PetscCall(PetscDeviceInitialize(PETSC_DEVICE_HIP)); /* first use of HIPSPARSE may be via MatConvert */ 3461*d52a580bSJunchao Zhang if (reuse == MAT_INITIAL_MATRIX) { 3462*d52a580bSJunchao Zhang PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat)); 3463*d52a580bSJunchao Zhang } else if (reuse == MAT_REUSE_MATRIX) { 3464*d52a580bSJunchao Zhang PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN)); 3465*d52a580bSJunchao Zhang } 3466*d52a580bSJunchao Zhang B = *newmat; 3467*d52a580bSJunchao Zhang PetscCall(PetscFree(B->defaultvectype)); 3468*d52a580bSJunchao Zhang PetscCall(PetscStrallocpy(VECHIP, &B->defaultvectype)); 3469*d52a580bSJunchao Zhang if (reuse != MAT_REUSE_MATRIX && !B->spptr) { 3470*d52a580bSJunchao Zhang if (B->factortype == MAT_FACTOR_NONE) { 3471*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *spptr; 3472*d52a580bSJunchao Zhang PetscCall(PetscNew(&spptr)); 3473*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreate(&spptr->handle)); 3474*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetStream(spptr->handle, PetscDefaultHipStream)); 3475*d52a580bSJunchao Zhang spptr->format = MAT_HIPSPARSE_CSR; 3476*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0) 3477*d52a580bSJunchao Zhang spptr->spmvAlg = HIPSPARSE_SPMV_CSR_ALG1; 3478*d52a580bSJunchao Zhang #else 3479*d52a580bSJunchao Zhang spptr->spmvAlg = HIPSPARSE_CSRMV_ALG1; /* default, since we only support csr */ 3480*d52a580bSJunchao Zhang #endif 3481*d52a580bSJunchao Zhang spptr->spmmAlg = HIPSPARSE_SPMM_CSR_ALG1; /* default, only support column-major dense matrix B */ 3482*d52a580bSJunchao Zhang //spptr->csr2cscAlg = HIPSPARSE_CSR2CSC_ALG1; 3483*d52a580bSJunchao Zhang 3484*d52a580bSJunchao Zhang B->spptr = spptr; 3485*d52a580bSJunchao Zhang } else { 3486*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *spptr; 3487*d52a580bSJunchao Zhang 3488*d52a580bSJunchao Zhang PetscCall(PetscNew(&spptr)); 3489*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreate(&spptr->handle)); 3490*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetStream(spptr->handle, PetscDefaultHipStream)); 3491*d52a580bSJunchao Zhang B->spptr = spptr; 3492*d52a580bSJunchao Zhang } 3493*d52a580bSJunchao Zhang B->offloadmask = PETSC_OFFLOAD_UNALLOCATED; 3494*d52a580bSJunchao Zhang } 3495*d52a580bSJunchao Zhang B->ops->assemblyend = MatAssemblyEnd_SeqAIJHIPSPARSE; 3496*d52a580bSJunchao Zhang B->ops->destroy = MatDestroy_SeqAIJHIPSPARSE; 3497*d52a580bSJunchao Zhang B->ops->setoption = MatSetOption_SeqAIJHIPSPARSE; 3498*d52a580bSJunchao Zhang B->ops->setfromoptions = MatSetFromOptions_SeqAIJHIPSPARSE; 3499*d52a580bSJunchao Zhang B->ops->bindtocpu = MatBindToCPU_SeqAIJHIPSPARSE; 3500*d52a580bSJunchao Zhang B->ops->duplicate = MatDuplicate_SeqAIJHIPSPARSE; 3501*d52a580bSJunchao Zhang B->ops->getcurrentmemtype = MatGetCurrentMemType_SeqAIJHIPSPARSE; 3502*d52a580bSJunchao Zhang 3503*d52a580bSJunchao Zhang PetscCall(MatBindToCPU_SeqAIJHIPSPARSE(B, PETSC_FALSE)); 3504*d52a580bSJunchao Zhang PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATSEQAIJHIPSPARSE)); 3505*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatHIPSPARSESetFormat_C", MatHIPSPARSESetFormat_SeqAIJHIPSPARSE)); 3506*d52a580bSJunchao Zhang #if defined(PETSC_HAVE_HYPRE) 3507*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatConvert_seqaijhipsparse_hypre_C", MatConvert_AIJ_HYPRE)); 3508*d52a580bSJunchao Zhang #endif 3509*d52a580bSJunchao Zhang PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatHIPSPARSESetUseCPUSolve_C", MatHIPSPARSESetUseCPUSolve_SeqAIJHIPSPARSE)); 3510*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3511*d52a580bSJunchao Zhang } 3512*d52a580bSJunchao Zhang 3513*d52a580bSJunchao Zhang PETSC_EXTERN PetscErrorCode MatCreate_SeqAIJHIPSPARSE(Mat B) 3514*d52a580bSJunchao Zhang { 3515*d52a580bSJunchao Zhang PetscFunctionBegin; 3516*d52a580bSJunchao Zhang PetscCall(MatCreate_SeqAIJ(B)); 3517*d52a580bSJunchao Zhang PetscCall(MatConvert_SeqAIJ_SeqAIJHIPSPARSE(B, MATSEQAIJHIPSPARSE, MAT_INPLACE_MATRIX, &B)); 3518*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3519*d52a580bSJunchao Zhang } 3520*d52a580bSJunchao Zhang 3521*d52a580bSJunchao Zhang /*MC 3522*d52a580bSJunchao Zhang MATSEQAIJHIPSPARSE - MATAIJHIPSPARSE = "(seq)aijhipsparse" - A matrix type to be used for sparse matrices on AMD GPUs 3523*d52a580bSJunchao Zhang 3524*d52a580bSJunchao Zhang A matrix type whose data resides on AMD GPUs. These matrices can be in either 3525*d52a580bSJunchao Zhang CSR, ELL, or Hybrid format. 3526*d52a580bSJunchao Zhang All matrix calculations are performed on AMD/NVIDIA GPUs using the HIPSPARSE library. 3527*d52a580bSJunchao Zhang 3528*d52a580bSJunchao Zhang Options Database Keys: 3529*d52a580bSJunchao Zhang + -mat_type aijhipsparse - sets the matrix type to `MATSEQAIJHIPSPARSE` 3530*d52a580bSJunchao Zhang . -mat_hipsparse_storage_format csr - sets the storage format of matrices (for `MatMult()` and factors in `MatSolve()`). 3531*d52a580bSJunchao Zhang Other options include ell (ellpack) or hyb (hybrid). 3532*d52a580bSJunchao Zhang . -mat_hipsparse_mult_storage_format csr - sets the storage format of matrices (for `MatMult()`). Other options include ell (ellpack) or hyb (hybrid). 3533*d52a580bSJunchao Zhang - -mat_hipsparse_use_cpu_solve - Do `MatSolve()` on the CPU 3534*d52a580bSJunchao Zhang 3535*d52a580bSJunchao Zhang Level: beginner 3536*d52a580bSJunchao Zhang 3537*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatCreateSeqAIJHIPSPARSE()`, `MATAIJHIPSPARSE`, `MatCreateAIJHIPSPARSE()`, `MatHIPSPARSESetFormat()`, `MatHIPSPARSEStorageFormat`, `MatHIPSPARSEFormatOperation` 3538*d52a580bSJunchao Zhang M*/ 3539*d52a580bSJunchao Zhang 3540*d52a580bSJunchao Zhang PETSC_INTERN PetscErrorCode MatSolverTypeRegister_HIPSPARSE(void) 3541*d52a580bSJunchao Zhang { 3542*d52a580bSJunchao Zhang PetscFunctionBegin; 3543*d52a580bSJunchao Zhang PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_LU, MatGetFactor_seqaijhipsparse_hipsparse)); 3544*d52a580bSJunchao Zhang PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_CHOLESKY, MatGetFactor_seqaijhipsparse_hipsparse)); 3545*d52a580bSJunchao Zhang PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_ILU, MatGetFactor_seqaijhipsparse_hipsparse)); 3546*d52a580bSJunchao Zhang PetscCall(MatSolverTypeRegister(MATSOLVERHIPSPARSE, MATSEQAIJHIPSPARSE, MAT_FACTOR_ICC, MatGetFactor_seqaijhipsparse_hipsparse)); 3547*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3548*d52a580bSJunchao Zhang } 3549*d52a580bSJunchao Zhang 3550*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSE_Destroy(Mat mat) 3551*d52a580bSJunchao Zhang { 3552*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp = static_cast<Mat_SeqAIJHIPSPARSE *>(mat->spptr); 3553*d52a580bSJunchao Zhang 3554*d52a580bSJunchao Zhang PetscFunctionBegin; 3555*d52a580bSJunchao Zhang if (cusp) { 3556*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&cusp->mat, cusp->format)); 3557*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&cusp->matTranspose, cusp->format)); 3558*d52a580bSJunchao Zhang delete cusp->workVector; 3559*d52a580bSJunchao Zhang delete cusp->rowoffsets_gpu; 3560*d52a580bSJunchao Zhang delete cusp->csr2csc_i; 3561*d52a580bSJunchao Zhang delete cusp->coords; 3562*d52a580bSJunchao Zhang if (cusp->handle) PetscCallHIPSPARSE(hipsparseDestroy(cusp->handle)); 3563*d52a580bSJunchao Zhang PetscCall(PetscFree(mat->spptr)); 3564*d52a580bSJunchao Zhang } 3565*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3566*d52a580bSJunchao Zhang } 3567*d52a580bSJunchao Zhang 3568*d52a580bSJunchao Zhang static PetscErrorCode CsrMatrix_Destroy(CsrMatrix **mat) 3569*d52a580bSJunchao Zhang { 3570*d52a580bSJunchao Zhang PetscFunctionBegin; 3571*d52a580bSJunchao Zhang if (*mat) { 3572*d52a580bSJunchao Zhang delete (*mat)->values; 3573*d52a580bSJunchao Zhang delete (*mat)->column_indices; 3574*d52a580bSJunchao Zhang delete (*mat)->row_offsets; 3575*d52a580bSJunchao Zhang delete *mat; 3576*d52a580bSJunchao Zhang *mat = 0; 3577*d52a580bSJunchao Zhang } 3578*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3579*d52a580bSJunchao Zhang } 3580*d52a580bSJunchao Zhang 3581*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSETriFactorStruct **trifactor) 3582*d52a580bSJunchao Zhang { 3583*d52a580bSJunchao Zhang PetscFunctionBegin; 3584*d52a580bSJunchao Zhang if (*trifactor) { 3585*d52a580bSJunchao Zhang if ((*trifactor)->descr) PetscCallHIPSPARSE(hipsparseDestroyMatDescr((*trifactor)->descr)); 3586*d52a580bSJunchao Zhang if ((*trifactor)->solveInfo) PetscCallHIPSPARSE(hipsparseDestroyCsrsvInfo((*trifactor)->solveInfo)); 3587*d52a580bSJunchao Zhang PetscCall(CsrMatrix_Destroy(&(*trifactor)->csrMat)); 3588*d52a580bSJunchao Zhang if ((*trifactor)->solveBuffer) PetscCallHIP(hipFree((*trifactor)->solveBuffer)); 3589*d52a580bSJunchao Zhang if ((*trifactor)->AA_h) PetscCallHIP(hipHostFree((*trifactor)->AA_h)); 3590*d52a580bSJunchao Zhang if ((*trifactor)->csr2cscBuffer) PetscCallHIP(hipFree((*trifactor)->csr2cscBuffer)); 3591*d52a580bSJunchao Zhang PetscCall(PetscFree(*trifactor)); 3592*d52a580bSJunchao Zhang } 3593*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3594*d52a580bSJunchao Zhang } 3595*d52a580bSJunchao Zhang 3596*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEMultStruct_Destroy(Mat_SeqAIJHIPSPARSEMultStruct **matstruct, MatHIPSPARSEStorageFormat format) 3597*d52a580bSJunchao Zhang { 3598*d52a580bSJunchao Zhang CsrMatrix *mat; 3599*d52a580bSJunchao Zhang 3600*d52a580bSJunchao Zhang PetscFunctionBegin; 3601*d52a580bSJunchao Zhang if (*matstruct) { 3602*d52a580bSJunchao Zhang if ((*matstruct)->mat) { 3603*d52a580bSJunchao Zhang if (format == MAT_HIPSPARSE_ELL || format == MAT_HIPSPARSE_HYB) { 3604*d52a580bSJunchao Zhang hipsparseHybMat_t hybMat = (hipsparseHybMat_t)(*matstruct)->mat; 3605*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyHybMat(hybMat)); 3606*d52a580bSJunchao Zhang } else { 3607*d52a580bSJunchao Zhang mat = (CsrMatrix *)(*matstruct)->mat; 3608*d52a580bSJunchao Zhang PetscCall(CsrMatrix_Destroy(&mat)); 3609*d52a580bSJunchao Zhang } 3610*d52a580bSJunchao Zhang } 3611*d52a580bSJunchao Zhang if ((*matstruct)->descr) PetscCallHIPSPARSE(hipsparseDestroyMatDescr((*matstruct)->descr)); 3612*d52a580bSJunchao Zhang delete (*matstruct)->cprowIndices; 3613*d52a580bSJunchao Zhang if ((*matstruct)->alpha_one) PetscCallHIP(hipFree((*matstruct)->alpha_one)); 3614*d52a580bSJunchao Zhang if ((*matstruct)->beta_zero) PetscCallHIP(hipFree((*matstruct)->beta_zero)); 3615*d52a580bSJunchao Zhang if ((*matstruct)->beta_one) PetscCallHIP(hipFree((*matstruct)->beta_one)); 3616*d52a580bSJunchao Zhang 3617*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *mdata = *matstruct; 3618*d52a580bSJunchao Zhang if (mdata->matDescr) PetscCallHIPSPARSE(hipsparseDestroySpMat(mdata->matDescr)); 3619*d52a580bSJunchao Zhang for (int i = 0; i < 3; i++) { 3620*d52a580bSJunchao Zhang if (mdata->hipSpMV[i].initialized) { 3621*d52a580bSJunchao Zhang PetscCallHIP(hipFree(mdata->hipSpMV[i].spmvBuffer)); 3622*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyDnVec(mdata->hipSpMV[i].vecXDescr)); 3623*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyDnVec(mdata->hipSpMV[i].vecYDescr)); 3624*d52a580bSJunchao Zhang } 3625*d52a580bSJunchao Zhang } 3626*d52a580bSJunchao Zhang delete *matstruct; 3627*d52a580bSJunchao Zhang *matstruct = NULL; 3628*d52a580bSJunchao Zhang } 3629*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3630*d52a580bSJunchao Zhang } 3631*d52a580bSJunchao Zhang 3632*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSETriFactors_Reset(Mat_SeqAIJHIPSPARSETriFactors_p *trifactors) 3633*d52a580bSJunchao Zhang { 3634*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSETriFactors *fs = *trifactors; 3635*d52a580bSJunchao Zhang 3636*d52a580bSJunchao Zhang PetscFunctionBegin; 3637*d52a580bSJunchao Zhang if (fs) { 3638*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->loTriFactorPtr)); 3639*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->upTriFactorPtr)); 3640*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->loTriFactorPtrTranspose)); 3641*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&fs->upTriFactorPtrTranspose)); 3642*d52a580bSJunchao Zhang delete fs->rpermIndices; 3643*d52a580bSJunchao Zhang delete fs->cpermIndices; 3644*d52a580bSJunchao Zhang delete fs->workVector; 3645*d52a580bSJunchao Zhang fs->rpermIndices = NULL; 3646*d52a580bSJunchao Zhang fs->cpermIndices = NULL; 3647*d52a580bSJunchao Zhang fs->workVector = NULL; 3648*d52a580bSJunchao Zhang fs->init_dev_prop = PETSC_FALSE; 3649*d52a580bSJunchao Zhang #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0) 3650*d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->csrRowPtr)); 3651*d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->csrColIdx)); 3652*d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->csrVal)); 3653*d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->X)); 3654*d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->Y)); 3655*d52a580bSJunchao Zhang // PetscCallHIP(hipFree(fs->factBuffer_M)); /* No needed since factBuffer_M shares with one of spsvBuffer_L/U */ 3656*d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->spsvBuffer_L)); 3657*d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->spsvBuffer_U)); 3658*d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->spsvBuffer_Lt)); 3659*d52a580bSJunchao Zhang PetscCallHIP(hipFree(fs->spsvBuffer_Ut)); 3660*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyMatDescr(fs->matDescr_M)); 3661*d52a580bSJunchao Zhang if (fs->spMatDescr_L) PetscCallHIPSPARSE(hipsparseDestroySpMat(fs->spMatDescr_L)); 3662*d52a580bSJunchao Zhang if (fs->spMatDescr_U) PetscCallHIPSPARSE(hipsparseDestroySpMat(fs->spMatDescr_U)); 3663*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_L)); 3664*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_Lt)); 3665*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_U)); 3666*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSpSV_destroyDescr(fs->spsvDescr_Ut)); 3667*d52a580bSJunchao Zhang if (fs->dnVecDescr_X) PetscCallHIPSPARSE(hipsparseDestroyDnVec(fs->dnVecDescr_X)); 3668*d52a580bSJunchao Zhang if (fs->dnVecDescr_Y) PetscCallHIPSPARSE(hipsparseDestroyDnVec(fs->dnVecDescr_Y)); 3669*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyCsrilu02Info(fs->ilu0Info_M)); 3670*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseDestroyCsric02Info(fs->ic0Info_M)); 3671*d52a580bSJunchao Zhang 3672*d52a580bSJunchao Zhang fs->createdTransposeSpSVDescr = PETSC_FALSE; 3673*d52a580bSJunchao Zhang fs->updatedTransposeSpSVAnalysis = PETSC_FALSE; 3674*d52a580bSJunchao Zhang #endif 3675*d52a580bSJunchao Zhang } 3676*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3677*d52a580bSJunchao Zhang } 3678*d52a580bSJunchao Zhang 3679*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSETriFactors_Destroy(Mat_SeqAIJHIPSPARSETriFactors **trifactors) 3680*d52a580bSJunchao Zhang { 3681*d52a580bSJunchao Zhang hipsparseHandle_t handle; 3682*d52a580bSJunchao Zhang 3683*d52a580bSJunchao Zhang PetscFunctionBegin; 3684*d52a580bSJunchao Zhang if (*trifactors) { 3685*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSETriFactors_Reset(trifactors)); 3686*d52a580bSJunchao Zhang if ((handle = (*trifactors)->handle)) PetscCallHIPSPARSE(hipsparseDestroy(handle)); 3687*d52a580bSJunchao Zhang PetscCall(PetscFree(*trifactors)); 3688*d52a580bSJunchao Zhang } 3689*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3690*d52a580bSJunchao Zhang } 3691*d52a580bSJunchao Zhang 3692*d52a580bSJunchao Zhang struct IJCompare { 3693*d52a580bSJunchao Zhang __host__ __device__ inline bool operator()(const thrust::tuple<PetscInt, PetscInt> &t1, const thrust::tuple<PetscInt, PetscInt> &t2) 3694*d52a580bSJunchao Zhang { 3695*d52a580bSJunchao Zhang if (t1.get<0>() < t2.get<0>()) return true; 3696*d52a580bSJunchao Zhang if (t1.get<0>() == t2.get<0>()) return t1.get<1>() < t2.get<1>(); 3697*d52a580bSJunchao Zhang return false; 3698*d52a580bSJunchao Zhang } 3699*d52a580bSJunchao Zhang }; 3700*d52a580bSJunchao Zhang 3701*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJHIPSPARSEInvalidateTranspose(Mat A, PetscBool destroy) 3702*d52a580bSJunchao Zhang { 3703*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr; 3704*d52a580bSJunchao Zhang 3705*d52a580bSJunchao Zhang PetscFunctionBegin; 3706*d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE); 3707*d52a580bSJunchao Zhang if (!cusp) PetscFunctionReturn(PETSC_SUCCESS); 3708*d52a580bSJunchao Zhang if (destroy) { 3709*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEMultStruct_Destroy(&cusp->matTranspose, cusp->format)); 3710*d52a580bSJunchao Zhang delete cusp->csr2csc_i; 3711*d52a580bSJunchao Zhang cusp->csr2csc_i = NULL; 3712*d52a580bSJunchao Zhang } 3713*d52a580bSJunchao Zhang A->transupdated = PETSC_FALSE; 3714*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3715*d52a580bSJunchao Zhang } 3716*d52a580bSJunchao Zhang 3717*d52a580bSJunchao Zhang static PetscErrorCode MatCOOStructDestroy_SeqAIJHIPSPARSE(PetscCtxRt data) 3718*d52a580bSJunchao Zhang { 3719*d52a580bSJunchao Zhang MatCOOStruct_SeqAIJ *coo = *(MatCOOStruct_SeqAIJ **)data; 3720*d52a580bSJunchao Zhang 3721*d52a580bSJunchao Zhang PetscFunctionBegin; 3722*d52a580bSJunchao Zhang PetscCallHIP(hipFree(coo->perm)); 3723*d52a580bSJunchao Zhang PetscCallHIP(hipFree(coo->jmap)); 3724*d52a580bSJunchao Zhang PetscCall(PetscFree(coo)); 3725*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3726*d52a580bSJunchao Zhang } 3727*d52a580bSJunchao Zhang 3728*d52a580bSJunchao Zhang static PetscErrorCode MatSetPreallocationCOO_SeqAIJHIPSPARSE(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) 3729*d52a580bSJunchao Zhang { 3730*d52a580bSJunchao Zhang PetscBool dev_ij = PETSC_FALSE; 3731*d52a580bSJunchao Zhang PetscMemType mtype = PETSC_MEMTYPE_HOST; 3732*d52a580bSJunchao Zhang PetscInt *i, *j; 3733*d52a580bSJunchao Zhang PetscContainer container_h; 3734*d52a580bSJunchao Zhang MatCOOStruct_SeqAIJ *coo_h, *coo_d; 3735*d52a580bSJunchao Zhang 3736*d52a580bSJunchao Zhang PetscFunctionBegin; 3737*d52a580bSJunchao Zhang PetscCall(PetscGetMemType(coo_i, &mtype)); 3738*d52a580bSJunchao Zhang if (PetscMemTypeDevice(mtype)) { 3739*d52a580bSJunchao Zhang dev_ij = PETSC_TRUE; 3740*d52a580bSJunchao Zhang PetscCall(PetscMalloc2(coo_n, &i, coo_n, &j)); 3741*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(i, coo_i, coo_n * sizeof(PetscInt), hipMemcpyDeviceToHost)); 3742*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(j, coo_j, coo_n * sizeof(PetscInt), hipMemcpyDeviceToHost)); 3743*d52a580bSJunchao Zhang } else { 3744*d52a580bSJunchao Zhang i = coo_i; 3745*d52a580bSJunchao Zhang j = coo_j; 3746*d52a580bSJunchao Zhang } 3747*d52a580bSJunchao Zhang PetscCall(MatSetPreallocationCOO_SeqAIJ(mat, coo_n, i, j)); 3748*d52a580bSJunchao Zhang if (dev_ij) PetscCall(PetscFree2(i, j)); 3749*d52a580bSJunchao Zhang mat->offloadmask = PETSC_OFFLOAD_CPU; 3750*d52a580bSJunchao Zhang // Create the GPU memory 3751*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(mat)); 3752*d52a580bSJunchao Zhang 3753*d52a580bSJunchao Zhang // Copy the COO struct to device 3754*d52a580bSJunchao Zhang PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h)); 3755*d52a580bSJunchao Zhang PetscCall(PetscContainerGetPointer(container_h, &coo_h)); 3756*d52a580bSJunchao Zhang PetscCall(PetscMalloc1(1, &coo_d)); 3757*d52a580bSJunchao Zhang *coo_d = *coo_h; // do a shallow copy and then amend some fields that need to be different 3758*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&coo_d->jmap, (coo_h->nz + 1) * sizeof(PetscCount))); 3759*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(coo_d->jmap, coo_h->jmap, (coo_h->nz + 1) * sizeof(PetscCount), hipMemcpyHostToDevice)); 3760*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&coo_d->perm, coo_h->Atot * sizeof(PetscCount))); 3761*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(coo_d->perm, coo_h->perm, coo_h->Atot * sizeof(PetscCount), hipMemcpyHostToDevice)); 3762*d52a580bSJunchao Zhang 3763*d52a580bSJunchao Zhang // Put the COO struct in a container and then attach that to the matrix 3764*d52a580bSJunchao Zhang PetscCall(PetscObjectContainerCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", coo_d, MatCOOStructDestroy_SeqAIJHIPSPARSE)); 3765*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3766*d52a580bSJunchao Zhang } 3767*d52a580bSJunchao Zhang 3768*d52a580bSJunchao Zhang __global__ static void MatAddCOOValues(const PetscScalar kv[], PetscCount nnz, const PetscCount jmap[], const PetscCount perm[], InsertMode imode, PetscScalar a[]) 3769*d52a580bSJunchao Zhang { 3770*d52a580bSJunchao Zhang PetscCount i = blockIdx.x * blockDim.x + threadIdx.x; 3771*d52a580bSJunchao Zhang const PetscCount grid_size = gridDim.x * blockDim.x; 3772*d52a580bSJunchao Zhang for (; i < nnz; i += grid_size) { 3773*d52a580bSJunchao Zhang PetscScalar sum = 0.0; 3774*d52a580bSJunchao Zhang for (PetscCount k = jmap[i]; k < jmap[i + 1]; k++) sum += kv[perm[k]]; 3775*d52a580bSJunchao Zhang a[i] = (imode == INSERT_VALUES ? 0.0 : a[i]) + sum; 3776*d52a580bSJunchao Zhang } 3777*d52a580bSJunchao Zhang } 3778*d52a580bSJunchao Zhang 3779*d52a580bSJunchao Zhang static PetscErrorCode MatSetValuesCOO_SeqAIJHIPSPARSE(Mat A, const PetscScalar v[], InsertMode imode) 3780*d52a580bSJunchao Zhang { 3781*d52a580bSJunchao Zhang Mat_SeqAIJ *seq = (Mat_SeqAIJ *)A->data; 3782*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *dev = (Mat_SeqAIJHIPSPARSE *)A->spptr; 3783*d52a580bSJunchao Zhang PetscCount Annz = seq->nz; 3784*d52a580bSJunchao Zhang PetscMemType memtype; 3785*d52a580bSJunchao Zhang const PetscScalar *v1 = v; 3786*d52a580bSJunchao Zhang PetscScalar *Aa; 3787*d52a580bSJunchao Zhang PetscContainer container; 3788*d52a580bSJunchao Zhang MatCOOStruct_SeqAIJ *coo; 3789*d52a580bSJunchao Zhang 3790*d52a580bSJunchao Zhang PetscFunctionBegin; 3791*d52a580bSJunchao Zhang if (!dev->mat) PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A)); 3792*d52a580bSJunchao Zhang 3793*d52a580bSJunchao Zhang PetscCall(PetscObjectQuery((PetscObject)A, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container)); 3794*d52a580bSJunchao Zhang PetscCall(PetscContainerGetPointer(container, &coo)); 3795*d52a580bSJunchao Zhang 3796*d52a580bSJunchao Zhang PetscCall(PetscGetMemType(v, &memtype)); 3797*d52a580bSJunchao Zhang if (PetscMemTypeHost(memtype)) { /* If user gave v[] in host, we might need to copy it to device if any */ 3798*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&v1, coo->n * sizeof(PetscScalar))); 3799*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy((void *)v1, v, coo->n * sizeof(PetscScalar), hipMemcpyHostToDevice)); 3800*d52a580bSJunchao Zhang } 3801*d52a580bSJunchao Zhang 3802*d52a580bSJunchao Zhang if (imode == INSERT_VALUES) PetscCall(MatSeqAIJHIPSPARSEGetArrayWrite(A, &Aa)); 3803*d52a580bSJunchao Zhang else PetscCall(MatSeqAIJHIPSPARSEGetArray(A, &Aa)); 3804*d52a580bSJunchao Zhang 3805*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 3806*d52a580bSJunchao Zhang if (Annz) { 3807*d52a580bSJunchao Zhang hipLaunchKernelGGL(HIP_KERNEL_NAME(MatAddCOOValues), dim3((Annz + 255) / 256), dim3(256), 0, PetscDefaultHipStream, v1, Annz, coo->jmap, coo->perm, imode, Aa); 3808*d52a580bSJunchao Zhang PetscCallHIP(hipPeekAtLastError()); 3809*d52a580bSJunchao Zhang } 3810*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 3811*d52a580bSJunchao Zhang 3812*d52a580bSJunchao Zhang if (imode == INSERT_VALUES) PetscCall(MatSeqAIJHIPSPARSERestoreArrayWrite(A, &Aa)); 3813*d52a580bSJunchao Zhang else PetscCall(MatSeqAIJHIPSPARSERestoreArray(A, &Aa)); 3814*d52a580bSJunchao Zhang 3815*d52a580bSJunchao Zhang if (PetscMemTypeHost(memtype)) PetscCallHIP(hipFree((void *)v1)); 3816*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3817*d52a580bSJunchao Zhang } 3818*d52a580bSJunchao Zhang 3819*d52a580bSJunchao Zhang /*@C 3820*d52a580bSJunchao Zhang MatSeqAIJHIPSPARSEGetIJ - returns the device row storage `i` and `j` indices for `MATSEQAIJHIPSPARSE` matrices. 3821*d52a580bSJunchao Zhang 3822*d52a580bSJunchao Zhang Not Collective 3823*d52a580bSJunchao Zhang 3824*d52a580bSJunchao Zhang Input Parameters: 3825*d52a580bSJunchao Zhang + A - the matrix 3826*d52a580bSJunchao Zhang - compressed - `PETSC_TRUE` or `PETSC_FALSE` indicating the matrix data structure should be always returned in compressed form 3827*d52a580bSJunchao Zhang 3828*d52a580bSJunchao Zhang Output Parameters: 3829*d52a580bSJunchao Zhang + i - the CSR row pointers 3830*d52a580bSJunchao Zhang - j - the CSR column indices 3831*d52a580bSJunchao Zhang 3832*d52a580bSJunchao Zhang Level: developer 3833*d52a580bSJunchao Zhang 3834*d52a580bSJunchao Zhang Note: 3835*d52a580bSJunchao Zhang When compressed is true, the CSR structure does not contain empty rows 3836*d52a580bSJunchao Zhang 3837*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSERestoreIJ()`, `MatSeqAIJHIPSPARSEGetArrayRead()` 3838*d52a580bSJunchao Zhang @*/ 3839*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetIJ(Mat A, PetscBool compressed, const int *i[], const int *j[]) 3840*d52a580bSJunchao Zhang { 3841*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr; 3842*d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data; 3843*d52a580bSJunchao Zhang CsrMatrix *csr; 3844*d52a580bSJunchao Zhang 3845*d52a580bSJunchao Zhang PetscFunctionBegin; 3846*d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 3847*d52a580bSJunchao Zhang if (!i || !j) PetscFunctionReturn(PETSC_SUCCESS); 3848*d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE); 3849*d52a580bSJunchao Zhang PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented"); 3850*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A)); 3851*d52a580bSJunchao Zhang PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct"); 3852*d52a580bSJunchao Zhang csr = (CsrMatrix *)cusp->mat->mat; 3853*d52a580bSJunchao Zhang if (i) { 3854*d52a580bSJunchao Zhang if (!compressed && a->compressedrow.use) { /* need full row offset */ 3855*d52a580bSJunchao Zhang if (!cusp->rowoffsets_gpu) { 3856*d52a580bSJunchao Zhang cusp->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1); 3857*d52a580bSJunchao Zhang cusp->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1); 3858*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu((A->rmap->n + 1) * sizeof(PetscInt))); 3859*d52a580bSJunchao Zhang } 3860*d52a580bSJunchao Zhang *i = cusp->rowoffsets_gpu->data().get(); 3861*d52a580bSJunchao Zhang } else *i = csr->row_offsets->data().get(); 3862*d52a580bSJunchao Zhang } 3863*d52a580bSJunchao Zhang if (j) *j = csr->column_indices->data().get(); 3864*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3865*d52a580bSJunchao Zhang } 3866*d52a580bSJunchao Zhang 3867*d52a580bSJunchao Zhang /*@C 3868*d52a580bSJunchao Zhang MatSeqAIJHIPSPARSERestoreIJ - restore the device row storage `i` and `j` indices obtained with `MatSeqAIJHIPSPARSEGetIJ()` 3869*d52a580bSJunchao Zhang 3870*d52a580bSJunchao Zhang Not Collective 3871*d52a580bSJunchao Zhang 3872*d52a580bSJunchao Zhang Input Parameters: 3873*d52a580bSJunchao Zhang + A - the matrix 3874*d52a580bSJunchao Zhang . compressed - `PETSC_TRUE` or `PETSC_FALSE` indicating the matrix data structure should be always returned in compressed form 3875*d52a580bSJunchao Zhang . i - the CSR row pointers 3876*d52a580bSJunchao Zhang - j - the CSR column indices 3877*d52a580bSJunchao Zhang 3878*d52a580bSJunchao Zhang Level: developer 3879*d52a580bSJunchao Zhang 3880*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetIJ()` 3881*d52a580bSJunchao Zhang @*/ 3882*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreIJ(Mat A, PetscBool compressed, const int *i[], const int *j[]) 3883*d52a580bSJunchao Zhang { 3884*d52a580bSJunchao Zhang PetscFunctionBegin; 3885*d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 3886*d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE); 3887*d52a580bSJunchao Zhang if (i) *i = NULL; 3888*d52a580bSJunchao Zhang if (j) *j = NULL; 3889*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3890*d52a580bSJunchao Zhang } 3891*d52a580bSJunchao Zhang 3892*d52a580bSJunchao Zhang /*@C 3893*d52a580bSJunchao Zhang MatSeqAIJHIPSPARSEGetArrayRead - gives read-only access to the array where the device data for a `MATSEQAIJHIPSPARSE` matrix is stored 3894*d52a580bSJunchao Zhang 3895*d52a580bSJunchao Zhang Not Collective 3896*d52a580bSJunchao Zhang 3897*d52a580bSJunchao Zhang Input Parameter: 3898*d52a580bSJunchao Zhang . A - a `MATSEQAIJHIPSPARSE` matrix 3899*d52a580bSJunchao Zhang 3900*d52a580bSJunchao Zhang Output Parameter: 3901*d52a580bSJunchao Zhang . a - pointer to the device data 3902*d52a580bSJunchao Zhang 3903*d52a580bSJunchao Zhang Level: developer 3904*d52a580bSJunchao Zhang 3905*d52a580bSJunchao Zhang Note: 3906*d52a580bSJunchao Zhang May trigger host-device copies if the up-to-date matrix data is on host 3907*d52a580bSJunchao Zhang 3908*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArray()`, `MatSeqAIJHIPSPARSEGetArrayWrite()`, `MatSeqAIJHIPSPARSERestoreArrayRead()` 3909*d52a580bSJunchao Zhang @*/ 3910*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetArrayRead(Mat A, const PetscScalar *a[]) 3911*d52a580bSJunchao Zhang { 3912*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr; 3913*d52a580bSJunchao Zhang CsrMatrix *csr; 3914*d52a580bSJunchao Zhang 3915*d52a580bSJunchao Zhang PetscFunctionBegin; 3916*d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 3917*d52a580bSJunchao Zhang PetscAssertPointer(a, 2); 3918*d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE); 3919*d52a580bSJunchao Zhang PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented"); 3920*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A)); 3921*d52a580bSJunchao Zhang PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct"); 3922*d52a580bSJunchao Zhang csr = (CsrMatrix *)cusp->mat->mat; 3923*d52a580bSJunchao Zhang PetscCheck(csr->values, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing HIP memory"); 3924*d52a580bSJunchao Zhang *a = csr->values->data().get(); 3925*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3926*d52a580bSJunchao Zhang } 3927*d52a580bSJunchao Zhang 3928*d52a580bSJunchao Zhang /*@C 3929*d52a580bSJunchao Zhang MatSeqAIJHIPSPARSERestoreArrayRead - restore the read-only access array obtained from `MatSeqAIJHIPSPARSEGetArrayRead()` 3930*d52a580bSJunchao Zhang 3931*d52a580bSJunchao Zhang Not Collective 3932*d52a580bSJunchao Zhang 3933*d52a580bSJunchao Zhang Input Parameters: 3934*d52a580bSJunchao Zhang + A - a `MATSEQAIJHIPSPARSE` matrix 3935*d52a580bSJunchao Zhang - a - pointer to the device data 3936*d52a580bSJunchao Zhang 3937*d52a580bSJunchao Zhang Level: developer 3938*d52a580bSJunchao Zhang 3939*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArrayRead()` 3940*d52a580bSJunchao Zhang @*/ 3941*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreArrayRead(Mat A, const PetscScalar *a[]) 3942*d52a580bSJunchao Zhang { 3943*d52a580bSJunchao Zhang PetscFunctionBegin; 3944*d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 3945*d52a580bSJunchao Zhang PetscAssertPointer(a, 2); 3946*d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE); 3947*d52a580bSJunchao Zhang *a = NULL; 3948*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3949*d52a580bSJunchao Zhang } 3950*d52a580bSJunchao Zhang 3951*d52a580bSJunchao Zhang /*@C 3952*d52a580bSJunchao Zhang MatSeqAIJHIPSPARSEGetArray - gives read-write access to the array where the device data for a `MATSEQAIJHIPSPARSE` matrix is stored 3953*d52a580bSJunchao Zhang 3954*d52a580bSJunchao Zhang Not Collective 3955*d52a580bSJunchao Zhang 3956*d52a580bSJunchao Zhang Input Parameter: 3957*d52a580bSJunchao Zhang . A - a `MATSEQAIJHIPSPARSE` matrix 3958*d52a580bSJunchao Zhang 3959*d52a580bSJunchao Zhang Output Parameter: 3960*d52a580bSJunchao Zhang . a - pointer to the device data 3961*d52a580bSJunchao Zhang 3962*d52a580bSJunchao Zhang Level: developer 3963*d52a580bSJunchao Zhang 3964*d52a580bSJunchao Zhang Note: 3965*d52a580bSJunchao Zhang May trigger host-device copies if up-to-date matrix data is on host 3966*d52a580bSJunchao Zhang 3967*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArrayRead()`, `MatSeqAIJHIPSPARSEGetArrayWrite()`, `MatSeqAIJHIPSPARSERestoreArray()` 3968*d52a580bSJunchao Zhang @*/ 3969*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetArray(Mat A, PetscScalar *a[]) 3970*d52a580bSJunchao Zhang { 3971*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr; 3972*d52a580bSJunchao Zhang CsrMatrix *csr; 3973*d52a580bSJunchao Zhang 3974*d52a580bSJunchao Zhang PetscFunctionBegin; 3975*d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 3976*d52a580bSJunchao Zhang PetscAssertPointer(a, 2); 3977*d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE); 3978*d52a580bSJunchao Zhang PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented"); 3979*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A)); 3980*d52a580bSJunchao Zhang PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct"); 3981*d52a580bSJunchao Zhang csr = (CsrMatrix *)cusp->mat->mat; 3982*d52a580bSJunchao Zhang PetscCheck(csr->values, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing HIP memory"); 3983*d52a580bSJunchao Zhang *a = csr->values->data().get(); 3984*d52a580bSJunchao Zhang A->offloadmask = PETSC_OFFLOAD_GPU; 3985*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_FALSE)); 3986*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 3987*d52a580bSJunchao Zhang } 3988*d52a580bSJunchao Zhang /*@C 3989*d52a580bSJunchao Zhang MatSeqAIJHIPSPARSERestoreArray - restore the read-write access array obtained from `MatSeqAIJHIPSPARSEGetArray()` 3990*d52a580bSJunchao Zhang 3991*d52a580bSJunchao Zhang Not Collective 3992*d52a580bSJunchao Zhang 3993*d52a580bSJunchao Zhang Input Parameters: 3994*d52a580bSJunchao Zhang + A - a `MATSEQAIJHIPSPARSE` matrix 3995*d52a580bSJunchao Zhang - a - pointer to the device data 3996*d52a580bSJunchao Zhang 3997*d52a580bSJunchao Zhang Level: developer 3998*d52a580bSJunchao Zhang 3999*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArray()` 4000*d52a580bSJunchao Zhang @*/ 4001*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreArray(Mat A, PetscScalar *a[]) 4002*d52a580bSJunchao Zhang { 4003*d52a580bSJunchao Zhang PetscFunctionBegin; 4004*d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 4005*d52a580bSJunchao Zhang PetscAssertPointer(a, 2); 4006*d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE); 4007*d52a580bSJunchao Zhang PetscCall(PetscObjectStateIncrease((PetscObject)A)); 4008*d52a580bSJunchao Zhang *a = NULL; 4009*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 4010*d52a580bSJunchao Zhang } 4011*d52a580bSJunchao Zhang 4012*d52a580bSJunchao Zhang /*@C 4013*d52a580bSJunchao Zhang MatSeqAIJHIPSPARSEGetArrayWrite - gives write access to the array where the device data for a `MATSEQAIJHIPSPARSE` matrix is stored 4014*d52a580bSJunchao Zhang 4015*d52a580bSJunchao Zhang Not Collective 4016*d52a580bSJunchao Zhang 4017*d52a580bSJunchao Zhang Input Parameter: 4018*d52a580bSJunchao Zhang . A - a `MATSEQAIJHIPSPARSE` matrix 4019*d52a580bSJunchao Zhang 4020*d52a580bSJunchao Zhang Output Parameter: 4021*d52a580bSJunchao Zhang . a - pointer to the device data 4022*d52a580bSJunchao Zhang 4023*d52a580bSJunchao Zhang Level: developer 4024*d52a580bSJunchao Zhang 4025*d52a580bSJunchao Zhang Note: 4026*d52a580bSJunchao Zhang Does not trigger host-device copies and flags data validity on the GPU 4027*d52a580bSJunchao Zhang 4028*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArray()`, `MatSeqAIJHIPSPARSEGetArrayRead()`, `MatSeqAIJHIPSPARSERestoreArrayWrite()` 4029*d52a580bSJunchao Zhang @*/ 4030*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEGetArrayWrite(Mat A, PetscScalar *a[]) 4031*d52a580bSJunchao Zhang { 4032*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *cusp = (Mat_SeqAIJHIPSPARSE *)A->spptr; 4033*d52a580bSJunchao Zhang CsrMatrix *csr; 4034*d52a580bSJunchao Zhang 4035*d52a580bSJunchao Zhang PetscFunctionBegin; 4036*d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 4037*d52a580bSJunchao Zhang PetscAssertPointer(a, 2); 4038*d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE); 4039*d52a580bSJunchao Zhang PetscCheck(cusp->format != MAT_HIPSPARSE_ELL && cusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented"); 4040*d52a580bSJunchao Zhang PetscCheck(cusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct"); 4041*d52a580bSJunchao Zhang csr = (CsrMatrix *)cusp->mat->mat; 4042*d52a580bSJunchao Zhang PetscCheck(csr->values, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing HIP memory"); 4043*d52a580bSJunchao Zhang *a = csr->values->data().get(); 4044*d52a580bSJunchao Zhang A->offloadmask = PETSC_OFFLOAD_GPU; 4045*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(A, PETSC_FALSE)); 4046*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 4047*d52a580bSJunchao Zhang } 4048*d52a580bSJunchao Zhang 4049*d52a580bSJunchao Zhang /*@C 4050*d52a580bSJunchao Zhang MatSeqAIJHIPSPARSERestoreArrayWrite - restore the write-only access array obtained from `MatSeqAIJHIPSPARSEGetArrayWrite()` 4051*d52a580bSJunchao Zhang 4052*d52a580bSJunchao Zhang Not Collective 4053*d52a580bSJunchao Zhang 4054*d52a580bSJunchao Zhang Input Parameters: 4055*d52a580bSJunchao Zhang + A - a `MATSEQAIJHIPSPARSE` matrix 4056*d52a580bSJunchao Zhang - a - pointer to the device data 4057*d52a580bSJunchao Zhang 4058*d52a580bSJunchao Zhang Level: developer 4059*d52a580bSJunchao Zhang 4060*d52a580bSJunchao Zhang .seealso: [](ch_matrices), `Mat`, `MatSeqAIJHIPSPARSEGetArrayWrite()` 4061*d52a580bSJunchao Zhang @*/ 4062*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSERestoreArrayWrite(Mat A, PetscScalar *a[]) 4063*d52a580bSJunchao Zhang { 4064*d52a580bSJunchao Zhang PetscFunctionBegin; 4065*d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 4066*d52a580bSJunchao Zhang PetscAssertPointer(a, 2); 4067*d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE); 4068*d52a580bSJunchao Zhang PetscCall(PetscObjectStateIncrease((PetscObject)A)); 4069*d52a580bSJunchao Zhang *a = NULL; 4070*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 4071*d52a580bSJunchao Zhang } 4072*d52a580bSJunchao Zhang 4073*d52a580bSJunchao Zhang struct IJCompare4 { 4074*d52a580bSJunchao Zhang __host__ __device__ inline bool operator()(const thrust::tuple<int, int, PetscScalar, int> &t1, const thrust::tuple<int, int, PetscScalar, int> &t2) 4075*d52a580bSJunchao Zhang { 4076*d52a580bSJunchao Zhang if (t1.get<0>() < t2.get<0>()) return true; 4077*d52a580bSJunchao Zhang if (t1.get<0>() == t2.get<0>()) return t1.get<1>() < t2.get<1>(); 4078*d52a580bSJunchao Zhang return false; 4079*d52a580bSJunchao Zhang } 4080*d52a580bSJunchao Zhang }; 4081*d52a580bSJunchao Zhang 4082*d52a580bSJunchao Zhang struct Shift { 4083*d52a580bSJunchao Zhang int _shift; 4084*d52a580bSJunchao Zhang 4085*d52a580bSJunchao Zhang Shift(int shift) : _shift(shift) { } 4086*d52a580bSJunchao Zhang __host__ __device__ inline int operator()(const int &c) { return c + _shift; } 4087*d52a580bSJunchao Zhang }; 4088*d52a580bSJunchao Zhang 4089*d52a580bSJunchao Zhang /* merges two SeqAIJHIPSPARSE matrices A, B by concatenating their rows. [A';B']' operation in MATLAB notation */ 4090*d52a580bSJunchao Zhang PetscErrorCode MatSeqAIJHIPSPARSEMergeMats(Mat A, Mat B, MatReuse reuse, Mat *C) 4091*d52a580bSJunchao Zhang { 4092*d52a580bSJunchao Zhang Mat_SeqAIJ *a = (Mat_SeqAIJ *)A->data, *b = (Mat_SeqAIJ *)B->data, *c; 4093*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSE *Acusp = (Mat_SeqAIJHIPSPARSE *)A->spptr, *Bcusp = (Mat_SeqAIJHIPSPARSE *)B->spptr, *Ccusp; 4094*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *Cmat; 4095*d52a580bSJunchao Zhang CsrMatrix *Acsr, *Bcsr, *Ccsr; 4096*d52a580bSJunchao Zhang PetscInt Annz, Bnnz; 4097*d52a580bSJunchao Zhang PetscInt i, m, n, zero = 0; 4098*d52a580bSJunchao Zhang 4099*d52a580bSJunchao Zhang PetscFunctionBegin; 4100*d52a580bSJunchao Zhang PetscValidHeaderSpecific(A, MAT_CLASSID, 1); 4101*d52a580bSJunchao Zhang PetscValidHeaderSpecific(B, MAT_CLASSID, 2); 4102*d52a580bSJunchao Zhang PetscAssertPointer(C, 4); 4103*d52a580bSJunchao Zhang PetscCheckTypeName(A, MATSEQAIJHIPSPARSE); 4104*d52a580bSJunchao Zhang PetscCheckTypeName(B, MATSEQAIJHIPSPARSE); 4105*d52a580bSJunchao Zhang PetscCheck(A->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, A->rmap->n, B->rmap->n); 4106*d52a580bSJunchao Zhang PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported"); 4107*d52a580bSJunchao Zhang PetscCheck(Acusp->format != MAT_HIPSPARSE_ELL && Acusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented"); 4108*d52a580bSJunchao Zhang PetscCheck(Bcusp->format != MAT_HIPSPARSE_ELL && Bcusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented"); 4109*d52a580bSJunchao Zhang if (reuse == MAT_INITIAL_MATRIX) { 4110*d52a580bSJunchao Zhang m = A->rmap->n; 4111*d52a580bSJunchao Zhang n = A->cmap->n + B->cmap->n; 4112*d52a580bSJunchao Zhang PetscCall(MatCreate(PETSC_COMM_SELF, C)); 4113*d52a580bSJunchao Zhang PetscCall(MatSetSizes(*C, m, n, m, n)); 4114*d52a580bSJunchao Zhang PetscCall(MatSetType(*C, MATSEQAIJHIPSPARSE)); 4115*d52a580bSJunchao Zhang c = (Mat_SeqAIJ *)(*C)->data; 4116*d52a580bSJunchao Zhang Ccusp = (Mat_SeqAIJHIPSPARSE *)(*C)->spptr; 4117*d52a580bSJunchao Zhang Cmat = new Mat_SeqAIJHIPSPARSEMultStruct; 4118*d52a580bSJunchao Zhang Ccsr = new CsrMatrix; 4119*d52a580bSJunchao Zhang Cmat->cprowIndices = NULL; 4120*d52a580bSJunchao Zhang c->compressedrow.use = PETSC_FALSE; 4121*d52a580bSJunchao Zhang c->compressedrow.nrows = 0; 4122*d52a580bSJunchao Zhang c->compressedrow.i = NULL; 4123*d52a580bSJunchao Zhang c->compressedrow.rindex = NULL; 4124*d52a580bSJunchao Zhang Ccusp->workVector = NULL; 4125*d52a580bSJunchao Zhang Ccusp->nrows = m; 4126*d52a580bSJunchao Zhang Ccusp->mat = Cmat; 4127*d52a580bSJunchao Zhang Ccusp->mat->mat = Ccsr; 4128*d52a580bSJunchao Zhang Ccsr->num_rows = m; 4129*d52a580bSJunchao Zhang Ccsr->num_cols = n; 4130*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&Cmat->descr)); 4131*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(Cmat->descr, HIPSPARSE_INDEX_BASE_ZERO)); 4132*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(Cmat->descr, HIPSPARSE_MATRIX_TYPE_GENERAL)); 4133*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&Cmat->alpha_one, sizeof(PetscScalar))); 4134*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&Cmat->beta_zero, sizeof(PetscScalar))); 4135*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&Cmat->beta_one, sizeof(PetscScalar))); 4136*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(Cmat->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice)); 4137*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(Cmat->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice)); 4138*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(Cmat->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice)); 4139*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A)); 4140*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B)); 4141*d52a580bSJunchao Zhang PetscCheck(Acusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct"); 4142*d52a580bSJunchao Zhang PetscCheck(Bcusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct"); 4143*d52a580bSJunchao Zhang 4144*d52a580bSJunchao Zhang Acsr = (CsrMatrix *)Acusp->mat->mat; 4145*d52a580bSJunchao Zhang Bcsr = (CsrMatrix *)Bcusp->mat->mat; 4146*d52a580bSJunchao Zhang Annz = (PetscInt)Acsr->column_indices->size(); 4147*d52a580bSJunchao Zhang Bnnz = (PetscInt)Bcsr->column_indices->size(); 4148*d52a580bSJunchao Zhang c->nz = Annz + Bnnz; 4149*d52a580bSJunchao Zhang Ccsr->row_offsets = new THRUSTINTARRAY32(m + 1); 4150*d52a580bSJunchao Zhang Ccsr->column_indices = new THRUSTINTARRAY32(c->nz); 4151*d52a580bSJunchao Zhang Ccsr->values = new THRUSTARRAY(c->nz); 4152*d52a580bSJunchao Zhang Ccsr->num_entries = c->nz; 4153*d52a580bSJunchao Zhang Ccusp->coords = new THRUSTINTARRAY(c->nz); 4154*d52a580bSJunchao Zhang if (c->nz) { 4155*d52a580bSJunchao Zhang auto Acoo = new THRUSTINTARRAY32(Annz); 4156*d52a580bSJunchao Zhang auto Bcoo = new THRUSTINTARRAY32(Bnnz); 4157*d52a580bSJunchao Zhang auto Ccoo = new THRUSTINTARRAY32(c->nz); 4158*d52a580bSJunchao Zhang THRUSTINTARRAY32 *Aroff, *Broff; 4159*d52a580bSJunchao Zhang 4160*d52a580bSJunchao Zhang if (a->compressedrow.use) { /* need full row offset */ 4161*d52a580bSJunchao Zhang if (!Acusp->rowoffsets_gpu) { 4162*d52a580bSJunchao Zhang Acusp->rowoffsets_gpu = new THRUSTINTARRAY32(A->rmap->n + 1); 4163*d52a580bSJunchao Zhang Acusp->rowoffsets_gpu->assign(a->i, a->i + A->rmap->n + 1); 4164*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu((A->rmap->n + 1) * sizeof(PetscInt))); 4165*d52a580bSJunchao Zhang } 4166*d52a580bSJunchao Zhang Aroff = Acusp->rowoffsets_gpu; 4167*d52a580bSJunchao Zhang } else Aroff = Acsr->row_offsets; 4168*d52a580bSJunchao Zhang if (b->compressedrow.use) { /* need full row offset */ 4169*d52a580bSJunchao Zhang if (!Bcusp->rowoffsets_gpu) { 4170*d52a580bSJunchao Zhang Bcusp->rowoffsets_gpu = new THRUSTINTARRAY32(B->rmap->n + 1); 4171*d52a580bSJunchao Zhang Bcusp->rowoffsets_gpu->assign(b->i, b->i + B->rmap->n + 1); 4172*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu((B->rmap->n + 1) * sizeof(PetscInt))); 4173*d52a580bSJunchao Zhang } 4174*d52a580bSJunchao Zhang Broff = Bcusp->rowoffsets_gpu; 4175*d52a580bSJunchao Zhang } else Broff = Bcsr->row_offsets; 4176*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 4177*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsr2coo(Acusp->handle, Aroff->data().get(), Annz, m, Acoo->data().get(), HIPSPARSE_INDEX_BASE_ZERO)); 4178*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcsr2coo(Bcusp->handle, Broff->data().get(), Bnnz, m, Bcoo->data().get(), HIPSPARSE_INDEX_BASE_ZERO)); 4179*d52a580bSJunchao Zhang /* Issues when using bool with large matrices on SUMMIT 10.2.89 */ 4180*d52a580bSJunchao Zhang auto Aperm = thrust::make_constant_iterator(1); 4181*d52a580bSJunchao Zhang auto Bperm = thrust::make_constant_iterator(0); 4182*d52a580bSJunchao Zhang auto Bcib = thrust::make_transform_iterator(Bcsr->column_indices->begin(), Shift(A->cmap->n)); 4183*d52a580bSJunchao Zhang auto Bcie = thrust::make_transform_iterator(Bcsr->column_indices->end(), Shift(A->cmap->n)); 4184*d52a580bSJunchao Zhang auto wPerm = new THRUSTINTARRAY32(Annz + Bnnz); 4185*d52a580bSJunchao Zhang auto Azb = thrust::make_zip_iterator(thrust::make_tuple(Acoo->begin(), Acsr->column_indices->begin(), Acsr->values->begin(), Aperm)); 4186*d52a580bSJunchao Zhang auto Aze = thrust::make_zip_iterator(thrust::make_tuple(Acoo->end(), Acsr->column_indices->end(), Acsr->values->end(), Aperm)); 4187*d52a580bSJunchao Zhang auto Bzb = thrust::make_zip_iterator(thrust::make_tuple(Bcoo->begin(), Bcib, Bcsr->values->begin(), Bperm)); 4188*d52a580bSJunchao Zhang auto Bze = thrust::make_zip_iterator(thrust::make_tuple(Bcoo->end(), Bcie, Bcsr->values->end(), Bperm)); 4189*d52a580bSJunchao Zhang auto Czb = thrust::make_zip_iterator(thrust::make_tuple(Ccoo->begin(), Ccsr->column_indices->begin(), Ccsr->values->begin(), wPerm->begin())); 4190*d52a580bSJunchao Zhang auto p1 = Ccusp->coords->begin(); 4191*d52a580bSJunchao Zhang auto p2 = Ccusp->coords->begin(); 4192*d52a580bSJunchao Zhang thrust::advance(p2, Annz); 4193*d52a580bSJunchao Zhang PetscCallThrust(thrust::merge(thrust::device, Azb, Aze, Bzb, Bze, Czb, IJCompare4())); 4194*d52a580bSJunchao Zhang auto cci = thrust::make_counting_iterator(zero); 4195*d52a580bSJunchao Zhang auto cce = thrust::make_counting_iterator(c->nz); 4196*d52a580bSJunchao Zhang #if 0 //Errors on SUMMIT cuda 11.1.0 4197*d52a580bSJunchao Zhang PetscCallThrust(thrust::partition_copy(thrust::device, cci, cce, wPerm->begin(), p1, p2, thrust::identity<int>())); 4198*d52a580bSJunchao Zhang #else 4199*d52a580bSJunchao Zhang auto pred = thrust::identity<int>(); 4200*d52a580bSJunchao Zhang PetscCallThrust(thrust::copy_if(thrust::device, cci, cce, wPerm->begin(), p1, pred)); 4201*d52a580bSJunchao Zhang PetscCallThrust(thrust::remove_copy_if(thrust::device, cci, cce, wPerm->begin(), p2, pred)); 4202*d52a580bSJunchao Zhang #endif 4203*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseXcoo2csr(Ccusp->handle, Ccoo->data().get(), c->nz, m, Ccsr->row_offsets->data().get(), HIPSPARSE_INDEX_BASE_ZERO)); 4204*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 4205*d52a580bSJunchao Zhang delete wPerm; 4206*d52a580bSJunchao Zhang delete Acoo; 4207*d52a580bSJunchao Zhang delete Bcoo; 4208*d52a580bSJunchao Zhang delete Ccoo; 4209*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&Cmat->matDescr, Ccsr->num_rows, Ccsr->num_cols, Ccsr->num_entries, Ccsr->row_offsets->data().get(), Ccsr->column_indices->data().get(), Ccsr->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype)); 4210*d52a580bSJunchao Zhang 4211*d52a580bSJunchao Zhang if (A->form_explicit_transpose && B->form_explicit_transpose) { /* if A and B have the transpose, generate C transpose too */ 4212*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(A)); 4213*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEFormExplicitTranspose(B)); 4214*d52a580bSJunchao Zhang PetscBool AT = Acusp->matTranspose ? PETSC_TRUE : PETSC_FALSE, BT = Bcusp->matTranspose ? PETSC_TRUE : PETSC_FALSE; 4215*d52a580bSJunchao Zhang Mat_SeqAIJHIPSPARSEMultStruct *CmatT = new Mat_SeqAIJHIPSPARSEMultStruct; 4216*d52a580bSJunchao Zhang CsrMatrix *CcsrT = new CsrMatrix; 4217*d52a580bSJunchao Zhang CsrMatrix *AcsrT = AT ? (CsrMatrix *)Acusp->matTranspose->mat : NULL; 4218*d52a580bSJunchao Zhang CsrMatrix *BcsrT = BT ? (CsrMatrix *)Bcusp->matTranspose->mat : NULL; 4219*d52a580bSJunchao Zhang 4220*d52a580bSJunchao Zhang (*C)->form_explicit_transpose = PETSC_TRUE; 4221*d52a580bSJunchao Zhang (*C)->transupdated = PETSC_TRUE; 4222*d52a580bSJunchao Zhang Ccusp->rowoffsets_gpu = NULL; 4223*d52a580bSJunchao Zhang CmatT->cprowIndices = NULL; 4224*d52a580bSJunchao Zhang CmatT->mat = CcsrT; 4225*d52a580bSJunchao Zhang CcsrT->num_rows = n; 4226*d52a580bSJunchao Zhang CcsrT->num_cols = m; 4227*d52a580bSJunchao Zhang CcsrT->num_entries = c->nz; 4228*d52a580bSJunchao Zhang CcsrT->row_offsets = new THRUSTINTARRAY32(n + 1); 4229*d52a580bSJunchao Zhang CcsrT->column_indices = new THRUSTINTARRAY32(c->nz); 4230*d52a580bSJunchao Zhang CcsrT->values = new THRUSTARRAY(c->nz); 4231*d52a580bSJunchao Zhang 4232*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 4233*d52a580bSJunchao Zhang auto rT = CcsrT->row_offsets->begin(); 4234*d52a580bSJunchao Zhang if (AT) { 4235*d52a580bSJunchao Zhang rT = thrust::copy(AcsrT->row_offsets->begin(), AcsrT->row_offsets->end(), rT); 4236*d52a580bSJunchao Zhang thrust::advance(rT, -1); 4237*d52a580bSJunchao Zhang } 4238*d52a580bSJunchao Zhang if (BT) { 4239*d52a580bSJunchao Zhang auto titb = thrust::make_transform_iterator(BcsrT->row_offsets->begin(), Shift(a->nz)); 4240*d52a580bSJunchao Zhang auto tite = thrust::make_transform_iterator(BcsrT->row_offsets->end(), Shift(a->nz)); 4241*d52a580bSJunchao Zhang thrust::copy(titb, tite, rT); 4242*d52a580bSJunchao Zhang } 4243*d52a580bSJunchao Zhang auto cT = CcsrT->column_indices->begin(); 4244*d52a580bSJunchao Zhang if (AT) cT = thrust::copy(AcsrT->column_indices->begin(), AcsrT->column_indices->end(), cT); 4245*d52a580bSJunchao Zhang if (BT) thrust::copy(BcsrT->column_indices->begin(), BcsrT->column_indices->end(), cT); 4246*d52a580bSJunchao Zhang auto vT = CcsrT->values->begin(); 4247*d52a580bSJunchao Zhang if (AT) vT = thrust::copy(AcsrT->values->begin(), AcsrT->values->end(), vT); 4248*d52a580bSJunchao Zhang if (BT) thrust::copy(BcsrT->values->begin(), BcsrT->values->end(), vT); 4249*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 4250*d52a580bSJunchao Zhang 4251*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateMatDescr(&CmatT->descr)); 4252*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatIndexBase(CmatT->descr, HIPSPARSE_INDEX_BASE_ZERO)); 4253*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseSetMatType(CmatT->descr, HIPSPARSE_MATRIX_TYPE_GENERAL)); 4254*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&CmatT->alpha_one, sizeof(PetscScalar))); 4255*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&CmatT->beta_zero, sizeof(PetscScalar))); 4256*d52a580bSJunchao Zhang PetscCallHIP(hipMalloc((void **)&CmatT->beta_one, sizeof(PetscScalar))); 4257*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(CmatT->alpha_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice)); 4258*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(CmatT->beta_zero, &PETSC_HIPSPARSE_ZERO, sizeof(PetscScalar), hipMemcpyHostToDevice)); 4259*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(CmatT->beta_one, &PETSC_HIPSPARSE_ONE, sizeof(PetscScalar), hipMemcpyHostToDevice)); 4260*d52a580bSJunchao Zhang 4261*d52a580bSJunchao Zhang PetscCallHIPSPARSE(hipsparseCreateCsr(&CmatT->matDescr, CcsrT->num_rows, CcsrT->num_cols, CcsrT->num_entries, CcsrT->row_offsets->data().get(), CcsrT->column_indices->data().get(), CcsrT->values->data().get(), HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, hipsparse_scalartype)); 4262*d52a580bSJunchao Zhang Ccusp->matTranspose = CmatT; 4263*d52a580bSJunchao Zhang } 4264*d52a580bSJunchao Zhang } 4265*d52a580bSJunchao Zhang 4266*d52a580bSJunchao Zhang c->free_a = PETSC_TRUE; 4267*d52a580bSJunchao Zhang PetscCall(PetscShmgetAllocateArray(c->nz, sizeof(PetscInt), (void **)&c->j)); 4268*d52a580bSJunchao Zhang PetscCall(PetscShmgetAllocateArray(m + 1, sizeof(PetscInt), (void **)&c->i)); 4269*d52a580bSJunchao Zhang c->free_ij = PETSC_TRUE; 4270*d52a580bSJunchao Zhang if (PetscDefined(USE_64BIT_INDICES)) { /* 32 to 64-bit conversion on the GPU and then copy to host (lazy) */ 4271*d52a580bSJunchao Zhang THRUSTINTARRAY ii(Ccsr->row_offsets->size()); 4272*d52a580bSJunchao Zhang THRUSTINTARRAY jj(Ccsr->column_indices->size()); 4273*d52a580bSJunchao Zhang ii = *Ccsr->row_offsets; 4274*d52a580bSJunchao Zhang jj = *Ccsr->column_indices; 4275*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(c->i, ii.data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost)); 4276*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(c->j, jj.data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost)); 4277*d52a580bSJunchao Zhang } else { 4278*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(c->i, Ccsr->row_offsets->data().get(), Ccsr->row_offsets->size() * sizeof(PetscInt), hipMemcpyDeviceToHost)); 4279*d52a580bSJunchao Zhang PetscCallHIP(hipMemcpy(c->j, Ccsr->column_indices->data().get(), Ccsr->column_indices->size() * sizeof(PetscInt), hipMemcpyDeviceToHost)); 4280*d52a580bSJunchao Zhang } 4281*d52a580bSJunchao Zhang PetscCall(PetscLogGpuToCpu((Ccsr->column_indices->size() + Ccsr->row_offsets->size()) * sizeof(PetscInt))); 4282*d52a580bSJunchao Zhang PetscCall(PetscMalloc1(m, &c->ilen)); 4283*d52a580bSJunchao Zhang PetscCall(PetscMalloc1(m, &c->imax)); 4284*d52a580bSJunchao Zhang c->maxnz = c->nz; 4285*d52a580bSJunchao Zhang c->nonzerorowcnt = 0; 4286*d52a580bSJunchao Zhang c->rmax = 0; 4287*d52a580bSJunchao Zhang for (i = 0; i < m; i++) { 4288*d52a580bSJunchao Zhang const PetscInt nn = c->i[i + 1] - c->i[i]; 4289*d52a580bSJunchao Zhang c->ilen[i] = c->imax[i] = nn; 4290*d52a580bSJunchao Zhang c->nonzerorowcnt += (PetscInt)!!nn; 4291*d52a580bSJunchao Zhang c->rmax = PetscMax(c->rmax, nn); 4292*d52a580bSJunchao Zhang } 4293*d52a580bSJunchao Zhang PetscCall(PetscMalloc1(c->nz, &c->a)); 4294*d52a580bSJunchao Zhang (*C)->nonzerostate++; 4295*d52a580bSJunchao Zhang PetscCall(PetscLayoutSetUp((*C)->rmap)); 4296*d52a580bSJunchao Zhang PetscCall(PetscLayoutSetUp((*C)->cmap)); 4297*d52a580bSJunchao Zhang Ccusp->nonzerostate = (*C)->nonzerostate; 4298*d52a580bSJunchao Zhang (*C)->preallocated = PETSC_TRUE; 4299*d52a580bSJunchao Zhang } else { 4300*d52a580bSJunchao Zhang PetscCheck((*C)->rmap->n == B->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Invalid number or rows %" PetscInt_FMT " != %" PetscInt_FMT, (*C)->rmap->n, B->rmap->n); 4301*d52a580bSJunchao Zhang c = (Mat_SeqAIJ *)(*C)->data; 4302*d52a580bSJunchao Zhang if (c->nz) { 4303*d52a580bSJunchao Zhang Ccusp = (Mat_SeqAIJHIPSPARSE *)(*C)->spptr; 4304*d52a580bSJunchao Zhang PetscCheck(Ccusp->coords, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing coords"); 4305*d52a580bSJunchao Zhang PetscCheck(Ccusp->format != MAT_HIPSPARSE_ELL && Ccusp->format != MAT_HIPSPARSE_HYB, PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented"); 4306*d52a580bSJunchao Zhang PetscCheck(Ccusp->nonzerostate == (*C)->nonzerostate, PETSC_COMM_SELF, PETSC_ERR_COR, "Wrong nonzerostate"); 4307*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(A)); 4308*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSECopyToGPU(B)); 4309*d52a580bSJunchao Zhang PetscCheck(Acusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct"); 4310*d52a580bSJunchao Zhang PetscCheck(Bcusp->mat, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing Mat_SeqAIJHIPSPARSEMultStruct"); 4311*d52a580bSJunchao Zhang Acsr = (CsrMatrix *)Acusp->mat->mat; 4312*d52a580bSJunchao Zhang Bcsr = (CsrMatrix *)Bcusp->mat->mat; 4313*d52a580bSJunchao Zhang Ccsr = (CsrMatrix *)Ccusp->mat->mat; 4314*d52a580bSJunchao Zhang PetscCheck(Acsr->num_entries == (PetscInt)Acsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "A nnz %" PetscInt_FMT " != %" PetscInt_FMT, Acsr->num_entries, (PetscInt)Acsr->values->size()); 4315*d52a580bSJunchao Zhang PetscCheck(Bcsr->num_entries == (PetscInt)Bcsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "B nnz %" PetscInt_FMT " != %" PetscInt_FMT, Bcsr->num_entries, (PetscInt)Bcsr->values->size()); 4316*d52a580bSJunchao Zhang PetscCheck(Ccsr->num_entries == (PetscInt)Ccsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "C nnz %" PetscInt_FMT " != %" PetscInt_FMT, Ccsr->num_entries, (PetscInt)Ccsr->values->size()); 4317*d52a580bSJunchao Zhang PetscCheck(Ccsr->num_entries == Acsr->num_entries + Bcsr->num_entries, PETSC_COMM_SELF, PETSC_ERR_COR, "C nnz %" PetscInt_FMT " != %" PetscInt_FMT " + %" PetscInt_FMT, Ccsr->num_entries, Acsr->num_entries, Bcsr->num_entries); 4318*d52a580bSJunchao Zhang PetscCheck(Ccusp->coords->size() == Ccsr->values->size(), PETSC_COMM_SELF, PETSC_ERR_COR, "permSize %" PetscInt_FMT " != %" PetscInt_FMT, (PetscInt)Ccusp->coords->size(), (PetscInt)Ccsr->values->size()); 4319*d52a580bSJunchao Zhang auto pmid = Ccusp->coords->begin(); 4320*d52a580bSJunchao Zhang thrust::advance(pmid, Acsr->num_entries); 4321*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeBegin()); 4322*d52a580bSJunchao Zhang auto zibait = thrust::make_zip_iterator(thrust::make_tuple(Acsr->values->begin(), thrust::make_permutation_iterator(Ccsr->values->begin(), Ccusp->coords->begin()))); 4323*d52a580bSJunchao Zhang auto zieait = thrust::make_zip_iterator(thrust::make_tuple(Acsr->values->end(), thrust::make_permutation_iterator(Ccsr->values->begin(), pmid))); 4324*d52a580bSJunchao Zhang thrust::for_each(zibait, zieait, VecHIPEquals()); 4325*d52a580bSJunchao Zhang auto zibbit = thrust::make_zip_iterator(thrust::make_tuple(Bcsr->values->begin(), thrust::make_permutation_iterator(Ccsr->values->begin(), pmid))); 4326*d52a580bSJunchao Zhang auto ziebit = thrust::make_zip_iterator(thrust::make_tuple(Bcsr->values->end(), thrust::make_permutation_iterator(Ccsr->values->begin(), Ccusp->coords->end()))); 4327*d52a580bSJunchao Zhang thrust::for_each(zibbit, ziebit, VecHIPEquals()); 4328*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEInvalidateTranspose(*C, PETSC_FALSE)); 4329*d52a580bSJunchao Zhang if (A->form_explicit_transpose && B->form_explicit_transpose && (*C)->form_explicit_transpose) { 4330*d52a580bSJunchao Zhang PetscCheck(Ccusp->matTranspose, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing transpose Mat_SeqAIJHIPSPARSEMultStruct"); 4331*d52a580bSJunchao Zhang PetscBool AT = Acusp->matTranspose ? PETSC_TRUE : PETSC_FALSE, BT = Bcusp->matTranspose ? PETSC_TRUE : PETSC_FALSE; 4332*d52a580bSJunchao Zhang CsrMatrix *AcsrT = AT ? (CsrMatrix *)Acusp->matTranspose->mat : NULL; 4333*d52a580bSJunchao Zhang CsrMatrix *BcsrT = BT ? (CsrMatrix *)Bcusp->matTranspose->mat : NULL; 4334*d52a580bSJunchao Zhang CsrMatrix *CcsrT = (CsrMatrix *)Ccusp->matTranspose->mat; 4335*d52a580bSJunchao Zhang auto vT = CcsrT->values->begin(); 4336*d52a580bSJunchao Zhang if (AT) vT = thrust::copy(AcsrT->values->begin(), AcsrT->values->end(), vT); 4337*d52a580bSJunchao Zhang if (BT) thrust::copy(BcsrT->values->begin(), BcsrT->values->end(), vT); 4338*d52a580bSJunchao Zhang (*C)->transupdated = PETSC_TRUE; 4339*d52a580bSJunchao Zhang } 4340*d52a580bSJunchao Zhang PetscCall(PetscLogGpuTimeEnd()); 4341*d52a580bSJunchao Zhang } 4342*d52a580bSJunchao Zhang } 4343*d52a580bSJunchao Zhang PetscCall(PetscObjectStateIncrease((PetscObject)*C)); 4344*d52a580bSJunchao Zhang (*C)->assembled = PETSC_TRUE; 4345*d52a580bSJunchao Zhang (*C)->was_assembled = PETSC_FALSE; 4346*d52a580bSJunchao Zhang (*C)->offloadmask = PETSC_OFFLOAD_GPU; 4347*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 4348*d52a580bSJunchao Zhang } 4349*d52a580bSJunchao Zhang 4350*d52a580bSJunchao Zhang static PetscErrorCode MatSeqAIJCopySubArray_SeqAIJHIPSPARSE(Mat A, PetscInt n, const PetscInt idx[], PetscScalar v[]) 4351*d52a580bSJunchao Zhang { 4352*d52a580bSJunchao Zhang bool dmem; 4353*d52a580bSJunchao Zhang const PetscScalar *av; 4354*d52a580bSJunchao Zhang 4355*d52a580bSJunchao Zhang PetscFunctionBegin; 4356*d52a580bSJunchao Zhang dmem = isHipMem(v); 4357*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSEGetArrayRead(A, &av)); 4358*d52a580bSJunchao Zhang if (n && idx) { 4359*d52a580bSJunchao Zhang THRUSTINTARRAY widx(n); 4360*d52a580bSJunchao Zhang widx.assign(idx, idx + n); 4361*d52a580bSJunchao Zhang PetscCall(PetscLogCpuToGpu(n * sizeof(PetscInt))); 4362*d52a580bSJunchao Zhang 4363*d52a580bSJunchao Zhang THRUSTARRAY *w = NULL; 4364*d52a580bSJunchao Zhang thrust::device_ptr<PetscScalar> dv; 4365*d52a580bSJunchao Zhang if (dmem) dv = thrust::device_pointer_cast(v); 4366*d52a580bSJunchao Zhang else { 4367*d52a580bSJunchao Zhang w = new THRUSTARRAY(n); 4368*d52a580bSJunchao Zhang dv = w->data(); 4369*d52a580bSJunchao Zhang } 4370*d52a580bSJunchao Zhang thrust::device_ptr<const PetscScalar> dav = thrust::device_pointer_cast(av); 4371*d52a580bSJunchao Zhang 4372*d52a580bSJunchao Zhang auto zibit = thrust::make_zip_iterator(thrust::make_tuple(thrust::make_permutation_iterator(dav, widx.begin()), dv)); 4373*d52a580bSJunchao Zhang auto zieit = thrust::make_zip_iterator(thrust::make_tuple(thrust::make_permutation_iterator(dav, widx.end()), dv + n)); 4374*d52a580bSJunchao Zhang thrust::for_each(zibit, zieit, VecHIPEquals()); 4375*d52a580bSJunchao Zhang if (w) PetscCallHIP(hipMemcpy(v, w->data().get(), n * sizeof(PetscScalar), hipMemcpyDeviceToHost)); 4376*d52a580bSJunchao Zhang delete w; 4377*d52a580bSJunchao Zhang } else PetscCallHIP(hipMemcpy(v, av, n * sizeof(PetscScalar), dmem ? hipMemcpyDeviceToDevice : hipMemcpyDeviceToHost)); 4378*d52a580bSJunchao Zhang 4379*d52a580bSJunchao Zhang if (!dmem) PetscCall(PetscLogCpuToGpu(n * sizeof(PetscScalar))); 4380*d52a580bSJunchao Zhang PetscCall(MatSeqAIJHIPSPARSERestoreArrayRead(A, &av)); 4381*d52a580bSJunchao Zhang PetscFunctionReturn(PETSC_SUCCESS); 4382*d52a580bSJunchao Zhang } 4383