xref: /libCEED/include/ceed/jit-source/sycl/sycl-shared-basis-tensor-templates.h (revision 5aed82e4fa97acf4ba24a7f10a35f5303a6798e0)
1*5aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3bd882c8aSJames Wright //
4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5bd882c8aSJames Wright //
6bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
7bd882c8aSJames Wright 
8bd882c8aSJames Wright /// @file
9bd882c8aSJames Wright /// Internal header for SYCL shared memory tensor product basis templates
1094b7b29bSJeremy L Thompson #ifndef CEED_SYCL_SHARED_BASIS_TENSOR_TEMPLATES_H
1194b7b29bSJeremy L Thompson #define CEED_SYCL_SHARED_BASIS_TENSOR_TEMPLATES_H
12bd882c8aSJames Wright 
13bd882c8aSJames Wright #include <ceed.h>
14bd882c8aSJames Wright 
15bd882c8aSJames Wright //------------------------------------------------------------------------------
16bd882c8aSJames Wright // 1D
17bd882c8aSJames Wright //------------------------------------------------------------------------------
18bd882c8aSJames Wright 
19bd882c8aSJames Wright //------------------------------------------------------------------------------
20bd882c8aSJames Wright // 1D tensor contraction x
21bd882c8aSJames Wright //------------------------------------------------------------------------------
22bd882c8aSJames Wright inline void ContractX1d(const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict U, local const CeedScalar *restrict B,
23bd882c8aSJames Wright                         private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
24bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
25bd882c8aSJames Wright 
26bd882c8aSJames Wright   scratch[item_id_x] = *U;
27bd882c8aSJames Wright   work_group_barrier(CLK_LOCAL_MEM_FENCE);
28bd882c8aSJames Wright 
29bd882c8aSJames Wright   *V = 0.0;
30bd882c8aSJames Wright   if (item_id_x < Q_1D) {
31bd882c8aSJames Wright     for (CeedInt i = 0; i < P_1D; i++) {
32bd882c8aSJames Wright       *V += B[i + item_id_x * P_1D] * scratch[i];  // Contract x direction
33bd882c8aSJames Wright     }
34bd882c8aSJames Wright   }
35bd882c8aSJames Wright   work_group_barrier(CLK_LOCAL_MEM_FENCE);
36bd882c8aSJames Wright }
37bd882c8aSJames Wright 
38bd882c8aSJames Wright //------------------------------------------------------------------------------
39bd882c8aSJames Wright // 1D transpose tensor contraction x
40bd882c8aSJames Wright //------------------------------------------------------------------------------
41bd882c8aSJames Wright inline void ContractTransposeX1d(const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict U, local const CeedScalar *restrict B,
42bd882c8aSJames Wright                                  private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
43bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
44bd882c8aSJames Wright 
45bd882c8aSJames Wright   scratch[item_id_x] = *U;
46bd882c8aSJames Wright   work_group_barrier(CLK_LOCAL_MEM_FENCE);
47bd882c8aSJames Wright 
48bd882c8aSJames Wright   *V = 0.0;
49bd882c8aSJames Wright   if (item_id_x < P_1D) {
50bd882c8aSJames Wright     for (CeedInt i = 0; i < Q_1D; i++) {
51bd882c8aSJames Wright       *V += B[item_id_x + i * P_1D] * scratch[i];  // Contract x direction
52bd882c8aSJames Wright     }
53bd882c8aSJames Wright   }
54bd882c8aSJames Wright   work_group_barrier(CLK_LOCAL_MEM_FENCE);
55bd882c8aSJames Wright }
56bd882c8aSJames Wright 
57bd882c8aSJames Wright //------------------------------------------------------------------------------
58bd882c8aSJames Wright // 1D interpolate to quadrature points
59bd882c8aSJames Wright //------------------------------------------------------------------------------
60bd882c8aSJames Wright inline void Interp1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict r_U,
61bd882c8aSJames Wright                      local const CeedScalar *restrict s_B, private CeedScalar *restrict r_V, local CeedScalar *restrict scratch) {
62bd882c8aSJames Wright   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
63bd882c8aSJames Wright     ContractX1d(P_1D, Q_1D, r_U + comp, s_B, r_V + comp, scratch);
64bd882c8aSJames Wright   }
65bd882c8aSJames Wright }
66bd882c8aSJames Wright 
67bd882c8aSJames Wright //------------------------------------------------------------------------------
68bd882c8aSJames Wright // 1D interpolate transpose
69bd882c8aSJames Wright //------------------------------------------------------------------------------
70bd882c8aSJames Wright inline void InterpTranspose1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict r_U,
71bd882c8aSJames Wright                               local const CeedScalar *restrict s_B, private CeedScalar *restrict r_V, local CeedScalar *restrict scratch) {
72bd882c8aSJames Wright   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
73bd882c8aSJames Wright     ContractTransposeX1d(P_1D, Q_1D, r_U + comp, s_B, r_V + comp, scratch);
74bd882c8aSJames Wright   }
75bd882c8aSJames Wright }
76bd882c8aSJames Wright 
77bd882c8aSJames Wright //------------------------------------------------------------------------------
78bd882c8aSJames Wright // 1D derivatives at quadrature points
79bd882c8aSJames Wright //------------------------------------------------------------------------------
80bd882c8aSJames Wright inline void Grad1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict r_U,
81bd882c8aSJames Wright                    local const CeedScalar *restrict s_G, private CeedScalar *restrict r_V, local CeedScalar *restrict scratch) {
82bd882c8aSJames Wright   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
83bd882c8aSJames Wright     ContractX1d(P_1D, Q_1D, r_U + comp, s_G, r_V + comp, scratch);
84bd882c8aSJames Wright   }
85bd882c8aSJames Wright }
86bd882c8aSJames Wright 
87bd882c8aSJames Wright //------------------------------------------------------------------------------
88bd882c8aSJames Wright // 1D derivatives transpose
89bd882c8aSJames Wright //------------------------------------------------------------------------------
90bd882c8aSJames Wright inline void GradTranspose1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict r_U,
91bd882c8aSJames Wright                             local const CeedScalar *restrict s_G, private CeedScalar *restrict r_V, local CeedScalar *restrict scratch) {
92bd882c8aSJames Wright   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
93bd882c8aSJames Wright     ContractTransposeX1d(P_1D, Q_1D, r_U + comp, s_G, r_V + comp, scratch);
94bd882c8aSJames Wright   }
95bd882c8aSJames Wright }
96bd882c8aSJames Wright 
97bd882c8aSJames Wright //------------------------------------------------------------------------------
98bd882c8aSJames Wright // 1D quadrature weights
99bd882c8aSJames Wright //------------------------------------------------------------------------------
100bd882c8aSJames Wright inline void Weight1d(const CeedInt Q_1D, const CeedScalar *restrict q_weight_1d, CeedScalar *restrict w) {
101bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
102bd882c8aSJames Wright   *w                      = (item_id_x < Q_1D) ? q_weight_1d[item_id_x] : 0.0;
103bd882c8aSJames Wright }
104bd882c8aSJames Wright 
105bd882c8aSJames Wright //------------------------------------------------------------------------------
106bd882c8aSJames Wright // 2D
107bd882c8aSJames Wright //------------------------------------------------------------------------------
108bd882c8aSJames Wright 
109bd882c8aSJames Wright //------------------------------------------------------------------------------
110bd882c8aSJames Wright // 2D tensor contraction x
111bd882c8aSJames Wright //------------------------------------------------------------------------------
112bd882c8aSJames Wright inline void ContractX2d(const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict U, local const CeedScalar *restrict B,
113bd882c8aSJames Wright                         private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
114bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
115bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
116bd882c8aSJames Wright 
117bd882c8aSJames Wright   scratch[item_id_x + item_id_y * T_1D] = *U;
118bd882c8aSJames Wright   work_group_barrier(CLK_LOCAL_MEM_FENCE);
119bd882c8aSJames Wright 
120bd882c8aSJames Wright   *V = 0.0;
121bd882c8aSJames Wright   if (item_id_x < Q_1D && item_id_y < P_1D) {
122bd882c8aSJames Wright     for (CeedInt i = 0; i < P_1D; i++) {
123bd882c8aSJames Wright       *V += B[i + item_id_x * P_1D] * scratch[i + item_id_y * T_1D];  // Contract x direction
124bd882c8aSJames Wright     }
125bd882c8aSJames Wright   }
126bd882c8aSJames Wright   work_group_barrier(CLK_LOCAL_MEM_FENCE);
127bd882c8aSJames Wright }
128bd882c8aSJames Wright 
129bd882c8aSJames Wright //------------------------------------------------------------------------------
130bd882c8aSJames Wright // 2D tensor contract y
131bd882c8aSJames Wright //------------------------------------------------------------------------------
132bd882c8aSJames Wright inline void ContractY2d(const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict U, local const CeedScalar *restrict B,
133bd882c8aSJames Wright                         private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
134bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
135bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
136bd882c8aSJames Wright 
137bd882c8aSJames Wright   scratch[item_id_x + item_id_y * T_1D] = *U;
138bd882c8aSJames Wright   work_group_barrier(CLK_LOCAL_MEM_FENCE);
139bd882c8aSJames Wright 
140bd882c8aSJames Wright   *V = 0.0;
141bd882c8aSJames Wright   if (item_id_x < Q_1D && item_id_y < Q_1D) {
142bd882c8aSJames Wright     for (CeedInt i = 0; i < P_1D; i++) {
143bd882c8aSJames Wright       *V += B[i + item_id_y * P_1D] * scratch[item_id_x + i * T_1D];  // Contract y direction
144bd882c8aSJames Wright     }
145bd882c8aSJames Wright   }
146bd882c8aSJames Wright   work_group_barrier(CLK_LOCAL_MEM_FENCE);
147bd882c8aSJames Wright }
148bd882c8aSJames Wright 
149bd882c8aSJames Wright //------------------------------------------------------------------------------
150bd882c8aSJames Wright // 2D transpose tensor contract y
151bd882c8aSJames Wright //------------------------------------------------------------------------------
152bd882c8aSJames Wright inline void ContractTransposeY2d(const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict U, local const CeedScalar *restrict B,
153bd882c8aSJames Wright                                  private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
154bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
155bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
156bd882c8aSJames Wright 
157bd882c8aSJames Wright   scratch[item_id_x + item_id_y * T_1D] = *U;
158bd882c8aSJames Wright   work_group_barrier(CLK_LOCAL_MEM_FENCE);
159bd882c8aSJames Wright 
160bd882c8aSJames Wright   *V = 0.0;
161bd882c8aSJames Wright   if (item_id_x < Q_1D && item_id_y < P_1D) {
162bd882c8aSJames Wright     for (CeedInt i = 0; i < Q_1D; i++) {
163bd882c8aSJames Wright       *V += B[item_id_y + i * P_1D] * scratch[item_id_x + i * T_1D];  // Contract y direction
164bd882c8aSJames Wright     }
165bd882c8aSJames Wright   }
166bd882c8aSJames Wright   work_group_barrier(CLK_LOCAL_MEM_FENCE);
167bd882c8aSJames Wright }
168bd882c8aSJames Wright 
169bd882c8aSJames Wright //------------------------------------------------------------------------------
170bd882c8aSJames Wright // 2D transpose tensor contract x
171bd882c8aSJames Wright //------------------------------------------------------------------------------
172bd882c8aSJames Wright inline void ContractTransposeX2d(const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict U, local const CeedScalar *restrict B,
173bd882c8aSJames Wright                                  private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
174bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
175bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
176bd882c8aSJames Wright 
177bd882c8aSJames Wright   scratch[item_id_x + item_id_y * T_1D] = *U;
178bd882c8aSJames Wright   work_group_barrier(CLK_LOCAL_MEM_FENCE);
179bd882c8aSJames Wright 
180bd882c8aSJames Wright   *V = 0.0;
181bd882c8aSJames Wright   if (item_id_x < P_1D && item_id_y < P_1D) {
182bd882c8aSJames Wright     for (CeedInt i = 0; i < Q_1D; i++) {
183bd882c8aSJames Wright       *V += B[item_id_x + i * P_1D] * scratch[i + item_id_y * T_1D];  // Contract x direction
184bd882c8aSJames Wright     }
185bd882c8aSJames Wright   }
186bd882c8aSJames Wright   work_group_barrier(CLK_LOCAL_MEM_FENCE);
187bd882c8aSJames Wright }
188bd882c8aSJames Wright 
189bd882c8aSJames Wright //------------------------------------------------------------------------------
190bd882c8aSJames Wright // 2D transpose tensor contract and add x
191bd882c8aSJames Wright //------------------------------------------------------------------------------
192bd882c8aSJames Wright inline void ContractTransposeAddX2d(const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict U, local const CeedScalar *restrict B,
193bd882c8aSJames Wright                                     private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
194bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
195bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
196bd882c8aSJames Wright 
197bd882c8aSJames Wright   scratch[item_id_x + item_id_y * T_1D] = *U;
198bd882c8aSJames Wright   work_group_barrier(CLK_LOCAL_MEM_FENCE);
199bd882c8aSJames Wright 
200bd882c8aSJames Wright   if (item_id_x < P_1D && item_id_y < P_1D) {
201bd882c8aSJames Wright     for (CeedInt i = 0; i < Q_1D; i++) {
202bd882c8aSJames Wright       *V += B[item_id_x + i * P_1D] * scratch[i + item_id_y * T_1D];  // Contract x direction
203bd882c8aSJames Wright     }
204bd882c8aSJames Wright   }
205bd882c8aSJames Wright   work_group_barrier(CLK_LOCAL_MEM_FENCE);
206bd882c8aSJames Wright }
207bd882c8aSJames Wright 
208bd882c8aSJames Wright //------------------------------------------------------------------------------
209bd882c8aSJames Wright // 2D interpolate to quadrature points
210bd882c8aSJames Wright //------------------------------------------------------------------------------
211bd882c8aSJames Wright inline void InterpTensor2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict r_U,
212bd882c8aSJames Wright                            local const CeedScalar *restrict s_B, private CeedScalar *restrict r_V, local CeedScalar *restrict scratch) {
213bd882c8aSJames Wright   CeedScalar r_t[1];
214bd882c8aSJames Wright 
215bd882c8aSJames Wright   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
216bd882c8aSJames Wright     ContractX2d(P_1D, Q_1D, r_U + comp, s_B, r_t, scratch);
217bd882c8aSJames Wright     ContractY2d(P_1D, Q_1D, r_t, s_B, r_V + comp, scratch);
218bd882c8aSJames Wright   }
219bd882c8aSJames Wright }
220bd882c8aSJames Wright 
221bd882c8aSJames Wright //------------------------------------------------------------------------------
222bd882c8aSJames Wright // 2D interpolate transpose
223bd882c8aSJames Wright //------------------------------------------------------------------------------
224bd882c8aSJames Wright inline void InterpTransposeTensor2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict r_U,
225bd882c8aSJames Wright                                     local const CeedScalar *restrict s_B, private CeedScalar *restrict r_V, local CeedScalar *restrict scratch) {
226bd882c8aSJames Wright   CeedScalar r_t[1];
227bd882c8aSJames Wright 
228bd882c8aSJames Wright   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
229bd882c8aSJames Wright     ContractTransposeY2d(P_1D, Q_1D, r_U + comp, s_B, r_t, scratch);
230bd882c8aSJames Wright     ContractTransposeX2d(P_1D, Q_1D, r_t, s_B, r_V + comp, scratch);
231bd882c8aSJames Wright   }
232bd882c8aSJames Wright }
233bd882c8aSJames Wright 
234bd882c8aSJames Wright //------------------------------------------------------------------------------
235bd882c8aSJames Wright // 2D derivatives at quadrature points
236bd882c8aSJames Wright //------------------------------------------------------------------------------
237bd882c8aSJames Wright inline void GradTensor2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict r_U,
238bd882c8aSJames Wright                          local const CeedScalar *restrict s_B, local const CeedScalar *restrict s_G, private CeedScalar *restrict r_V,
239bd882c8aSJames Wright                          local CeedScalar *restrict scratch) {
240bd882c8aSJames Wright   CeedScalar r_t[1];
241bd882c8aSJames Wright 
242bd882c8aSJames Wright   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
243bd882c8aSJames Wright     ContractX2d(P_1D, Q_1D, r_U + comp, s_G, r_t, scratch);
244bd882c8aSJames Wright     ContractY2d(P_1D, Q_1D, r_t, s_B, r_V + comp + 0 * NUM_COMP, scratch);
245bd882c8aSJames Wright     ContractX2d(P_1D, Q_1D, r_U + comp, s_B, r_t, scratch);
246bd882c8aSJames Wright     ContractY2d(P_1D, Q_1D, r_t, s_G, r_V + comp + 1 * NUM_COMP, scratch);
247bd882c8aSJames Wright   }
248bd882c8aSJames Wright }
249bd882c8aSJames Wright 
250bd882c8aSJames Wright //------------------------------------------------------------------------------
251bd882c8aSJames Wright // 2D derivatives transpose
252bd882c8aSJames Wright //------------------------------------------------------------------------------
253bd882c8aSJames Wright inline void GradTransposeTensor2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict r_U,
254bd882c8aSJames Wright                                   local const CeedScalar *restrict s_B, local const CeedScalar *restrict s_G, private CeedScalar *restrict r_V,
255bd882c8aSJames Wright                                   local CeedScalar *restrict scratch) {
256bd882c8aSJames Wright   CeedScalar r_t[1];
257bd882c8aSJames Wright 
258bd882c8aSJames Wright   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
259bd882c8aSJames Wright     ContractTransposeY2d(P_1D, Q_1D, r_U + comp + 0 * NUM_COMP, s_B, r_t, scratch);
260bd882c8aSJames Wright     ContractTransposeX2d(P_1D, Q_1D, r_t, s_G, r_V + comp, scratch);
261bd882c8aSJames Wright     ContractTransposeY2d(P_1D, Q_1D, r_U + comp + 1 * NUM_COMP, s_G, r_t, scratch);
262bd882c8aSJames Wright     ContractTransposeAddX2d(P_1D, Q_1D, r_t, s_B, r_V + comp, scratch);
263bd882c8aSJames Wright   }
264bd882c8aSJames Wright }
265bd882c8aSJames Wright 
266bd882c8aSJames Wright //------------------------------------------------------------------------------
267bd882c8aSJames Wright // 2D quadrature weights
268bd882c8aSJames Wright //------------------------------------------------------------------------------
269bd882c8aSJames Wright inline void WeightTensor2d(const CeedInt Q_1D, const CeedScalar *restrict q_weight_1d, CeedScalar *restrict w) {
270bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
271bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
272bd882c8aSJames Wright 
273bd882c8aSJames Wright   *w = (item_id_x < Q_1D && item_id_y < Q_1D) ? q_weight_1d[item_id_x] * q_weight_1d[item_id_y] : 0.0;
274bd882c8aSJames Wright }
275bd882c8aSJames Wright 
276bd882c8aSJames Wright //------------------------------------------------------------------------------
277bd882c8aSJames Wright // 3D
278bd882c8aSJames Wright //------------------------------------------------------------------------------
279bd882c8aSJames Wright 
280bd882c8aSJames Wright //------------------------------------------------------------------------------
281bd882c8aSJames Wright // 3D tensor contract x
282bd882c8aSJames Wright //------------------------------------------------------------------------------
283bd882c8aSJames Wright inline void ContractX3d(const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict U, local const CeedScalar *restrict B,
284bd882c8aSJames Wright                         private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
285bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
286bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
287bd882c8aSJames Wright 
288bd882c8aSJames Wright   CeedScalar r_B[T_1D];
289bd882c8aSJames Wright   for (CeedInt i = 0; i < P_1D; i++) {
290bd882c8aSJames Wright     r_B[i] = B[i + item_id_x * P_1D];
291bd882c8aSJames Wright   }
292bd882c8aSJames Wright 
293bd882c8aSJames Wright   for (CeedInt k = 0; k < P_1D; k++) {
294bd882c8aSJames Wright     scratch[item_id_x + item_id_y * T_1D] = U[k];
295bd882c8aSJames Wright     work_group_barrier(CLK_LOCAL_MEM_FENCE);
296bd882c8aSJames Wright 
297bd882c8aSJames Wright     V[k] = 0.0;
298bd882c8aSJames Wright     if (item_id_x < Q_1D && item_id_y < P_1D) {
299bd882c8aSJames Wright       for (CeedInt i = 0; i < P_1D; i++) {
300bd882c8aSJames Wright         V[k] += r_B[i] * scratch[i + item_id_y * T_1D];  // Contract x direction
301bd882c8aSJames Wright       }
302bd882c8aSJames Wright     }
303bd882c8aSJames Wright     work_group_barrier(CLK_LOCAL_MEM_FENCE);
304bd882c8aSJames Wright   }
305bd882c8aSJames Wright }
306bd882c8aSJames Wright 
307bd882c8aSJames Wright //------------------------------------------------------------------------------
308bd882c8aSJames Wright // 3D tensor contract y
309bd882c8aSJames Wright //------------------------------------------------------------------------------
310bd882c8aSJames Wright inline void ContractY3d(const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict U, local const CeedScalar *restrict B,
311bd882c8aSJames Wright                         private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
312bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
313bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
314bd882c8aSJames Wright 
315bd882c8aSJames Wright   CeedScalar r_B[T_1D];
316bd882c8aSJames Wright   for (CeedInt i = 0; i < P_1D; i++) {
317bd882c8aSJames Wright     r_B[i] = B[i + item_id_y * P_1D];
318bd882c8aSJames Wright   }
319bd882c8aSJames Wright 
320bd882c8aSJames Wright   for (CeedInt k = 0; k < P_1D; k++) {
321bd882c8aSJames Wright     scratch[item_id_x + item_id_y * T_1D] = U[k];
322bd882c8aSJames Wright     work_group_barrier(CLK_LOCAL_MEM_FENCE);
323bd882c8aSJames Wright 
324bd882c8aSJames Wright     V[k] = 0.0;
325bd882c8aSJames Wright     if (item_id_x < Q_1D && item_id_y < Q_1D) {
326bd882c8aSJames Wright       for (CeedInt i = 0; i < P_1D; i++) {
327bd882c8aSJames Wright         V[k] += r_B[i] * scratch[item_id_x + i * T_1D];  // Contract y direction
328bd882c8aSJames Wright       }
329bd882c8aSJames Wright     }
330bd882c8aSJames Wright     work_group_barrier(CLK_LOCAL_MEM_FENCE);
331bd882c8aSJames Wright   }
332bd882c8aSJames Wright }
333bd882c8aSJames Wright 
334bd882c8aSJames Wright //------------------------------------------------------------------------------
335bd882c8aSJames Wright // 3D tensor contract z
336bd882c8aSJames Wright //------------------------------------------------------------------------------
337bd882c8aSJames Wright inline void ContractZ3d(const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict U, local const CeedScalar *restrict B,
338bd882c8aSJames Wright                         private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
339bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
340bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
341bd882c8aSJames Wright 
342bd882c8aSJames Wright   for (CeedInt k = 0; k < Q_1D; k++) {
343bd882c8aSJames Wright     V[k] = 0.0;
344bd882c8aSJames Wright     if (item_id_x < Q_1D && item_id_y < Q_1D) {
345bd882c8aSJames Wright       for (CeedInt i = 0; i < P_1D; i++) {
346bd882c8aSJames Wright         V[k] += B[i + k * P_1D] * U[i];  // Contract z direction
347bd882c8aSJames Wright       }
348bd882c8aSJames Wright     }
349bd882c8aSJames Wright   }
350bd882c8aSJames Wright }
351bd882c8aSJames Wright 
352bd882c8aSJames Wright //------------------------------------------------------------------------------
353bd882c8aSJames Wright // 3D transpose tensor contract z
354bd882c8aSJames Wright //------------------------------------------------------------------------------
355bd882c8aSJames Wright inline void ContractTransposeZ3d(const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict U, local const CeedScalar *restrict B,
356bd882c8aSJames Wright                                  private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
357bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
358bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
359bd882c8aSJames Wright 
360bd882c8aSJames Wright   for (CeedInt k = 0; k < P_1D; k++) {
361bd882c8aSJames Wright     V[k] = 0.0;
362bd882c8aSJames Wright     if (item_id_x < Q_1D && item_id_y < Q_1D) {
363bd882c8aSJames Wright       for (CeedInt i = 0; i < Q_1D; i++) {
364bd882c8aSJames Wright         V[k] += B[k + i * P_1D] * U[i];  // Contract z direction
365bd882c8aSJames Wright       }
366bd882c8aSJames Wright     }
367bd882c8aSJames Wright   }
368bd882c8aSJames Wright }
369bd882c8aSJames Wright 
370bd882c8aSJames Wright //------------------------------------------------------------------------------
371bd882c8aSJames Wright // 3D transpose tensor contract y
372bd882c8aSJames Wright //------------------------------------------------------------------------------
373bd882c8aSJames Wright inline void ContractTransposeY3d(const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict U, local const CeedScalar *restrict B,
374bd882c8aSJames Wright                                  private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
375bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
376bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
377bd882c8aSJames Wright 
378bd882c8aSJames Wright   CeedScalar r_B[T_1D];
379bd882c8aSJames Wright   for (CeedInt i = 0; i < Q_1D; i++) {
380bd882c8aSJames Wright     r_B[i] = B[item_id_y + i * P_1D];
381bd882c8aSJames Wright   }
382bd882c8aSJames Wright 
383bd882c8aSJames Wright   for (CeedInt k = 0; k < P_1D; k++) {
384bd882c8aSJames Wright     scratch[item_id_x + item_id_y * T_1D] = U[k];
385bd882c8aSJames Wright     work_group_barrier(CLK_LOCAL_MEM_FENCE);
386bd882c8aSJames Wright 
387bd882c8aSJames Wright     V[k] = 0.0;
388bd882c8aSJames Wright     if (item_id_x < Q_1D && item_id_y < P_1D) {
389bd882c8aSJames Wright       for (CeedInt i = 0; i < Q_1D; i++) {
390bd882c8aSJames Wright         V[k] += r_B[i] * scratch[item_id_x + i * T_1D];  // Contract y direction
391bd882c8aSJames Wright       }
392bd882c8aSJames Wright     }
393bd882c8aSJames Wright     work_group_barrier(CLK_LOCAL_MEM_FENCE);
394bd882c8aSJames Wright   }
395bd882c8aSJames Wright }
396bd882c8aSJames Wright 
397bd882c8aSJames Wright //------------------------------------------------------------------------------
398bd882c8aSJames Wright // 3D transpose tensor contract y
399bd882c8aSJames Wright //------------------------------------------------------------------------------
400bd882c8aSJames Wright inline void ContractTransposeAddY3d(const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict U, local const CeedScalar *restrict B,
401bd882c8aSJames Wright                                     private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
402bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
403bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
404bd882c8aSJames Wright 
405bd882c8aSJames Wright   CeedScalar r_B[T_1D];
406bd882c8aSJames Wright   for (CeedInt i = 0; i < Q_1D; i++) {
407bd882c8aSJames Wright     r_B[i] = B[item_id_y + i * P_1D];
408bd882c8aSJames Wright   }
409bd882c8aSJames Wright 
410bd882c8aSJames Wright   for (CeedInt k = 0; k < P_1D; k++) {
411bd882c8aSJames Wright     scratch[item_id_x + item_id_y * T_1D] = U[k];
412bd882c8aSJames Wright     work_group_barrier(CLK_LOCAL_MEM_FENCE);
413bd882c8aSJames Wright     if (item_id_x < Q_1D && item_id_y < P_1D) {
414bd882c8aSJames Wright       for (CeedInt i = 0; i < Q_1D; i++) {
415bd882c8aSJames Wright         V[k] += r_B[i] * scratch[item_id_x + i * T_1D];  // Contract y direction
416bd882c8aSJames Wright       }
417bd882c8aSJames Wright     }
418bd882c8aSJames Wright     work_group_barrier(CLK_LOCAL_MEM_FENCE);
419bd882c8aSJames Wright   }
420bd882c8aSJames Wright }
421bd882c8aSJames Wright 
422bd882c8aSJames Wright //------------------------------------------------------------------------------
423bd882c8aSJames Wright // 3D transpose tensor contract x
424bd882c8aSJames Wright //------------------------------------------------------------------------------
425bd882c8aSJames Wright inline void ContractTransposeX3d(const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict U, local const CeedScalar *restrict B,
426bd882c8aSJames Wright                                  private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
427bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
428bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
429bd882c8aSJames Wright 
430bd882c8aSJames Wright   CeedScalar r_B[T_1D];
431bd882c8aSJames Wright   for (CeedInt i = 0; i < Q_1D; i++) {
432bd882c8aSJames Wright     r_B[i] = B[item_id_x + i * P_1D];
433bd882c8aSJames Wright   }
434bd882c8aSJames Wright 
435bd882c8aSJames Wright   for (CeedInt k = 0; k < P_1D; k++) {
436bd882c8aSJames Wright     scratch[item_id_x + item_id_y * T_1D] = U[k];
437bd882c8aSJames Wright     work_group_barrier(CLK_LOCAL_MEM_FENCE);
438bd882c8aSJames Wright     V[k] = 0.0;
439bd882c8aSJames Wright     if (item_id_x < P_1D && item_id_y < P_1D) {
440bd882c8aSJames Wright       for (CeedInt i = 0; i < Q_1D; i++) {
441bd882c8aSJames Wright         V[k] += r_B[i] * scratch[i + item_id_y * T_1D];  // Contract x direction
442bd882c8aSJames Wright       }
443bd882c8aSJames Wright     }
444bd882c8aSJames Wright     work_group_barrier(CLK_LOCAL_MEM_FENCE);
445bd882c8aSJames Wright   }
446bd882c8aSJames Wright }
447bd882c8aSJames Wright 
448bd882c8aSJames Wright //------------------------------------------------------------------------------
449bd882c8aSJames Wright // 3D transpose tensor contract add x
450bd882c8aSJames Wright //------------------------------------------------------------------------------
451bd882c8aSJames Wright inline void ContractTransposeAddX3d(const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict U, local const CeedScalar *restrict B,
452bd882c8aSJames Wright                                     private CeedScalar *restrict V, local CeedScalar *restrict scratch) {
453bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
454bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
455bd882c8aSJames Wright 
456bd882c8aSJames Wright   CeedScalar r_B[T_1D];
457bd882c8aSJames Wright   for (CeedInt i = 0; i < Q_1D; i++) {
458bd882c8aSJames Wright     r_B[i] = B[item_id_x + i * P_1D];
459bd882c8aSJames Wright   }
460bd882c8aSJames Wright 
461bd882c8aSJames Wright   for (CeedInt k = 0; k < P_1D; k++) {
462bd882c8aSJames Wright     scratch[item_id_x + item_id_y * T_1D] = U[k];
463bd882c8aSJames Wright     work_group_barrier(CLK_LOCAL_MEM_FENCE);
464bd882c8aSJames Wright 
465bd882c8aSJames Wright     if (item_id_x < P_1D && item_id_y < P_1D) {
466bd882c8aSJames Wright       for (CeedInt i = 0; i < Q_1D; i++) {
467bd882c8aSJames Wright         V[k] += r_B[i] * scratch[i + item_id_y * T_1D];  // Contract x direction
468bd882c8aSJames Wright       }
469bd882c8aSJames Wright     }
470bd882c8aSJames Wright     work_group_barrier(CLK_LOCAL_MEM_FENCE);
471bd882c8aSJames Wright   }
472bd882c8aSJames Wright }
473bd882c8aSJames Wright 
474bd882c8aSJames Wright //------------------------------------------------------------------------------
475bd882c8aSJames Wright // 3D interpolate to quadrature points
476bd882c8aSJames Wright //------------------------------------------------------------------------------
477bd882c8aSJames Wright inline void InterpTensor3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict r_U,
478bd882c8aSJames Wright                            local const CeedScalar *restrict s_B, private CeedScalar *restrict r_V, local CeedScalar *restrict scratch) {
479bd882c8aSJames Wright   CeedScalar r_t1[T_1D];
480bd882c8aSJames Wright   CeedScalar r_t2[T_1D];
481bd882c8aSJames Wright 
482bd882c8aSJames Wright   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
483bd882c8aSJames Wright     ContractX3d(P_1D, Q_1D, r_U + comp * P_1D, s_B, r_t1, scratch);
484bd882c8aSJames Wright     ContractY3d(P_1D, Q_1D, r_t1, s_B, r_t2, scratch);
485bd882c8aSJames Wright     ContractZ3d(P_1D, Q_1D, r_t2, s_B, r_V + comp * Q_1D, scratch);
486bd882c8aSJames Wright   }
487bd882c8aSJames Wright }
488bd882c8aSJames Wright 
489bd882c8aSJames Wright //------------------------------------------------------------------------------
490bd882c8aSJames Wright // 3D interpolate transpose
491bd882c8aSJames Wright //------------------------------------------------------------------------------
492bd882c8aSJames Wright inline void InterpTransposeTensor3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict r_U,
493bd882c8aSJames Wright                                     local const CeedScalar *restrict s_B, private CeedScalar *restrict r_V, local CeedScalar *restrict scratch) {
494bd882c8aSJames Wright   CeedScalar r_t1[T_1D];
495bd882c8aSJames Wright   CeedScalar r_t2[T_1D];
496bd882c8aSJames Wright 
497bd882c8aSJames Wright   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
498bd882c8aSJames Wright     ContractTransposeZ3d(P_1D, Q_1D, r_U + comp * Q_1D, s_B, r_t1, scratch);
499bd882c8aSJames Wright     ContractTransposeY3d(P_1D, Q_1D, r_t1, s_B, r_t2, scratch);
500bd882c8aSJames Wright     ContractTransposeX3d(P_1D, Q_1D, r_t2, s_B, r_V + comp * P_1D, scratch);
501bd882c8aSJames Wright   }
502bd882c8aSJames Wright }
503bd882c8aSJames Wright 
504bd882c8aSJames Wright //------------------------------------------------------------------------------
505bd882c8aSJames Wright // 3D derivatives at quadrature points
506bd882c8aSJames Wright //------------------------------------------------------------------------------
507bd882c8aSJames Wright inline void GradTensor3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict r_U,
508bd882c8aSJames Wright                          local const CeedScalar *restrict s_B, local const CeedScalar *restrict s_G, private CeedScalar *restrict r_V,
509bd882c8aSJames Wright                          local CeedScalar *restrict scratch) {
510bd882c8aSJames Wright   CeedScalar r_t1[T_1D];
511bd882c8aSJames Wright   CeedScalar r_t2[T_1D];
512bd882c8aSJames Wright 
513bd882c8aSJames Wright   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
514bd882c8aSJames Wright     ContractX3d(P_1D, Q_1D, r_U + comp * P_1D, s_G, r_t1, scratch);
515bd882c8aSJames Wright     ContractY3d(P_1D, Q_1D, r_t1, s_B, r_t2, scratch);
516bd882c8aSJames Wright     ContractZ3d(P_1D, Q_1D, r_t2, s_B, r_V + comp * Q_1D + 0 * NUM_COMP * Q_1D, scratch);
517bd882c8aSJames Wright     ContractX3d(P_1D, Q_1D, r_U + comp * P_1D, s_B, r_t1, scratch);
518bd882c8aSJames Wright     ContractY3d(P_1D, Q_1D, r_t1, s_G, r_t2, scratch);
519bd882c8aSJames Wright     ContractZ3d(P_1D, Q_1D, r_t2, s_B, r_V + comp * Q_1D + 1 * NUM_COMP * Q_1D, scratch);
520bd882c8aSJames Wright     ContractX3d(P_1D, Q_1D, r_U + comp * P_1D, s_B, r_t1, scratch);
521bd882c8aSJames Wright     ContractY3d(P_1D, Q_1D, r_t1, s_B, r_t2, scratch);
522bd882c8aSJames Wright     ContractZ3d(P_1D, Q_1D, r_t2, s_G, r_V + comp * Q_1D + 2 * NUM_COMP * Q_1D, scratch);
523bd882c8aSJames Wright   }
524bd882c8aSJames Wright }
525bd882c8aSJames Wright 
526bd882c8aSJames Wright //------------------------------------------------------------------------------
527bd882c8aSJames Wright // 3D derivatives transpose
528bd882c8aSJames Wright //------------------------------------------------------------------------------
529bd882c8aSJames Wright inline void GradTransposeTensor3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict r_U,
530bd882c8aSJames Wright                                   local const CeedScalar *restrict s_B, local const CeedScalar *restrict s_G, private CeedScalar *restrict r_V,
531bd882c8aSJames Wright                                   local CeedScalar *restrict scratch) {
532bd882c8aSJames Wright   CeedScalar r_t1[T_1D];
533bd882c8aSJames Wright   CeedScalar r_t2[T_1D];
534bd882c8aSJames Wright 
535bd882c8aSJames Wright   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
536bd882c8aSJames Wright     ContractTransposeZ3d(P_1D, Q_1D, r_U + comp * Q_1D + 0 * NUM_COMP * Q_1D, s_B, r_t1, scratch);
537bd882c8aSJames Wright     ContractTransposeY3d(P_1D, Q_1D, r_t1, s_B, r_t2, scratch);
538bd882c8aSJames Wright     ContractTransposeX3d(P_1D, Q_1D, r_t2, s_G, r_V + comp * P_1D, scratch);
539bd882c8aSJames Wright     ContractTransposeZ3d(P_1D, Q_1D, r_U + comp * Q_1D + 1 * NUM_COMP * Q_1D, s_B, r_t1, scratch);
540bd882c8aSJames Wright     ContractTransposeY3d(P_1D, Q_1D, r_t1, s_G, r_t2, scratch);
541bd882c8aSJames Wright     ContractTransposeAddX3d(P_1D, Q_1D, r_t2, s_B, r_V + comp * P_1D, scratch);
542bd882c8aSJames Wright     ContractTransposeZ3d(P_1D, Q_1D, r_U + comp * Q_1D + 2 * NUM_COMP * Q_1D, s_G, r_t1, scratch);
543bd882c8aSJames Wright     ContractTransposeY3d(P_1D, Q_1D, r_t1, s_B, r_t2, scratch);
544bd882c8aSJames Wright     ContractTransposeAddX3d(P_1D, Q_1D, r_t2, s_B, r_V + comp * P_1D, scratch);
545bd882c8aSJames Wright   }
546bd882c8aSJames Wright }
547bd882c8aSJames Wright 
548bd882c8aSJames Wright //------------------------------------------------------------------------------
549bd882c8aSJames Wright // 3D derivatives at quadrature points
550bd882c8aSJames Wright //------------------------------------------------------------------------------
551bd882c8aSJames Wright inline void GradTensorCollocated3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict r_U,
552bd882c8aSJames Wright                                    local const CeedScalar *restrict s_B, local const CeedScalar *restrict s_G, private CeedScalar *restrict r_V,
553bd882c8aSJames Wright                                    local CeedScalar *restrict scratch) {
554bd882c8aSJames Wright   CeedScalar r_t1[T_1D];
555bd882c8aSJames Wright   CeedScalar r_t2[T_1D];
556bd882c8aSJames Wright 
557bd882c8aSJames Wright   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
558bd882c8aSJames Wright     ContractX3d(P_1D, Q_1D, r_U + comp * P_1D, s_B, r_t1, scratch);
559bd882c8aSJames Wright     ContractY3d(P_1D, Q_1D, r_t1, s_B, r_t2, scratch);
560bd882c8aSJames Wright     ContractZ3d(P_1D, Q_1D, r_t2, s_B, r_t1, scratch);
561bd882c8aSJames Wright     ContractX3d(Q_1D, Q_1D, r_t1, s_G, r_V + comp * Q_1D + 0 * NUM_COMP * Q_1D, scratch);
562bd882c8aSJames Wright     ContractY3d(Q_1D, Q_1D, r_t1, s_G, r_V + comp * Q_1D + 1 * NUM_COMP * Q_1D, scratch);
563bd882c8aSJames Wright     ContractZ3d(Q_1D, Q_1D, r_t1, s_G, r_V + comp * Q_1D + 2 * NUM_COMP * Q_1D, scratch);
564bd882c8aSJames Wright   }
565bd882c8aSJames Wright }
566bd882c8aSJames Wright 
567bd882c8aSJames Wright //------------------------------------------------------------------------------
568bd882c8aSJames Wright // 3D derivatives transpose
569bd882c8aSJames Wright //------------------------------------------------------------------------------
570bd882c8aSJames Wright inline void GradTransposeTensorCollocated3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt Q_1D, private const CeedScalar *restrict r_U,
571bd882c8aSJames Wright                                             local const CeedScalar *restrict s_B, local const CeedScalar *restrict s_G,
572bd882c8aSJames Wright                                             private CeedScalar *restrict r_V, local CeedScalar *restrict scratch) {
573bd882c8aSJames Wright   CeedScalar r_t1[T_1D];
574bd882c8aSJames Wright   CeedScalar r_t2[T_1D];
575bd882c8aSJames Wright 
576bd882c8aSJames Wright   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
577bd882c8aSJames Wright     ContractTransposeZ3d(Q_1D, Q_1D, r_U + comp * Q_1D + 2 * NUM_COMP * Q_1D, s_G, r_t2, scratch);
578bd882c8aSJames Wright     ContractTransposeAddY3d(Q_1D, Q_1D, r_U + comp * Q_1D + 1 * NUM_COMP * Q_1D, s_G, r_t2, scratch);
579bd882c8aSJames Wright     ContractTransposeAddX3d(Q_1D, Q_1D, r_U + comp * Q_1D + 0 * NUM_COMP * Q_1D, s_G, r_t2, scratch);
580bd882c8aSJames Wright     ContractTransposeZ3d(P_1D, Q_1D, r_t2, s_B, r_t1, scratch);
581bd882c8aSJames Wright     ContractTransposeY3d(P_1D, Q_1D, r_t1, s_B, r_t2, scratch);
582bd882c8aSJames Wright     ContractTransposeX3d(P_1D, Q_1D, r_t2, s_B, r_V + comp * P_1D, scratch);
583bd882c8aSJames Wright   }
584bd882c8aSJames Wright }
585bd882c8aSJames Wright 
586bd882c8aSJames Wright //------------------------------------------------------------------------------
587bd882c8aSJames Wright // 3D quadrature weights
588bd882c8aSJames Wright //------------------------------------------------------------------------------
589bd882c8aSJames Wright // template <int Q_1D>
590bd882c8aSJames Wright inline void WeightTensor3d(const CeedInt Q_1D, const CeedScalar *restrict q_weight_1d, CeedScalar *restrict w) {
591bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
592bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
593bd882c8aSJames Wright 
594bd882c8aSJames Wright   if (item_id_x < Q_1D && item_id_y < Q_1D) {
595bd882c8aSJames Wright     const CeedScalar w_xy = q_weight_1d[item_id_x] * q_weight_1d[item_id_y];
596bd882c8aSJames Wright     for (CeedInt q = 0; q < Q_1D; ++q) w[q] = w_xy * q_weight_1d[q];
597bd882c8aSJames Wright   } else {
598bd882c8aSJames Wright     for (CeedInt q = 0; q < Q_1D; q++) w[q] = 0.0;
599bd882c8aSJames Wright   }
600bd882c8aSJames Wright }
601bd882c8aSJames Wright 
602bd882c8aSJames Wright //------------------------------------------------------------------------------
603bd882c8aSJames Wright 
60494b7b29bSJeremy L Thompson #endif  // CEED_SYCL_SHARED_BASIS_TENSOR_TEMPLATES_H
605