1*47d993e7Ssuyashtn /* Portions of this code are under: 2*47d993e7Ssuyashtn Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. 3*47d993e7Ssuyashtn */ 4*47d993e7Ssuyashtn #if !defined(HIPSPARSEMATIMPL) 5*47d993e7Ssuyashtn #define HIPSPARSEMATIMPL 6*47d993e7Ssuyashtn 7*47d993e7Ssuyashtn #include <petscpkg_version.h> 8*47d993e7Ssuyashtn #include <petsc/private/hipvecimpl.h> 9*47d993e7Ssuyashtn #include <petscaijdevice.h> 10*47d993e7Ssuyashtn 11*47d993e7Ssuyashtn #if PETSC_PKG_HIP_VERSION_GE(5, 2, 0) 12*47d993e7Ssuyashtn #include <hipsparse/hipsparse.h> 13*47d993e7Ssuyashtn #else /* PETSC_PKG_HIP_VERSION_GE(5,2,0) */ 14*47d993e7Ssuyashtn #include <hipsparse.h> 15*47d993e7Ssuyashtn #endif /* PETSC_PKG_HIP_VERSION_GE(5,2,0) */ 16*47d993e7Ssuyashtn #include "hip/hip_runtime.h" 17*47d993e7Ssuyashtn 18*47d993e7Ssuyashtn #include <algorithm> 19*47d993e7Ssuyashtn #include <vector> 20*47d993e7Ssuyashtn 21*47d993e7Ssuyashtn #include <thrust/device_vector.h> 22*47d993e7Ssuyashtn #include <thrust/device_ptr.h> 23*47d993e7Ssuyashtn #include <thrust/device_malloc_allocator.h> 24*47d993e7Ssuyashtn #include <thrust/transform.h> 25*47d993e7Ssuyashtn #include <thrust/functional.h> 26*47d993e7Ssuyashtn #include <thrust/sequence.h> 27*47d993e7Ssuyashtn #include <thrust/system/system_error.h> 28*47d993e7Ssuyashtn 29*47d993e7Ssuyashtn #define PetscCallThrust(body) \ 30*47d993e7Ssuyashtn do { \ 31*47d993e7Ssuyashtn try { \ 32*47d993e7Ssuyashtn body; \ 33*47d993e7Ssuyashtn } catch (thrust::system_error & e) { \ 34*47d993e7Ssuyashtn SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Error in Thrust %s", e.what()); \ 35*47d993e7Ssuyashtn } \ 36*47d993e7Ssuyashtn } while (0) 37*47d993e7Ssuyashtn 38*47d993e7Ssuyashtn #if defined(PETSC_USE_COMPLEX) 39*47d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE) 40*47d993e7Ssuyashtn const hipComplex PETSC_HIPSPARSE_ONE = {1.0f, 0.0f}; 41*47d993e7Ssuyashtn const hipComplex PETSC_HIPSPARSE_ZERO = {0.0f, 0.0f}; 42*47d993e7Ssuyashtn #define hipsparseXcsrilu02_bufferSize(a, b, c, d, e, f, g, h, i) hipsparseCcsrilu02_bufferSize(a, b, c, d, (hipComplex *)e, f, g, h, i) 43*47d993e7Ssuyashtn #define hipsparseXcsrilu02_analysis(a, b, c, d, e, f, g, h, i, j) hipsparseCcsrilu02_analysis(a, b, c, d, (hipComplex *)e, f, g, h, i, j) 44*47d993e7Ssuyashtn #define hipsparseXcsrilu02(a, b, c, d, e, f, g, h, i, j) hipsparseCcsrilu02(a, b, c, d, (hipComplex *)e, f, g, h, i, j) 45*47d993e7Ssuyashtn #define hipsparseXcsric02_bufferSize(a, b, c, d, e, f, g, h, i) hipsparseCcsric02_bufferSize(a, b, c, d, (hipComplex *)e, f, g, h, i) 46*47d993e7Ssuyashtn #define hipsparseXcsric02_analysis(a, b, c, d, e, f, g, h, i, j) hipsparseCcsric02_analysis(a, b, c, d, (hipComplex *)e, f, g, h, i, j) 47*47d993e7Ssuyashtn #define hipsparseXcsric02(a, b, c, d, e, f, g, h, i, j) hipsparseCcsric02(a, b, c, d, (hipComplex *)e, f, g, h, i, j) 48*47d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE) 49*47d993e7Ssuyashtn const hipDoubleComplex PETSC_HIPSPARSE_ONE = {1.0, 0.0}; 50*47d993e7Ssuyashtn const hipDoubleComplex PETSC_HIPSPARSE_ZERO = {0.0, 0.0}; 51*47d993e7Ssuyashtn #define hipsparseXcsrilu02_bufferSize(a, b, c, d, e, f, g, h, i) hipsparseZcsrilu02_bufferSize(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i) 52*47d993e7Ssuyashtn #define hipsparseXcsrilu02_analysis(a, b, c, d, e, f, g, h, i, j) hipsparseZcsrilu02_analysis(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i, j) 53*47d993e7Ssuyashtn #define hipsparseXcsrilu02(a, b, c, d, e, f, g, h, i, j) hipsparseZcsrilu02(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i, j) 54*47d993e7Ssuyashtn #define hipsparseXcsric02_bufferSize(a, b, c, d, e, f, g, h, i) hipsparseZcsric02_bufferSize(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i) 55*47d993e7Ssuyashtn #define hipsparseXcsric02_analysis(a, b, c, d, e, f, g, h, i, j) hipsparseZcsric02_analysis(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i, j) 56*47d993e7Ssuyashtn #define hipsparseXcsric02(a, b, c, d, e, f, g, h, i, j) hipsparseZcsric02(a, b, c, d, (hipDoubleComplex *)e, f, g, h, i, j) 57*47d993e7Ssuyashtn #endif /* Single or double */ 58*47d993e7Ssuyashtn #else /* not complex */ 59*47d993e7Ssuyashtn const PetscScalar PETSC_HIPSPARSE_ONE = 1.0; 60*47d993e7Ssuyashtn const PetscScalar PETSC_HIPSPARSE_ZERO = 0.0; 61*47d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE) 62*47d993e7Ssuyashtn #define hipsparseXcsrilu02_bufferSize hipsparseScsrilu02_bufferSize 63*47d993e7Ssuyashtn #define hipsparseXcsrilu02_analysis hipsparseScsrilu02_analysis 64*47d993e7Ssuyashtn #define hipsparseXcsrilu02 hipsparseScsrilu02 65*47d993e7Ssuyashtn #define hipsparseXcsric02_bufferSize hipsparseScsric02_bufferSize 66*47d993e7Ssuyashtn #define hipsparseXcsric02_analysis hipsparseScsric02_analysis 67*47d993e7Ssuyashtn #define hipsparseXcsric02 hipsparseScsric02 68*47d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE) 69*47d993e7Ssuyashtn #define hipsparseXcsrilu02_bufferSize hipsparseDcsrilu02_bufferSize 70*47d993e7Ssuyashtn #define hipsparseXcsrilu02_analysis hipsparseDcsrilu02_analysis 71*47d993e7Ssuyashtn #define hipsparseXcsrilu02 hipsparseDcsrilu02 72*47d993e7Ssuyashtn #define hipsparseXcsric02_bufferSize hipsparseDcsric02_bufferSize 73*47d993e7Ssuyashtn #define hipsparseXcsric02_analysis hipsparseDcsric02_analysis 74*47d993e7Ssuyashtn #define hipsparseXcsric02 hipsparseDcsric02 75*47d993e7Ssuyashtn #endif /* Single or double */ 76*47d993e7Ssuyashtn #endif /* complex or not */ 77*47d993e7Ssuyashtn 78*47d993e7Ssuyashtn #define csrsvInfo_t csrsv2Info_t 79*47d993e7Ssuyashtn #define hipsparseCreateCsrsvInfo hipsparseCreateCsrsv2Info 80*47d993e7Ssuyashtn #define hipsparseDestroyCsrsvInfo hipsparseDestroyCsrsv2Info 81*47d993e7Ssuyashtn #if defined(PETSC_USE_COMPLEX) 82*47d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE) 83*47d993e7Ssuyashtn #define hipsparseXcsrsv_buffsize(a, b, c, d, e, f, g, h, i, j) hipsparseCcsrsv2_bufferSize(a, b, c, d, e, (hipComplex *)(f), g, h, i, j) 84*47d993e7Ssuyashtn #define hipsparseXcsrsv_analysis(a, b, c, d, e, f, g, h, i, j, k) hipsparseCcsrsv2_analysis(a, b, c, d, e, (const hipComplex *)(f), g, h, i, j, k) 85*47d993e7Ssuyashtn #define hipsparseXcsrsv_solve(a, b, c, d, e, f, g, h, i, j, k, l, m, n) hipsparseCcsrsv2_solve(a, b, c, d, (const hipComplex *)(e), f, (const hipComplex *)(g), h, i, j, (const hipComplex *)(k), (hipComplex *)(l), m, n) 86*47d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE) 87*47d993e7Ssuyashtn #define hipsparseXcsrsv_buffsize(a, b, c, d, e, f, g, h, i, j) hipsparseZcsrsv2_bufferSize(a, b, c, d, e, (hipDoubleComplex *)(f), g, h, i, j) 88*47d993e7Ssuyashtn #define hipsparseXcsrsv_analysis(a, b, c, d, e, f, g, h, i, j, k) hipsparseZcsrsv2_analysis(a, b, c, d, e, (const hipDoubleComplex *)(f), g, h, i, j, k) 89*47d993e7Ssuyashtn #define hipsparseXcsrsv_solve(a, b, c, d, e, f, g, h, i, j, k, l, m, n) hipsparseZcsrsv2_solve(a, b, c, d, (const hipDoubleComplex *)(e), f, (const hipDoubleComplex *)(g), h, i, j, (const hipDoubleComplex *)(k), (hipDoubleComplex *)(l), m, n) 90*47d993e7Ssuyashtn #endif /* Single or double */ 91*47d993e7Ssuyashtn #else /* not complex */ 92*47d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE) 93*47d993e7Ssuyashtn #define hipsparseXcsrsv_buffsize hipsparseScsrsv2_bufferSize 94*47d993e7Ssuyashtn #define hipsparseXcsrsv_analysis hipsparseScsrsv2_analysis 95*47d993e7Ssuyashtn #define hipsparseXcsrsv_solve hipsparseScsrsv2_solve 96*47d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE) 97*47d993e7Ssuyashtn #define hipsparseXcsrsv_buffsize hipsparseDcsrsv2_bufferSize 98*47d993e7Ssuyashtn #define hipsparseXcsrsv_analysis hipsparseDcsrsv2_analysis 99*47d993e7Ssuyashtn #define hipsparseXcsrsv_solve hipsparseDcsrsv2_solve 100*47d993e7Ssuyashtn #endif /* Single or double */ 101*47d993e7Ssuyashtn #endif /* not complex */ 102*47d993e7Ssuyashtn 103*47d993e7Ssuyashtn #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0) 104*47d993e7Ssuyashtn // #define cusparse_csr2csc cusparseCsr2cscEx2 105*47d993e7Ssuyashtn #if defined(PETSC_USE_COMPLEX) 106*47d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE) 107*47d993e7Ssuyashtn #define hipsparse_scalartype HIP_C_32F 108*47d993e7Ssuyashtn #define hipsparse_csr_spgeam(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) hipsparseCcsrgeam2(a, b, c, (hipComplex *)d, e, f, (hipComplex *)g, h, i, (hipComplex *)j, k, l, (hipComplex *)m, n, o, p, (hipComplex *)q, r, s, t) 109*47d993e7Ssuyashtn #define hipsparse_csr_spgeam_bufferSize(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) \ 110*47d993e7Ssuyashtn hipsparseCcsrgeam2_bufferSizeExt(a, b, c, (hipComplex *)d, e, f, (hipComplex *)g, h, i, (hipComplex *)j, k, l, (hipComplex *)m, n, o, p, (hipComplex *)q, r, s, t) 111*47d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE) 112*47d993e7Ssuyashtn #define hipsparse_scalartype HIP_C_64F 113*47d993e7Ssuyashtn #define hipsparse_csr_spgeam(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) \ 114*47d993e7Ssuyashtn hipsparseZcsrgeam2(a, b, c, (hipDoubleComplex *)d, e, f, (hipDoubleComplex *)g, h, i, (hipDoubleComplex *)j, k, l, (hipDoubleComplex *)m, n, o, p, (hipDoubleComplex *)q, r, s, t) 115*47d993e7Ssuyashtn #define hipsparse_csr_spgeam_bufferSize(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) \ 116*47d993e7Ssuyashtn hipsparseZcsrgeam2_bufferSizeExt(a, b, c, (hipDoubleComplex *)d, e, f, (hipDoubleComplex *)g, h, i, (hipDoubleComplex *)j, k, l, (hipDoubleComplex *)m, n, o, p, (hipDoubleComplex *)q, r, s, t) 117*47d993e7Ssuyashtn #endif /* Single or double */ 118*47d993e7Ssuyashtn #else /* not complex */ 119*47d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE) 120*47d993e7Ssuyashtn #define hipsparse_scalartype HIP_R_32F 121*47d993e7Ssuyashtn #define hipsparse_csr_spgeam hipsparseScsrgeam2 122*47d993e7Ssuyashtn #define hipsparse_csr_spgeam_bufferSize hipsparseScsrgeam2_bufferSizeExt 123*47d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE) 124*47d993e7Ssuyashtn #define hipsparse_scalartype HIP_R_64F 125*47d993e7Ssuyashtn #define hipsparse_csr_spgeam hipsparseDcsrgeam2 126*47d993e7Ssuyashtn #define hipsparse_csr_spgeam_bufferSize hipsparseDcsrgeam2_bufferSizeExt 127*47d993e7Ssuyashtn #endif /* Single or double */ 128*47d993e7Ssuyashtn #endif /* complex or not */ 129*47d993e7Ssuyashtn #endif /* PETSC_PKG_HIP_VERSION_GE(4, 5, 0) */ 130*47d993e7Ssuyashtn 131*47d993e7Ssuyashtn #if defined(PETSC_USE_COMPLEX) 132*47d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE) 133*47d993e7Ssuyashtn #define hipsparse_scalartype HIP_C_32F 134*47d993e7Ssuyashtn #define hipsparse_csr_spmv(a, b, c, d, e, f, g, h, i, j, k, l, m) hipsparseCcsrmv((a), (b), (c), (d), (e), (hipComplex *)(f), (g), (hipComplex *)(h), (i), (j), (hipComplex *)(k), (hipComplex *)(l), (hipComplex *)(m)) 135*47d993e7Ssuyashtn #define hipsparse_csr_spmm(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) hipsparseCcsrmm((a), (b), (c), (d), (e), (f), (hipComplex *)(g), (h), (hipComplex *)(i), (j), (k), (hipComplex *)(l), (m), (hipComplex *)(n), (hipComplex *)(o), (p)) 136*47d993e7Ssuyashtn #define hipsparse_csr2csc(a, b, c, d, e, f, g, h, i, j, k, l) hipsparseCcsr2csc((a), (b), (c), (d), (hipComplex *)(e), (f), (g), (hipComplex *)(h), (i), (j), (k), (l)) 137*47d993e7Ssuyashtn #define hipsparse_hyb_spmv(a, b, c, d, e, f, g, h) hipsparseChybmv((a), (b), (hipComplex *)(c), (d), (e), (hipComplex *)(f), (hipComplex *)(g), (hipComplex *)(h)) 138*47d993e7Ssuyashtn #define hipsparse_csr2hyb(a, b, c, d, e, f, g, h, i, j) hipsparseCcsr2hyb((a), (b), (c), (d), (hipComplex *)(e), (f), (g), (h), (i), (j)) 139*47d993e7Ssuyashtn #define hipsparse_hyb2csr(a, b, c, d, e, f) hipsparseChyb2csr((a), (b), (c), (hipComplex *)(d), (e), (f)) 140*47d993e7Ssuyashtn #define hipsparse_csr_spgemm(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) hipsparseCcsrgemm(a, b, c, d, e, f, g, h, (hipComplex *)i, j, k, l, m, (hipComplex *)n, o, p, q, (hipComplex *)r, s, t) 141*47d993e7Ssuyashtn // #define hipsparse_csr_spgeam(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s) hipsparseCcsrgeam(a, b, c, (hipComplex *)d, e, f, (hipComplex *)g, h, i, (hipComplex *)j, k, l, (hipComplex *)m, n, o, p, (hipComplex *)q, r, s) 142*47d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE) 143*47d993e7Ssuyashtn #define hipsparse_scalartype HIP_C_64F 144*47d993e7Ssuyashtn #define hipsparse_csr_spmv(a, b, c, d, e, f, g, h, i, j, k, l, m) \ 145*47d993e7Ssuyashtn hipsparseZcsrmv((a), (b), (c), (d), (e), (hipDoubleComplex *)(f), (g), (hipDoubleComplex *)(h), (i), (j), (hipDoubleComplex *)(k), (hipDoubleComplex *)(l), (hipDoubleComplex *)(m)) 146*47d993e7Ssuyashtn #define hipsparse_csr_spmm(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) \ 147*47d993e7Ssuyashtn hipsparseZcsrmm((a), (b), (c), (d), (e), (f), (hipDoubleComplex *)(g), (h), (hipDoubleComplex *)(i), (j), (k), (hipDoubleComplex *)(l), (m), (hipDoubleComplex *)(n), (hipDoubleComplex *)(o), (p)) 148*47d993e7Ssuyashtn #define hipsparse_csr2csc(a, b, c, d, e, f, g, h, i, j, k, l) hipsparseZcsr2csc((a), (b), (c), (d), (hipDoubleComplex *)(e), (f), (g), (hipDoubleComplex *)(h), (i), (j), (k), (l)) 149*47d993e7Ssuyashtn #define hipsparse_hyb_spmv(a, b, c, d, e, f, g, h) hipsparseZhybmv((a), (b), (hipDoubleComplex *)(c), (d), (e), (hipDoubleComplex *)(f), (hipDoubleComplex *)(g), (hipDoubleComplex *)(h)) 150*47d993e7Ssuyashtn #define hipsparse_csr2hyb(a, b, c, d, e, f, g, h, i, j) hipsparseZcsr2hyb((a), (b), (c), (d), (hipDoubleComplex *)(e), (f), (g), (h), (i), (j)) 151*47d993e7Ssuyashtn #define hipsparse_hyb2csr(a, b, c, d, e, f) hipsparseZhyb2csr((a), (b), (c), (hipDoubleComplex *)(d), (e), (f)) 152*47d993e7Ssuyashtn #define hipsparse_csr_spgemm(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t) hipsparseZcsrgemm(a, b, c, d, e, f, g, h, (hipDoubleComplex *)i, j, k, l, m, (hipDoubleComplex *)n, o, p, q, (hipDoubleComplex *)r, s, t) 153*47d993e7Ssuyashtn // #define hipsparse_csr_spgeam(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s) hipsparseZcsrgeam(a, b, c, (hipDoubleComplex *)d, e, f, (hipDoubleComplex *)g, h, i, (hipDoubleComplex *)j, k, l, (hipDoubleComplex *)m, n, o, p, (hipDoubleComplex *)q, r, s) 154*47d993e7Ssuyashtn #endif /* Single or double */ 155*47d993e7Ssuyashtn #else /* not complex */ 156*47d993e7Ssuyashtn #if defined(PETSC_USE_REAL_SINGLE) 157*47d993e7Ssuyashtn #define hipsparse_scalartype HIP_R_32F 158*47d993e7Ssuyashtn #define hipsparse_csr_spmv hipsparseScsrmv 159*47d993e7Ssuyashtn #define hipsparse_csr_spmm hipsparseScsrmm 160*47d993e7Ssuyashtn #define hipsparse_csr2csc hipsparseScsr2csc 161*47d993e7Ssuyashtn #define hipsparse_hyb_spmv hipsparseShybmv 162*47d993e7Ssuyashtn #define hipsparse_csr2hyb hipsparseScsr2hyb 163*47d993e7Ssuyashtn #define hipsparse_hyb2csr hipsparseShyb2csr 164*47d993e7Ssuyashtn #define hipsparse_csr_spgemm hipsparseScsrgemm 165*47d993e7Ssuyashtn // #define hipsparse_csr_spgeam hipsparseScsrgeam 166*47d993e7Ssuyashtn #elif defined(PETSC_USE_REAL_DOUBLE) 167*47d993e7Ssuyashtn #define hipsparse_scalartype HIP_R_64F 168*47d993e7Ssuyashtn #define hipsparse_csr_spmv hipsparseDcsrmv 169*47d993e7Ssuyashtn #define hipsparse_csr_spmm hipsparseDcsrmm 170*47d993e7Ssuyashtn #define hipsparse_csr2csc hipsparseDcsr2csc 171*47d993e7Ssuyashtn #define hipsparse_hyb_spmv hipsparseDhybmv 172*47d993e7Ssuyashtn #define hipsparse_csr2hyb hipsparseDcsr2hyb 173*47d993e7Ssuyashtn #define hipsparse_hyb2csr hipsparseDhyb2csr 174*47d993e7Ssuyashtn #define hipsparse_csr_spgemm hipsparseDcsrgemm 175*47d993e7Ssuyashtn // #define hipsparse_csr_spgeam hipsparseDcsrgeam 176*47d993e7Ssuyashtn #endif /* Single or double */ 177*47d993e7Ssuyashtn #endif /* complex or not */ 178*47d993e7Ssuyashtn 179*47d993e7Ssuyashtn #define THRUSTINTARRAY32 thrust::device_vector<int> 180*47d993e7Ssuyashtn #define THRUSTINTARRAY thrust::device_vector<PetscInt> 181*47d993e7Ssuyashtn #define THRUSTARRAY thrust::device_vector<PetscScalar> 182*47d993e7Ssuyashtn 183*47d993e7Ssuyashtn /* A CSR matrix structure */ 184*47d993e7Ssuyashtn struct CsrMatrix { 185*47d993e7Ssuyashtn PetscInt num_rows; 186*47d993e7Ssuyashtn PetscInt num_cols; 187*47d993e7Ssuyashtn PetscInt num_entries; 188*47d993e7Ssuyashtn THRUSTINTARRAY32 *row_offsets; 189*47d993e7Ssuyashtn THRUSTINTARRAY32 *column_indices; 190*47d993e7Ssuyashtn THRUSTARRAY *values; 191*47d993e7Ssuyashtn }; 192*47d993e7Ssuyashtn 193*47d993e7Ssuyashtn /* This is struct holding the relevant data needed to a MatSolve */ 194*47d993e7Ssuyashtn struct Mat_SeqAIJHIPSPARSETriFactorStruct { 195*47d993e7Ssuyashtn /* Data needed for triangular solve */ 196*47d993e7Ssuyashtn hipsparseMatDescr_t descr; 197*47d993e7Ssuyashtn hipsparseOperation_t solveOp; 198*47d993e7Ssuyashtn CsrMatrix *csrMat; 199*47d993e7Ssuyashtn csrsvInfo_t solveInfo; 200*47d993e7Ssuyashtn hipsparseSolvePolicy_t solvePolicy; /* whether level information is generated and used */ 201*47d993e7Ssuyashtn int solveBufferSize; 202*47d993e7Ssuyashtn void *solveBuffer; 203*47d993e7Ssuyashtn size_t csr2cscBufferSize; /* to transpose the triangular factor (only used for CUDA >= 11.0) */ 204*47d993e7Ssuyashtn void *csr2cscBuffer; 205*47d993e7Ssuyashtn PetscScalar *AA_h; /* managed host buffer for moving values to the GPU */ 206*47d993e7Ssuyashtn }; 207*47d993e7Ssuyashtn 208*47d993e7Ssuyashtn /* This is a larger struct holding all the triangular factors for a solve, transpose solve, and any indices used in a reordering */ 209*47d993e7Ssuyashtn struct Mat_SeqAIJHIPSPARSETriFactors { 210*47d993e7Ssuyashtn Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorPtr; /* pointer for lower triangular (factored matrix) on GPU */ 211*47d993e7Ssuyashtn Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorPtr; /* pointer for upper triangular (factored matrix) on GPU */ 212*47d993e7Ssuyashtn Mat_SeqAIJHIPSPARSETriFactorStruct *loTriFactorPtrTranspose; /* pointer for lower triangular (factored matrix) on GPU for the transpose (useful for BiCG) */ 213*47d993e7Ssuyashtn Mat_SeqAIJHIPSPARSETriFactorStruct *upTriFactorPtrTranspose; /* pointer for upper triangular (factored matrix) on GPU for the transpose (useful for BiCG)*/ 214*47d993e7Ssuyashtn THRUSTINTARRAY *rpermIndices; /* indices used for any reordering */ 215*47d993e7Ssuyashtn THRUSTINTARRAY *cpermIndices; /* indices used for any reordering */ 216*47d993e7Ssuyashtn THRUSTARRAY *workVector; 217*47d993e7Ssuyashtn hipsparseHandle_t handle; /* a handle to the hipsparse library */ 218*47d993e7Ssuyashtn PetscInt nnz; /* number of nonzeros ... need this for accurate logging between ICC and ILU */ 219*47d993e7Ssuyashtn PetscScalar *a_band_d; /* GPU data for banded CSR LU factorization matrix diag(L)=1 */ 220*47d993e7Ssuyashtn int *i_band_d; /* this could be optimized away */ 221*47d993e7Ssuyashtn hipDeviceProp_t dev_prop; 222*47d993e7Ssuyashtn PetscBool init_dev_prop; 223*47d993e7Ssuyashtn 224*47d993e7Ssuyashtn /* csrilu0/csric0 appeared in earlier versions of AMD ROCm^{TM}, but we use it along with hipsparseSpSV, 225*47d993e7Ssuyashtn which first appeared in hipsparse with ROCm-4.5.0. 226*47d993e7Ssuyashtn */ 227*47d993e7Ssuyashtn PetscBool factorizeOnDevice; /* Do factorization on device or not */ 228*47d993e7Ssuyashtn #if PETSC_PKG_HIP_VERSION_GE(4, 5, 0) 229*47d993e7Ssuyashtn PetscScalar *csrVal; 230*47d993e7Ssuyashtn int *csrRowPtr, *csrColIdx; /* a,i,j of M. Using int since some hipsparse APIs only support 32-bit indices */ 231*47d993e7Ssuyashtn 232*47d993e7Ssuyashtn /* Mixed mat descriptor types? yes, different hipsparse APIs use different types */ 233*47d993e7Ssuyashtn hipsparseMatDescr_t matDescr_M; 234*47d993e7Ssuyashtn hipsparseSpMatDescr_t spMatDescr_L, spMatDescr_U; 235*47d993e7Ssuyashtn hipsparseSpSVDescr_t spsvDescr_L, spsvDescr_Lt, spsvDescr_U, spsvDescr_Ut; 236*47d993e7Ssuyashtn 237*47d993e7Ssuyashtn hipsparseDnVecDescr_t dnVecDescr_X, dnVecDescr_Y; 238*47d993e7Ssuyashtn PetscScalar *X, *Y; /* data array of dnVec X and Y */ 239*47d993e7Ssuyashtn 240*47d993e7Ssuyashtn /* Mixed size types? yes */ 241*47d993e7Ssuyashtn int factBufferSize_M; /* M ~= LU or LLt */ 242*47d993e7Ssuyashtn size_t spsvBufferSize_L, spsvBufferSize_Lt, spsvBufferSize_U, spsvBufferSize_Ut; 243*47d993e7Ssuyashtn /* hipsparse needs various buffers for factorization and solve of L, U, Lt, or Ut. 244*47d993e7Ssuyashtn To save memory, we share the factorization buffer with one of spsvBuffer_L/U. 245*47d993e7Ssuyashtn */ 246*47d993e7Ssuyashtn void *factBuffer_M, *spsvBuffer_L, *spsvBuffer_U, *spsvBuffer_Lt, *spsvBuffer_Ut; 247*47d993e7Ssuyashtn 248*47d993e7Ssuyashtn csrilu02Info_t ilu0Info_M; 249*47d993e7Ssuyashtn csric02Info_t ic0Info_M; 250*47d993e7Ssuyashtn int structural_zero, numerical_zero; 251*47d993e7Ssuyashtn hipsparseSolvePolicy_t policy_M; 252*47d993e7Ssuyashtn 253*47d993e7Ssuyashtn /* In MatSolveTranspose() for ILU0, we use the two flags to do on-demand solve */ 254*47d993e7Ssuyashtn PetscBool createdTransposeSpSVDescr; /* Have we created SpSV descriptors for Lt, Ut? */ 255*47d993e7Ssuyashtn PetscBool updatedTransposeSpSVAnalysis; /* Have we updated SpSV analysis with the latest L, U values? */ 256*47d993e7Ssuyashtn 257*47d993e7Ssuyashtn PetscLogDouble numericFactFlops; /* Estimated FLOPs in ILU0/ICC0 numeric factorization */ 258*47d993e7Ssuyashtn #endif 259*47d993e7Ssuyashtn }; 260*47d993e7Ssuyashtn 261*47d993e7Ssuyashtn struct Mat_HipsparseSpMV { 262*47d993e7Ssuyashtn PetscBool initialized; /* Don't rely on spmvBuffer != NULL to test if the struct is initialized, */ 263*47d993e7Ssuyashtn size_t spmvBufferSize; /* since I'm not sure if smvBuffer can be NULL even after hipsparseSpMV_bufferSize() */ 264*47d993e7Ssuyashtn void *spmvBuffer; 265*47d993e7Ssuyashtn hipsparseDnVecDescr_t vecXDescr, vecYDescr; /* descriptor for the dense vectors in y=op(A)x */ 266*47d993e7Ssuyashtn }; 267*47d993e7Ssuyashtn 268*47d993e7Ssuyashtn /* This is struct holding the relevant data needed to a MatMult */ 269*47d993e7Ssuyashtn struct Mat_SeqAIJHIPSPARSEMultStruct { 270*47d993e7Ssuyashtn void *mat; /* opaque pointer to a matrix. This could be either a hipsparseHybMat_t or a CsrMatrix */ 271*47d993e7Ssuyashtn hipsparseMatDescr_t descr; /* Data needed to describe the matrix for a multiply */ 272*47d993e7Ssuyashtn THRUSTINTARRAY *cprowIndices; /* compressed row indices used in the parallel SpMV */ 273*47d993e7Ssuyashtn PetscScalar *alpha_one; /* pointer to a device "scalar" storing the alpha parameter in the SpMV */ 274*47d993e7Ssuyashtn PetscScalar *beta_zero; /* pointer to a device "scalar" storing the beta parameter in the SpMV as zero*/ 275*47d993e7Ssuyashtn PetscScalar *beta_one; /* pointer to a device "scalar" storing the beta parameter in the SpMV as one */ 276*47d993e7Ssuyashtn hipsparseSpMatDescr_t matDescr; /* descriptor for the matrix, used by SpMV and SpMM */ 277*47d993e7Ssuyashtn Mat_HipsparseSpMV hipSpMV[3]; /* different Mat_CusparseSpMV structs for non-transpose, transpose, conj-transpose */ 278*47d993e7Ssuyashtn Mat_SeqAIJHIPSPARSEMultStruct() : matDescr(NULL) 279*47d993e7Ssuyashtn { 280*47d993e7Ssuyashtn for (int i = 0; i < 3; i++) hipSpMV[i].initialized = PETSC_FALSE; 281*47d993e7Ssuyashtn } 282*47d993e7Ssuyashtn }; 283*47d993e7Ssuyashtn 284*47d993e7Ssuyashtn /* This is a larger struct holding all the matrices for a SpMV, and SpMV Transpose */ 285*47d993e7Ssuyashtn struct Mat_SeqAIJHIPSPARSE { 286*47d993e7Ssuyashtn Mat_SeqAIJHIPSPARSEMultStruct *mat; /* pointer to the matrix on the GPU */ 287*47d993e7Ssuyashtn Mat_SeqAIJHIPSPARSEMultStruct *matTranspose; /* pointer to the matrix on the GPU (for the transpose ... useful for BiCG) */ 288*47d993e7Ssuyashtn THRUSTARRAY *workVector; /* pointer to a workvector to which we can copy the relevant indices of a vector we want to multiply */ 289*47d993e7Ssuyashtn THRUSTINTARRAY32 *rowoffsets_gpu; /* rowoffsets on GPU in non-compressed-row format. It is used to convert CSR to CSC */ 290*47d993e7Ssuyashtn PetscInt nrows; /* number of rows of the matrix seen by GPU */ 291*47d993e7Ssuyashtn MatHIPSPARSEStorageFormat format; /* the storage format for the matrix on the device */ 292*47d993e7Ssuyashtn PetscBool use_cpu_solve; /* Use AIJ_Seq (I)LU solve */ 293*47d993e7Ssuyashtn hipStream_t stream; /* a stream for the parallel SpMV ... this is not owned and should not be deleted */ 294*47d993e7Ssuyashtn hipsparseHandle_t handle; /* a handle to the cusparse library ... this may not be owned (if we're working in parallel i.e. multiGPUs) */ 295*47d993e7Ssuyashtn PetscObjectState nonzerostate; /* track nonzero state to possibly recreate the GPU matrix */ 296*47d993e7Ssuyashtn size_t csr2cscBufferSize; /* stuff used to compute the matTranspose above */ 297*47d993e7Ssuyashtn void *csr2cscBuffer; /* This is used as a C struct and is calloc'ed by PetscNewLog() */ 298*47d993e7Ssuyashtn // hipsparseCsr2CscAlg_t csr2cscAlg; /* algorithms can be selected from command line options */ 299*47d993e7Ssuyashtn hipsparseSpMVAlg_t spmvAlg; 300*47d993e7Ssuyashtn hipsparseSpMMAlg_t spmmAlg; 301*47d993e7Ssuyashtn THRUSTINTARRAY *csr2csc_i; 302*47d993e7Ssuyashtn PetscSplitCSRDataStructure deviceMat; /* Matrix on device for, eg, assembly */ 303*47d993e7Ssuyashtn THRUSTINTARRAY *cooPerm; /* permutation array that sorts the input coo entris by row and col */ 304*47d993e7Ssuyashtn THRUSTINTARRAY *cooPerm_a; /* ordered array that indicate i-th nonzero (after sorting) is the j-th unique nonzero */ 305*47d993e7Ssuyashtn 306*47d993e7Ssuyashtn /* Stuff for extended COO support */ 307*47d993e7Ssuyashtn PetscBool use_extended_coo; /* Use extended COO format */ 308*47d993e7Ssuyashtn PetscCount *jmap_d; /* perm[disp+jmap[i]..disp+jmap[i+1]) gives indices of entries in v[] associated with i-th nonzero of the matrix */ 309*47d993e7Ssuyashtn PetscCount *perm_d; 310*47d993e7Ssuyashtn 311*47d993e7Ssuyashtn Mat_SeqAIJHIPSPARSE() : use_extended_coo(PETSC_FALSE), perm_d(NULL), jmap_d(NULL) { } 312*47d993e7Ssuyashtn }; 313*47d993e7Ssuyashtn 314*47d993e7Ssuyashtn typedef struct Mat_SeqAIJHIPSPARSETriFactors *Mat_SeqAIJHIPSPARSETriFactors_p; 315*47d993e7Ssuyashtn 316*47d993e7Ssuyashtn PETSC_INTERN PetscErrorCode MatSeqAIJHIPSPARSECopyToGPU(Mat); 317*47d993e7Ssuyashtn PETSC_INTERN PetscErrorCode MatSetPreallocationCOO_SeqAIJHIPSPARSE_Basic(Mat, PetscCount, PetscInt[], PetscInt[]); 318*47d993e7Ssuyashtn PETSC_INTERN PetscErrorCode MatSetValuesCOO_SeqAIJHIPSPARSE_Basic(Mat, const PetscScalar[], InsertMode); 319*47d993e7Ssuyashtn PETSC_INTERN PetscErrorCode MatSeqAIJHIPSPARSEMergeMats(Mat, Mat, MatReuse, Mat *); 320*47d993e7Ssuyashtn PETSC_INTERN PetscErrorCode MatSeqAIJHIPSPARSETriFactors_Reset(Mat_SeqAIJHIPSPARSETriFactors_p *); 321*47d993e7Ssuyashtn 322*47d993e7Ssuyashtn static inline bool isHipMem(const void *data) 323*47d993e7Ssuyashtn { 324*47d993e7Ssuyashtn hipError_t cerr; 325*47d993e7Ssuyashtn struct hipPointerAttribute_t attr; 326*47d993e7Ssuyashtn enum hipMemoryType mtype; 327*47d993e7Ssuyashtn cerr = hipPointerGetAttributes(&attr, data); /* Do not check error since before CUDA 11.0, passing a host pointer returns hipErrorInvalidValue */ 328*47d993e7Ssuyashtn hipGetLastError(); /* Reset the last error */ 329*47d993e7Ssuyashtn mtype = attr.memoryType; 330*47d993e7Ssuyashtn if (cerr == hipSuccess && mtype == hipMemoryTypeDevice) return true; 331*47d993e7Ssuyashtn else return false; 332*47d993e7Ssuyashtn } 333*47d993e7Ssuyashtn 334*47d993e7Ssuyashtn #endif 335