Lines Matching +full:- +full:t

1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
4 // SPDX-License-Identifier: BSD-2-Clause
10 #include "magma-common-tensor.h"
18 template <typename T, bool Add>
21 template <typename T>
22 struct magma_grad_2d_device_accumulate<T, true> {
23 static __device__ __inline__ void op(T &rV, const T &rTmp) { rV += rTmp; }
26 template <typename T>
27 struct magma_grad_2d_device_accumulate<T, false> {
28 static __device__ __inline__ void op(T &rV, const T &rTmp) { rV = rTmp; }
34 // DIM_U -- for the size of rU[DIM_U * NUM_COMP * MAX_P_Q]
35 // DIM_V -- for the size of rV[DIM_V * NUM_COMP * MAX_P_Q]
36 // i_DIM -- the index of the outermost loop over dimensions in grad
37 // i_DIM_U -- which dim index of rU is accessed (always 0 for notrans, 0 or 1 for trans)
38 // i_DIM_V -- which dim index of rV is accessed (0 or 1 for notrans, always 0 for trans)
39 template <typename T, int DIM_U, int DIM_V, int NUM_COMP, int P, int Q, int rU_SIZE, int rV_SIZE, i…
40 static __device__ __inline__ void magma_grad_2d_device(const T *sTinterp, const T *sTgrad, T rU[DIM…
41T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx, T rTmp, T *swork) {
54 // 1st product -- Batch P of (1xP) matrices [reg] x (PxQ) [shmem] => Batch P of (1xQ) matrices
59 const T *sT = (i_DIM == 0) ? sTgrad : sTinterp;
60 T *sTmp = swork + batchid * (1 * Q);
71 // 2nd product -- Batch 1 of a (QxP) matrix [shmem] x (PxQ) [shmem] => (QxQ) matrix [reg]
75 const T *sT = (i_DIM == 1) ? sTgrad : sTinterp;
76 T *sTmp = swork + batchid * (Q * P);
82 magma_grad_2d_device_accumulate<T, ADD>::op(rV[i_DIM_V][comp][j], rTmp);
115 // read T
121 /* read U (idim = 0 for dU, i_DIM = 0 for rU) --
125 /* first call (i_DIM = 0, i_DIM_U = 0, i_DIM_V = 0) --
132 /* second call (i_DIM = 1, i_DIM_U = 0, i_DIM_V = 0) --
166 // read T
173 /* read U (idim = 0 for dU, i_DIM = 0 for rU) --
180 /* read U (idim = 1 for dU, i_DIM = 0 for rU) --
217 // read T
224 /* read U (idim = 0 for dU, i_DIM = 0 for rU) --
231 /* read U (idim = 1 for dU, i_DIM = 0 for rU) --