1f80f4a74SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3f80f4a74SSebastian Grimberg // 4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5f80f4a74SSebastian Grimberg // 6f80f4a74SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7f80f4a74SSebastian Grimberg 8*3c1e2affSSebastian Grimberg /// @file 9*3c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common tensor basis definitions 10f80f4a74SSebastian Grimberg #ifndef CEED_MAGMA_COMMON_TENSOR_H 11f80f4a74SSebastian Grimberg #define CEED_MAGMA_COMMON_TENSOR_H 12f80f4a74SSebastian Grimberg 13*3c1e2affSSebastian Grimberg #include "magma-common-defs.h" 14f80f4a74SSebastian Grimberg 15f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 16f80f4a74SSebastian Grimberg // read U or V of a 1D element into shared memory sU[][] or sV[][] -- for all components 17f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element 18f80f4a74SSebastian Grimberg // must sync after call 19*3c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP> 20*3c1e2affSSebastian Grimberg static __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NUM_COMP], const int tx) { 21f80f4a74SSebastian Grimberg if (tx < LENGTH) { 22*3c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 23*3c1e2affSSebastian Grimberg sBuffer[comp][tx] = devptr[comp * compstride + tx]; 24f80f4a74SSebastian Grimberg } 25f80f4a74SSebastian Grimberg } 26f80f4a74SSebastian Grimberg } 27f80f4a74SSebastian Grimberg 28f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 29f80f4a74SSebastian Grimberg // write V of a 1D element into global memory from sV[][] -- for all components 30f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element 31*3c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP> 32*3c1e2affSSebastian Grimberg static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) { 33f80f4a74SSebastian Grimberg if (tx < LENGTH) { 34*3c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 35*3c1e2affSSebastian Grimberg devptr[comp * compstride + tx] = sBuffer[comp][tx]; 36f80f4a74SSebastian Grimberg } 37f80f4a74SSebastian Grimberg } 38f80f4a74SSebastian Grimberg } 39f80f4a74SSebastian Grimberg 40f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 41f80f4a74SSebastian Grimberg // read U of a 2D element into registers rU[][][] -- for all components of a single dim 42f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride 43*3c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE] 44*3c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU 45*3c1e2affSSebastian Grimberg // rU_SIZE can be different from P (e.g. MAXP_Q) 46*3c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^2 47*3c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM> 48*3c1e2affSSebastian Grimberg static __device__ __inline__ void readU_2d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) { 49*3c1e2affSSebastian Grimberg // read U as a batch P of (1 x P_) vectors 50*3c1e2affSSebastian Grimberg // vec 0 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 51*3c1e2affSSebastian Grimberg // vec 1 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 52f80f4a74SSebastian Grimberg // ... 53*3c1e2affSSebastian Grimberg // vec P-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 54f80f4a74SSebastian Grimberg // threads collaboratively read vec0 and then vec1 and so on 55f80f4a74SSebastian Grimberg // but for the kernel, we want 56f80f4a74SSebastian Grimberg // thread 0 to hold all of vec0 in registers, and 57f80f4a74SSebastian Grimberg // thread 1 to hold all of vec1 in registers, and and so on 58f80f4a74SSebastian Grimberg // so we need to transpose 59*3c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 60f80f4a74SSebastian Grimberg // read from global memory into shared memory 61*3c1e2affSSebastian Grimberg if (tx < P) { 62*3c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 63*3c1e2affSSebastian Grimberg sTmp[i * P + tx] = dU[comp * compstride + i * P + tx]; 64f80f4a74SSebastian Grimberg } 65f80f4a74SSebastian Grimberg } 66f80f4a74SSebastian Grimberg __syncthreads(); 67f80f4a74SSebastian Grimberg 68*3c1e2affSSebastian Grimberg if (tx < P) { 69*3c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 70*3c1e2affSSebastian Grimberg rU[i_DIM][comp][i] = sTmp[tx * P + i]; 71f80f4a74SSebastian Grimberg } 72f80f4a74SSebastian Grimberg } 73f80f4a74SSebastian Grimberg __syncthreads(); 74f80f4a74SSebastian Grimberg } 75f80f4a74SSebastian Grimberg } 76f80f4a74SSebastian Grimberg 77f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 78f80f4a74SSebastian Grimberg // read V of a 2D element into registers rV[][][] -- for all components of a single dim 79f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride 80*3c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 81*3c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV 82*3c1e2affSSebastian Grimberg // rV_SIZE can be different from P (e.g. MAXP_Q) 83*3c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 84*3c1e2affSSebastian Grimberg static __device__ __inline__ void readV_2d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 85*3c1e2affSSebastian Grimberg if (tx < Q) { 86*3c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 87*3c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 88*3c1e2affSSebastian Grimberg rV[i_DIM][comp][j] = dV[comp * compstride + j * Q + tx]; 89f80f4a74SSebastian Grimberg } 90f80f4a74SSebastian Grimberg } 91f80f4a74SSebastian Grimberg } 92f80f4a74SSebastian Grimberg } 93f80f4a74SSebastian Grimberg 94f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 95f80f4a74SSebastian Grimberg // write V of a 2D element from registers rV[][][] to global memory -- for all components of a single dim 96f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride 97*3c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 98*3c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read from in rV 99f80f4a74SSebastian Grimberg // idim specifies which dimension is being written to in dV 100*3c1e2affSSebastian Grimberg // rV_SIZE can be different from P (e.g. MAXP_Q) 101*3c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 102*3c1e2affSSebastian Grimberg static __device__ __inline__ void writeV_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 103*3c1e2affSSebastian Grimberg if (tx < Q) { 104*3c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 105*3c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 106*3c1e2affSSebastian Grimberg dV[comp * compstride + j * Q + tx] = rV[i_DIM][comp][j]; 107f80f4a74SSebastian Grimberg } 108f80f4a74SSebastian Grimberg } 109f80f4a74SSebastian Grimberg } 110f80f4a74SSebastian Grimberg } 111f80f4a74SSebastian Grimberg 112f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 113f80f4a74SSebastian Grimberg // read U of a 3D element into registers rU[][][] -- for all components of a single dim 114f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride 115*3c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE] 116*3c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU 117*3c1e2affSSebastian Grimberg // rU_SIZE can be different from P (e.g. MAXP_Q) 118*3c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^3 119*3c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM> 120*3c1e2affSSebastian Grimberg static __device__ __inline__ void readU_3d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) { 121*3c1e2affSSebastian Grimberg // read U as a batch P^2 of (1 x P_) vectors 122*3c1e2affSSebastian Grimberg // vec 0 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 123*3c1e2affSSebastian Grimberg // vec 1 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 124f80f4a74SSebastian Grimberg // ... 125*3c1e2affSSebastian Grimberg // vec P^2-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 126f80f4a74SSebastian Grimberg // threads collaboratively read vec0 and then vec1 and so on 127f80f4a74SSebastian Grimberg // but for the kernel, we want 128f80f4a74SSebastian Grimberg // thread 0 to hold all of vec0 in registers, and 129f80f4a74SSebastian Grimberg // thread 1 to hold all of vec1 in registers, and and so on 130f80f4a74SSebastian Grimberg // so we need to transpose 131*3c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 132f80f4a74SSebastian Grimberg // read from global memory into shared memory 133*3c1e2affSSebastian Grimberg if (tx < P * P) { 134*3c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 135*3c1e2affSSebastian Grimberg sTmp[i * P * P + tx] = dU[comp * compstride + i * P * P + tx]; 136f80f4a74SSebastian Grimberg } 137f80f4a74SSebastian Grimberg } 138f80f4a74SSebastian Grimberg __syncthreads(); 139f80f4a74SSebastian Grimberg 140*3c1e2affSSebastian Grimberg if (tx < P * P) { 141*3c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 142*3c1e2affSSebastian Grimberg rU[i_DIM][comp][i] = sTmp[tx * P + i]; 143f80f4a74SSebastian Grimberg } 144f80f4a74SSebastian Grimberg } 145f80f4a74SSebastian Grimberg __syncthreads(); 146f80f4a74SSebastian Grimberg } 147f80f4a74SSebastian Grimberg } 148f80f4a74SSebastian Grimberg 149f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 150f80f4a74SSebastian Grimberg // read V of a 3D element into registers rV[][][] -- for all components of a single dim 151f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride 152*3c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 153*3c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV 154*3c1e2affSSebastian Grimberg // rV_SIZE can be different from P (e.g. MAXP_Q) 155*3c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 156*3c1e2affSSebastian Grimberg static __device__ __inline__ void readV_3d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 157*3c1e2affSSebastian Grimberg if (tx < Q * Q) { 158*3c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 159*3c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 160*3c1e2affSSebastian Grimberg rV[i_DIM][comp][j] = dV[comp * compstride + j * (Q * Q) + tx]; 161f80f4a74SSebastian Grimberg } 162f80f4a74SSebastian Grimberg } 163f80f4a74SSebastian Grimberg } 164f80f4a74SSebastian Grimberg } 165f80f4a74SSebastian Grimberg 166f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 167f80f4a74SSebastian Grimberg // write V of a 3D element from registers rV[][][] to global memory -- for all components of a single dim 168f80f4a74SSebastian Grimberg // dV is assumed to point directly to the element (i.e. already offset by elem-stride) 169*3c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 170*3c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read from in rV 171f80f4a74SSebastian Grimberg // idim specifies which dimension is being written to in dV 172*3c1e2affSSebastian Grimberg // rV_SIZE can be different from P (e.g. MAXP_Q) 173*3c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 174*3c1e2affSSebastian Grimberg static __device__ __inline__ void writeV_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 175*3c1e2affSSebastian Grimberg if (tx < (Q * Q)) { 176*3c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 177*3c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 178*3c1e2affSSebastian Grimberg dV[comp * compstride + j * (Q * Q) + tx] = rV[i_DIM][comp][j]; 179f80f4a74SSebastian Grimberg } 180f80f4a74SSebastian Grimberg } 181f80f4a74SSebastian Grimberg } 182f80f4a74SSebastian Grimberg } 183f80f4a74SSebastian Grimberg 184f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////////////// 185f80f4a74SSebastian Grimberg // reads T into shared memory 186f80f4a74SSebastian Grimberg // must sync after call 187f80f4a74SSebastian Grimberg template <int B, int J> 188*3c1e2affSSebastian Grimberg static __device__ __inline__ void dread_T_gm2sm(const int tx, const magma_trans_t transT, const CeedScalar *dT, CeedScalar *sT) { 189f80f4a74SSebastian Grimberg if (transT == MagmaNoTrans) { 190f80f4a74SSebastian Grimberg // T is B x J 191f80f4a74SSebastian Grimberg if (tx < B) { 192f80f4a74SSebastian Grimberg for (int i = 0; i < J; i++) { 193f80f4a74SSebastian Grimberg sT[i * B + tx] = dT[i * B + tx]; 194f80f4a74SSebastian Grimberg } 195f80f4a74SSebastian Grimberg } 196f80f4a74SSebastian Grimberg } else { 197f80f4a74SSebastian Grimberg // T is J x B 198f80f4a74SSebastian Grimberg if (tx < J) { 199f80f4a74SSebastian Grimberg for (int i = 0; i < B; i++) { 200f80f4a74SSebastian Grimberg sT[tx * B + i] = dT[i * J + tx]; 201f80f4a74SSebastian Grimberg } 202f80f4a74SSebastian Grimberg } 203f80f4a74SSebastian Grimberg } 204f80f4a74SSebastian Grimberg // must sync after call 205f80f4a74SSebastian Grimberg } 206f80f4a74SSebastian Grimberg 207f80f4a74SSebastian Grimberg #endif // CEED_MAGMA_COMMON_TENSOR_H 208