15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3f80f4a74SSebastian Grimberg // 4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5f80f4a74SSebastian Grimberg // 6f80f4a74SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7f80f4a74SSebastian Grimberg 83c1e2affSSebastian Grimberg /// @file 93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common tensor basis definitions 10*509d4af6SJeremy L Thompson #pragma once 11f80f4a74SSebastian Grimberg 123c1e2affSSebastian Grimberg #include "magma-common-defs.h" 13f80f4a74SSebastian Grimberg 149e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 15f80f4a74SSebastian Grimberg // read U or V of a 1D element into shared memory sU[][] or sV[][] -- for all components 16f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element 17f80f4a74SSebastian Grimberg // must sync after call 183c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP> 193c1e2affSSebastian Grimberg static __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NUM_COMP], const int tx) { 20f80f4a74SSebastian Grimberg if (tx < LENGTH) { 213c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 223c1e2affSSebastian Grimberg sBuffer[comp][tx] = devptr[comp * compstride + tx]; 23f80f4a74SSebastian Grimberg } 24f80f4a74SSebastian Grimberg } 25f80f4a74SSebastian Grimberg } 26f80f4a74SSebastian Grimberg 279e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 28f80f4a74SSebastian Grimberg // write V of a 1D element into global memory from sV[][] -- for all components 29f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element 303c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP> 313c1e2affSSebastian Grimberg static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) { 32f80f4a74SSebastian Grimberg if (tx < LENGTH) { 333c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 343c1e2affSSebastian Grimberg devptr[comp * compstride + tx] = sBuffer[comp][tx]; 35f80f4a74SSebastian Grimberg } 36f80f4a74SSebastian Grimberg } 37f80f4a74SSebastian Grimberg } 38f80f4a74SSebastian Grimberg 399e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 40f80f4a74SSebastian Grimberg // read U of a 2D element into registers rU[][][] -- for all components of a single dim 41f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride 423c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE] 433c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU 449e0c01faSSebastian Grimberg // rU_SIZE can be different from P (e.g. max(P, Q)) 453c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^2 463c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM> 479e0c01faSSebastian Grimberg static __device__ __inline__ void read_U_2d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) { 489e0c01faSSebastian Grimberg // read U as a batch P of (1 x P) vectors 493c1e2affSSebastian Grimberg // vec 0 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 503c1e2affSSebastian Grimberg // vec 1 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 51f80f4a74SSebastian Grimberg // ... 523c1e2affSSebastian Grimberg // vec P-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 53f80f4a74SSebastian Grimberg // threads collaboratively read vec0 and then vec1 and so on 54f80f4a74SSebastian Grimberg // but for the kernel, we want 55f80f4a74SSebastian Grimberg // thread 0 to hold all of vec0 in registers, and 56f80f4a74SSebastian Grimberg // thread 1 to hold all of vec1 in registers, and and so on 57f80f4a74SSebastian Grimberg // so we need to transpose 583c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 59f80f4a74SSebastian Grimberg // read from global memory into shared memory 603c1e2affSSebastian Grimberg if (tx < P) { 613c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 623c1e2affSSebastian Grimberg sTmp[i * P + tx] = dU[comp * compstride + i * P + tx]; 63f80f4a74SSebastian Grimberg } 64f80f4a74SSebastian Grimberg } 65f80f4a74SSebastian Grimberg __syncthreads(); 66f80f4a74SSebastian Grimberg 673c1e2affSSebastian Grimberg if (tx < P) { 683c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 693c1e2affSSebastian Grimberg rU[i_DIM][comp][i] = sTmp[tx * P + i]; 70f80f4a74SSebastian Grimberg } 71f80f4a74SSebastian Grimberg } 72f80f4a74SSebastian Grimberg __syncthreads(); 73f80f4a74SSebastian Grimberg } 74f80f4a74SSebastian Grimberg } 75f80f4a74SSebastian Grimberg 769e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 77f80f4a74SSebastian Grimberg // read V of a 2D element into registers rV[][][] -- for all components of a single dim 78f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride 793c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 803c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV 819e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q)) 823c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 839e0c01faSSebastian Grimberg static __device__ __inline__ void read_V_2d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 843c1e2affSSebastian Grimberg if (tx < Q) { 853c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 863c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 873c1e2affSSebastian Grimberg rV[i_DIM][comp][j] = dV[comp * compstride + j * Q + tx]; 88f80f4a74SSebastian Grimberg } 89f80f4a74SSebastian Grimberg } 90f80f4a74SSebastian Grimberg } 91f80f4a74SSebastian Grimberg } 92f80f4a74SSebastian Grimberg 939e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 94f80f4a74SSebastian Grimberg // write V of a 2D element from registers rV[][][] to global memory -- for all components of a single dim 95f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride 963c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 979e0c01faSSebastian Grimberg // i_DIM specifies which dimension is being written to in dV 989e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q)) 993c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 1009e0c01faSSebastian Grimberg static __device__ __inline__ void write_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 1013c1e2affSSebastian Grimberg if (tx < Q) { 1023c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 1033c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 1043c1e2affSSebastian Grimberg dV[comp * compstride + j * Q + tx] = rV[i_DIM][comp][j]; 105f80f4a74SSebastian Grimberg } 106f80f4a74SSebastian Grimberg } 107f80f4a74SSebastian Grimberg } 108f80f4a74SSebastian Grimberg } 109f80f4a74SSebastian Grimberg 1109e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 111f80f4a74SSebastian Grimberg // read U of a 3D element into registers rU[][][] -- for all components of a single dim 112f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride 1133c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE] 1143c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU 1159e0c01faSSebastian Grimberg // rU_SIZE can be different from P (e.g. max(P, Q)) 1163c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^3 1173c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM> 1189e0c01faSSebastian Grimberg static __device__ __inline__ void read_U_3d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) { 1193c1e2affSSebastian Grimberg // read U as a batch P^2 of (1 x P_) vectors 1203c1e2affSSebastian Grimberg // vec 0 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 1213c1e2affSSebastian Grimberg // vec 1 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 122f80f4a74SSebastian Grimberg // ... 1233c1e2affSSebastian Grimberg // vec P^2-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 124f80f4a74SSebastian Grimberg // threads collaboratively read vec0 and then vec1 and so on 125f80f4a74SSebastian Grimberg // but for the kernel, we want 126f80f4a74SSebastian Grimberg // thread 0 to hold all of vec0 in registers, and 127f80f4a74SSebastian Grimberg // thread 1 to hold all of vec1 in registers, and and so on 128f80f4a74SSebastian Grimberg // so we need to transpose 1293c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 130f80f4a74SSebastian Grimberg // read from global memory into shared memory 1313c1e2affSSebastian Grimberg if (tx < P * P) { 1323c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 1333c1e2affSSebastian Grimberg sTmp[i * P * P + tx] = dU[comp * compstride + i * P * P + tx]; 134f80f4a74SSebastian Grimberg } 135f80f4a74SSebastian Grimberg } 136f80f4a74SSebastian Grimberg __syncthreads(); 137f80f4a74SSebastian Grimberg 1383c1e2affSSebastian Grimberg if (tx < P * P) { 1393c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 1403c1e2affSSebastian Grimberg rU[i_DIM][comp][i] = sTmp[tx * P + i]; 141f80f4a74SSebastian Grimberg } 142f80f4a74SSebastian Grimberg } 143f80f4a74SSebastian Grimberg __syncthreads(); 144f80f4a74SSebastian Grimberg } 145f80f4a74SSebastian Grimberg } 146f80f4a74SSebastian Grimberg 1479e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 148f80f4a74SSebastian Grimberg // read V of a 3D element into registers rV[][][] -- for all components of a single dim 149f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride 1503c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 1513c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV 1529e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q)) 1533c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 1549e0c01faSSebastian Grimberg static __device__ __inline__ void read_V_3d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 1553c1e2affSSebastian Grimberg if (tx < Q * Q) { 1563c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 1573c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 1583c1e2affSSebastian Grimberg rV[i_DIM][comp][j] = dV[comp * compstride + j * (Q * Q) + tx]; 159f80f4a74SSebastian Grimberg } 160f80f4a74SSebastian Grimberg } 161f80f4a74SSebastian Grimberg } 162f80f4a74SSebastian Grimberg } 163f80f4a74SSebastian Grimberg 1649e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 165f80f4a74SSebastian Grimberg // write V of a 3D element from registers rV[][][] to global memory -- for all components of a single dim 166f80f4a74SSebastian Grimberg // dV is assumed to point directly to the element (i.e. already offset by elem-stride) 1673c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 1689e0c01faSSebastian Grimberg // i_DIM specifies which dimension is being written to in dV 1699e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q)) 1703c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 1719e0c01faSSebastian Grimberg static __device__ __inline__ void write_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 1723c1e2affSSebastian Grimberg if (tx < (Q * Q)) { 1733c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 1743c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 1753c1e2affSSebastian Grimberg dV[comp * compstride + j * (Q * Q) + tx] = rV[i_DIM][comp][j]; 176f80f4a74SSebastian Grimberg } 177f80f4a74SSebastian Grimberg } 178f80f4a74SSebastian Grimberg } 179f80f4a74SSebastian Grimberg } 180f80f4a74SSebastian Grimberg 1819e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 1829e0c01faSSebastian Grimberg // reads T (no-trans) into shared memory 1839e0c01faSSebastian Grimberg // T is B x J 184f80f4a74SSebastian Grimberg // must sync after call 185f80f4a74SSebastian Grimberg template <int B, int J> 1869e0c01faSSebastian Grimberg static __device__ __inline__ void read_T_notrans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) { 187f80f4a74SSebastian Grimberg if (tx < B) { 188f80f4a74SSebastian Grimberg for (int i = 0; i < J; i++) { 189f80f4a74SSebastian Grimberg sT[i * B + tx] = dT[i * B + tx]; 190f80f4a74SSebastian Grimberg } 191f80f4a74SSebastian Grimberg } 1929e0c01faSSebastian Grimberg // must sync after call 1939e0c01faSSebastian Grimberg } 1949e0c01faSSebastian Grimberg 1959e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 1969e0c01faSSebastian Grimberg // reads T (trans) into shared memory 197f80f4a74SSebastian Grimberg // T is J x B 1989e0c01faSSebastian Grimberg // must sync after call 1999e0c01faSSebastian Grimberg template <int B, int J> 2009e0c01faSSebastian Grimberg static __device__ __inline__ void read_T_trans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) { 201f80f4a74SSebastian Grimberg if (tx < J) { 202f80f4a74SSebastian Grimberg for (int i = 0; i < B; i++) { 203f80f4a74SSebastian Grimberg sT[tx * B + i] = dT[i * J + tx]; 204f80f4a74SSebastian Grimberg } 205f80f4a74SSebastian Grimberg } 206f80f4a74SSebastian Grimberg // must sync after call 207f80f4a74SSebastian Grimberg } 208