xref: /libCEED/rust/libceed-sys/c-src/include/ceed/jit-source/sycl/sycl-gen-templates.h (revision 94b7b29b41ad8a17add4c577886859ef16f89dec)
16ca0f394SUmesh Unnikrishnan // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
26ca0f394SUmesh Unnikrishnan // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
36ca0f394SUmesh Unnikrishnan //
46ca0f394SUmesh Unnikrishnan // SPDX-License-Identifier: BSD-2-Clause
56ca0f394SUmesh Unnikrishnan //
66ca0f394SUmesh Unnikrishnan // This file is part of CEED:  http://github.com/ceed
76ca0f394SUmesh Unnikrishnan 
86ca0f394SUmesh Unnikrishnan /// @file
96ca0f394SUmesh Unnikrishnan /// Internal header for SYCL backend macro and type definitions for JiT source
10*94b7b29bSJeremy L Thompson #ifndef CEED_SYCL_GEN_TEMPLATES_H
11*94b7b29bSJeremy L Thompson #define CEED_SYCL_GEN_TEMPLATES_H
126ca0f394SUmesh Unnikrishnan 
136ca0f394SUmesh Unnikrishnan #include <ceed/types.h>
146ca0f394SUmesh Unnikrishnan 
156ca0f394SUmesh Unnikrishnan #pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
166ca0f394SUmesh Unnikrishnan #pragma OPENCL EXTENSION cl_khr_int64_extended_atomics : enable
176ca0f394SUmesh Unnikrishnan // TODO: Handle FP32 case
186ca0f394SUmesh Unnikrishnan typedef atomic_double CeedAtomicScalar;
196ca0f394SUmesh Unnikrishnan 
206ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
216ca0f394SUmesh Unnikrishnan // Load matrices for basis actions
226ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
236ca0f394SUmesh Unnikrishnan inline void loadMatrix(const CeedInt N, const CeedScalar* restrict d_B, CeedScalar* restrict B) {
246ca0f394SUmesh Unnikrishnan   const CeedInt item_id    = get_local_linear_id();
256ca0f394SUmesh Unnikrishnan   const CeedInt group_size = get_local_size(0) * get_local_size(1) * get_local_size(2);
266ca0f394SUmesh Unnikrishnan   for (CeedInt i = item_id; i < N; i += group_size) B[i] = d_B[i];
276ca0f394SUmesh Unnikrishnan }
286ca0f394SUmesh Unnikrishnan 
296ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
306ca0f394SUmesh Unnikrishnan // 1D
316ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
326ca0f394SUmesh Unnikrishnan 
336ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
346ca0f394SUmesh Unnikrishnan // L-vector -> E-vector, offsets provided
356ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
366ca0f394SUmesh Unnikrishnan inline void readDofsOffset1d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt P_1D, const CeedInt num_elem,
376ca0f394SUmesh Unnikrishnan                              const global CeedInt* restrict indices, const global CeedScalar* restrict d_u, private CeedScalar* restrict r_u) {
386ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
396ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
406ca0f394SUmesh Unnikrishnan 
416ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && elem < num_elem) {
426ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x;
436ca0f394SUmesh Unnikrishnan     const CeedInt ind  = indices[node + elem * P_1D];
446ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp) {
456ca0f394SUmesh Unnikrishnan       r_u[comp] = d_u[ind + strides_comp * comp];
466ca0f394SUmesh Unnikrishnan     }
476ca0f394SUmesh Unnikrishnan   }
486ca0f394SUmesh Unnikrishnan }
496ca0f394SUmesh Unnikrishnan 
506ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
516ca0f394SUmesh Unnikrishnan // L-vector -> E-vector, strided
526ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
536ca0f394SUmesh Unnikrishnan inline void readDofsStrided1d(const CeedInt num_comp, const CeedInt P_1D, const CeedInt strides_node, const CeedInt strides_comp,
546ca0f394SUmesh Unnikrishnan                               const CeedInt strides_elem, const CeedInt num_elem, global const CeedScalar* restrict d_u,
556ca0f394SUmesh Unnikrishnan                               private CeedScalar* restrict r_u) {
566ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
576ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
586ca0f394SUmesh Unnikrishnan 
596ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && elem < num_elem) {
606ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x;
616ca0f394SUmesh Unnikrishnan     const CeedInt ind  = node * strides_node + elem * strides_elem;
626ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; comp++) {
636ca0f394SUmesh Unnikrishnan       r_u[comp] = d_u[ind + comp * strides_comp];
646ca0f394SUmesh Unnikrishnan     }
656ca0f394SUmesh Unnikrishnan   }
666ca0f394SUmesh Unnikrishnan }
676ca0f394SUmesh Unnikrishnan 
686ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
696ca0f394SUmesh Unnikrishnan // E-vector -> L-vector, offsets provided
706ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
716ca0f394SUmesh Unnikrishnan inline void writeDofsOffset1d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt P_1D, const CeedInt num_elem,
726ca0f394SUmesh Unnikrishnan                               const global CeedInt* restrict indices, const private CeedScalar* restrict r_v, global CeedAtomicScalar* restrict d_v) {
736ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
746ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
756ca0f394SUmesh Unnikrishnan 
766ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && elem < num_elem) {
776ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x;
786ca0f394SUmesh Unnikrishnan     const CeedInt ind  = indices[node + elem * P_1D];
796ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp)
806ca0f394SUmesh Unnikrishnan       atomic_fetch_add_explicit(&d_v[ind + strides_comp * comp], r_v[comp], memory_order_relaxed, memory_scope_device);
816ca0f394SUmesh Unnikrishnan   }
826ca0f394SUmesh Unnikrishnan }
836ca0f394SUmesh Unnikrishnan 
846ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
856ca0f394SUmesh Unnikrishnan // E-vector -> L-vector, strided
866ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
876ca0f394SUmesh Unnikrishnan inline void writeDofsStrided1d(const CeedInt num_comp, const CeedInt P_1D, const CeedInt strides_node, const CeedInt strides_comp,
886ca0f394SUmesh Unnikrishnan                                const CeedInt strides_elem, const CeedInt num_elem, private const CeedScalar* restrict r_v,
896ca0f394SUmesh Unnikrishnan                                global CeedScalar* restrict d_v) {
906ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
916ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
926ca0f394SUmesh Unnikrishnan 
936ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && elem < num_elem) {
946ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x;
956ca0f394SUmesh Unnikrishnan     const CeedInt ind  = node * strides_node + elem * strides_elem;
966ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; comp++) {
976ca0f394SUmesh Unnikrishnan       d_v[ind + comp * strides_comp] = r_v[comp];
986ca0f394SUmesh Unnikrishnan     }
996ca0f394SUmesh Unnikrishnan   }
1006ca0f394SUmesh Unnikrishnan }
1016ca0f394SUmesh Unnikrishnan 
1026ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1036ca0f394SUmesh Unnikrishnan // 2D
1046ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1056ca0f394SUmesh Unnikrishnan 
1066ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1076ca0f394SUmesh Unnikrishnan // L-vector -> E-vector, offsets provided
1086ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1096ca0f394SUmesh Unnikrishnan inline void readDofsOffset2d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt P_1D, const CeedInt num_elem,
1106ca0f394SUmesh Unnikrishnan                              const global CeedInt* restrict indices, const global CeedScalar* restrict d_u, private CeedScalar* restrict r_u) {
1116ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
1126ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
1136ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
1146ca0f394SUmesh Unnikrishnan 
1156ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
1166ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x + item_id_y * P_1D;
1176ca0f394SUmesh Unnikrishnan     const CeedInt ind  = indices[node + elem * P_1D * P_1D];
1186ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[comp] = d_u[ind + strides_comp * comp];
1196ca0f394SUmesh Unnikrishnan   }
1206ca0f394SUmesh Unnikrishnan }
1216ca0f394SUmesh Unnikrishnan 
1226ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1236ca0f394SUmesh Unnikrishnan // L-vector -> E-vector, strided
1246ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1256ca0f394SUmesh Unnikrishnan inline void readDofsStrided2d(const CeedInt num_comp, const CeedInt P_1D, const CeedInt strides_node, const CeedInt strides_comp,
1266ca0f394SUmesh Unnikrishnan                               const CeedInt strides_elem, const CeedInt num_elem, const global CeedScalar* restrict d_u,
1276ca0f394SUmesh Unnikrishnan                               private CeedScalar* restrict r_u) {
1286ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
1296ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
1306ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
1316ca0f394SUmesh Unnikrishnan 
1326ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
1336ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x + item_id_y * P_1D;
1346ca0f394SUmesh Unnikrishnan     const CeedInt ind  = node * strides_node + elem * strides_elem;
1356ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[comp] = d_u[ind + comp * strides_comp];
1366ca0f394SUmesh Unnikrishnan   }
1376ca0f394SUmesh Unnikrishnan }
1386ca0f394SUmesh Unnikrishnan 
1396ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1406ca0f394SUmesh Unnikrishnan // E-vector -> L-vector, offsets provided
1416ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1426ca0f394SUmesh Unnikrishnan inline void writeDofsOffset2d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt P_1D, const CeedInt num_elem,
1436ca0f394SUmesh Unnikrishnan                               const global CeedInt* restrict indices, const private CeedScalar* restrict r_v, global CeedAtomicScalar* restrict d_v) {
1446ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
1456ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
1466ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
1476ca0f394SUmesh Unnikrishnan 
1486ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
1496ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x + item_id_y * P_1D;
1506ca0f394SUmesh Unnikrishnan     const CeedInt ind  = indices[node + elem * P_1D * P_1D];
1516ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp)
1526ca0f394SUmesh Unnikrishnan       atomic_fetch_add_explicit(&d_v[ind + strides_comp * comp], r_v[comp], memory_order_relaxed, memory_scope_device);
1536ca0f394SUmesh Unnikrishnan   }
1546ca0f394SUmesh Unnikrishnan }
1556ca0f394SUmesh Unnikrishnan 
1566ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1576ca0f394SUmesh Unnikrishnan // E-vector -> L-vector, strided
1586ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1596ca0f394SUmesh Unnikrishnan inline void writeDofsStrided2d(const CeedInt num_comp, const CeedInt P_1D, const CeedInt strides_node, const CeedInt strides_comp,
1606ca0f394SUmesh Unnikrishnan                                const CeedInt strides_elem, const CeedInt num_elem, const private CeedScalar* restrict r_v,
1616ca0f394SUmesh Unnikrishnan                                global CeedScalar* restrict d_v) {
1626ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
1636ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
1646ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
1656ca0f394SUmesh Unnikrishnan 
1666ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
1676ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x + item_id_y * P_1D;
1686ca0f394SUmesh Unnikrishnan     const CeedInt ind  = node * strides_node + elem * strides_elem;
1696ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp) d_v[ind + comp * strides_comp] += r_v[comp];
1706ca0f394SUmesh Unnikrishnan   }
1716ca0f394SUmesh Unnikrishnan }
1726ca0f394SUmesh Unnikrishnan 
1736ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1746ca0f394SUmesh Unnikrishnan // 3D
1756ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1766ca0f394SUmesh Unnikrishnan 
1776ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1786ca0f394SUmesh Unnikrishnan // L-vector -> E-vector, offsets provided
1796ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1806ca0f394SUmesh Unnikrishnan inline void readDofsOffset3d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt P_1D, const CeedInt num_elem,
1816ca0f394SUmesh Unnikrishnan                              const global CeedInt* restrict indices, const global CeedScalar* restrict d_u, private CeedScalar* restrict r_u) {
1826ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
1836ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
1846ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
1856ca0f394SUmesh Unnikrishnan 
1866ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
1876ca0f394SUmesh Unnikrishnan     for (CeedInt z = 0; z < P_1D; ++z) {
1886ca0f394SUmesh Unnikrishnan       const CeedInt node = item_id_x + P_1D * (item_id_y + P_1D * z);
1896ca0f394SUmesh Unnikrishnan       const CeedInt ind  = indices[node + elem * P_1D * P_1D * P_1D];
1906ca0f394SUmesh Unnikrishnan       for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[z + comp * P_1D] = d_u[ind + strides_comp * comp];
1916ca0f394SUmesh Unnikrishnan     }
1926ca0f394SUmesh Unnikrishnan   }
1936ca0f394SUmesh Unnikrishnan }
1946ca0f394SUmesh Unnikrishnan 
1956ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1966ca0f394SUmesh Unnikrishnan // L-vector -> E-vector, strided
1976ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1986ca0f394SUmesh Unnikrishnan inline void readDofsStrided3d(const CeedInt num_comp, const CeedInt P_1D, const CeedInt strides_node, const CeedInt strides_comp,
1996ca0f394SUmesh Unnikrishnan                               const CeedInt strides_elem, const CeedInt num_elem, const global CeedScalar* restrict d_u,
2006ca0f394SUmesh Unnikrishnan                               private CeedScalar* restrict r_u) {
2016ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
2026ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
2036ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
2046ca0f394SUmesh Unnikrishnan 
2056ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
2066ca0f394SUmesh Unnikrishnan     for (CeedInt z = 0; z < P_1D; ++z) {
2076ca0f394SUmesh Unnikrishnan       const CeedInt node = item_id_x + P_1D * (item_id_y + P_1D * z);
2086ca0f394SUmesh Unnikrishnan       const CeedInt ind  = node * strides_node + elem * strides_elem;
2096ca0f394SUmesh Unnikrishnan       for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
2106ca0f394SUmesh Unnikrishnan     }
2116ca0f394SUmesh Unnikrishnan   }
2126ca0f394SUmesh Unnikrishnan }
2136ca0f394SUmesh Unnikrishnan 
2146ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
2156ca0f394SUmesh Unnikrishnan // E-vector -> Q-vector, offests provided
2166ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
2176ca0f394SUmesh Unnikrishnan inline void readSliceQuadsOffset3d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt Q_1D, const CeedInt num_elem, const CeedInt q,
2186ca0f394SUmesh Unnikrishnan                                    const global CeedInt* restrict indices, const global CeedScalar* restrict d_u, private CeedScalar* restrict r_u) {
2196ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
2206ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
2216ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
2226ca0f394SUmesh Unnikrishnan 
2236ca0f394SUmesh Unnikrishnan   if (item_id_x < Q_1D && item_id_y < Q_1D && elem < num_elem) {
2246ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x + Q_1D * (item_id_y + Q_1D * q);
2256ca0f394SUmesh Unnikrishnan     const CeedInt ind  = indices[node + elem * Q_1D * Q_1D * Q_1D];
2266ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[comp] = d_u[ind + strides_comp * comp];
2276ca0f394SUmesh Unnikrishnan   }
2286ca0f394SUmesh Unnikrishnan }
2296ca0f394SUmesh Unnikrishnan 
2306ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
2316ca0f394SUmesh Unnikrishnan // E-vector -> Q-vector, strided
2326ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
2336ca0f394SUmesh Unnikrishnan inline void readSliceQuadsStrided3d(const CeedInt num_comp, const CeedInt Q_1D, CeedInt strides_node, CeedInt strides_comp, CeedInt strides_elem,
2346ca0f394SUmesh Unnikrishnan                                     const CeedInt num_elem, const CeedInt q, const global CeedScalar* restrict d_u,
2356ca0f394SUmesh Unnikrishnan                                     private CeedScalar* restrict r_u) {
2366ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
2376ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
2386ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
2396ca0f394SUmesh Unnikrishnan 
2406ca0f394SUmesh Unnikrishnan   if (item_id_x < Q_1D && item_id_y < Q_1D && elem < num_elem) {
2416ca0f394SUmesh Unnikrishnan     const CeedInt node = item_id_x + Q_1D * (item_id_y + Q_1D * q);
2426ca0f394SUmesh Unnikrishnan     const CeedInt ind  = node * strides_node + elem * strides_elem;
2436ca0f394SUmesh Unnikrishnan     for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[comp] = d_u[ind + comp * strides_comp];
2446ca0f394SUmesh Unnikrishnan   }
2456ca0f394SUmesh Unnikrishnan }
2466ca0f394SUmesh Unnikrishnan 
2476ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
2486ca0f394SUmesh Unnikrishnan // E-vector -> L-vector, offsets provided
2496ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
2506ca0f394SUmesh Unnikrishnan inline void writeDofsOffset3d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt P_1D, const CeedInt num_elem,
2516ca0f394SUmesh Unnikrishnan                               const global CeedInt* restrict indices, const private CeedScalar* restrict r_v, global CeedAtomicScalar* restrict d_v) {
2526ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
2536ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
2546ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
2556ca0f394SUmesh Unnikrishnan 
2566ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
2576ca0f394SUmesh Unnikrishnan     for (CeedInt z = 0; z < P_1D; ++z) {
2586ca0f394SUmesh Unnikrishnan       const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
2596ca0f394SUmesh Unnikrishnan       const CeedInt ind  = indices[node + elem * P_1D * P_1D * P_1D];
2606ca0f394SUmesh Unnikrishnan       for (CeedInt comp = 0; comp < num_comp; ++comp)
2616ca0f394SUmesh Unnikrishnan         atomic_fetch_add_explicit(&d_v[ind + strides_comp * comp], r_v[z + comp * P_1D], memory_order_relaxed, memory_scope_device);
2626ca0f394SUmesh Unnikrishnan     }
2636ca0f394SUmesh Unnikrishnan   }
2646ca0f394SUmesh Unnikrishnan }
2656ca0f394SUmesh Unnikrishnan 
2666ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
2676ca0f394SUmesh Unnikrishnan // E-vector -> L-vector, strided
2686ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
2696ca0f394SUmesh Unnikrishnan inline void writeDofsStrided3d(const CeedInt num_comp, const CeedInt P_1D, const CeedInt strides_node, const CeedInt strides_comp,
2706ca0f394SUmesh Unnikrishnan                                const CeedInt strides_elem, const CeedInt num_elem, const private CeedScalar* restrict r_v,
2716ca0f394SUmesh Unnikrishnan                                global CeedScalar* restrict d_v) {
2726ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
2736ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
2746ca0f394SUmesh Unnikrishnan   const CeedInt elem      = get_global_id(2);
2756ca0f394SUmesh Unnikrishnan 
2766ca0f394SUmesh Unnikrishnan   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
2776ca0f394SUmesh Unnikrishnan     for (CeedInt z = 0; z < P_1D; ++z) {
2786ca0f394SUmesh Unnikrishnan       const CeedInt node = item_id_x + P_1D * (item_id_y + P_1D * z);
2796ca0f394SUmesh Unnikrishnan       const CeedInt ind  = node * strides_node + elem * strides_elem;
2806ca0f394SUmesh Unnikrishnan       for (CeedInt comp = 0; comp < num_comp; ++comp) d_v[ind + comp * strides_comp] += r_v[z + comp * P_1D];
2816ca0f394SUmesh Unnikrishnan     }
2826ca0f394SUmesh Unnikrishnan   }
2836ca0f394SUmesh Unnikrishnan }
2846ca0f394SUmesh Unnikrishnan 
2856ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
2866ca0f394SUmesh Unnikrishnan // 3D collocated derivatives computation
2876ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
2886ca0f394SUmesh Unnikrishnan inline void gradCollo3d(const CeedInt num_comp, const CeedInt Q_1D, const CeedInt q, const private CeedScalar* restrict r_U,
2896ca0f394SUmesh Unnikrishnan                         const local CeedScalar* s_G, private CeedScalar* restrict r_V, local CeedScalar* restrict scratch) {
2906ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
2916ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
2926ca0f394SUmesh Unnikrishnan 
2936ca0f394SUmesh Unnikrishnan   for (CeedInt comp = 0; comp < num_comp; ++comp) {
2946ca0f394SUmesh Unnikrishnan     if (item_id_x < Q_1D && item_id_y < Q_1D) {
2956ca0f394SUmesh Unnikrishnan       scratch[item_id_x + item_id_y * T_1D] = r_U[q + comp * Q_1D];
2966ca0f394SUmesh Unnikrishnan     }
2976ca0f394SUmesh Unnikrishnan     work_group_barrier(CLK_LOCAL_MEM_FENCE);
2986ca0f394SUmesh Unnikrishnan 
2996ca0f394SUmesh Unnikrishnan     if (item_id_x < Q_1D && item_id_y < Q_1D) {
3006ca0f394SUmesh Unnikrishnan       // X derivative
3016ca0f394SUmesh Unnikrishnan       r_V[comp + 0 * num_comp] = 0.0;
3026ca0f394SUmesh Unnikrishnan       for (CeedInt i = 0; i < Q_1D; ++i)
3036ca0f394SUmesh Unnikrishnan         r_V[comp + 0 * num_comp] += s_G[i + item_id_x * Q_1D] * scratch[i + item_id_y * T_1D];  // Contract x direction (X derivative)
3046ca0f394SUmesh Unnikrishnan 
3056ca0f394SUmesh Unnikrishnan       // Y derivative
3066ca0f394SUmesh Unnikrishnan       r_V[comp + 1 * num_comp] = 0.0;
3076ca0f394SUmesh Unnikrishnan       for (CeedInt i = 0; i < Q_1D; ++i)
3086ca0f394SUmesh Unnikrishnan         r_V[comp + 1 * num_comp] += s_G[i + item_id_y * Q_1D] * scratch[item_id_x + i * T_1D];  // Contract y direction (Y derivative)
3096ca0f394SUmesh Unnikrishnan 
3106ca0f394SUmesh Unnikrishnan       // Z derivative
3116ca0f394SUmesh Unnikrishnan       r_V[comp + 2 * num_comp] = 0.0;
3126ca0f394SUmesh Unnikrishnan       for (CeedInt i = 0; i < Q_1D; ++i) r_V[comp + 2 * num_comp] += s_G[i + q * Q_1D] * r_U[i + comp * Q_1D];  // Contract z direction (Z derivative)
3136ca0f394SUmesh Unnikrishnan     }
3146ca0f394SUmesh Unnikrishnan 
3156ca0f394SUmesh Unnikrishnan     work_group_barrier(CLK_LOCAL_MEM_FENCE);
3166ca0f394SUmesh Unnikrishnan   }
3176ca0f394SUmesh Unnikrishnan }
3186ca0f394SUmesh Unnikrishnan 
3196ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
3206ca0f394SUmesh Unnikrishnan // 3D collocated derivatives transpose
3216ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
3226ca0f394SUmesh Unnikrishnan inline void gradColloTranspose3d(const CeedInt num_comp, const CeedInt Q_1D, const CeedInt q, const private CeedScalar* restrict r_U,
3236ca0f394SUmesh Unnikrishnan                                  const local CeedScalar* restrict s_G, private CeedScalar* restrict r_V, local CeedScalar* restrict scratch) {
3246ca0f394SUmesh Unnikrishnan   const CeedInt item_id_x = get_local_id(0);
3256ca0f394SUmesh Unnikrishnan   const CeedInt item_id_y = get_local_id(1);
3266ca0f394SUmesh Unnikrishnan 
3276ca0f394SUmesh Unnikrishnan   for (CeedInt comp = 0; comp < num_comp; ++comp) {
3286ca0f394SUmesh Unnikrishnan     // X derivative
3296ca0f394SUmesh Unnikrishnan     if (item_id_x < Q_1D && item_id_y < Q_1D) {
3306ca0f394SUmesh Unnikrishnan       scratch[item_id_x + item_id_y * T_1D] = r_U[comp + 0 * num_comp];
3316ca0f394SUmesh Unnikrishnan     }
3326ca0f394SUmesh Unnikrishnan     work_group_barrier(CLK_LOCAL_MEM_FENCE);
3336ca0f394SUmesh Unnikrishnan 
3346ca0f394SUmesh Unnikrishnan     if (item_id_x < Q_1D && item_id_y < Q_1D) {
3356ca0f394SUmesh Unnikrishnan       for (CeedInt i = 0; i < Q_1D; ++i)
3366ca0f394SUmesh Unnikrishnan         r_V[q + comp * Q_1D] += s_G[item_id_x + i * Q_1D] * scratch[i + item_id_y * T_1D];  // Contract x direction (X derivative)
3376ca0f394SUmesh Unnikrishnan     }
3386ca0f394SUmesh Unnikrishnan     work_group_barrier(CLK_LOCAL_MEM_FENCE);
3396ca0f394SUmesh Unnikrishnan 
3406ca0f394SUmesh Unnikrishnan     // Y derivative
3416ca0f394SUmesh Unnikrishnan     if (item_id_x < Q_1D && item_id_y < Q_1D) {
3426ca0f394SUmesh Unnikrishnan       scratch[item_id_x + item_id_y * T_1D] = r_U[comp + 1 * num_comp];
3436ca0f394SUmesh Unnikrishnan     }
3446ca0f394SUmesh Unnikrishnan     work_group_barrier(CLK_LOCAL_MEM_FENCE);
3456ca0f394SUmesh Unnikrishnan 
3466ca0f394SUmesh Unnikrishnan     if (item_id_x < Q_1D && item_id_y < Q_1D) {
3476ca0f394SUmesh Unnikrishnan       for (CeedInt i = 0; i < Q_1D; ++i)
3486ca0f394SUmesh Unnikrishnan         r_V[q + comp * Q_1D] += s_G[item_id_y + i * Q_1D] * scratch[item_id_x + i * T_1D];  // Contract y direction (Y derivative)
3496ca0f394SUmesh Unnikrishnan     }
3506ca0f394SUmesh Unnikrishnan     work_group_barrier(CLK_LOCAL_MEM_FENCE);
3516ca0f394SUmesh Unnikrishnan 
3526ca0f394SUmesh Unnikrishnan     // Z derivative
3536ca0f394SUmesh Unnikrishnan     if (item_id_x < Q_1D && item_id_y < Q_1D) {
3546ca0f394SUmesh Unnikrishnan       for (CeedInt i = 0; i < Q_1D; ++i)
3556ca0f394SUmesh Unnikrishnan         r_V[i + comp * Q_1D] += s_G[i + q * Q_1D] * r_U[comp + 2 * num_comp];  // PARTIAL contract z direction (Z derivative)
3566ca0f394SUmesh Unnikrishnan     }
3576ca0f394SUmesh Unnikrishnan   }
3586ca0f394SUmesh Unnikrishnan }
3596ca0f394SUmesh Unnikrishnan 
36049ed4312SSebastian Grimberg //------------------------------------------------------------------------------
36149ed4312SSebastian Grimberg 
362*94b7b29bSJeremy L Thompson #endif  // CEED_SYCL_GEN_TEMPLATES_H
363