1*f80f4a74SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*f80f4a74SSebastian Grimberg // 4*f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5*f80f4a74SSebastian Grimberg // 6*f80f4a74SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7*f80f4a74SSebastian Grimberg 8*f80f4a74SSebastian Grimberg #ifndef CEED_MAGMA_COMMON_NONTENSOR_H 9*f80f4a74SSebastian Grimberg #define CEED_MAGMA_COMMON_NONTENSOR_H 10*f80f4a74SSebastian Grimberg 11*f80f4a74SSebastian Grimberg #define NONTENSOR_MAX_THREADS (128) 12*f80f4a74SSebastian Grimberg 13*f80f4a74SSebastian Grimberg #ifndef MAGMA_DEVICE_SHARED 14*f80f4a74SSebastian Grimberg #define MAGMA_DEVICE_SHARED 15*f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP 16*f80f4a74SSebastian Grimberg #define MAGMA_DEVICE_SHARED(type, name) HIP_DYNAMIC_SHARED(type, name) 17*f80f4a74SSebastian Grimberg #else 18*f80f4a74SSebastian Grimberg #define MAGMA_DEVICE_SHARED(type, name) extern __shared__ type name[]; 19*f80f4a74SSebastian Grimberg #endif // CEED_MAGMA_USE_HIP 20*f80f4a74SSebastian Grimberg #endif // MAGMA_DEVICE_SHARED 21*f80f4a74SSebastian Grimberg 22*f80f4a74SSebastian Grimberg #define MAGMA_NONTENSOR_BASIS_NTCOL(N) (MAGMA_MAX(1, (NONTENSOR_MAX_THREADS / (N)))) 23*f80f4a74SSebastian Grimberg 24*f80f4a74SSebastian Grimberg #define dA(i, j) dA[(j)*ldda + (i)] 25*f80f4a74SSebastian Grimberg #define sA(i, j) sA[(j)*slda + (i)] 26*f80f4a74SSebastian Grimberg #define dB(i, j) dB[(j)*lddb + (i)] 27*f80f4a74SSebastian Grimberg #define sB(i, j) sB[(j)*sldb + (i)] 28*f80f4a74SSebastian Grimberg 29*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 30*f80f4a74SSebastian Grimberg // read C from global to reg. 31*f80f4a74SSebastian Grimberg // C is (P_ x NB_) 32*f80f4a74SSebastian Grimberg // 1D thread config. with (Mx1) threads 33*f80f4a74SSebastian Grimberg // no sync at the end of the function 34*f80f4a74SSebastian Grimberg template <typename T, int P_, int NB_, int Q_> 35*f80f4a74SSebastian Grimberg static __device__ __inline__ void read_C_g2r_1D_nosync(const int tx, const int n, T *dC, int lddc, const T &beta, T rC[NB_]) { 36*f80f4a74SSebastian Grimberg if (n != NB_) { 37*f80f4a74SSebastian Grimberg #pragma unroll 38*f80f4a74SSebastian Grimberg for (int j = 0; j < NB_; j++) { 39*f80f4a74SSebastian Grimberg rC[j] = (j < n) ? beta * dC[j * lddc + tx] : 0; 40*f80f4a74SSebastian Grimberg } 41*f80f4a74SSebastian Grimberg } else { 42*f80f4a74SSebastian Grimberg #pragma unroll 43*f80f4a74SSebastian Grimberg for (int j = 0; j < NB_; j++) { 44*f80f4a74SSebastian Grimberg rC[j] = beta * dC[j * lddc + tx]; 45*f80f4a74SSebastian Grimberg } 46*f80f4a74SSebastian Grimberg } 47*f80f4a74SSebastian Grimberg } 48*f80f4a74SSebastian Grimberg 49*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 50*f80f4a74SSebastian Grimberg // write C from reg. to global 51*f80f4a74SSebastian Grimberg // C is (P_ x NB_) 52*f80f4a74SSebastian Grimberg // 1D thread config. with (Mx1) threads 53*f80f4a74SSebastian Grimberg // no sync at the end of the function 54*f80f4a74SSebastian Grimberg template <typename T, int P_, int NB_, int Q_> 55*f80f4a74SSebastian Grimberg static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB_], T *dC, int lddc) { 56*f80f4a74SSebastian Grimberg if (n != NB_) { 57*f80f4a74SSebastian Grimberg #pragma unroll 58*f80f4a74SSebastian Grimberg for (int j = 0; j < NB_; j++) { 59*f80f4a74SSebastian Grimberg if (j < n) { 60*f80f4a74SSebastian Grimberg dC[j * lddc + tx] = rC[j]; 61*f80f4a74SSebastian Grimberg } 62*f80f4a74SSebastian Grimberg } 63*f80f4a74SSebastian Grimberg } else { 64*f80f4a74SSebastian Grimberg #pragma unroll 65*f80f4a74SSebastian Grimberg for (int j = 0; j < NB_; j++) { 66*f80f4a74SSebastian Grimberg dC[j * lddc + tx] = rC[j]; 67*f80f4a74SSebastian Grimberg } 68*f80f4a74SSebastian Grimberg } 69*f80f4a74SSebastian Grimberg } 70*f80f4a74SSebastian Grimberg 71*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 72*f80f4a74SSebastian Grimberg // read A (no-trans) from global to reg. 73*f80f4a74SSebastian Grimberg // A is (P_ x Q_) 74*f80f4a74SSebastian Grimberg // 1D thread config. with (Mx1) threads 75*f80f4a74SSebastian Grimberg // no sync at the end of the function 76*f80f4a74SSebastian Grimberg template <typename T, int P_, int NB_, int Q_> 77*f80f4a74SSebastian Grimberg static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const T *dA, int ldda, T *sA, int slda, T rA[Q_]) { 78*f80f4a74SSebastian Grimberg #pragma unroll 79*f80f4a74SSebastian Grimberg for (int j = 0; j < Q_; j++) { 80*f80f4a74SSebastian Grimberg rA[j] = dA(tx, j); 81*f80f4a74SSebastian Grimberg } 82*f80f4a74SSebastian Grimberg } 83*f80f4a74SSebastian Grimberg 84*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 85*f80f4a74SSebastian Grimberg // read A (no-trans) from global to reg. 86*f80f4a74SSebastian Grimberg // A is (P_ x Q_) 87*f80f4a74SSebastian Grimberg // 1D thread config. with (Mx1) threads 88*f80f4a74SSebastian Grimberg // no sync at the end of the function 89*f80f4a74SSebastian Grimberg template <typename T, int P_, int NB_, int Q_> 90*f80f4a74SSebastian Grimberg static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, int ldda, T *sA, int slda, T rA[Q_]) { 91*f80f4a74SSebastian Grimberg int ix = 0; 92*f80f4a74SSebastian Grimberg const int nTH = P_ * MAGMA_NONTENSOR_BASIS_NTCOL(P_); 93*f80f4a74SSebastian Grimberg const int tid = ty * blockDim.x + tx; 94*f80f4a74SSebastian Grimberg 95*f80f4a74SSebastian Grimberg #pragma unroll 96*f80f4a74SSebastian Grimberg for (ix = 0; ix < (Q_ * P_) - nTH; ix += nTH) { 97*f80f4a74SSebastian Grimberg sA[ix + tid] = dA[ix + tid]; 98*f80f4a74SSebastian Grimberg } 99*f80f4a74SSebastian Grimberg 100*f80f4a74SSebastian Grimberg if (tid < ((Q_ * P_) - ix)) { 101*f80f4a74SSebastian Grimberg sA[ix + tid] = dA[ix + tid]; 102*f80f4a74SSebastian Grimberg } 103*f80f4a74SSebastian Grimberg __syncthreads(); 104*f80f4a74SSebastian Grimberg 105*f80f4a74SSebastian Grimberg #pragma unroll 106*f80f4a74SSebastian Grimberg for (int j = 0; j < Q_; j++) { 107*f80f4a74SSebastian Grimberg rA[j] = sA[tx * slda + j]; 108*f80f4a74SSebastian Grimberg } 109*f80f4a74SSebastian Grimberg } 110*f80f4a74SSebastian Grimberg 111*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 112*f80f4a74SSebastian Grimberg // read B from global to shared 113*f80f4a74SSebastian Grimberg // B is (Q_ x NB_) 114*f80f4a74SSebastian Grimberg // 1D thread config. with (Mx1) threads 115*f80f4a74SSebastian Grimberg // no sync at the end of the function 116*f80f4a74SSebastian Grimberg template <typename T, int P_, int NB_, int Q_> 117*f80f4a74SSebastian Grimberg static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, int n, const T *dB, int lddb, T *sB, int sldb) { 118*f80f4a74SSebastian Grimberg if (n != NB_) { 119*f80f4a74SSebastian Grimberg for (int i = 0; i < (Q_ * n) - P_; i += P_) { 120*f80f4a74SSebastian Grimberg sB[i + tx] = dB[i + tx]; 121*f80f4a74SSebastian Grimberg } 122*f80f4a74SSebastian Grimberg } else { 123*f80f4a74SSebastian Grimberg #pragma unroll 124*f80f4a74SSebastian Grimberg for (int i = 0; i < (Q_ * NB_) - P_; i += P_) { 125*f80f4a74SSebastian Grimberg sB[i + tx] = dB[i + tx]; 126*f80f4a74SSebastian Grimberg } 127*f80f4a74SSebastian Grimberg } 128*f80f4a74SSebastian Grimberg 129*f80f4a74SSebastian Grimberg // cleanup for B 130*f80f4a74SSebastian Grimberg const int stride = MAGMA_ROUNDUP(Q_ * n - P_, P_); 131*f80f4a74SSebastian Grimberg if (tx < (Q_ * n) - stride) { 132*f80f4a74SSebastian Grimberg sB[stride + tx] = dB[stride + tx]; 133*f80f4a74SSebastian Grimberg } 134*f80f4a74SSebastian Grimberg } 135*f80f4a74SSebastian Grimberg 136*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 137*f80f4a74SSebastian Grimberg // multiply C = AxB using 1D threads in Mx1 config 138*f80f4a74SSebastian Grimberg // A (MxK) in reg., one row per thread 139*f80f4a74SSebastian Grimberg // B (KxNB) in shared memory 140*f80f4a74SSebastian Grimberg // C in registers -- one row per thread 141*f80f4a74SSebastian Grimberg // no sync at the end of the function 142*f80f4a74SSebastian Grimberg template <typename T, int P_, int NB_, int Q_> 143*f80f4a74SSebastian Grimberg static __device__ __inline__ void mul_rAsBrC_1D_nosync(const int tx, const T &alpha, T rA[Q_], T *sB, int sldb, T rC[NB_]) { 144*f80f4a74SSebastian Grimberg T rB[Q_] = {0}; 145*f80f4a74SSebastian Grimberg #pragma unroll 146*f80f4a74SSebastian Grimberg for (int i = 0; i < NB_; i++) { 147*f80f4a74SSebastian Grimberg #pragma unroll 148*f80f4a74SSebastian Grimberg for (int k = 0; k < Q_; k++) { 149*f80f4a74SSebastian Grimberg rB[k] = sB[i * sldb + k]; 150*f80f4a74SSebastian Grimberg } 151*f80f4a74SSebastian Grimberg 152*f80f4a74SSebastian Grimberg T rTmp = 0; 153*f80f4a74SSebastian Grimberg #pragma unroll 154*f80f4a74SSebastian Grimberg for (int k = 0; k < Q_; k++) { 155*f80f4a74SSebastian Grimberg rTmp += rA[k] * rB[k]; 156*f80f4a74SSebastian Grimberg } 157*f80f4a74SSebastian Grimberg rC[i] += alpha * rTmp; 158*f80f4a74SSebastian Grimberg } 159*f80f4a74SSebastian Grimberg } 160*f80f4a74SSebastian Grimberg 161*f80f4a74SSebastian Grimberg #undef dA 162*f80f4a74SSebastian Grimberg #undef sA 163*f80f4a74SSebastian Grimberg #undef dB 164*f80f4a74SSebastian Grimberg #undef sB 165*f80f4a74SSebastian Grimberg 166*f80f4a74SSebastian Grimberg #endif // CEED_MAGMA_COMMON_NONTENSOR_H 167