1*f80f4a74SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*f80f4a74SSebastian Grimberg // 4*f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5*f80f4a74SSebastian Grimberg // 6*f80f4a74SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7*f80f4a74SSebastian Grimberg 8*f80f4a74SSebastian Grimberg // macros to abstract access of shared memory and reg. file 9*f80f4a74SSebastian Grimberg #define sT(i, j) sT[(j)*P_ + (i)] 10*f80f4a74SSebastian Grimberg #define sTmp(i, j, ldw) sTmp[(j) * (ldw) + (i)] 11*f80f4a74SSebastian Grimberg 12*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 13*f80f4a74SSebastian Grimberg // interp basis action (2D) 14*f80f4a74SSebastian Grimberg template <typename T, int DIM_U, int DIM_V, int NCOMP_, int P_, int Q_, int rUsize, int rVsize> 15*f80f4a74SSebastian Grimberg static __device__ __inline__ void magma_interp_2d_device(const T *sT, magma_trans_t transT, T rU[DIM_U][NCOMP_][rUsize], T rV[DIM_V][NCOMP_][rVsize], 16*f80f4a74SSebastian Grimberg const int tx, T rTmp, T *swork) { 17*f80f4a74SSebastian Grimberg // Assumptions 18*f80f4a74SSebastian Grimberg // 1. 1D threads of size max(P_,Q_) 19*f80f4a74SSebastian Grimberg // 2. input: rU[DIM_U x NCOMP_ x rUsize] in registers (per thread) 20*f80f4a74SSebastian Grimberg // 3. output: rV[DIM_V x NCOMP_ x rVsize] in registers (per thread) 21*f80f4a74SSebastian Grimberg // 4. Two products per component 22*f80f4a74SSebastian Grimberg // 4.1 Batch P_ of (1xP_) matrices times (P_xQ_) matrix => Batch P_ of (1xQ_) matrices 23*f80f4a74SSebastian Grimberg // 4.2 Batch 1 of (Q_xP_) matrix times (P_xQ_) matrix => (Q_xQ_) matrix 24*f80f4a74SSebastian Grimberg // 5. Each thread computes one row of the output of each product 25*f80f4a74SSebastian Grimberg // 6. Sync is recommended before and after the call 26*f80f4a74SSebastian Grimberg 27*f80f4a74SSebastian Grimberg for (int icomp = 0; icomp < NCOMP_; icomp++) { 28*f80f4a74SSebastian Grimberg // 1st product -- Batch P_ of (1xP_) matrices [reg] x (P_xQ_) [shmem] => Batch P_ of (1xQ_) matrices 29*f80f4a74SSebastian Grimberg // the batch output P_ x (1xQ_) is written on the fly to shmem 30*f80f4a74SSebastian Grimberg if (tx < P_) { 31*f80f4a74SSebastian Grimberg const int batchid = tx; 32*f80f4a74SSebastian Grimberg const int sld = 1; 33*f80f4a74SSebastian Grimberg T *sTmp = swork + batchid * (1 * Q_); 34*f80f4a74SSebastian Grimberg for (int j = 0; j < Q_; j++) { 35*f80f4a74SSebastian Grimberg rTmp = 0.0; 36*f80f4a74SSebastian Grimberg for (int i = 0; i < P_; i++) { 37*f80f4a74SSebastian Grimberg rTmp += rU[0][icomp][i] * sT(i, j); 38*f80f4a74SSebastian Grimberg } 39*f80f4a74SSebastian Grimberg sTmp(0, j, sld) = rTmp; 40*f80f4a74SSebastian Grimberg } 41*f80f4a74SSebastian Grimberg } // end of: if (tx < P_) 42*f80f4a74SSebastian Grimberg __syncthreads(); 43*f80f4a74SSebastian Grimberg 44*f80f4a74SSebastian Grimberg // 2nd product -- Batch 1 of a (Q_xP_) matrix [shmem] x (P_xQ_) [shmem] => (Q_xQ_) matrix [reg] 45*f80f4a74SSebastian Grimberg if (tx < Q_) { 46*f80f4a74SSebastian Grimberg const int batchid = 0; 47*f80f4a74SSebastian Grimberg const int sld = Q_; 48*f80f4a74SSebastian Grimberg T *sTmp = swork + batchid * (Q_ * P_); 49*f80f4a74SSebastian Grimberg for (int j = 0; j < Q_; j++) { 50*f80f4a74SSebastian Grimberg rTmp = 0.0; 51*f80f4a74SSebastian Grimberg for (int i = 0; i < P_; i++) { 52*f80f4a74SSebastian Grimberg rTmp += sTmp(tx, i, sld) * sT(i, j); 53*f80f4a74SSebastian Grimberg } 54*f80f4a74SSebastian Grimberg rV[0][icomp][j] += rTmp; 55*f80f4a74SSebastian Grimberg } 56*f80f4a74SSebastian Grimberg } 57*f80f4a74SSebastian Grimberg __syncthreads(); 58*f80f4a74SSebastian Grimberg } 59*f80f4a74SSebastian Grimberg } 60*f80f4a74SSebastian Grimberg 61*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 62*f80f4a74SSebastian Grimberg extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(MAXPQ, MAGMA_MAXTHREADS_2D)) __global__ 63*f80f4a74SSebastian Grimberg void magma_interpn_2d_kernel(const CeedScalar *dT, const CeedScalar *dU, const int estrdU, const int cstrdU, CeedScalar *dV, const int estrdV, 64*f80f4a74SSebastian Grimberg const int cstrdV, const int nelem) { 65*f80f4a74SSebastian Grimberg MAGMA_DEVICE_SHARED(CeedScalar, shared_data) 66*f80f4a74SSebastian Grimberg 67*f80f4a74SSebastian Grimberg const int tx = threadIdx.x; 68*f80f4a74SSebastian Grimberg const int ty = threadIdx.y; 69*f80f4a74SSebastian Grimberg const int elem_id = (blockIdx.x * blockDim.y) + ty; 70*f80f4a74SSebastian Grimberg magma_trans_t transT = MagmaNoTrans; 71*f80f4a74SSebastian Grimberg 72*f80f4a74SSebastian Grimberg if (elem_id >= nelem) return; 73*f80f4a74SSebastian Grimberg 74*f80f4a74SSebastian Grimberg CeedScalar rU[1][NCOMP][P] = {0.0}; // for a non fused operator DIM is always 1 75*f80f4a74SSebastian Grimberg CeedScalar rV[1][NCOMP][Q] = {0.0}; // for a non fused operator DIM is always 1 76*f80f4a74SSebastian Grimberg CeedScalar rTmp = 0.0; 77*f80f4a74SSebastian Grimberg 78*f80f4a74SSebastian Grimberg // shift global memory pointers by elem stride 79*f80f4a74SSebastian Grimberg dU += elem_id * estrdU; 80*f80f4a74SSebastian Grimberg dV += elem_id * estrdV; 81*f80f4a74SSebastian Grimberg 82*f80f4a74SSebastian Grimberg // assign shared memory pointers 83*f80f4a74SSebastian Grimberg CeedScalar *sT = (CeedScalar *)(shared_data); 84*f80f4a74SSebastian Grimberg CeedScalar *sTmp = sT + P * Q; 85*f80f4a74SSebastian Grimberg sTmp += ty * (P * MAXPQ); 86*f80f4a74SSebastian Grimberg 87*f80f4a74SSebastian Grimberg // read T 88*f80f4a74SSebastian Grimberg if (ty == 0) { 89*f80f4a74SSebastian Grimberg dread_T_gm2sm<P, Q>(tx, transT, dT, sT); 90*f80f4a74SSebastian Grimberg } 91*f80f4a74SSebastian Grimberg 92*f80f4a74SSebastian Grimberg // read U -- there is a sync at the end of this function 93*f80f4a74SSebastian Grimberg readU_2d<CeedScalar, P, 1, NCOMP, P, 0>(dU, cstrdU, rU, sTmp, tx); 94*f80f4a74SSebastian Grimberg 95*f80f4a74SSebastian Grimberg // no sync needed here -- readU_2d already syncs at the end 96*f80f4a74SSebastian Grimberg magma_interp_2d_device<CeedScalar, 1, 1, NCOMP, P, Q, P, Q>(sT, transT, rU, rV, tx, rTmp, sTmp); 97*f80f4a74SSebastian Grimberg __syncthreads(); 98*f80f4a74SSebastian Grimberg 99*f80f4a74SSebastian Grimberg // write V 100*f80f4a74SSebastian Grimberg writeV_2d<CeedScalar, Q, 1, NCOMP, Q, 0>(dV, cstrdV, rV, tx); 101*f80f4a74SSebastian Grimberg } 102*f80f4a74SSebastian Grimberg 103*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 104*f80f4a74SSebastian Grimberg extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(MAXPQ, MAGMA_MAXTHREADS_2D)) __global__ 105*f80f4a74SSebastian Grimberg void magma_interpt_2d_kernel(const CeedScalar *dT, const CeedScalar *dU, const int estrdU, const int cstrdU, CeedScalar *dV, const int estrdV, 106*f80f4a74SSebastian Grimberg const int cstrdV, const int nelem) { 107*f80f4a74SSebastian Grimberg MAGMA_DEVICE_SHARED(CeedScalar, shared_data) 108*f80f4a74SSebastian Grimberg 109*f80f4a74SSebastian Grimberg const int tx = threadIdx.x; 110*f80f4a74SSebastian Grimberg const int ty = threadIdx.y; 111*f80f4a74SSebastian Grimberg const int elem_id = (blockIdx.x * blockDim.y) + ty; 112*f80f4a74SSebastian Grimberg magma_trans_t transT = MagmaTrans; 113*f80f4a74SSebastian Grimberg 114*f80f4a74SSebastian Grimberg if (elem_id >= nelem) return; 115*f80f4a74SSebastian Grimberg 116*f80f4a74SSebastian Grimberg CeedScalar rU[1][NCOMP][Q] = {0.0}; // for a non fused operator DIM is always 1 117*f80f4a74SSebastian Grimberg CeedScalar rV[1][NCOMP][P] = {0.0}; // for a non fused operator DIM is always 1 118*f80f4a74SSebastian Grimberg CeedScalar rTmp = 0.0; 119*f80f4a74SSebastian Grimberg 120*f80f4a74SSebastian Grimberg // shift global memory pointers by elem stride 121*f80f4a74SSebastian Grimberg dU += elem_id * estrdU; 122*f80f4a74SSebastian Grimberg dV += elem_id * estrdV; 123*f80f4a74SSebastian Grimberg 124*f80f4a74SSebastian Grimberg // assign shared memory pointers 125*f80f4a74SSebastian Grimberg CeedScalar *sT = (CeedScalar *)(shared_data); 126*f80f4a74SSebastian Grimberg CeedScalar *sTmp = sT + Q * P; 127*f80f4a74SSebastian Grimberg sTmp += ty * (Q * MAXPQ); 128*f80f4a74SSebastian Grimberg 129*f80f4a74SSebastian Grimberg // read T 130*f80f4a74SSebastian Grimberg if (ty == 0) { 131*f80f4a74SSebastian Grimberg dread_T_gm2sm<Q, P>(tx, transT, dT, sT); 132*f80f4a74SSebastian Grimberg } 133*f80f4a74SSebastian Grimberg 134*f80f4a74SSebastian Grimberg // read V 135*f80f4a74SSebastian Grimberg readV_2d<CeedScalar, P, 1, NCOMP, P, 0>(dV, cstrdV, rV, tx); 136*f80f4a74SSebastian Grimberg 137*f80f4a74SSebastian Grimberg // read U -- there is a sync at the end of this function 138*f80f4a74SSebastian Grimberg readU_2d<CeedScalar, Q, 1, NCOMP, Q, 0>(dU, cstrdU, rU, sTmp, tx); 139*f80f4a74SSebastian Grimberg 140*f80f4a74SSebastian Grimberg // no sync needed here -- readU_2d already syncs at the end 141*f80f4a74SSebastian Grimberg magma_interp_2d_device<CeedScalar, 1, 1, NCOMP, Q, P, Q, P>(sT, transT, rU, rV, tx, rTmp, sTmp); 142*f80f4a74SSebastian Grimberg __syncthreads(); 143*f80f4a74SSebastian Grimberg 144*f80f4a74SSebastian Grimberg // write V 145*f80f4a74SSebastian Grimberg writeV_2d<CeedScalar, P, 1, NCOMP, P, 0>(dV, cstrdV, rV, tx); 146*f80f4a74SSebastian Grimberg } 147