1*f80f4a74SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*f80f4a74SSebastian Grimberg // 4*f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5*f80f4a74SSebastian Grimberg // 6*f80f4a74SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7*f80f4a74SSebastian Grimberg 8*f80f4a74SSebastian Grimberg // macros to abstract access of shared memory and reg. file 9*f80f4a74SSebastian Grimberg #define sT(i, j) sT[(j)*P_ + (i)] 10*f80f4a74SSebastian Grimberg #define sTmp(i, j, ldw) sTmp[(j) * (ldw) + (i)] 11*f80f4a74SSebastian Grimberg 12*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 13*f80f4a74SSebastian Grimberg // interp basis action (3D) 14*f80f4a74SSebastian Grimberg template <typename T, int DIM_U, int DIM_V, int NCOMP_, int P_, int Q_, int rUsize, int rVsize> 15*f80f4a74SSebastian Grimberg static __device__ __inline__ void magma_interp_3d_device(const T *sT, magma_trans_t transT, T rU[DIM_U][NCOMP_][rUsize], T rV[DIM_V][NCOMP_][rVsize], 16*f80f4a74SSebastian Grimberg const int tx, T rTmp[Q_], T *swork) { 17*f80f4a74SSebastian Grimberg // Assumptions 18*f80f4a74SSebastian Grimberg // 1. 1D threads of size max(P_,Q_)^2 19*f80f4a74SSebastian Grimberg // 2. input: rU[DIM_U x NCOMP_ x rUsize] in registers (per thread) 20*f80f4a74SSebastian Grimberg // 3. output: rV[DIM_V x NCOMP_ x rVsize] in registers (per thread) 21*f80f4a74SSebastian Grimberg // 4. Three products per component 22*f80f4a74SSebastian Grimberg // 4.1 Batch P_^2 of (1xP_) matrices times (P_xQ_) matrix => Batch P_^2 of (1xQ_) matrices 23*f80f4a74SSebastian Grimberg // 4.2 Batch P_ of (Q_xP_) matrices times (P_xQ_) matrix => Batch P_ of (Q_xQ_) matrices 24*f80f4a74SSebastian Grimberg // 4.3 Batch 1 of (Q_^2xP_) matrix times (P_xQ_) matrix => (Q_^2xQ_) matrix 25*f80f4a74SSebastian Grimberg // 5. Each thread computes one row of the output of each product 26*f80f4a74SSebastian Grimberg // 6. Sync is recommended before and after the call 27*f80f4a74SSebastian Grimberg 28*f80f4a74SSebastian Grimberg for (int icomp = 0; icomp < NCOMP_; icomp++) { 29*f80f4a74SSebastian Grimberg // Batch P_^2 of (1xP_) matrices [reg] times (P_xQ_) matrix [shmem] => Batch P_^2 of (1xQ_) matrices [shmem] 30*f80f4a74SSebastian Grimberg if (tx < (P_ * P_)) { 31*f80f4a74SSebastian Grimberg const int batchid = tx; 32*f80f4a74SSebastian Grimberg const int sld = 1; 33*f80f4a74SSebastian Grimberg T *sTmp = swork + batchid * (1 * Q_); 34*f80f4a74SSebastian Grimberg for (int j = 0; j < Q_; j++) { 35*f80f4a74SSebastian Grimberg rTmp[0] = 0.0; 36*f80f4a74SSebastian Grimberg for (int i = 0; i < P_; i++) { 37*f80f4a74SSebastian Grimberg rTmp[0] += rU[0][icomp][i] * sT(i, j); 38*f80f4a74SSebastian Grimberg } 39*f80f4a74SSebastian Grimberg sTmp(0, j, sld) = rTmp[0]; 40*f80f4a74SSebastian Grimberg } 41*f80f4a74SSebastian Grimberg } // end of: if (tx < P_*P_) 42*f80f4a74SSebastian Grimberg __syncthreads(); 43*f80f4a74SSebastian Grimberg 44*f80f4a74SSebastian Grimberg // Batch P_ of (Q_xP_) matrices [shmem] times (P_xQ_) matrix [shmem] => Batch P_ of (Q_xQ_) matrices [reg] 45*f80f4a74SSebastian Grimberg if (tx < (P_ * Q_)) { 46*f80f4a74SSebastian Grimberg const int batchid = tx / Q_; 47*f80f4a74SSebastian Grimberg const int tx_ = tx % Q_; 48*f80f4a74SSebastian Grimberg const int sld = Q_; 49*f80f4a74SSebastian Grimberg T *sTmp = swork + batchid * (Q_ * P_); // sTmp is input 50*f80f4a74SSebastian Grimberg for (int j = 0; j < Q_; j++) { 51*f80f4a74SSebastian Grimberg rTmp[j] = 0.0; 52*f80f4a74SSebastian Grimberg for (int i = 0; i < P_; i++) { 53*f80f4a74SSebastian Grimberg rTmp[j] += sTmp(tx_, i, sld) * sT(i, j); 54*f80f4a74SSebastian Grimberg } 55*f80f4a74SSebastian Grimberg } 56*f80f4a74SSebastian Grimberg } 57*f80f4a74SSebastian Grimberg __syncthreads(); 58*f80f4a74SSebastian Grimberg 59*f80f4a74SSebastian Grimberg // write rTmp[] into shmem as batch P_ of Q_xQ_ matrices 60*f80f4a74SSebastian Grimberg if (tx < (P_ * Q_)) { 61*f80f4a74SSebastian Grimberg const int batchid = tx / Q_; 62*f80f4a74SSebastian Grimberg const int tx_ = tx % Q_; 63*f80f4a74SSebastian Grimberg const int sld = Q_; 64*f80f4a74SSebastian Grimberg T *sTmp = swork + batchid * (Q_ * Q_); 65*f80f4a74SSebastian Grimberg for (int j = 0; j < Q_; j++) { 66*f80f4a74SSebastian Grimberg sTmp(tx_, j, sld) = rTmp[j]; 67*f80f4a74SSebastian Grimberg } 68*f80f4a74SSebastian Grimberg } 69*f80f4a74SSebastian Grimberg __syncthreads(); 70*f80f4a74SSebastian Grimberg 71*f80f4a74SSebastian Grimberg // Batch 1 of (Q_^2xP_) matrices [shmem] times (P_xQ_) matrix [shmem] => Batch 1 of (Q_^2xQ_) matrices [reg] 72*f80f4a74SSebastian Grimberg if (tx < (Q_ * Q_)) { 73*f80f4a74SSebastian Grimberg // No need to declare batchid = (tx / Q_^2) = always zero 74*f80f4a74SSebastian Grimberg // No need to declare tx_ = (tx_ % Q_^2) = always tx 75*f80f4a74SSebastian Grimberg const int sld = Q_ * Q_; 76*f80f4a74SSebastian Grimberg T *sTmp = swork; 77*f80f4a74SSebastian Grimberg for (int j = 0; j < Q_; j++) { 78*f80f4a74SSebastian Grimberg rTmp[0] = 0.0; 79*f80f4a74SSebastian Grimberg for (int i = 0; i < P_; i++) { 80*f80f4a74SSebastian Grimberg rTmp[0] += sTmp(tx, i, sld) * sT(i, j); 81*f80f4a74SSebastian Grimberg } 82*f80f4a74SSebastian Grimberg rV[0][icomp][j] += rTmp[0]; 83*f80f4a74SSebastian Grimberg } 84*f80f4a74SSebastian Grimberg } 85*f80f4a74SSebastian Grimberg __syncthreads(); 86*f80f4a74SSebastian Grimberg } 87*f80f4a74SSebastian Grimberg } 88*f80f4a74SSebastian Grimberg 89*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 90*f80f4a74SSebastian Grimberg extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(MAXPQ *MAXPQ, MAGMA_MAXTHREADS_3D)) __global__ 91*f80f4a74SSebastian Grimberg void magma_interpn_3d_kernel(const CeedScalar *dT, const CeedScalar *dU, const int estrdU, const int cstrdU, CeedScalar *dV, const int estrdV, 92*f80f4a74SSebastian Grimberg const int cstrdV, const int nelem) { 93*f80f4a74SSebastian Grimberg MAGMA_DEVICE_SHARED(CeedScalar, shared_data) 94*f80f4a74SSebastian Grimberg 95*f80f4a74SSebastian Grimberg const int tx = threadIdx.x; 96*f80f4a74SSebastian Grimberg const int ty = threadIdx.y; 97*f80f4a74SSebastian Grimberg const int elem_id = (blockIdx.x * blockDim.y) + ty; 98*f80f4a74SSebastian Grimberg magma_trans_t transT = MagmaNoTrans; 99*f80f4a74SSebastian Grimberg 100*f80f4a74SSebastian Grimberg if (elem_id >= nelem) return; 101*f80f4a74SSebastian Grimberg 102*f80f4a74SSebastian Grimberg CeedScalar rU[1][NCOMP][P] = {0.0}; // for a non fused operator DIM is always 1 103*f80f4a74SSebastian Grimberg CeedScalar rV[1][NCOMP][Q] = {0.0}; // for a non fused operator DIM is always 1 104*f80f4a74SSebastian Grimberg CeedScalar rTmp[Q] = {0.0}; 105*f80f4a74SSebastian Grimberg 106*f80f4a74SSebastian Grimberg // shift global memory pointers by elem stride 107*f80f4a74SSebastian Grimberg dU += elem_id * estrdU; 108*f80f4a74SSebastian Grimberg dV += elem_id * estrdV; 109*f80f4a74SSebastian Grimberg 110*f80f4a74SSebastian Grimberg // assign shared memory pointers 111*f80f4a74SSebastian Grimberg CeedScalar *sT = (CeedScalar *)(shared_data); 112*f80f4a74SSebastian Grimberg CeedScalar *sTmp = sT + P * Q; 113*f80f4a74SSebastian Grimberg sTmp += ty * (max(P * P * MAXPQ, P * Q * Q)); 114*f80f4a74SSebastian Grimberg 115*f80f4a74SSebastian Grimberg // read T 116*f80f4a74SSebastian Grimberg if (ty == 0) { 117*f80f4a74SSebastian Grimberg dread_T_gm2sm<P, Q>(tx, transT, dT, sT); 118*f80f4a74SSebastian Grimberg } 119*f80f4a74SSebastian Grimberg 120*f80f4a74SSebastian Grimberg // read U (idim = 0 for dU, iDIM = 0 for rU, u_dimstride is always 0) 121*f80f4a74SSebastian Grimberg readU_3d<CeedScalar, P, 1, NCOMP, P, 0>(dU, cstrdU, rU, sTmp, tx); 122*f80f4a74SSebastian Grimberg // there is a sync at the end of this function 123*f80f4a74SSebastian Grimberg 124*f80f4a74SSebastian Grimberg magma_interp_3d_device<CeedScalar, 1, 1, NCOMP, P, Q, P, Q>(sT, transT, rU, rV, tx, rTmp, sTmp); 125*f80f4a74SSebastian Grimberg __syncthreads(); 126*f80f4a74SSebastian Grimberg 127*f80f4a74SSebastian Grimberg // write V 128*f80f4a74SSebastian Grimberg writeV_3d<CeedScalar, Q, 1, NCOMP, Q, 0>(dV, cstrdV, rV, tx); 129*f80f4a74SSebastian Grimberg } 130*f80f4a74SSebastian Grimberg 131*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 132*f80f4a74SSebastian Grimberg extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(MAXPQ *MAXPQ, MAGMA_MAXTHREADS_3D)) __global__ 133*f80f4a74SSebastian Grimberg void magma_interpt_3d_kernel(const CeedScalar *dT, const CeedScalar *dU, const int estrdU, const int cstrdU, CeedScalar *dV, const int estrdV, 134*f80f4a74SSebastian Grimberg const int cstrdV, const int nelem) { 135*f80f4a74SSebastian Grimberg MAGMA_DEVICE_SHARED(CeedScalar, shared_data) 136*f80f4a74SSebastian Grimberg 137*f80f4a74SSebastian Grimberg const int tx = threadIdx.x; 138*f80f4a74SSebastian Grimberg const int ty = threadIdx.y; 139*f80f4a74SSebastian Grimberg const int elem_id = (blockIdx.x * blockDim.y) + ty; 140*f80f4a74SSebastian Grimberg magma_trans_t transT = MagmaTrans; 141*f80f4a74SSebastian Grimberg 142*f80f4a74SSebastian Grimberg if (elem_id >= nelem) return; 143*f80f4a74SSebastian Grimberg 144*f80f4a74SSebastian Grimberg CeedScalar rU[1][NCOMP][Q] = {0.0}; // for a non fused operator DIM is always 1 145*f80f4a74SSebastian Grimberg CeedScalar rV[1][NCOMP][P] = {0.0}; // for a non fused operator DIM is always 1 146*f80f4a74SSebastian Grimberg CeedScalar rTmp[P] = {0.0}; 147*f80f4a74SSebastian Grimberg 148*f80f4a74SSebastian Grimberg // shift global memory pointers by elem stride 149*f80f4a74SSebastian Grimberg dU += elem_id * estrdU; 150*f80f4a74SSebastian Grimberg dV += elem_id * estrdV; 151*f80f4a74SSebastian Grimberg 152*f80f4a74SSebastian Grimberg // assign shared memory pointers 153*f80f4a74SSebastian Grimberg CeedScalar *sT = (CeedScalar *)(shared_data); 154*f80f4a74SSebastian Grimberg CeedScalar *sTmp = sT + Q * P; 155*f80f4a74SSebastian Grimberg sTmp += ty * (max(Q * Q * MAXPQ, Q * P * P)); 156*f80f4a74SSebastian Grimberg 157*f80f4a74SSebastian Grimberg // read T 158*f80f4a74SSebastian Grimberg if (ty == 0) { 159*f80f4a74SSebastian Grimberg dread_T_gm2sm<Q, P>(tx, transT, dT, sT); 160*f80f4a74SSebastian Grimberg } 161*f80f4a74SSebastian Grimberg 162*f80f4a74SSebastian Grimberg // read V 163*f80f4a74SSebastian Grimberg readV_3d<CeedScalar, P, 1, NCOMP, P, 0>(dV, cstrdV, rV, tx); 164*f80f4a74SSebastian Grimberg 165*f80f4a74SSebastian Grimberg // read U (idim = 0 for dU, iDIM = 0 for rU, u_dimstride is always 0) 166*f80f4a74SSebastian Grimberg readU_3d<CeedScalar, Q, 1, NCOMP, Q, 0>(dU, cstrdU, rU, sTmp, tx); 167*f80f4a74SSebastian Grimberg // there is a sync at the end of this function 168*f80f4a74SSebastian Grimberg 169*f80f4a74SSebastian Grimberg magma_interp_3d_device<CeedScalar, 1, 1, NCOMP, Q, P, Q, P>(sT, transT, rU, rV, tx, rTmp, sTmp); 170*f80f4a74SSebastian Grimberg __syncthreads(); 171*f80f4a74SSebastian Grimberg 172*f80f4a74SSebastian Grimberg // write V 173*f80f4a74SSebastian Grimberg writeV_3d<CeedScalar, P, 1, NCOMP, P, 0>(dV, cstrdV, rV, tx); 174*f80f4a74SSebastian Grimberg } 175