1*f80f4a74SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*f80f4a74SSebastian Grimberg // 4*f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5*f80f4a74SSebastian Grimberg // 6*f80f4a74SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7*f80f4a74SSebastian Grimberg 8*f80f4a74SSebastian Grimberg // macros to abstract access of shared memory and reg. file 9*f80f4a74SSebastian Grimberg #define sT(i, j) sT[(j)*P_ + (i)] 10*f80f4a74SSebastian Grimberg #define sTmp(i, j, ldw) sTmp[(j) * (ldw) + (i)] 11*f80f4a74SSebastian Grimberg 12*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 13*f80f4a74SSebastian Grimberg // grad basis action (2D) 14*f80f4a74SSebastian Grimberg // This function is called two times at a higher level for 2D 15*f80f4a74SSebastian Grimberg // DIM_U -- for the size of rU[DIM_U * NCOMP_ * MAXP_Q_] 16*f80f4a74SSebastian Grimberg // DIM_V -- for the size of rV[DIM_V * NCOMP_ * MAXP_Q_] 17*f80f4a74SSebastian Grimberg // iDIM_ -- the index of the outermost loop over dimensions in grad 18*f80f4a74SSebastian Grimberg // iDIM_U -- which dim index of rU is accessed (always 0 for notrans, 0 or 1 for trans) 19*f80f4a74SSebastian Grimberg // iDIM_V -- which dim index of rV is accessed (0 or 1 for notrans, always 0 for trans) 20*f80f4a74SSebastian Grimberg // the scalar beta is used to specify whether to accumulate to rV, or overwrite it 21*f80f4a74SSebastian Grimberg template <typename T, int DIM_U, int DIM_V, int NCOMP_, int P_, int Q_, int rUsize, int rVsize, int iDIM_, int iDIM_U, int iDIM_V> 22*f80f4a74SSebastian Grimberg static __device__ __inline__ void magma_grad_2d_device(const T *sTinterp, const T *sTgrad, T rU[DIM_U][NCOMP_][rUsize], T rV[DIM_V][NCOMP_][rVsize], 23*f80f4a74SSebastian Grimberg T beta, const int tx, T rTmp, T *swork) { 24*f80f4a74SSebastian Grimberg // Assumptions 25*f80f4a74SSebastian Grimberg // 0. This device routine applies grad for one dim only (iDIM_), so it should be called twice for 2D 26*f80f4a74SSebastian Grimberg // 1. 1D threads of size max(P_,Q_) 27*f80f4a74SSebastian Grimberg // 2. input: rU[DIM_U x NCOMP_ x P_] in registers (per thread) 28*f80f4a74SSebastian Grimberg // 3. output: rV[DIM_V x NCOMP_ x Q_] in registers (per thread) 29*f80f4a74SSebastian Grimberg // 4. Two products per each (dim,component) pair 30*f80f4a74SSebastian Grimberg // 4.1 Batch P_ of (1xP_) matrices times (P_xQ_) matrix => Batch P_ of (1xQ_) matrices 31*f80f4a74SSebastian Grimberg // 4.2 Batch 1 of (Q_xP_) matrix times (P_xQ_) matrix => (Q_xQ_) matrix 32*f80f4a74SSebastian Grimberg // 6. Each thread computes one row of the output of each product 33*f80f4a74SSebastian Grimberg // 7. Sync is recommended before and after the call 34*f80f4a74SSebastian Grimberg 35*f80f4a74SSebastian Grimberg for (int icomp = 0; icomp < NCOMP_; icomp++) { 36*f80f4a74SSebastian Grimberg // 1st product -- Batch P_ of (1xP_) matrices [reg] x (P_xQ_) [shmem] => Batch P_ of (1xQ_) matrices 37*f80f4a74SSebastian Grimberg // the batch output P_ x (1xQ_) is written on the fly to shmem 38*f80f4a74SSebastian Grimberg if (tx < P_) { 39*f80f4a74SSebastian Grimberg const int batchid = tx; 40*f80f4a74SSebastian Grimberg const int sld = 1; 41*f80f4a74SSebastian Grimberg const T *sT = (iDIM_ == 0) ? sTgrad : sTinterp; 42*f80f4a74SSebastian Grimberg T *sTmp = swork + batchid * (1 * Q_); 43*f80f4a74SSebastian Grimberg for (int j = 0; j < Q_; j++) { 44*f80f4a74SSebastian Grimberg rTmp = 0.0; 45*f80f4a74SSebastian Grimberg for (int i = 0; i < P_; i++) { 46*f80f4a74SSebastian Grimberg rTmp += rU[iDIM_U][icomp][i] * sT(i, j); 47*f80f4a74SSebastian Grimberg } 48*f80f4a74SSebastian Grimberg sTmp(0, j, sld) = rTmp; 49*f80f4a74SSebastian Grimberg } 50*f80f4a74SSebastian Grimberg } // end of: if (tx < P_) 51*f80f4a74SSebastian Grimberg __syncthreads(); 52*f80f4a74SSebastian Grimberg 53*f80f4a74SSebastian Grimberg // 2nd product -- Batch 1 of a (Q_xP_) matrix [shmem] x (P_xQ_) [shmem] => (Q_xQ_) matrix [reg] 54*f80f4a74SSebastian Grimberg if (tx < Q_) { 55*f80f4a74SSebastian Grimberg const int batchid = 0; 56*f80f4a74SSebastian Grimberg const int sld = Q_; 57*f80f4a74SSebastian Grimberg const T *sT = (iDIM_ == 1) ? sTgrad : sTinterp; 58*f80f4a74SSebastian Grimberg T *sTmp = swork + batchid * (Q_ * P_); 59*f80f4a74SSebastian Grimberg for (int j = 0; j < Q_; j++) { 60*f80f4a74SSebastian Grimberg rTmp = 0.0; 61*f80f4a74SSebastian Grimberg for (int i = 0; i < P_; i++) { 62*f80f4a74SSebastian Grimberg rTmp += sTmp(tx, i, sld) * sT(i, j); 63*f80f4a74SSebastian Grimberg } 64*f80f4a74SSebastian Grimberg rV[iDIM_V][icomp][j] *= beta; 65*f80f4a74SSebastian Grimberg rV[iDIM_V][icomp][j] += rTmp; 66*f80f4a74SSebastian Grimberg } 67*f80f4a74SSebastian Grimberg } 68*f80f4a74SSebastian Grimberg __syncthreads(); 69*f80f4a74SSebastian Grimberg } // loop over NCOMP_ 70*f80f4a74SSebastian Grimberg } 71*f80f4a74SSebastian Grimberg 72*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 73*f80f4a74SSebastian Grimberg extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(MAXPQ, MAGMA_MAXTHREADS_2D)) __global__ 74*f80f4a74SSebastian Grimberg void magma_gradn_2d_kernel(const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, const CeedScalar *dU, const int estrdU, const int cstrdU, 75*f80f4a74SSebastian Grimberg const int dstrdU, CeedScalar *dV, const int estrdV, const int cstrdV, const int dstrdV, const int nelem) { 76*f80f4a74SSebastian Grimberg MAGMA_DEVICE_SHARED(CeedScalar, shared_data) 77*f80f4a74SSebastian Grimberg 78*f80f4a74SSebastian Grimberg const int tx = threadIdx.x; 79*f80f4a74SSebastian Grimberg const int ty = threadIdx.y; 80*f80f4a74SSebastian Grimberg const int elem_id = (blockIdx.x * blockDim.y) + ty; 81*f80f4a74SSebastian Grimberg magma_trans_t transT = MagmaNoTrans; 82*f80f4a74SSebastian Grimberg 83*f80f4a74SSebastian Grimberg if (elem_id >= nelem) return; 84*f80f4a74SSebastian Grimberg 85*f80f4a74SSebastian Grimberg CeedScalar rU[1][NCOMP][P] = {0.0}; // here DIMU = 1, but might be different for a fused operator 86*f80f4a74SSebastian Grimberg CeedScalar rV[1][NCOMP][Q] = {0.0}; // here DIMV = 1, but might be different for a fused operator 87*f80f4a74SSebastian Grimberg CeedScalar rTmp = 0.0; 88*f80f4a74SSebastian Grimberg 89*f80f4a74SSebastian Grimberg // shift global memory pointers by elem stride 90*f80f4a74SSebastian Grimberg dU += elem_id * estrdU; 91*f80f4a74SSebastian Grimberg dV += elem_id * estrdV; 92*f80f4a74SSebastian Grimberg 93*f80f4a74SSebastian Grimberg // assign shared memory pointers 94*f80f4a74SSebastian Grimberg CeedScalar *sTinterp = (CeedScalar *)(shared_data); 95*f80f4a74SSebastian Grimberg CeedScalar *sTgrad = sTinterp + P * Q; 96*f80f4a74SSebastian Grimberg CeedScalar *sTmp = sTgrad + P * Q; 97*f80f4a74SSebastian Grimberg sTmp += ty * (P * MAXPQ); 98*f80f4a74SSebastian Grimberg 99*f80f4a74SSebastian Grimberg // read T 100*f80f4a74SSebastian Grimberg if (ty == 0) { 101*f80f4a74SSebastian Grimberg dread_T_gm2sm<P, Q>(tx, transT, dinterp1d, sTinterp); 102*f80f4a74SSebastian Grimberg dread_T_gm2sm<P, Q>(tx, transT, dgrad1d, sTgrad); 103*f80f4a74SSebastian Grimberg } 104*f80f4a74SSebastian Grimberg 105*f80f4a74SSebastian Grimberg // No need to read V ( required only in transposed grad ) 106*f80f4a74SSebastian Grimberg const CeedScalar beta = 0.0; 107*f80f4a74SSebastian Grimberg 108*f80f4a74SSebastian Grimberg /* read U (idim = 0 for dU, iDIM = 0 for rU) -- 109*f80f4a74SSebastian Grimberg there is a sync at the end of this function */ 110*f80f4a74SSebastian Grimberg readU_2d<CeedScalar, P, 1, NCOMP, P, 0>(dU + (0 * dstrdU), cstrdU, rU, sTmp, tx); 111*f80f4a74SSebastian Grimberg 112*f80f4a74SSebastian Grimberg /* first call (iDIM = 0, iDIMU = 0, iDIMV = 0) -- 113*f80f4a74SSebastian Grimberg output from rV[0][][] into dV (idim = 0) */ 114*f80f4a74SSebastian Grimberg magma_grad_2d_device<CeedScalar, 1, 1, NCOMP, P, Q, P, Q, 0, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp); 115*f80f4a74SSebastian Grimberg /* there is a sync at the end of magma_grad_2d_device */ 116*f80f4a74SSebastian Grimberg writeV_2d<CeedScalar, Q, 1, NCOMP, Q, 0>(dV + (0 * dstrdV), cstrdV, rV, tx); 117*f80f4a74SSebastian Grimberg 118*f80f4a74SSebastian Grimberg /* second call (iDIM = 1, iDIMU = 0, iDIMV = 0) -- 119*f80f4a74SSebastian Grimberg output from rV[0][][] into dV (idim = 1) */ 120*f80f4a74SSebastian Grimberg magma_grad_2d_device<CeedScalar, 1, 1, NCOMP, P, Q, P, Q, 1, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp); 121*f80f4a74SSebastian Grimberg /* there is a sync at the end of magma_grad_2d_device */ 122*f80f4a74SSebastian Grimberg writeV_2d<CeedScalar, Q, 1, NCOMP, Q, 0>(dV + (1 * dstrdV), cstrdV, rV, tx); 123*f80f4a74SSebastian Grimberg } 124*f80f4a74SSebastian Grimberg 125*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 126*f80f4a74SSebastian Grimberg extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(MAXPQ, MAGMA_MAXTHREADS_2D)) __global__ 127*f80f4a74SSebastian Grimberg void magma_gradt_2d_kernel(const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, const CeedScalar *dU, const int estrdU, const int cstrdU, 128*f80f4a74SSebastian Grimberg const int dstrdU, CeedScalar *dV, const int estrdV, const int cstrdV, const int dstrdV, const int nelem) { 129*f80f4a74SSebastian Grimberg MAGMA_DEVICE_SHARED(CeedScalar, shared_data) 130*f80f4a74SSebastian Grimberg 131*f80f4a74SSebastian Grimberg const int tx = threadIdx.x; 132*f80f4a74SSebastian Grimberg const int ty = threadIdx.y; 133*f80f4a74SSebastian Grimberg const int elem_id = (blockIdx.x * blockDim.y) + ty; 134*f80f4a74SSebastian Grimberg magma_trans_t transT = MagmaTrans; 135*f80f4a74SSebastian Grimberg 136*f80f4a74SSebastian Grimberg if (elem_id >= nelem) return; 137*f80f4a74SSebastian Grimberg 138*f80f4a74SSebastian Grimberg CeedScalar rU[1][NCOMP][Q] = {0.0}; // here DIMU = 1, but might be different for a fused operator 139*f80f4a74SSebastian Grimberg CeedScalar rV[1][NCOMP][P] = {0.0}; // here DIMV = 1, but might be different for a fused operator 140*f80f4a74SSebastian Grimberg CeedScalar rTmp = 0.0; 141*f80f4a74SSebastian Grimberg 142*f80f4a74SSebastian Grimberg // shift global memory pointers by elem stride 143*f80f4a74SSebastian Grimberg dU += elem_id * estrdU; 144*f80f4a74SSebastian Grimberg dV += elem_id * estrdV; 145*f80f4a74SSebastian Grimberg 146*f80f4a74SSebastian Grimberg // assign shared memory pointers 147*f80f4a74SSebastian Grimberg CeedScalar *sTinterp = (CeedScalar *)(shared_data); 148*f80f4a74SSebastian Grimberg CeedScalar *sTgrad = sTinterp + Q * P; 149*f80f4a74SSebastian Grimberg CeedScalar *sTmp = sTgrad + Q * P; 150*f80f4a74SSebastian Grimberg sTmp += ty * (Q * MAXPQ); 151*f80f4a74SSebastian Grimberg 152*f80f4a74SSebastian Grimberg // read T 153*f80f4a74SSebastian Grimberg if (ty == 0) { 154*f80f4a74SSebastian Grimberg dread_T_gm2sm<Q, P>(tx, transT, dinterp1d, sTinterp); 155*f80f4a74SSebastian Grimberg dread_T_gm2sm<Q, P>(tx, transT, dgrad1d, sTgrad); 156*f80f4a74SSebastian Grimberg } 157*f80f4a74SSebastian Grimberg __syncthreads(); 158*f80f4a74SSebastian Grimberg 159*f80f4a74SSebastian Grimberg /* read V (since this is transposed mode -- 160*f80f4a74SSebastian Grimberg idim = 0 for dV, iDIM = 0 for rV) */ 161*f80f4a74SSebastian Grimberg const CeedScalar beta = 1.0; 162*f80f4a74SSebastian Grimberg readV_2d<CeedScalar, P, 1, NCOMP, P, 0>(dV + (0 * dstrdV), cstrdV, rV, tx); 163*f80f4a74SSebastian Grimberg 164*f80f4a74SSebastian Grimberg /* read U (idim = 0 for dU, iDIM = 0 for rU) -- 165*f80f4a74SSebastian Grimberg there is a sync at the end of this function */ 166*f80f4a74SSebastian Grimberg readU_2d<CeedScalar, Q, 1, NCOMP, Q, 0>(dU + (0 * dstrdU), cstrdU, rU, sTmp, tx); 167*f80f4a74SSebastian Grimberg /* first call (iDIM = 0, iDIMU = 0, iDIMV = 0) */ 168*f80f4a74SSebastian Grimberg magma_grad_2d_device<CeedScalar, 1, 1, NCOMP, Q, P, Q, P, 0, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp); 169*f80f4a74SSebastian Grimberg /* there is a sync at the end of magma_grad_2d_device */ 170*f80f4a74SSebastian Grimberg 171*f80f4a74SSebastian Grimberg /* read U (idim = 1 for dU, iDIM = 0 for rU) -- 172*f80f4a74SSebastian Grimberg there is a sync at the end of this function */ 173*f80f4a74SSebastian Grimberg readU_2d<CeedScalar, Q, 1, NCOMP, Q, 0>(dU + (1 * dstrdU), cstrdU, rU, sTmp, tx); 174*f80f4a74SSebastian Grimberg /* second call (iDIM = 1, iDIMU = 0, iDIMV = 0) */ 175*f80f4a74SSebastian Grimberg magma_grad_2d_device<CeedScalar, 1, 1, NCOMP, Q, P, Q, P, 1, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp); 176*f80f4a74SSebastian Grimberg /* there is a sync at the end of magma_grad_2d_device */ 177*f80f4a74SSebastian Grimberg 178*f80f4a74SSebastian Grimberg // write V 179*f80f4a74SSebastian Grimberg writeV_2d<CeedScalar, P, 1, NCOMP, P, 0>(dV + (0 * dstrdV), cstrdV, rV, tx); 180*f80f4a74SSebastian Grimberg } 181