Lines Matching +full:- +full:t
1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
4 // SPDX-License-Identifier: BSD-2-Clause
10 #include "magma-common-tensor.h"
19 template <typename T, bool Add>
22 template <typename T>
23 struct magma_grad_3d_device_accumulate<T, true> {
24 static __device__ __inline__ void op(T &rV, const T &rTmp) { rV += rTmp; }
27 template <typename T>
28 struct magma_grad_3d_device_accumulate<T, false> {
29 static __device__ __inline__ void op(T &rV, const T &rTmp) { rV = rTmp; }
35 // DIM_U -- for the size of rU[DIM_U * NUM_COMP * MAX_P_Q]
36 // DIM_V -- for the size of rV[DIM_V * NUM_COMP * MAX_P_Q]
37 // i_DIM -- the index of the outermost loop over dimensions in grad
38 // i_DIM_U -- which dim index of rU is accessed (always 0 for notrans, 0, 1, or 2 for trans)
39 // i_DIM_V -- which dim index of rV is accessed (0, 1, or 2 for notrans, always 0 for trans)
40 template <typename T, int DIM_U, int DIM_V, int NUM_COMP, int P, int Q, int rU_SIZE, int rV_SIZE, i…
41 static __device__ __inline__ void magma_grad_3d_device(const T *sTinterp, const T *sTgrad, T rU[DIM…
42 … T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx, T rTmp, T *swork) {
55 T *sW1 = swork;
56 T *sW2 = sW1 + P * P * Q;
62 const T *sT = (i_DIM == 0) ? sTgrad : sTinterp;
63 T *sTmp = sW1 + batchid * (1 * Q);
79 const T *sT = (i_DIM == 1) ? sTgrad : sTinterp;
80 T *sTmp = sW1 + batchid * (Q * P); // sTmp is input
81 T *sTmp2 = sW2 + batchid * (Q * Q); // sTmp2 is output
97 const T *sT = (i_DIM == 2) ? sTgrad : sTinterp;
98 T *sTmp = sW2; // sTmp is input
104 magma_grad_3d_device_accumulate<T, ADD>::op(rV[i_DIM_V][comp][j], rTmp);
137 // read T
144 /* read U (idim = 0 for dU, i_DIM = 0 for rU) --
148 /* first call (i_DIM = 0, i_DIM_U = 0, i_DIM_V = 0) --
155 /* second call (i_DIM = 1, i_DIM_U = 0, i_DIM_V = 0) --
162 /* third call (i_DIM = 2, i_DIM_U = 0, i_DIM_V = 0) --
196 // read T
203 /* read U (idim = 0 for dU, i_DIM = 0 for rU) --
210 /* read U (idim = 1 for dU, i_DIM = 0 for rU) --
217 /* read U (idim = 2 for dU, i_DIM = 0 for rU) --
254 // read T
261 /* read U (idim = 0 for dU, i_DIM = 0 for rU) --
268 /* read U (idim = 1 for dU, i_DIM = 0 for rU) --
275 /* read U (idim = 2 for dU, i_DIM = 0 for rU) --