xref: /libCEED/include/ceed/jit-source/magma/magma-common-tensor.h (revision f80f4a748154eed4bc661c135f695b92b1bc45b9)
1*f80f4a74SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*f80f4a74SSebastian Grimberg //
4*f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5*f80f4a74SSebastian Grimberg //
6*f80f4a74SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7*f80f4a74SSebastian Grimberg 
8*f80f4a74SSebastian Grimberg #ifndef CEED_MAGMA_COMMON_TENSOR_H
9*f80f4a74SSebastian Grimberg #define CEED_MAGMA_COMMON_TENSOR_H
10*f80f4a74SSebastian Grimberg 
11*f80f4a74SSebastian Grimberg #define MAGMA_MAXTHREADS_1D 128
12*f80f4a74SSebastian Grimberg #define MAGMA_MAXTHREADS_2D 128
13*f80f4a74SSebastian Grimberg #define MAGMA_MAXTHREADS_3D 64
14*f80f4a74SSebastian Grimberg // Define macro for determining number of threads in y-direction
15*f80f4a74SSebastian Grimberg // for basis kernels
16*f80f4a74SSebastian Grimberg #define MAGMA_BASIS_NTCOL(x, maxt) (((maxt) < (x)) ? 1 : ((maxt) / (x)))
17*f80f4a74SSebastian Grimberg // Define macro for computing the total threads in a block
18*f80f4a74SSebastian Grimberg // for use with __launch_bounds__()
19*f80f4a74SSebastian Grimberg #define MAGMA_BASIS_BOUNDS(x, maxt) (x * MAGMA_BASIS_NTCOL(x, maxt))
20*f80f4a74SSebastian Grimberg 
21*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
22*f80f4a74SSebastian Grimberg // read U or V of a 1D element into shared memory sU[][] or sV[][] --  for all components
23*f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element
24*f80f4a74SSebastian Grimberg // must sync after call
25*f80f4a74SSebastian Grimberg template <typename T, int LENGTH, int NCOMP_>
26*f80f4a74SSebastian Grimberg __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NCOMP_], const int tx) {
27*f80f4a74SSebastian Grimberg   if (tx < LENGTH) {
28*f80f4a74SSebastian Grimberg     for (int icomp = 0; icomp < NCOMP_; icomp++) {
29*f80f4a74SSebastian Grimberg       sBuffer[icomp][tx] = devptr[icomp * compstride + tx];
30*f80f4a74SSebastian Grimberg     }
31*f80f4a74SSebastian Grimberg   }
32*f80f4a74SSebastian Grimberg }
33*f80f4a74SSebastian Grimberg 
34*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
35*f80f4a74SSebastian Grimberg // write V of a 1D element into global memory from sV[][] --  for all components
36*f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element
37*f80f4a74SSebastian Grimberg template <typename T, int LENGTH, int NCOMP_>
38*f80f4a74SSebastian Grimberg __device__ __inline__ void write_1d(T *sBuffer[NCOMP_], T *devptr, const int compstride, const int tx) {
39*f80f4a74SSebastian Grimberg   if (tx < LENGTH) {
40*f80f4a74SSebastian Grimberg     for (int icomp = 0; icomp < NCOMP_; icomp++) {
41*f80f4a74SSebastian Grimberg       devptr[icomp * compstride + tx] = sBuffer[icomp][tx];
42*f80f4a74SSebastian Grimberg     }
43*f80f4a74SSebastian Grimberg   }
44*f80f4a74SSebastian Grimberg }
45*f80f4a74SSebastian Grimberg 
46*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
47*f80f4a74SSebastian Grimberg // read U of a 2D element into registers rU[][][] --  for all components of a single dim
48*f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride
49*f80f4a74SSebastian Grimberg // register is assumed to be rU[DIMU][NCOMP_][rUsize]
50*f80f4a74SSebastian Grimberg // iDIM specifies which dimension is being read into in rU
51*f80f4a74SSebastian Grimberg // rUsize can be different from P_ (e.g. MAXP_Q)
52*f80f4a74SSebastian Grimberg // sTmp is a shared memory workspace of size P_^2
53*f80f4a74SSebastian Grimberg template <typename T, int P_, int DIMU, int NCOMP_, int rUsize, int iDIM>
54*f80f4a74SSebastian Grimberg __device__ __inline__ void readU_2d(const T *dU, const int compstride, T rU[DIMU][NCOMP_][rUsize], T *sTmp, const int tx) {
55*f80f4a74SSebastian Grimberg   // read U as a batch P_ of (1xP_) vectors
56*f80f4a74SSebastian Grimberg   // vec 0  : [u0, u1, u2, ... u_(P_-1)] -- contiguous in memory
57*f80f4a74SSebastian Grimberg   // vec 1  : [u0, u1, u2, ... u_(P_-1)] -- contiguous in memory
58*f80f4a74SSebastian Grimberg   // ...
59*f80f4a74SSebastian Grimberg   // vec P_-1: [u0, u1, u2, ... u_(P_-1)] -- contiguous in memory
60*f80f4a74SSebastian Grimberg   // threads collaboratively read vec0 and then vec1 and so on
61*f80f4a74SSebastian Grimberg   // but for the kernel, we want
62*f80f4a74SSebastian Grimberg   // thread 0 to hold all of vec0 in registers, and
63*f80f4a74SSebastian Grimberg   // thread 1 to hold all of vec1 in registers, and and so on
64*f80f4a74SSebastian Grimberg   // so we need to transpose
65*f80f4a74SSebastian Grimberg   for (int icomp = 0; icomp < NCOMP_; icomp++) {
66*f80f4a74SSebastian Grimberg     // read from global memory into shared memory
67*f80f4a74SSebastian Grimberg     if (tx < P_) {
68*f80f4a74SSebastian Grimberg       for (int i = 0; i < P_; i++) {
69*f80f4a74SSebastian Grimberg         sTmp[i * P_ + tx] = dU[icomp * compstride + i * P_ + tx];
70*f80f4a74SSebastian Grimberg       }
71*f80f4a74SSebastian Grimberg     }
72*f80f4a74SSebastian Grimberg     __syncthreads();
73*f80f4a74SSebastian Grimberg 
74*f80f4a74SSebastian Grimberg     if (tx < P_) {
75*f80f4a74SSebastian Grimberg       for (int i = 0; i < P_; i++) {
76*f80f4a74SSebastian Grimberg         rU[iDIM][icomp][i] = sTmp[tx * P_ + i];
77*f80f4a74SSebastian Grimberg       }
78*f80f4a74SSebastian Grimberg     }
79*f80f4a74SSebastian Grimberg     __syncthreads();
80*f80f4a74SSebastian Grimberg   }
81*f80f4a74SSebastian Grimberg }
82*f80f4a74SSebastian Grimberg 
83*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
84*f80f4a74SSebastian Grimberg // read V of a 2D element into registers rV[][][] --  for all components of a single dim
85*f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
86*f80f4a74SSebastian Grimberg // register is assumed to be rV[DIMV][NCOMP_][rVsize]
87*f80f4a74SSebastian Grimberg // iDIM specifies which dimension is being read into in rV
88*f80f4a74SSebastian Grimberg // rVsize can be different from P_ (e.g. MAXP_Q)
89*f80f4a74SSebastian Grimberg template <typename T, int Q_, int DIMV, int NCOMP_, int rVsize, int iDIM>
90*f80f4a74SSebastian Grimberg __device__ __inline__ void readV_2d(const T *dV, const int compstride, T rV[DIMV][NCOMP_][rVsize], const int tx) {
91*f80f4a74SSebastian Grimberg   if (tx < Q_) {
92*f80f4a74SSebastian Grimberg     for (int icomp = 0; icomp < NCOMP_; icomp++) {
93*f80f4a74SSebastian Grimberg       for (int j = 0; j < Q_; j++) {
94*f80f4a74SSebastian Grimberg         rV[iDIM][icomp][j] = dV[icomp * compstride + j * Q_ + tx];
95*f80f4a74SSebastian Grimberg       }
96*f80f4a74SSebastian Grimberg     }
97*f80f4a74SSebastian Grimberg   }
98*f80f4a74SSebastian Grimberg }
99*f80f4a74SSebastian Grimberg 
100*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
101*f80f4a74SSebastian Grimberg // write V of a 2D element from registers rV[][][] to global memory --  for all components of a single dim
102*f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
103*f80f4a74SSebastian Grimberg // register is assumed to be rV[DIMV][NCOMP_][rVsize]
104*f80f4a74SSebastian Grimberg // iDIM specifies which dimension is being read from in rV
105*f80f4a74SSebastian Grimberg // idim specifies which dimension is being written to in dV
106*f80f4a74SSebastian Grimberg // rVsize can be different from P_ (e.g. MAXP_Q)
107*f80f4a74SSebastian Grimberg template <typename T, int Q_, int DIMV, int NCOMP_, int rVsize, int iDIM>
108*f80f4a74SSebastian Grimberg __device__ __inline__ void writeV_2d(T *dV, const int compstride, T rV[DIMV][NCOMP_][rVsize], const int tx) {
109*f80f4a74SSebastian Grimberg   if (tx < Q_) {
110*f80f4a74SSebastian Grimberg     for (int icomp = 0; icomp < NCOMP_; icomp++) {
111*f80f4a74SSebastian Grimberg       for (int j = 0; j < Q_; j++) {
112*f80f4a74SSebastian Grimberg         dV[icomp * compstride + j * Q_ + tx] = rV[iDIM][icomp][j];
113*f80f4a74SSebastian Grimberg       }
114*f80f4a74SSebastian Grimberg     }
115*f80f4a74SSebastian Grimberg   }
116*f80f4a74SSebastian Grimberg }
117*f80f4a74SSebastian Grimberg 
118*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
119*f80f4a74SSebastian Grimberg // read U of a 3D element into registers rU[][][] --  for all components of a single dim
120*f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride
121*f80f4a74SSebastian Grimberg // register is assumed to be rU[DIMU][NCOMP_][rUsize]
122*f80f4a74SSebastian Grimberg // iDIM specifies which dimension is being read into in rU
123*f80f4a74SSebastian Grimberg // rUsize can be different from P_ (e.g. MAXP_Q)
124*f80f4a74SSebastian Grimberg // sTmp is a shared memory workspace of size P_^3
125*f80f4a74SSebastian Grimberg template <typename T, int P_, int DIMU, int NCOMP_, int rUsize, int iDIM>
126*f80f4a74SSebastian Grimberg __device__ __inline__ void readU_3d(const T *dU, const int compstride, T rU[DIMU][NCOMP_][rUsize], T *sTmp, const int tx) {
127*f80f4a74SSebastian Grimberg   // read U as a batch P_^2 of (1xP_) vectors
128*f80f4a74SSebastian Grimberg   // vec 0    : [u0, u1, u2, ... u_(P_-1)] -- contiguous in memory
129*f80f4a74SSebastian Grimberg   // vec 1    : [u0, u1, u2, ... u_(P_-1)] -- contiguous in memory
130*f80f4a74SSebastian Grimberg   // ...
131*f80f4a74SSebastian Grimberg   // vec P_^2-1: [u0, u1, u2, ... u_(P_-1)] -- contiguous in memory
132*f80f4a74SSebastian Grimberg   // threads collaboratively read vec0 and then vec1 and so on
133*f80f4a74SSebastian Grimberg   // but for the kernel, we want
134*f80f4a74SSebastian Grimberg   // thread 0 to hold all of vec0 in registers, and
135*f80f4a74SSebastian Grimberg   // thread 1 to hold all of vec1 in registers, and and so on
136*f80f4a74SSebastian Grimberg   // so we need to transpose
137*f80f4a74SSebastian Grimberg   for (int icomp = 0; icomp < NCOMP_; icomp++) {
138*f80f4a74SSebastian Grimberg     // read from global memory into shared memory
139*f80f4a74SSebastian Grimberg     if (tx < P_ * P_) {
140*f80f4a74SSebastian Grimberg       for (int i = 0; i < P_; i++) {
141*f80f4a74SSebastian Grimberg         sTmp[i * P_ * P_ + tx] = dU[icomp * compstride + i * P_ * P_ + tx];
142*f80f4a74SSebastian Grimberg       }
143*f80f4a74SSebastian Grimberg     }
144*f80f4a74SSebastian Grimberg     __syncthreads();
145*f80f4a74SSebastian Grimberg 
146*f80f4a74SSebastian Grimberg     if (tx < P_ * P_) {
147*f80f4a74SSebastian Grimberg       for (int i = 0; i < P_; i++) {
148*f80f4a74SSebastian Grimberg         rU[iDIM][icomp][i] = sTmp[tx * P_ + i];
149*f80f4a74SSebastian Grimberg       }
150*f80f4a74SSebastian Grimberg     }
151*f80f4a74SSebastian Grimberg     __syncthreads();
152*f80f4a74SSebastian Grimberg   }
153*f80f4a74SSebastian Grimberg }
154*f80f4a74SSebastian Grimberg 
155*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
156*f80f4a74SSebastian Grimberg // read V of a 3D element into registers rV[][][] --  for all components of a single dim
157*f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
158*f80f4a74SSebastian Grimberg // register is assumed to be rV[DIMV][NCOMP_][rVsize]
159*f80f4a74SSebastian Grimberg // iDIM specifies which dimension is being read into in rV
160*f80f4a74SSebastian Grimberg // rVsize can be different from P_ (e.g. MAXP_Q)
161*f80f4a74SSebastian Grimberg template <typename T, int Q_, int DIMV, int NCOMP_, int rVsize, int iDIM>
162*f80f4a74SSebastian Grimberg __device__ __inline__ void readV_3d(const T *dV, const int compstride, T rV[DIMV][NCOMP_][rVsize], const int tx) {
163*f80f4a74SSebastian Grimberg   if (tx < Q_ * Q_) {
164*f80f4a74SSebastian Grimberg     for (int icomp = 0; icomp < NCOMP_; icomp++) {
165*f80f4a74SSebastian Grimberg       for (int j = 0; j < Q_; j++) {
166*f80f4a74SSebastian Grimberg         rV[iDIM][icomp][j] = dV[icomp * compstride + j * (Q_ * Q_) + tx];
167*f80f4a74SSebastian Grimberg       }
168*f80f4a74SSebastian Grimberg     }
169*f80f4a74SSebastian Grimberg   }
170*f80f4a74SSebastian Grimberg }
171*f80f4a74SSebastian Grimberg 
172*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
173*f80f4a74SSebastian Grimberg // write V of a 3D element from registers rV[][][] to global memory --  for all components of a single dim
174*f80f4a74SSebastian Grimberg // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
175*f80f4a74SSebastian Grimberg // register is assumed to be rV[DIMV][NCOMP_][rVsize]
176*f80f4a74SSebastian Grimberg // iDIM specifies which dimension is being read from in rV
177*f80f4a74SSebastian Grimberg // idim specifies which dimension is being written to in dV
178*f80f4a74SSebastian Grimberg // rVsize can be different from P_ (e.g. MAXP_Q)
179*f80f4a74SSebastian Grimberg template <typename T, int Q_, int DIMV, int NCOMP_, int rVsize, int iDIM>
180*f80f4a74SSebastian Grimberg __device__ __inline__ void writeV_3d(T *dV, const int compstride, T rV[DIMV][NCOMP_][rVsize], const int tx) {
181*f80f4a74SSebastian Grimberg   if (tx < (Q_ * Q_)) {
182*f80f4a74SSebastian Grimberg     for (int icomp = 0; icomp < NCOMP_; icomp++) {
183*f80f4a74SSebastian Grimberg       for (int j = 0; j < Q_; j++) {
184*f80f4a74SSebastian Grimberg         dV[icomp * compstride + j * (Q_ * Q_) + tx] = rV[iDIM][icomp][j];
185*f80f4a74SSebastian Grimberg       }
186*f80f4a74SSebastian Grimberg     }
187*f80f4a74SSebastian Grimberg   }
188*f80f4a74SSebastian Grimberg }
189*f80f4a74SSebastian Grimberg 
190*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
191*f80f4a74SSebastian Grimberg // reads T into shared memory
192*f80f4a74SSebastian Grimberg // must sync after call
193*f80f4a74SSebastian Grimberg template <int B, int J>
194*f80f4a74SSebastian Grimberg __device__ __inline__ void dread_T_gm2sm(const int tx, const magma_trans_t transT, const CeedScalar *dT, CeedScalar *sT) {
195*f80f4a74SSebastian Grimberg   if (transT == MagmaNoTrans) {
196*f80f4a74SSebastian Grimberg     // T is B x J
197*f80f4a74SSebastian Grimberg     if (tx < B) {
198*f80f4a74SSebastian Grimberg       for (int i = 0; i < J; i++) {
199*f80f4a74SSebastian Grimberg         sT[i * B + tx] = dT[i * B + tx];
200*f80f4a74SSebastian Grimberg       }
201*f80f4a74SSebastian Grimberg     }
202*f80f4a74SSebastian Grimberg   } else {
203*f80f4a74SSebastian Grimberg     // T is J x B
204*f80f4a74SSebastian Grimberg     if (tx < J) {
205*f80f4a74SSebastian Grimberg       for (int i = 0; i < B; i++) {
206*f80f4a74SSebastian Grimberg         sT[tx * B + i] = dT[i * J + tx];
207*f80f4a74SSebastian Grimberg       }
208*f80f4a74SSebastian Grimberg     }
209*f80f4a74SSebastian Grimberg   }
210*f80f4a74SSebastian Grimberg   // must sync after call
211*f80f4a74SSebastian Grimberg }
212*f80f4a74SSebastian Grimberg 
213*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
214*f80f4a74SSebastian Grimberg // reads a slice of U from shared/global memory into registers
215*f80f4a74SSebastian Grimberg // the correct pointer U must be precomputed
216*f80f4a74SSebastian Grimberg template <int B>
217*f80f4a74SSebastian Grimberg __device__ __inline__ void dread_U_gsm2reg(const int C, const int tx_, const CeedScalar *U, CeedScalar rU[B]) {
218*f80f4a74SSebastian Grimberg   for (int i = 0; i < B; i++) {
219*f80f4a74SSebastian Grimberg     rU[i] = U[i * C + tx_];
220*f80f4a74SSebastian Grimberg   }
221*f80f4a74SSebastian Grimberg }
222*f80f4a74SSebastian Grimberg 
223*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
224*f80f4a74SSebastian Grimberg // reads a slice of V from shared/global memory into registers with scaling
225*f80f4a74SSebastian Grimberg // the correct pointer V must be precomputed
226*f80f4a74SSebastian Grimberg template <int J>
227*f80f4a74SSebastian Grimberg __device__ __inline__ void dread_V_gsm2reg(const int C, const int tx_, const CeedScalar *V, CeedScalar rV[J]) {
228*f80f4a74SSebastian Grimberg   for (int i = 0; i < J; i++) {
229*f80f4a74SSebastian Grimberg     rV[i] = V[i * C + tx_];
230*f80f4a74SSebastian Grimberg   }
231*f80f4a74SSebastian Grimberg }
232*f80f4a74SSebastian Grimberg 
233*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
234*f80f4a74SSebastian Grimberg // writes a slice of V from reg to shared/global memory
235*f80f4a74SSebastian Grimberg // the correct pointer V must be precomputed
236*f80f4a74SSebastian Grimberg template <int J>
237*f80f4a74SSebastian Grimberg __device__ __inline__ void dwrite_V_reg2gsm(const int C, const int tx_, CeedScalar rV[J], CeedScalar *V) {
238*f80f4a74SSebastian Grimberg   for (int i = 0; i < J; i++) {
239*f80f4a74SSebastian Grimberg     V[i * C + tx_] = rV[i];
240*f80f4a74SSebastian Grimberg   }
241*f80f4a74SSebastian Grimberg }
242*f80f4a74SSebastian Grimberg 
243*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
244*f80f4a74SSebastian Grimberg // multiply a slice of U times T to produce a slice of V
245*f80f4a74SSebastian Grimberg template <int B, int J>
246*f80f4a74SSebastian Grimberg __device__ __inline__ void dgemm_slice(CeedScalar alpha, CeedScalar *sT, CeedScalar rU[B], CeedScalar beta, CeedScalar rV[J]) {
247*f80f4a74SSebastian Grimberg   CeedScalar rTmp;
248*f80f4a74SSebastian Grimberg   for (int j = 0; j < J; j++) {
249*f80f4a74SSebastian Grimberg     rTmp = 0.0;
250*f80f4a74SSebastian Grimberg     for (int b = 0; b < B; b++) {
251*f80f4a74SSebastian Grimberg       rTmp += rU[b] * sT[j * B + b];
252*f80f4a74SSebastian Grimberg     }
253*f80f4a74SSebastian Grimberg     rV[j] *= beta;
254*f80f4a74SSebastian Grimberg     rV[j] += alpha * rTmp;
255*f80f4a74SSebastian Grimberg   }
256*f80f4a74SSebastian Grimberg }
257*f80f4a74SSebastian Grimberg 
258*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////////////////
259*f80f4a74SSebastian Grimberg template <int B, int J>
260*f80f4a74SSebastian Grimberg __device__ __inline__ void dgemm_ceed_device(const int tx, const int A, const int C, magma_trans_t transT, CeedScalar *sT, const CeedScalar alpha,
261*f80f4a74SSebastian Grimberg                                              const CeedScalar beta, const CeedScalar *dU, CeedScalar *dV, CeedScalar rU[B], CeedScalar rV[J]) {
262*f80f4a74SSebastian Grimberg   const int tx_      = tx % C;
263*f80f4a74SSebastian Grimberg   const int slice_id = tx / C;
264*f80f4a74SSebastian Grimberg 
265*f80f4a74SSebastian Grimberg   // advance pointers for U and V
266*f80f4a74SSebastian Grimberg   dU += slice_id * C * B;
267*f80f4a74SSebastian Grimberg   dV += slice_id * C * J;
268*f80f4a74SSebastian Grimberg 
269*f80f4a74SSebastian Grimberg   // read V if beta is non-zero
270*f80f4a74SSebastian Grimberg   if (beta != 0.0) {
271*f80f4a74SSebastian Grimberg     dread_V_gsm2reg<J>(C, tx_, (const CeedScalar *)dV, rV);
272*f80f4a74SSebastian Grimberg   }
273*f80f4a74SSebastian Grimberg 
274*f80f4a74SSebastian Grimberg   // read U
275*f80f4a74SSebastian Grimberg   dread_U_gsm2reg<B>(C, tx_, dU, rU);
276*f80f4a74SSebastian Grimberg 
277*f80f4a74SSebastian Grimberg   // multiply
278*f80f4a74SSebastian Grimberg   dgemm_slice<B, J>(alpha, sT, rU, beta, rV);
279*f80f4a74SSebastian Grimberg 
280*f80f4a74SSebastian Grimberg   // write V back
281*f80f4a74SSebastian Grimberg   dwrite_V_reg2gsm<J>(C, tx_, rV, dV);
282*f80f4a74SSebastian Grimberg }
283*f80f4a74SSebastian Grimberg 
284*f80f4a74SSebastian Grimberg #endif  // CEED_MAGMA_COMMON_TENSOR_H
285