1*f80f4a74SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*f80f4a74SSebastian Grimberg // 4*f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5*f80f4a74SSebastian Grimberg // 6*f80f4a74SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7*f80f4a74SSebastian Grimberg 8*f80f4a74SSebastian Grimberg // macros to abstract access of shared memory and reg. file 9*f80f4a74SSebastian Grimberg #define sT(i, j) sT[(j)*P_ + (i)] 10*f80f4a74SSebastian Grimberg #define sTmp(i, j, ldw) sTmp[(j) * (ldw) + (i)] 11*f80f4a74SSebastian Grimberg #define sTmp2(i, j, ldw) sTmp2[(j) * (ldw) + (i)] 12*f80f4a74SSebastian Grimberg 13*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 14*f80f4a74SSebastian Grimberg // grad basis action (3D) 15*f80f4a74SSebastian Grimberg // This function is called three times at a higher level for 3D 16*f80f4a74SSebastian Grimberg // DIM_U -- for the size of rU[DIM_U * NCOMP_ * MAXP_Q_] 17*f80f4a74SSebastian Grimberg // DIM_V -- for the size of rV[DIM_V * NCOMP_ * MAXP_Q_] 18*f80f4a74SSebastian Grimberg // iDIM_ -- the index of the outermost loop over dimensions in grad 19*f80f4a74SSebastian Grimberg // iDIM_U -- which dim index of rU is accessed (always 0 for notrans, 0, 1, or 2 for trans) 20*f80f4a74SSebastian Grimberg // iDIM_V -- which dim index of rV is accessed (0, 1, or 2 for notrans, always 0 for trans) 21*f80f4a74SSebastian Grimberg // the scalar beta is used to specify whether to accumulate to rV, or overwrite it 22*f80f4a74SSebastian Grimberg template <typename T, int DIM_U, int DIM_V, int NCOMP_, int P_, int Q_, int rUsize, int rVsize, int iDIM_, int iDIM_U, int iDIM_V> 23*f80f4a74SSebastian Grimberg static __device__ __inline__ void magma_grad_3d_device(const T *sTinterp, const T *sTgrad, T rU[DIM_U][NCOMP_][rUsize], T rV[DIM_V][NCOMP_][rVsize], 24*f80f4a74SSebastian Grimberg T beta, const int tx, T rTmp, T *swork) { 25*f80f4a74SSebastian Grimberg // Assumptions 26*f80f4a74SSebastian Grimberg // 0. This device routine applies grad for one dim only (iDIM_), so it should be thrice for 3D 27*f80f4a74SSebastian Grimberg // 1. 1D threads of size max(P_,Q_)^2 28*f80f4a74SSebastian Grimberg // 2. input: rU[DIM_U x NCOMP_ x rUsize] in registers (per thread) 29*f80f4a74SSebastian Grimberg // 3. output: rV[DIM_V x NCOMP_ x rVsize] in registers (per thread) 30*f80f4a74SSebastian Grimberg // 4. Three products per each (dim,component) pair 31*f80f4a74SSebastian Grimberg // 4.1 Batch P_^2 of (1xP_) matrices times (P_xQ_) matrix => Batch P_^2 of (1xQ_) matrices 32*f80f4a74SSebastian Grimberg // 4.2 Batch P_ of (Q_xP_) matrices times (P_xQ_) matrix => Batch P_ of (Q_xQ_) matrices 33*f80f4a74SSebastian Grimberg // 4.3 Batch 1 of (Q_^2xP_) matrix times (P_xQ_) matrix => (Q_^2xQ_) matrix 34*f80f4a74SSebastian Grimberg // 6. Each thread computes one row of the output of each product 35*f80f4a74SSebastian Grimberg // 7. Sync is recommended before and after the call 36*f80f4a74SSebastian Grimberg 37*f80f4a74SSebastian Grimberg T *sW1 = swork; 38*f80f4a74SSebastian Grimberg T *sW2 = sW1 + P_ * P_ * Q_; 39*f80f4a74SSebastian Grimberg for (int icomp = 0; icomp < NCOMP_; icomp++) { 40*f80f4a74SSebastian Grimberg // Batch P_^2 of (1xP_) matrices [reg] times (P_xQ_) matrix [shmem] => Batch P_^2 of (1xQ_) matrices [shmem] 41*f80f4a74SSebastian Grimberg if (tx < (P_ * P_)) { 42*f80f4a74SSebastian Grimberg const int batchid = tx; 43*f80f4a74SSebastian Grimberg const int sld = 1; 44*f80f4a74SSebastian Grimberg const T *sT = (iDIM_ == 0) ? sTgrad : sTinterp; 45*f80f4a74SSebastian Grimberg T *sTmp = sW1 + batchid * (1 * Q_); 46*f80f4a74SSebastian Grimberg for (int j = 0; j < Q_; j++) { 47*f80f4a74SSebastian Grimberg rTmp = 0.0; 48*f80f4a74SSebastian Grimberg for (int i = 0; i < P_; i++) { 49*f80f4a74SSebastian Grimberg rTmp += rU[iDIM_U][icomp][i] * sT(i, j); 50*f80f4a74SSebastian Grimberg } 51*f80f4a74SSebastian Grimberg sTmp(0, j, sld) = rTmp; 52*f80f4a74SSebastian Grimberg } 53*f80f4a74SSebastian Grimberg } // end of: if (tx < P_*P_) 54*f80f4a74SSebastian Grimberg __syncthreads(); 55*f80f4a74SSebastian Grimberg 56*f80f4a74SSebastian Grimberg // Batch P_ of (Q_xP_) matrices [shmem] times (P_xQ_) matrix [shmem] => Batch P_ of (Q_xQ_) matrices [reg] 57*f80f4a74SSebastian Grimberg if (tx < (P_ * Q_)) { 58*f80f4a74SSebastian Grimberg const int batchid = tx / Q_; 59*f80f4a74SSebastian Grimberg const int tx_ = tx % Q_; 60*f80f4a74SSebastian Grimberg const int sld = Q_; 61*f80f4a74SSebastian Grimberg const T *sT = (iDIM_ == 1) ? sTgrad : sTinterp; 62*f80f4a74SSebastian Grimberg T *sTmp = sW1 + batchid * (Q_ * P_); // sTmp is input 63*f80f4a74SSebastian Grimberg T *sTmp2 = sW2 + batchid * (Q_ * Q_); // sTmp2 is output 64*f80f4a74SSebastian Grimberg for (int j = 0; j < Q_; j++) { 65*f80f4a74SSebastian Grimberg rTmp = 0.0; 66*f80f4a74SSebastian Grimberg for (int i = 0; i < P_; i++) { 67*f80f4a74SSebastian Grimberg rTmp += sTmp(tx_, i, sld) * sT(i, j); 68*f80f4a74SSebastian Grimberg } 69*f80f4a74SSebastian Grimberg sTmp2(tx_, j, sld) = rTmp; 70*f80f4a74SSebastian Grimberg } 71*f80f4a74SSebastian Grimberg } 72*f80f4a74SSebastian Grimberg __syncthreads(); 73*f80f4a74SSebastian Grimberg 74*f80f4a74SSebastian Grimberg // Batch 1 of (Q_^2xP_) matrices [shmem] times (P_xQ_) matrix [shmem] => Batch 1 of (Q_^2xQ_) matrices [reg] 75*f80f4a74SSebastian Grimberg if (tx < (Q_ * Q_)) { 76*f80f4a74SSebastian Grimberg // No need to declare batchid = (tx / Q_^2) = always zero 77*f80f4a74SSebastian Grimberg // No need to declare tx_ = (tx_ % Q_^2) = always tx 78*f80f4a74SSebastian Grimberg const int sld = Q_ * Q_; 79*f80f4a74SSebastian Grimberg const T *sT = (iDIM_ == 2) ? sTgrad : sTinterp; 80*f80f4a74SSebastian Grimberg T *sTmp = sW2; // sTmp is input 81*f80f4a74SSebastian Grimberg for (int j = 0; j < Q_; j++) { 82*f80f4a74SSebastian Grimberg rTmp = 0.0; 83*f80f4a74SSebastian Grimberg for (int i = 0; i < P_; i++) { 84*f80f4a74SSebastian Grimberg rTmp += sTmp(tx, i, sld) * sT(i, j); 85*f80f4a74SSebastian Grimberg } 86*f80f4a74SSebastian Grimberg rV[iDIM_V][icomp][j] *= beta; 87*f80f4a74SSebastian Grimberg rV[iDIM_V][icomp][j] += rTmp; 88*f80f4a74SSebastian Grimberg } 89*f80f4a74SSebastian Grimberg } 90*f80f4a74SSebastian Grimberg __syncthreads(); 91*f80f4a74SSebastian Grimberg } // loop over NCOMP_ 92*f80f4a74SSebastian Grimberg } 93*f80f4a74SSebastian Grimberg 94*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 95*f80f4a74SSebastian Grimberg extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(MAXPQ *MAXPQ, MAGMA_MAXTHREADS_3D)) __global__ 96*f80f4a74SSebastian Grimberg void magma_gradn_3d_kernel(const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, const CeedScalar *dU, const int estrdU, const int cstrdU, 97*f80f4a74SSebastian Grimberg const int dstrdU, CeedScalar *dV, const int estrdV, const int cstrdV, const int dstrdV, const int nelem) { 98*f80f4a74SSebastian Grimberg MAGMA_DEVICE_SHARED(CeedScalar, shared_data) 99*f80f4a74SSebastian Grimberg 100*f80f4a74SSebastian Grimberg const int tx = threadIdx.x; 101*f80f4a74SSebastian Grimberg const int ty = threadIdx.y; 102*f80f4a74SSebastian Grimberg const int elem_id = (blockIdx.x * blockDim.y) + ty; 103*f80f4a74SSebastian Grimberg magma_trans_t transT = MagmaNoTrans; 104*f80f4a74SSebastian Grimberg 105*f80f4a74SSebastian Grimberg if (elem_id >= nelem) return; 106*f80f4a74SSebastian Grimberg 107*f80f4a74SSebastian Grimberg CeedScalar rU[1][NCOMP][P] = {0.0}; // here DIMU = 1, but might be different for a fused operator 108*f80f4a74SSebastian Grimberg CeedScalar rV[1][NCOMP][Q] = {0.0}; // here DIMV = 1, but might be different for a fused operator 109*f80f4a74SSebastian Grimberg CeedScalar rTmp = 0.0; 110*f80f4a74SSebastian Grimberg 111*f80f4a74SSebastian Grimberg // shift global memory pointers by elem stride 112*f80f4a74SSebastian Grimberg dU += elem_id * estrdU; 113*f80f4a74SSebastian Grimberg dV += elem_id * estrdV; 114*f80f4a74SSebastian Grimberg 115*f80f4a74SSebastian Grimberg // assign shared memory pointers 116*f80f4a74SSebastian Grimberg CeedScalar *sTinterp = (CeedScalar *)(shared_data); 117*f80f4a74SSebastian Grimberg CeedScalar *sTgrad = sTinterp + P * Q; 118*f80f4a74SSebastian Grimberg CeedScalar *sTmp = sTgrad + P * Q; 119*f80f4a74SSebastian Grimberg sTmp += ty * (max(P * P * P, (P * P * Q) + (P * Q * Q))); 120*f80f4a74SSebastian Grimberg 121*f80f4a74SSebastian Grimberg // read T 122*f80f4a74SSebastian Grimberg if (ty == 0) { 123*f80f4a74SSebastian Grimberg dread_T_gm2sm<P, Q>(tx, transT, dinterp1d, sTinterp); 124*f80f4a74SSebastian Grimberg dread_T_gm2sm<P, Q>(tx, transT, dgrad1d, sTgrad); 125*f80f4a74SSebastian Grimberg } 126*f80f4a74SSebastian Grimberg __syncthreads(); 127*f80f4a74SSebastian Grimberg 128*f80f4a74SSebastian Grimberg // No need to read V ( required only in transposed grad ) 129*f80f4a74SSebastian Grimberg const CeedScalar beta = 0.0; 130*f80f4a74SSebastian Grimberg 131*f80f4a74SSebastian Grimberg /* read U (idim = 0 for dU, iDIM = 0 for rU) -- 132*f80f4a74SSebastian Grimberg there is a sync at the end of this function */ 133*f80f4a74SSebastian Grimberg readU_3d<CeedScalar, P, 1, NCOMP, P, 0>(dU + (0 * dstrdU), cstrdU, rU, sTmp, tx); 134*f80f4a74SSebastian Grimberg 135*f80f4a74SSebastian Grimberg /* first call (iDIM = 0, iDIMU = 0, iDIMV = 0) -- 136*f80f4a74SSebastian Grimberg output from rV[0][][] into dV (idim = 0) */ 137*f80f4a74SSebastian Grimberg magma_grad_3d_device<CeedScalar, 1, 1, NCOMP, P, Q, P, Q, 0, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp); 138*f80f4a74SSebastian Grimberg /* there is a sync at the end of magma_grad_3d_device */ 139*f80f4a74SSebastian Grimberg writeV_3d<CeedScalar, Q, 1, NCOMP, Q, 0>(dV + (0 * dstrdV), cstrdV, rV, tx); 140*f80f4a74SSebastian Grimberg 141*f80f4a74SSebastian Grimberg /* second call (iDIM = 1, iDIMU = 0, iDIMV = 0) -- 142*f80f4a74SSebastian Grimberg output from rV[0][][] into dV (idim = 1) */ 143*f80f4a74SSebastian Grimberg magma_grad_3d_device<CeedScalar, 1, 1, NCOMP, P, Q, P, Q, 1, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp); 144*f80f4a74SSebastian Grimberg /* there is a sync at the end of magma_grad_3d_device */ 145*f80f4a74SSebastian Grimberg writeV_3d<CeedScalar, Q, 1, NCOMP, Q, 0>(dV + (1 * dstrdV), cstrdV, rV, tx); 146*f80f4a74SSebastian Grimberg 147*f80f4a74SSebastian Grimberg /* third call (iDIM = 2, iDIMU = 0, iDIMV = 0) -- 148*f80f4a74SSebastian Grimberg output from rV[0][][] into dV (idim = 2) */ 149*f80f4a74SSebastian Grimberg magma_grad_3d_device<CeedScalar, 1, 1, NCOMP, P, Q, P, Q, 2, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp); 150*f80f4a74SSebastian Grimberg /* there is a sync at the end of magma_grad_3d_device */ 151*f80f4a74SSebastian Grimberg writeV_3d<CeedScalar, Q, 1, NCOMP, Q, 0>(dV + (2 * dstrdV), cstrdV, rV, tx); 152*f80f4a74SSebastian Grimberg } 153*f80f4a74SSebastian Grimberg 154*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 155*f80f4a74SSebastian Grimberg extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(MAXPQ *MAXPQ, MAGMA_MAXTHREADS_3D)) __global__ 156*f80f4a74SSebastian Grimberg void magma_gradt_3d_kernel(const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, const CeedScalar *dU, const int estrdU, const int cstrdU, 157*f80f4a74SSebastian Grimberg const int dstrdU, CeedScalar *dV, const int estrdV, const int cstrdV, const int dstrdV, const int nelem) { 158*f80f4a74SSebastian Grimberg MAGMA_DEVICE_SHARED(CeedScalar, shared_data) 159*f80f4a74SSebastian Grimberg 160*f80f4a74SSebastian Grimberg const int tx = threadIdx.x; 161*f80f4a74SSebastian Grimberg const int ty = threadIdx.y; 162*f80f4a74SSebastian Grimberg const int elem_id = (blockIdx.x * blockDim.y) + ty; 163*f80f4a74SSebastian Grimberg magma_trans_t transT = MagmaTrans; 164*f80f4a74SSebastian Grimberg 165*f80f4a74SSebastian Grimberg if (elem_id >= nelem) return; 166*f80f4a74SSebastian Grimberg 167*f80f4a74SSebastian Grimberg CeedScalar rU[1][NCOMP][Q] = {0.0}; // here DIMU = 1, but might be different for a fused operator 168*f80f4a74SSebastian Grimberg CeedScalar rV[1][NCOMP][P] = {0.0}; // here DIMV = 1, but might be different for a fused operator 169*f80f4a74SSebastian Grimberg CeedScalar rTmp = 0.0; 170*f80f4a74SSebastian Grimberg 171*f80f4a74SSebastian Grimberg // shift global memory pointers by elem stride 172*f80f4a74SSebastian Grimberg dU += elem_id * estrdU; 173*f80f4a74SSebastian Grimberg dV += elem_id * estrdV; 174*f80f4a74SSebastian Grimberg 175*f80f4a74SSebastian Grimberg // assign shared memory pointers 176*f80f4a74SSebastian Grimberg CeedScalar *sTinterp = (CeedScalar *)(shared_data); 177*f80f4a74SSebastian Grimberg CeedScalar *sTgrad = sTinterp + Q * P; 178*f80f4a74SSebastian Grimberg CeedScalar *sTmp = sTgrad + Q * P; 179*f80f4a74SSebastian Grimberg sTmp += ty * (max(Q * Q * Q, (Q * Q * P) + (Q * P * P))); 180*f80f4a74SSebastian Grimberg 181*f80f4a74SSebastian Grimberg // read T 182*f80f4a74SSebastian Grimberg if (ty == 0) { 183*f80f4a74SSebastian Grimberg dread_T_gm2sm<Q, P>(tx, transT, dinterp1d, sTinterp); 184*f80f4a74SSebastian Grimberg dread_T_gm2sm<Q, P>(tx, transT, dgrad1d, sTgrad); 185*f80f4a74SSebastian Grimberg } 186*f80f4a74SSebastian Grimberg __syncthreads(); 187*f80f4a74SSebastian Grimberg 188*f80f4a74SSebastian Grimberg // read V (since this is transposed mode) 189*f80f4a74SSebastian Grimberg const CeedScalar beta = 1.0; 190*f80f4a74SSebastian Grimberg readV_3d<CeedScalar, P, 1, NCOMP, P, 0>(dV + (0 * dstrdV), cstrdV, rV, tx); 191*f80f4a74SSebastian Grimberg 192*f80f4a74SSebastian Grimberg /* read U (idim = 0 for dU, iDIM = 0 for rU) -- 193*f80f4a74SSebastian Grimberg there is a sync at the end of this function */ 194*f80f4a74SSebastian Grimberg readU_3d<CeedScalar, Q, 1, NCOMP, Q, 0>(dU + (0 * dstrdU), cstrdU, rU, sTmp, tx); 195*f80f4a74SSebastian Grimberg /* then first call (iDIM = 0, iDIMU = 0, iDIMV = 0) */ 196*f80f4a74SSebastian Grimberg magma_grad_3d_device<CeedScalar, 1, 1, NCOMP, Q, P, Q, P, 0, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp); 197*f80f4a74SSebastian Grimberg /* there is a sync at the end of magma_grad_3d_device */ 198*f80f4a74SSebastian Grimberg 199*f80f4a74SSebastian Grimberg /* read U (idim = 1 for dU, iDIM = 0 for rU) -- 200*f80f4a74SSebastian Grimberg there is a sync at the end of this function */ 201*f80f4a74SSebastian Grimberg readU_3d<CeedScalar, Q, 1, NCOMP, Q, 0>(dU + (1 * dstrdU), cstrdU, rU, sTmp, tx); 202*f80f4a74SSebastian Grimberg /* then second call (iDIM = 1, iDIMU = 0, iDIMV = 0) */ 203*f80f4a74SSebastian Grimberg magma_grad_3d_device<CeedScalar, 1, 1, NCOMP, Q, P, Q, P, 1, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp); 204*f80f4a74SSebastian Grimberg /* there is a sync at the end of magma_grad_3d_device */ 205*f80f4a74SSebastian Grimberg 206*f80f4a74SSebastian Grimberg /* read U (idim = 2 for dU, iDIM = 0 for rU) -- 207*f80f4a74SSebastian Grimberg there is a sync at the end of this function */ 208*f80f4a74SSebastian Grimberg readU_3d<CeedScalar, Q, 1, NCOMP, Q, 0>(dU + (2 * dstrdU), cstrdU, rU, sTmp, tx); 209*f80f4a74SSebastian Grimberg /* then third call (iDIM = 2, iDIMU = 0, iDIMV = 0) */ 210*f80f4a74SSebastian Grimberg magma_grad_3d_device<CeedScalar, 1, 1, NCOMP, Q, P, Q, P, 2, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp); 211*f80f4a74SSebastian Grimberg /* there is a sync at the end of magma_grad_3d_device */ 212*f80f4a74SSebastian Grimberg 213*f80f4a74SSebastian Grimberg // write V 214*f80f4a74SSebastian Grimberg writeV_3d<CeedScalar, P, 1, NCOMP, P, 0>(dV + (0 * dstrdV), cstrdV, rV, tx); 215*f80f4a74SSebastian Grimberg } 216