xref: /libCEED/include/ceed/jit-source/magma/magma-common-tensor.h (revision 509d4af65d23546c690c9766d8b29e47dc3b3afb)
15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3f80f4a74SSebastian Grimberg //
4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5f80f4a74SSebastian Grimberg //
6f80f4a74SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7f80f4a74SSebastian Grimberg 
83c1e2affSSebastian Grimberg /// @file
93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common tensor basis definitions
10*509d4af6SJeremy L Thompson #pragma once
11f80f4a74SSebastian Grimberg 
123c1e2affSSebastian Grimberg #include "magma-common-defs.h"
13f80f4a74SSebastian Grimberg 
149e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
15f80f4a74SSebastian Grimberg // read U or V of a 1D element into shared memory sU[][] or sV[][] --  for all components
16f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element
17f80f4a74SSebastian Grimberg // must sync after call
183c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP>
193c1e2affSSebastian Grimberg static __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NUM_COMP], const int tx) {
20f80f4a74SSebastian Grimberg   if (tx < LENGTH) {
213c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
223c1e2affSSebastian Grimberg       sBuffer[comp][tx] = devptr[comp * compstride + tx];
23f80f4a74SSebastian Grimberg     }
24f80f4a74SSebastian Grimberg   }
25f80f4a74SSebastian Grimberg }
26f80f4a74SSebastian Grimberg 
279e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
28f80f4a74SSebastian Grimberg // write V of a 1D element into global memory from sV[][] --  for all components
29f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element
303c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP>
313c1e2affSSebastian Grimberg static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) {
32f80f4a74SSebastian Grimberg   if (tx < LENGTH) {
333c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
343c1e2affSSebastian Grimberg       devptr[comp * compstride + tx] = sBuffer[comp][tx];
35f80f4a74SSebastian Grimberg     }
36f80f4a74SSebastian Grimberg   }
37f80f4a74SSebastian Grimberg }
38f80f4a74SSebastian Grimberg 
399e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
40f80f4a74SSebastian Grimberg // read U of a 2D element into registers rU[][][] --  for all components of a single dim
41f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride
423c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
433c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU
449e0c01faSSebastian Grimberg // rU_SIZE can be different from P (e.g. max(P, Q))
453c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^2
463c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
479e0c01faSSebastian Grimberg static __device__ __inline__ void read_U_2d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
489e0c01faSSebastian Grimberg   // read U as a batch P of (1 x P) vectors
493c1e2affSSebastian Grimberg   // vec 0  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
503c1e2affSSebastian Grimberg   // vec 1  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
51f80f4a74SSebastian Grimberg   // ...
523c1e2affSSebastian Grimberg   // vec P-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
53f80f4a74SSebastian Grimberg   // threads collaboratively read vec0 and then vec1 and so on
54f80f4a74SSebastian Grimberg   // but for the kernel, we want
55f80f4a74SSebastian Grimberg   // thread 0 to hold all of vec0 in registers, and
56f80f4a74SSebastian Grimberg   // thread 1 to hold all of vec1 in registers, and and so on
57f80f4a74SSebastian Grimberg   // so we need to transpose
583c1e2affSSebastian Grimberg   for (int comp = 0; comp < NUM_COMP; comp++) {
59f80f4a74SSebastian Grimberg     // read from global memory into shared memory
603c1e2affSSebastian Grimberg     if (tx < P) {
613c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
623c1e2affSSebastian Grimberg         sTmp[i * P + tx] = dU[comp * compstride + i * P + tx];
63f80f4a74SSebastian Grimberg       }
64f80f4a74SSebastian Grimberg     }
65f80f4a74SSebastian Grimberg     __syncthreads();
66f80f4a74SSebastian Grimberg 
673c1e2affSSebastian Grimberg     if (tx < P) {
683c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
693c1e2affSSebastian Grimberg         rU[i_DIM][comp][i] = sTmp[tx * P + i];
70f80f4a74SSebastian Grimberg       }
71f80f4a74SSebastian Grimberg     }
72f80f4a74SSebastian Grimberg     __syncthreads();
73f80f4a74SSebastian Grimberg   }
74f80f4a74SSebastian Grimberg }
75f80f4a74SSebastian Grimberg 
769e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
77f80f4a74SSebastian Grimberg // read V of a 2D element into registers rV[][][] --  for all components of a single dim
78f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
793c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
803c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV
819e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
823c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
839e0c01faSSebastian Grimberg static __device__ __inline__ void read_V_2d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
843c1e2affSSebastian Grimberg   if (tx < Q) {
853c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
863c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
873c1e2affSSebastian Grimberg         rV[i_DIM][comp][j] = dV[comp * compstride + j * Q + tx];
88f80f4a74SSebastian Grimberg       }
89f80f4a74SSebastian Grimberg     }
90f80f4a74SSebastian Grimberg   }
91f80f4a74SSebastian Grimberg }
92f80f4a74SSebastian Grimberg 
939e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
94f80f4a74SSebastian Grimberg // write V of a 2D element from registers rV[][][] to global memory --  for all components of a single dim
95f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
963c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
979e0c01faSSebastian Grimberg // i_DIM specifies which dimension is being written to in dV
989e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
993c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
1009e0c01faSSebastian Grimberg static __device__ __inline__ void write_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
1013c1e2affSSebastian Grimberg   if (tx < Q) {
1023c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
1033c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
1043c1e2affSSebastian Grimberg         dV[comp * compstride + j * Q + tx] = rV[i_DIM][comp][j];
105f80f4a74SSebastian Grimberg       }
106f80f4a74SSebastian Grimberg     }
107f80f4a74SSebastian Grimberg   }
108f80f4a74SSebastian Grimberg }
109f80f4a74SSebastian Grimberg 
1109e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
111f80f4a74SSebastian Grimberg // read U of a 3D element into registers rU[][][] --  for all components of a single dim
112f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride
1133c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
1143c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU
1159e0c01faSSebastian Grimberg // rU_SIZE can be different from P (e.g. max(P, Q))
1163c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^3
1173c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
1189e0c01faSSebastian Grimberg static __device__ __inline__ void read_U_3d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
1193c1e2affSSebastian Grimberg   // read U as a batch P^2 of (1 x P_) vectors
1203c1e2affSSebastian Grimberg   // vec 0    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
1213c1e2affSSebastian Grimberg   // vec 1    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
122f80f4a74SSebastian Grimberg   // ...
1233c1e2affSSebastian Grimberg   // vec P^2-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
124f80f4a74SSebastian Grimberg   // threads collaboratively read vec0 and then vec1 and so on
125f80f4a74SSebastian Grimberg   // but for the kernel, we want
126f80f4a74SSebastian Grimberg   // thread 0 to hold all of vec0 in registers, and
127f80f4a74SSebastian Grimberg   // thread 1 to hold all of vec1 in registers, and and so on
128f80f4a74SSebastian Grimberg   // so we need to transpose
1293c1e2affSSebastian Grimberg   for (int comp = 0; comp < NUM_COMP; comp++) {
130f80f4a74SSebastian Grimberg     // read from global memory into shared memory
1313c1e2affSSebastian Grimberg     if (tx < P * P) {
1323c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
1333c1e2affSSebastian Grimberg         sTmp[i * P * P + tx] = dU[comp * compstride + i * P * P + tx];
134f80f4a74SSebastian Grimberg       }
135f80f4a74SSebastian Grimberg     }
136f80f4a74SSebastian Grimberg     __syncthreads();
137f80f4a74SSebastian Grimberg 
1383c1e2affSSebastian Grimberg     if (tx < P * P) {
1393c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
1403c1e2affSSebastian Grimberg         rU[i_DIM][comp][i] = sTmp[tx * P + i];
141f80f4a74SSebastian Grimberg       }
142f80f4a74SSebastian Grimberg     }
143f80f4a74SSebastian Grimberg     __syncthreads();
144f80f4a74SSebastian Grimberg   }
145f80f4a74SSebastian Grimberg }
146f80f4a74SSebastian Grimberg 
1479e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
148f80f4a74SSebastian Grimberg // read V of a 3D element into registers rV[][][] --  for all components of a single dim
149f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
1503c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
1513c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV
1529e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
1533c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
1549e0c01faSSebastian Grimberg static __device__ __inline__ void read_V_3d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
1553c1e2affSSebastian Grimberg   if (tx < Q * Q) {
1563c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
1573c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
1583c1e2affSSebastian Grimberg         rV[i_DIM][comp][j] = dV[comp * compstride + j * (Q * Q) + tx];
159f80f4a74SSebastian Grimberg       }
160f80f4a74SSebastian Grimberg     }
161f80f4a74SSebastian Grimberg   }
162f80f4a74SSebastian Grimberg }
163f80f4a74SSebastian Grimberg 
1649e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
165f80f4a74SSebastian Grimberg // write V of a 3D element from registers rV[][][] to global memory --  for all components of a single dim
166f80f4a74SSebastian Grimberg // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
1673c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
1689e0c01faSSebastian Grimberg // i_DIM specifies which dimension is being written to in dV
1699e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
1703c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
1719e0c01faSSebastian Grimberg static __device__ __inline__ void write_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
1723c1e2affSSebastian Grimberg   if (tx < (Q * Q)) {
1733c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
1743c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
1753c1e2affSSebastian Grimberg         dV[comp * compstride + j * (Q * Q) + tx] = rV[i_DIM][comp][j];
176f80f4a74SSebastian Grimberg       }
177f80f4a74SSebastian Grimberg     }
178f80f4a74SSebastian Grimberg   }
179f80f4a74SSebastian Grimberg }
180f80f4a74SSebastian Grimberg 
1819e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
1829e0c01faSSebastian Grimberg // reads T (no-trans) into shared memory
1839e0c01faSSebastian Grimberg // T is B x J
184f80f4a74SSebastian Grimberg // must sync after call
185f80f4a74SSebastian Grimberg template <int B, int J>
1869e0c01faSSebastian Grimberg static __device__ __inline__ void read_T_notrans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
187f80f4a74SSebastian Grimberg   if (tx < B) {
188f80f4a74SSebastian Grimberg     for (int i = 0; i < J; i++) {
189f80f4a74SSebastian Grimberg       sT[i * B + tx] = dT[i * B + tx];
190f80f4a74SSebastian Grimberg     }
191f80f4a74SSebastian Grimberg   }
1929e0c01faSSebastian Grimberg   // must sync after call
1939e0c01faSSebastian Grimberg }
1949e0c01faSSebastian Grimberg 
1959e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
1969e0c01faSSebastian Grimberg // reads T (trans) into shared memory
197f80f4a74SSebastian Grimberg // T is J x B
1989e0c01faSSebastian Grimberg // must sync after call
1999e0c01faSSebastian Grimberg template <int B, int J>
2009e0c01faSSebastian Grimberg static __device__ __inline__ void read_T_trans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
201f80f4a74SSebastian Grimberg   if (tx < J) {
202f80f4a74SSebastian Grimberg     for (int i = 0; i < B; i++) {
203f80f4a74SSebastian Grimberg       sT[tx * B + i] = dT[i * J + tx];
204f80f4a74SSebastian Grimberg     }
205f80f4a74SSebastian Grimberg   }
206f80f4a74SSebastian Grimberg   // must sync after call
207f80f4a74SSebastian Grimberg }
208