xref: /libCEED/include/ceed/jit-source/magma/magma-common-tensor.h (revision 3c1e2aff6d111c93ca8797996aaf987f66b08927)
1f80f4a74SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3f80f4a74SSebastian Grimberg //
4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5f80f4a74SSebastian Grimberg //
6f80f4a74SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7f80f4a74SSebastian Grimberg 
8*3c1e2affSSebastian Grimberg /// @file
9*3c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common tensor basis definitions
10f80f4a74SSebastian Grimberg #ifndef CEED_MAGMA_COMMON_TENSOR_H
11f80f4a74SSebastian Grimberg #define CEED_MAGMA_COMMON_TENSOR_H
12f80f4a74SSebastian Grimberg 
13*3c1e2affSSebastian Grimberg #include "magma-common-defs.h"
14f80f4a74SSebastian Grimberg 
15f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
16f80f4a74SSebastian Grimberg // read U or V of a 1D element into shared memory sU[][] or sV[][] --  for all components
17f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element
18f80f4a74SSebastian Grimberg // must sync after call
19*3c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP>
20*3c1e2affSSebastian Grimberg static __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NUM_COMP], const int tx) {
21f80f4a74SSebastian Grimberg   if (tx < LENGTH) {
22*3c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
23*3c1e2affSSebastian Grimberg       sBuffer[comp][tx] = devptr[comp * compstride + tx];
24f80f4a74SSebastian Grimberg     }
25f80f4a74SSebastian Grimberg   }
26f80f4a74SSebastian Grimberg }
27f80f4a74SSebastian Grimberg 
28f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
29f80f4a74SSebastian Grimberg // write V of a 1D element into global memory from sV[][] --  for all components
30f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element
31*3c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP>
32*3c1e2affSSebastian Grimberg static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) {
33f80f4a74SSebastian Grimberg   if (tx < LENGTH) {
34*3c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
35*3c1e2affSSebastian Grimberg       devptr[comp * compstride + tx] = sBuffer[comp][tx];
36f80f4a74SSebastian Grimberg     }
37f80f4a74SSebastian Grimberg   }
38f80f4a74SSebastian Grimberg }
39f80f4a74SSebastian Grimberg 
40f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
41f80f4a74SSebastian Grimberg // read U of a 2D element into registers rU[][][] --  for all components of a single dim
42f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride
43*3c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
44*3c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU
45*3c1e2affSSebastian Grimberg // rU_SIZE can be different from P (e.g. MAXP_Q)
46*3c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^2
47*3c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
48*3c1e2affSSebastian Grimberg static __device__ __inline__ void readU_2d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
49*3c1e2affSSebastian Grimberg   // read U as a batch P of (1 x P_) vectors
50*3c1e2affSSebastian Grimberg   // vec 0  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
51*3c1e2affSSebastian Grimberg   // vec 1  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
52f80f4a74SSebastian Grimberg   // ...
53*3c1e2affSSebastian Grimberg   // vec P-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
54f80f4a74SSebastian Grimberg   // threads collaboratively read vec0 and then vec1 and so on
55f80f4a74SSebastian Grimberg   // but for the kernel, we want
56f80f4a74SSebastian Grimberg   // thread 0 to hold all of vec0 in registers, and
57f80f4a74SSebastian Grimberg   // thread 1 to hold all of vec1 in registers, and and so on
58f80f4a74SSebastian Grimberg   // so we need to transpose
59*3c1e2affSSebastian Grimberg   for (int comp = 0; comp < NUM_COMP; comp++) {
60f80f4a74SSebastian Grimberg     // read from global memory into shared memory
61*3c1e2affSSebastian Grimberg     if (tx < P) {
62*3c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
63*3c1e2affSSebastian Grimberg         sTmp[i * P + tx] = dU[comp * compstride + i * P + tx];
64f80f4a74SSebastian Grimberg       }
65f80f4a74SSebastian Grimberg     }
66f80f4a74SSebastian Grimberg     __syncthreads();
67f80f4a74SSebastian Grimberg 
68*3c1e2affSSebastian Grimberg     if (tx < P) {
69*3c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
70*3c1e2affSSebastian Grimberg         rU[i_DIM][comp][i] = sTmp[tx * P + i];
71f80f4a74SSebastian Grimberg       }
72f80f4a74SSebastian Grimberg     }
73f80f4a74SSebastian Grimberg     __syncthreads();
74f80f4a74SSebastian Grimberg   }
75f80f4a74SSebastian Grimberg }
76f80f4a74SSebastian Grimberg 
77f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
78f80f4a74SSebastian Grimberg // read V of a 2D element into registers rV[][][] --  for all components of a single dim
79f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
80*3c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
81*3c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV
82*3c1e2affSSebastian Grimberg // rV_SIZE can be different from P (e.g. MAXP_Q)
83*3c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
84*3c1e2affSSebastian Grimberg static __device__ __inline__ void readV_2d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
85*3c1e2affSSebastian Grimberg   if (tx < Q) {
86*3c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
87*3c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
88*3c1e2affSSebastian Grimberg         rV[i_DIM][comp][j] = dV[comp * compstride + j * Q + tx];
89f80f4a74SSebastian Grimberg       }
90f80f4a74SSebastian Grimberg     }
91f80f4a74SSebastian Grimberg   }
92f80f4a74SSebastian Grimberg }
93f80f4a74SSebastian Grimberg 
94f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
95f80f4a74SSebastian Grimberg // write V of a 2D element from registers rV[][][] to global memory --  for all components of a single dim
96f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
97*3c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
98*3c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read from in rV
99f80f4a74SSebastian Grimberg // idim specifies which dimension is being written to in dV
100*3c1e2affSSebastian Grimberg // rV_SIZE can be different from P (e.g. MAXP_Q)
101*3c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
102*3c1e2affSSebastian Grimberg static __device__ __inline__ void writeV_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
103*3c1e2affSSebastian Grimberg   if (tx < Q) {
104*3c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
105*3c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
106*3c1e2affSSebastian Grimberg         dV[comp * compstride + j * Q + tx] = rV[i_DIM][comp][j];
107f80f4a74SSebastian Grimberg       }
108f80f4a74SSebastian Grimberg     }
109f80f4a74SSebastian Grimberg   }
110f80f4a74SSebastian Grimberg }
111f80f4a74SSebastian Grimberg 
112f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
113f80f4a74SSebastian Grimberg // read U of a 3D element into registers rU[][][] --  for all components of a single dim
114f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride
115*3c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
116*3c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU
117*3c1e2affSSebastian Grimberg // rU_SIZE can be different from P (e.g. MAXP_Q)
118*3c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^3
119*3c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
120*3c1e2affSSebastian Grimberg static __device__ __inline__ void readU_3d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
121*3c1e2affSSebastian Grimberg   // read U as a batch P^2 of (1 x P_) vectors
122*3c1e2affSSebastian Grimberg   // vec 0    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
123*3c1e2affSSebastian Grimberg   // vec 1    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
124f80f4a74SSebastian Grimberg   // ...
125*3c1e2affSSebastian Grimberg   // vec P^2-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
126f80f4a74SSebastian Grimberg   // threads collaboratively read vec0 and then vec1 and so on
127f80f4a74SSebastian Grimberg   // but for the kernel, we want
128f80f4a74SSebastian Grimberg   // thread 0 to hold all of vec0 in registers, and
129f80f4a74SSebastian Grimberg   // thread 1 to hold all of vec1 in registers, and and so on
130f80f4a74SSebastian Grimberg   // so we need to transpose
131*3c1e2affSSebastian Grimberg   for (int comp = 0; comp < NUM_COMP; comp++) {
132f80f4a74SSebastian Grimberg     // read from global memory into shared memory
133*3c1e2affSSebastian Grimberg     if (tx < P * P) {
134*3c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
135*3c1e2affSSebastian Grimberg         sTmp[i * P * P + tx] = dU[comp * compstride + i * P * P + tx];
136f80f4a74SSebastian Grimberg       }
137f80f4a74SSebastian Grimberg     }
138f80f4a74SSebastian Grimberg     __syncthreads();
139f80f4a74SSebastian Grimberg 
140*3c1e2affSSebastian Grimberg     if (tx < P * P) {
141*3c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
142*3c1e2affSSebastian Grimberg         rU[i_DIM][comp][i] = sTmp[tx * P + i];
143f80f4a74SSebastian Grimberg       }
144f80f4a74SSebastian Grimberg     }
145f80f4a74SSebastian Grimberg     __syncthreads();
146f80f4a74SSebastian Grimberg   }
147f80f4a74SSebastian Grimberg }
148f80f4a74SSebastian Grimberg 
149f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
150f80f4a74SSebastian Grimberg // read V of a 3D element into registers rV[][][] --  for all components of a single dim
151f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
152*3c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
153*3c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV
154*3c1e2affSSebastian Grimberg // rV_SIZE can be different from P (e.g. MAXP_Q)
155*3c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
156*3c1e2affSSebastian Grimberg static __device__ __inline__ void readV_3d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
157*3c1e2affSSebastian Grimberg   if (tx < Q * Q) {
158*3c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
159*3c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
160*3c1e2affSSebastian Grimberg         rV[i_DIM][comp][j] = dV[comp * compstride + j * (Q * Q) + tx];
161f80f4a74SSebastian Grimberg       }
162f80f4a74SSebastian Grimberg     }
163f80f4a74SSebastian Grimberg   }
164f80f4a74SSebastian Grimberg }
165f80f4a74SSebastian Grimberg 
166f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
167f80f4a74SSebastian Grimberg // write V of a 3D element from registers rV[][][] to global memory --  for all components of a single dim
168f80f4a74SSebastian Grimberg // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
169*3c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
170*3c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read from in rV
171f80f4a74SSebastian Grimberg // idim specifies which dimension is being written to in dV
172*3c1e2affSSebastian Grimberg // rV_SIZE can be different from P (e.g. MAXP_Q)
173*3c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
174*3c1e2affSSebastian Grimberg static __device__ __inline__ void writeV_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
175*3c1e2affSSebastian Grimberg   if (tx < (Q * Q)) {
176*3c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
177*3c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
178*3c1e2affSSebastian Grimberg         dV[comp * compstride + j * (Q * Q) + tx] = rV[i_DIM][comp][j];
179f80f4a74SSebastian Grimberg       }
180f80f4a74SSebastian Grimberg     }
181f80f4a74SSebastian Grimberg   }
182f80f4a74SSebastian Grimberg }
183f80f4a74SSebastian Grimberg 
184f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
185f80f4a74SSebastian Grimberg // reads T into shared memory
186f80f4a74SSebastian Grimberg // must sync after call
187f80f4a74SSebastian Grimberg template <int B, int J>
188*3c1e2affSSebastian Grimberg static __device__ __inline__ void dread_T_gm2sm(const int tx, const magma_trans_t transT, const CeedScalar *dT, CeedScalar *sT) {
189f80f4a74SSebastian Grimberg   if (transT == MagmaNoTrans) {
190f80f4a74SSebastian Grimberg     // T is B x J
191f80f4a74SSebastian Grimberg     if (tx < B) {
192f80f4a74SSebastian Grimberg       for (int i = 0; i < J; i++) {
193f80f4a74SSebastian Grimberg         sT[i * B + tx] = dT[i * B + tx];
194f80f4a74SSebastian Grimberg       }
195f80f4a74SSebastian Grimberg     }
196f80f4a74SSebastian Grimberg   } else {
197f80f4a74SSebastian Grimberg     // T is J x B
198f80f4a74SSebastian Grimberg     if (tx < J) {
199f80f4a74SSebastian Grimberg       for (int i = 0; i < B; i++) {
200f80f4a74SSebastian Grimberg         sT[tx * B + i] = dT[i * J + tx];
201f80f4a74SSebastian Grimberg       }
202f80f4a74SSebastian Grimberg     }
203f80f4a74SSebastian Grimberg   }
204f80f4a74SSebastian Grimberg   // must sync after call
205f80f4a74SSebastian Grimberg }
206f80f4a74SSebastian Grimberg 
207f80f4a74SSebastian Grimberg #endif  // CEED_MAGMA_COMMON_TENSOR_H
208