xref: /libCEED/rust/libceed-sys/c-src/include/ceed/jit-source/sycl/sycl-gen-templates.h (revision 6ca0f394dabdca92269b68ec74be8bebae3befa4)
1*6ca0f394SUmesh Unnikrishnan // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*6ca0f394SUmesh Unnikrishnan // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*6ca0f394SUmesh Unnikrishnan //
4*6ca0f394SUmesh Unnikrishnan // SPDX-License-Identifier: BSD-2-Clause
5*6ca0f394SUmesh Unnikrishnan //
6*6ca0f394SUmesh Unnikrishnan // This file is part of CEED:  http://github.com/ceed
7*6ca0f394SUmesh Unnikrishnan 
8*6ca0f394SUmesh Unnikrishnan /// @file
9*6ca0f394SUmesh Unnikrishnan /// Internal header for SYCL backend macro and type definitions for JiT source
10*6ca0f394SUmesh Unnikrishnan #ifndef _ceed_sycl_gen_templates_h
11*6ca0f394SUmesh Unnikrishnan #define _ceed_sycl_gen_templates_h
12*6ca0f394SUmesh Unnikrishnan 
13*6ca0f394SUmesh Unnikrishnan #include <ceed/types.h>
14*6ca0f394SUmesh Unnikrishnan 
15*6ca0f394SUmesh Unnikrishnan #pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
16*6ca0f394SUmesh Unnikrishnan #pragma OPENCL EXTENSION cl_khr_int64_extended_atomics : enable
17*6ca0f394SUmesh Unnikrishnan // TODO: Handle FP32 case
18*6ca0f394SUmesh Unnikrishnan typedef atomic_double CeedAtomicScalar;
19*6ca0f394SUmesh Unnikrishnan 
20*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
21*6ca0f394SUmesh Unnikrishnan // Load matrices for basis actions
22*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
23*6ca0f394SUmesh Unnikrishnan inline void loadMatrix(const CeedInt N, const CeedScalar* restrict d_B, CeedScalar* restrict B) {
24*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id    = get_local_linear_id();
25*6ca0f394SUmesh Unnikrishnan   const CeedInt group_size = get_local_size(0) * get_local_size(1) * get_local_size(2);
26*6ca0f394SUmesh Unnikrishnan   for (CeedInt i = item_id; i < N; i += group_size) B[i] = d_B[i];
27*6ca0f394SUmesh Unnikrishnan }
28*6ca0f394SUmesh Unnikrishnan 
29*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
30*6ca0f394SUmesh Unnikrishnan // 1D
31*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
32*6ca0f394SUmesh Unnikrishnan 
33*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
34*6ca0f394SUmesh Unnikrishnan // L-vector -> E-vector, offsets provided
35*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
36*6ca0f394SUmesh Unnikrishnan inline void readDofsOffset1d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt P_1D, const CeedInt num_elem,
37*6ca0f394SUmesh Unnikrishnan                              const global CeedInt* restrict indices, const global CeedScalar* restrict d_u, private CeedScalar* restrict r_u) {
38*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
39*6ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
40*6ca0f394SUmesh Unnikrishnan 
41*6ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && elem < num_elem) {
42*6ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x;
43*6ca0f394SUmesh Unnikrishnan     const CeedInt ind  = indices[node + elem * P_1D];
44*6ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp) {
45*6ca0f394SUmesh Unnikrishnan       r_u[comp] = d_u[ind + strides_comp * comp];
46*6ca0f394SUmesh Unnikrishnan     }
47*6ca0f394SUmesh Unnikrishnan   }
48*6ca0f394SUmesh Unnikrishnan }
49*6ca0f394SUmesh Unnikrishnan 
50*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
51*6ca0f394SUmesh Unnikrishnan // L-vector -> E-vector, strided
52*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
53*6ca0f394SUmesh Unnikrishnan inline void readDofsStrided1d(const CeedInt num_comp, const CeedInt P_1D, const CeedInt strides_node, const CeedInt strides_comp,
54*6ca0f394SUmesh Unnikrishnan                               const CeedInt strides_elem, const CeedInt num_elem, global const CeedScalar* restrict d_u,
55*6ca0f394SUmesh Unnikrishnan                               private CeedScalar* restrict r_u) {
56*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
57*6ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
58*6ca0f394SUmesh Unnikrishnan 
59*6ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && elem < num_elem) {
60*6ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x;
61*6ca0f394SUmesh Unnikrishnan     const CeedInt ind  = node * strides_node + elem * strides_elem;
62*6ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; comp++) {
63*6ca0f394SUmesh Unnikrishnan       r_u[comp] = d_u[ind + comp * strides_comp];
64*6ca0f394SUmesh Unnikrishnan     }
65*6ca0f394SUmesh Unnikrishnan   }
66*6ca0f394SUmesh Unnikrishnan }
67*6ca0f394SUmesh Unnikrishnan 
68*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
69*6ca0f394SUmesh Unnikrishnan // E-vector -> L-vector, offsets provided
70*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
71*6ca0f394SUmesh Unnikrishnan inline void writeDofsOffset1d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt P_1D, const CeedInt num_elem,
72*6ca0f394SUmesh Unnikrishnan                               const global CeedInt* restrict indices, const private CeedScalar* restrict r_v, global CeedAtomicScalar* restrict d_v) {
73*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
74*6ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
75*6ca0f394SUmesh Unnikrishnan 
76*6ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && elem < num_elem) {
77*6ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x;
78*6ca0f394SUmesh Unnikrishnan     const CeedInt ind  = indices[node + elem * P_1D];
79*6ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp)
80*6ca0f394SUmesh Unnikrishnan       atomic_fetch_add_explicit(&d_v[ind + strides_comp * comp], r_v[comp], memory_order_relaxed, memory_scope_device);
81*6ca0f394SUmesh Unnikrishnan   }
82*6ca0f394SUmesh Unnikrishnan }
83*6ca0f394SUmesh Unnikrishnan 
84*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
85*6ca0f394SUmesh Unnikrishnan // E-vector -> L-vector, strided
86*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
87*6ca0f394SUmesh Unnikrishnan inline void writeDofsStrided1d(const CeedInt num_comp, const CeedInt P_1D, const CeedInt strides_node, const CeedInt strides_comp,
88*6ca0f394SUmesh Unnikrishnan                                const CeedInt strides_elem, const CeedInt num_elem, private const CeedScalar* restrict r_v,
89*6ca0f394SUmesh Unnikrishnan                                global CeedScalar* restrict d_v) {
90*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
91*6ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
92*6ca0f394SUmesh Unnikrishnan 
93*6ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && elem < num_elem) {
94*6ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x;
95*6ca0f394SUmesh Unnikrishnan     const CeedInt ind  = node * strides_node + elem * strides_elem;
96*6ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; comp++) {
97*6ca0f394SUmesh Unnikrishnan       d_v[ind + comp * strides_comp] = r_v[comp];
98*6ca0f394SUmesh Unnikrishnan     }
99*6ca0f394SUmesh Unnikrishnan   }
100*6ca0f394SUmesh Unnikrishnan }
101*6ca0f394SUmesh Unnikrishnan 
102*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
103*6ca0f394SUmesh Unnikrishnan // 2D
104*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
105*6ca0f394SUmesh Unnikrishnan 
106*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
107*6ca0f394SUmesh Unnikrishnan // L-vector -> E-vector, offsets provided
108*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
109*6ca0f394SUmesh Unnikrishnan inline void readDofsOffset2d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt P_1D, const CeedInt num_elem,
110*6ca0f394SUmesh Unnikrishnan                              const global CeedInt* restrict indices, const global CeedScalar* restrict d_u, private CeedScalar* restrict r_u) {
111*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
112*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
113*6ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
114*6ca0f394SUmesh Unnikrishnan 
115*6ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
116*6ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x + item_id_y * P_1D;
117*6ca0f394SUmesh Unnikrishnan     const CeedInt ind  = indices[node + elem * P_1D * P_1D];
118*6ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[comp] = d_u[ind + strides_comp * comp];
119*6ca0f394SUmesh Unnikrishnan   }
120*6ca0f394SUmesh Unnikrishnan }
121*6ca0f394SUmesh Unnikrishnan 
122*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
123*6ca0f394SUmesh Unnikrishnan // L-vector -> E-vector, strided
124*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
125*6ca0f394SUmesh Unnikrishnan inline void readDofsStrided2d(const CeedInt num_comp, const CeedInt P_1D, const CeedInt strides_node, const CeedInt strides_comp,
126*6ca0f394SUmesh Unnikrishnan                               const CeedInt strides_elem, const CeedInt num_elem, const global CeedScalar* restrict d_u,
127*6ca0f394SUmesh Unnikrishnan                               private CeedScalar* restrict r_u) {
128*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
129*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
130*6ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
131*6ca0f394SUmesh Unnikrishnan 
132*6ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
133*6ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x + item_id_y * P_1D;
134*6ca0f394SUmesh Unnikrishnan     const CeedInt ind  = node * strides_node + elem * strides_elem;
135*6ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[comp] = d_u[ind + comp * strides_comp];
136*6ca0f394SUmesh Unnikrishnan   }
137*6ca0f394SUmesh Unnikrishnan }
138*6ca0f394SUmesh Unnikrishnan 
139*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
140*6ca0f394SUmesh Unnikrishnan // E-vector -> L-vector, offsets provided
141*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
142*6ca0f394SUmesh Unnikrishnan inline void writeDofsOffset2d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt P_1D, const CeedInt num_elem,
143*6ca0f394SUmesh Unnikrishnan                               const global CeedInt* restrict indices, const private CeedScalar* restrict r_v, global CeedAtomicScalar* restrict d_v) {
144*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
145*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
146*6ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
147*6ca0f394SUmesh Unnikrishnan 
148*6ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
149*6ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x + item_id_y * P_1D;
150*6ca0f394SUmesh Unnikrishnan     const CeedInt ind  = indices[node + elem * P_1D * P_1D];
151*6ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp)
152*6ca0f394SUmesh Unnikrishnan       atomic_fetch_add_explicit(&d_v[ind + strides_comp * comp], r_v[comp], memory_order_relaxed, memory_scope_device);
153*6ca0f394SUmesh Unnikrishnan   }
154*6ca0f394SUmesh Unnikrishnan }
155*6ca0f394SUmesh Unnikrishnan 
156*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
157*6ca0f394SUmesh Unnikrishnan // E-vector -> L-vector, strided
158*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
159*6ca0f394SUmesh Unnikrishnan inline void writeDofsStrided2d(const CeedInt num_comp, const CeedInt P_1D, const CeedInt strides_node, const CeedInt strides_comp,
160*6ca0f394SUmesh Unnikrishnan                                const CeedInt strides_elem, const CeedInt num_elem, const private CeedScalar* restrict r_v,
161*6ca0f394SUmesh Unnikrishnan                                global CeedScalar* restrict d_v) {
162*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
163*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
164*6ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
165*6ca0f394SUmesh Unnikrishnan 
166*6ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
167*6ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x + item_id_y * P_1D;
168*6ca0f394SUmesh Unnikrishnan     const CeedInt ind  = node * strides_node + elem * strides_elem;
169*6ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp) d_v[ind + comp * strides_comp] += r_v[comp];
170*6ca0f394SUmesh Unnikrishnan   }
171*6ca0f394SUmesh Unnikrishnan }
172*6ca0f394SUmesh Unnikrishnan 
173*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
174*6ca0f394SUmesh Unnikrishnan // 3D
175*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
176*6ca0f394SUmesh Unnikrishnan 
177*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
178*6ca0f394SUmesh Unnikrishnan // L-vector -> E-vector, offsets provided
179*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
180*6ca0f394SUmesh Unnikrishnan inline void readDofsOffset3d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt P_1D, const CeedInt num_elem,
181*6ca0f394SUmesh Unnikrishnan                              const global CeedInt* restrict indices, const global CeedScalar* restrict d_u, private CeedScalar* restrict r_u) {
182*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
183*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
184*6ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
185*6ca0f394SUmesh Unnikrishnan 
186*6ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
187*6ca0f394SUmesh Unnikrishnan     for (CeedInt z = 0; z < P_1D; ++z) {
188*6ca0f394SUmesh Unnikrishnan       const CeedInt node = item_id_x + P_1D * (item_id_y + P_1D * z);
189*6ca0f394SUmesh Unnikrishnan       const CeedInt ind  = indices[node + elem * P_1D * P_1D * P_1D];
190*6ca0f394SUmesh Unnikrishnan       for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[z + comp * P_1D] = d_u[ind + strides_comp * comp];
191*6ca0f394SUmesh Unnikrishnan     }
192*6ca0f394SUmesh Unnikrishnan   }
193*6ca0f394SUmesh Unnikrishnan }
194*6ca0f394SUmesh Unnikrishnan 
195*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
196*6ca0f394SUmesh Unnikrishnan // L-vector -> E-vector, strided
197*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
198*6ca0f394SUmesh Unnikrishnan inline void readDofsStrided3d(const CeedInt num_comp, const CeedInt P_1D, const CeedInt strides_node, const CeedInt strides_comp,
199*6ca0f394SUmesh Unnikrishnan                               const CeedInt strides_elem, const CeedInt num_elem, const global CeedScalar* restrict d_u,
200*6ca0f394SUmesh Unnikrishnan                               private CeedScalar* restrict r_u) {
201*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
202*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
203*6ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
204*6ca0f394SUmesh Unnikrishnan 
205*6ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
206*6ca0f394SUmesh Unnikrishnan     for (CeedInt z = 0; z < P_1D; ++z) {
207*6ca0f394SUmesh Unnikrishnan       const CeedInt node = item_id_x + P_1D * (item_id_y + P_1D * z);
208*6ca0f394SUmesh Unnikrishnan       const CeedInt ind  = node * strides_node + elem * strides_elem;
209*6ca0f394SUmesh Unnikrishnan       for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
210*6ca0f394SUmesh Unnikrishnan     }
211*6ca0f394SUmesh Unnikrishnan   }
212*6ca0f394SUmesh Unnikrishnan }
213*6ca0f394SUmesh Unnikrishnan 
214*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
215*6ca0f394SUmesh Unnikrishnan // E-vector -> Q-vector, offests provided
216*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
217*6ca0f394SUmesh Unnikrishnan inline void readSliceQuadsOffset3d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt Q_1D, const CeedInt num_elem, const CeedInt q,
218*6ca0f394SUmesh Unnikrishnan                                    const global CeedInt* restrict indices, const global CeedScalar* restrict d_u, private CeedScalar* restrict r_u) {
219*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
220*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
221*6ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
222*6ca0f394SUmesh Unnikrishnan 
223*6ca0f394SUmesh Unnikrishnan   if (item_id_x < Q_1D && item_id_y < Q_1D && elem < num_elem) {
224*6ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x + Q_1D * (item_id_y + Q_1D * q);
225*6ca0f394SUmesh Unnikrishnan     const CeedInt ind  = indices[node + elem * Q_1D * Q_1D * Q_1D];
226*6ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[comp] = d_u[ind + strides_comp * comp];
227*6ca0f394SUmesh Unnikrishnan   }
228*6ca0f394SUmesh Unnikrishnan }
229*6ca0f394SUmesh Unnikrishnan 
230*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
231*6ca0f394SUmesh Unnikrishnan // E-vector -> Q-vector, strided
232*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
233*6ca0f394SUmesh Unnikrishnan inline void readSliceQuadsStrided3d(const CeedInt num_comp, const CeedInt Q_1D, CeedInt strides_node, CeedInt strides_comp, CeedInt strides_elem,
234*6ca0f394SUmesh Unnikrishnan                                     const CeedInt num_elem, const CeedInt q, const global CeedScalar* restrict d_u,
235*6ca0f394SUmesh Unnikrishnan                                     private CeedScalar* restrict r_u) {
236*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
237*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
238*6ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
239*6ca0f394SUmesh Unnikrishnan 
240*6ca0f394SUmesh Unnikrishnan   if (item_id_x < Q_1D && item_id_y < Q_1D && elem < num_elem) {
241*6ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x + Q_1D * (item_id_y + Q_1D * q);
242*6ca0f394SUmesh Unnikrishnan     const CeedInt ind  = node * strides_node + elem * strides_elem;
243*6ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[comp] = d_u[ind + comp * strides_comp];
244*6ca0f394SUmesh Unnikrishnan   }
245*6ca0f394SUmesh Unnikrishnan }
246*6ca0f394SUmesh Unnikrishnan 
247*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
248*6ca0f394SUmesh Unnikrishnan // E-vector -> L-vector, offsets provided
249*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
250*6ca0f394SUmesh Unnikrishnan inline void writeDofsOffset3d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt P_1D, const CeedInt num_elem,
251*6ca0f394SUmesh Unnikrishnan                               const global CeedInt* restrict indices, const private CeedScalar* restrict r_v, global CeedAtomicScalar* restrict d_v) {
252*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
253*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
254*6ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
255*6ca0f394SUmesh Unnikrishnan 
256*6ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
257*6ca0f394SUmesh Unnikrishnan     for (CeedInt z = 0; z < P_1D; ++z) {
258*6ca0f394SUmesh Unnikrishnan       const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
259*6ca0f394SUmesh Unnikrishnan       const CeedInt ind  = indices[node + elem * P_1D * P_1D * P_1D];
260*6ca0f394SUmesh Unnikrishnan       for (CeedInt comp = 0; comp < num_comp; ++comp)
261*6ca0f394SUmesh Unnikrishnan         atomic_fetch_add_explicit(&d_v[ind + strides_comp * comp], r_v[z + comp * P_1D], memory_order_relaxed, memory_scope_device);
262*6ca0f394SUmesh Unnikrishnan     }
263*6ca0f394SUmesh Unnikrishnan   }
264*6ca0f394SUmesh Unnikrishnan }
265*6ca0f394SUmesh Unnikrishnan 
266*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
267*6ca0f394SUmesh Unnikrishnan // E-vector -> L-vector, strided
268*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
269*6ca0f394SUmesh Unnikrishnan inline void writeDofsStrided3d(const CeedInt num_comp, const CeedInt P_1D, const CeedInt strides_node, const CeedInt strides_comp,
270*6ca0f394SUmesh Unnikrishnan                                const CeedInt strides_elem, const CeedInt num_elem, const private CeedScalar* restrict r_v,
271*6ca0f394SUmesh Unnikrishnan                                global CeedScalar* restrict d_v) {
272*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
273*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
274*6ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
275*6ca0f394SUmesh Unnikrishnan 
276*6ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
277*6ca0f394SUmesh Unnikrishnan     for (CeedInt z = 0; z < P_1D; ++z) {
278*6ca0f394SUmesh Unnikrishnan       const CeedInt node = item_id_x + P_1D * (item_id_y + P_1D * z);
279*6ca0f394SUmesh Unnikrishnan       const CeedInt ind  = node * strides_node + elem * strides_elem;
280*6ca0f394SUmesh Unnikrishnan       for (CeedInt comp = 0; comp < num_comp; ++comp) d_v[ind + comp * strides_comp] += r_v[z + comp * P_1D];
281*6ca0f394SUmesh Unnikrishnan     }
282*6ca0f394SUmesh Unnikrishnan   }
283*6ca0f394SUmesh Unnikrishnan }
284*6ca0f394SUmesh Unnikrishnan 
285*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
286*6ca0f394SUmesh Unnikrishnan // 3D collocated derivatives computation
287*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
288*6ca0f394SUmesh Unnikrishnan inline void gradCollo3d(const CeedInt num_comp, const CeedInt Q_1D, const CeedInt q, const private CeedScalar* restrict r_U,
289*6ca0f394SUmesh Unnikrishnan                         const local CeedScalar* s_G, private CeedScalar* restrict r_V, local CeedScalar* restrict scratch) {
290*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
291*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
292*6ca0f394SUmesh Unnikrishnan 
293*6ca0f394SUmesh Unnikrishnan   for (CeedInt comp = 0; comp < num_comp; ++comp) {
294*6ca0f394SUmesh Unnikrishnan     if (item_id_x < Q_1D && item_id_y < Q_1D) {
295*6ca0f394SUmesh Unnikrishnan       scratch[item_id_x + item_id_y * T_1D] = r_U[q + comp * Q_1D];
296*6ca0f394SUmesh Unnikrishnan     }
297*6ca0f394SUmesh Unnikrishnan     work_group_barrier(CLK_LOCAL_MEM_FENCE);
298*6ca0f394SUmesh Unnikrishnan 
299*6ca0f394SUmesh Unnikrishnan     if (item_id_x < Q_1D && item_id_y < Q_1D) {
300*6ca0f394SUmesh Unnikrishnan       // X derivative
301*6ca0f394SUmesh Unnikrishnan       r_V[comp + 0 * num_comp] = 0.0;
302*6ca0f394SUmesh Unnikrishnan       for (CeedInt i = 0; i < Q_1D; ++i)
303*6ca0f394SUmesh Unnikrishnan         r_V[comp + 0 * num_comp] += s_G[i + item_id_x * Q_1D] * scratch[i + item_id_y * T_1D];  // Contract x direction (X derivative)
304*6ca0f394SUmesh Unnikrishnan 
305*6ca0f394SUmesh Unnikrishnan       // Y derivative
306*6ca0f394SUmesh Unnikrishnan       r_V[comp + 1 * num_comp] = 0.0;
307*6ca0f394SUmesh Unnikrishnan       for (CeedInt i = 0; i < Q_1D; ++i)
308*6ca0f394SUmesh Unnikrishnan         r_V[comp + 1 * num_comp] += s_G[i + item_id_y * Q_1D] * scratch[item_id_x + i * T_1D];  // Contract y direction (Y derivative)
309*6ca0f394SUmesh Unnikrishnan 
310*6ca0f394SUmesh Unnikrishnan       // Z derivative
311*6ca0f394SUmesh Unnikrishnan       r_V[comp + 2 * num_comp] = 0.0;
312*6ca0f394SUmesh Unnikrishnan       for (CeedInt i = 0; i < Q_1D; ++i) r_V[comp + 2 * num_comp] += s_G[i + q * Q_1D] * r_U[i + comp * Q_1D];  // Contract z direction (Z derivative)
313*6ca0f394SUmesh Unnikrishnan     }
314*6ca0f394SUmesh Unnikrishnan 
315*6ca0f394SUmesh Unnikrishnan     work_group_barrier(CLK_LOCAL_MEM_FENCE);
316*6ca0f394SUmesh Unnikrishnan   }
317*6ca0f394SUmesh Unnikrishnan }
318*6ca0f394SUmesh Unnikrishnan 
319*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
320*6ca0f394SUmesh Unnikrishnan // 3D collocated derivatives transpose
321*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
322*6ca0f394SUmesh Unnikrishnan inline void gradColloTranspose3d(const CeedInt num_comp, const CeedInt Q_1D, const CeedInt q, const private CeedScalar* restrict r_U,
323*6ca0f394SUmesh Unnikrishnan                                  const local CeedScalar* restrict s_G, private CeedScalar* restrict r_V, local CeedScalar* restrict scratch) {
324*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
325*6ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
326*6ca0f394SUmesh Unnikrishnan 
327*6ca0f394SUmesh Unnikrishnan   for (CeedInt comp = 0; comp < num_comp; ++comp) {
328*6ca0f394SUmesh Unnikrishnan     // X derivative
329*6ca0f394SUmesh Unnikrishnan     if (item_id_x < Q_1D && item_id_y < Q_1D) {
330*6ca0f394SUmesh Unnikrishnan       scratch[item_id_x + item_id_y * T_1D] = r_U[comp + 0 * num_comp];
331*6ca0f394SUmesh Unnikrishnan     }
332*6ca0f394SUmesh Unnikrishnan     work_group_barrier(CLK_LOCAL_MEM_FENCE);
333*6ca0f394SUmesh Unnikrishnan 
334*6ca0f394SUmesh Unnikrishnan     if (item_id_x < Q_1D && item_id_y < Q_1D) {
335*6ca0f394SUmesh Unnikrishnan       for (CeedInt i = 0; i < Q_1D; ++i)
336*6ca0f394SUmesh Unnikrishnan         r_V[q + comp * Q_1D] += s_G[item_id_x + i * Q_1D] * scratch[i + item_id_y * T_1D];  // Contract x direction (X derivative)
337*6ca0f394SUmesh Unnikrishnan     }
338*6ca0f394SUmesh Unnikrishnan     work_group_barrier(CLK_LOCAL_MEM_FENCE);
339*6ca0f394SUmesh Unnikrishnan 
340*6ca0f394SUmesh Unnikrishnan     // Y derivative
341*6ca0f394SUmesh Unnikrishnan     if (item_id_x < Q_1D && item_id_y < Q_1D) {
342*6ca0f394SUmesh Unnikrishnan       scratch[item_id_x + item_id_y * T_1D] = r_U[comp + 1 * num_comp];
343*6ca0f394SUmesh Unnikrishnan     }
344*6ca0f394SUmesh Unnikrishnan     work_group_barrier(CLK_LOCAL_MEM_FENCE);
345*6ca0f394SUmesh Unnikrishnan 
346*6ca0f394SUmesh Unnikrishnan     if (item_id_x < Q_1D && item_id_y < Q_1D) {
347*6ca0f394SUmesh Unnikrishnan       for (CeedInt i = 0; i < Q_1D; ++i)
348*6ca0f394SUmesh Unnikrishnan         r_V[q + comp * Q_1D] += s_G[item_id_y + i * Q_1D] * scratch[item_id_x + i * T_1D];  // Contract y direction (Y derivative)
349*6ca0f394SUmesh Unnikrishnan     }
350*6ca0f394SUmesh Unnikrishnan     work_group_barrier(CLK_LOCAL_MEM_FENCE);
351*6ca0f394SUmesh Unnikrishnan 
352*6ca0f394SUmesh Unnikrishnan     // Z derivative
353*6ca0f394SUmesh Unnikrishnan     if (item_id_x < Q_1D && item_id_y < Q_1D) {
354*6ca0f394SUmesh Unnikrishnan       for (CeedInt i = 0; i < Q_1D; ++i)
355*6ca0f394SUmesh Unnikrishnan         r_V[i + comp * Q_1D] += s_G[i + q * Q_1D] * r_U[comp + 2 * num_comp];  // PARTIAL contract z direction (Z derivative)
356*6ca0f394SUmesh Unnikrishnan     }
357*6ca0f394SUmesh Unnikrishnan   }
358*6ca0f394SUmesh Unnikrishnan }
359*6ca0f394SUmesh Unnikrishnan 
360*6ca0f394SUmesh Unnikrishnan #endif
361