xref: /libCEED/rust/libceed-sys/c-src/include/ceed/jit-source/sycl/sycl-shared-basis-read-write-templates.h (revision bd882c8a454763a096666645dc9a6229d5263694)
1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*bd882c8aSJames Wright //
4*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5*bd882c8aSJames Wright //
6*bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
7*bd882c8aSJames Wright 
8*bd882c8aSJames Wright /// @file
9*bd882c8aSJames Wright /// Internal header for SYCL shared memory basis read/write templates
10*bd882c8aSJames Wright #ifndef _ceed_sycl_shared_basis_read_write_templates_h
11*bd882c8aSJames Wright #define _ceed_sycl_shared_basis_read_write_templates_h
12*bd882c8aSJames Wright 
13*bd882c8aSJames Wright #include <ceed.h>
14*bd882c8aSJames Wright #include "sycl-types.h"
15*bd882c8aSJames Wright 
16*bd882c8aSJames Wright //------------------------------------------------------------------------------
17*bd882c8aSJames Wright // Helper function: load matrices for basis actions
18*bd882c8aSJames Wright //------------------------------------------------------------------------------
19*bd882c8aSJames Wright inline void loadMatrix(const CeedInt N, const CeedScalar* restrict d_B, CeedScalar* restrict B) {
20*bd882c8aSJames Wright   const CeedInt item_id    = get_local_linear_id();
21*bd882c8aSJames Wright   const CeedInt group_size = get_local_size(0) * get_local_size(1) * get_local_size(2);
22*bd882c8aSJames Wright   for (CeedInt i = item_id; i < N; i += group_size) B[i] = d_B[i];
23*bd882c8aSJames Wright }
24*bd882c8aSJames Wright 
25*bd882c8aSJames Wright //------------------------------------------------------------------------------
26*bd882c8aSJames Wright // 1D
27*bd882c8aSJames Wright //------------------------------------------------------------------------------
28*bd882c8aSJames Wright 
29*bd882c8aSJames Wright //------------------------------------------------------------------------------
30*bd882c8aSJames Wright // E-vector -> single element
31*bd882c8aSJames Wright //------------------------------------------------------------------------------
32*bd882c8aSJames Wright inline void ReadElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
33*bd882c8aSJames Wright                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar* restrict d_u,
34*bd882c8aSJames Wright                                  private CeedScalar* restrict r_u) {
35*bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
36*bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
37*bd882c8aSJames Wright 
38*bd882c8aSJames Wright   if (item_id_x < P_1D && elem < num_elem) {
39*bd882c8aSJames Wright     const CeedInt node = item_id_x;
40*bd882c8aSJames Wright     const CeedInt ind  = node * strides_node + elem * strides_elem;
41*bd882c8aSJames Wright     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
42*bd882c8aSJames Wright       r_u[comp] = d_u[ind + comp * strides_comp];
43*bd882c8aSJames Wright     }
44*bd882c8aSJames Wright   }
45*bd882c8aSJames Wright }
46*bd882c8aSJames Wright 
47*bd882c8aSJames Wright //------------------------------------------------------------------------------
48*bd882c8aSJames Wright // Single element -> E-vector
49*bd882c8aSJames Wright //------------------------------------------------------------------------------
50*bd882c8aSJames Wright inline void WriteElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
51*bd882c8aSJames Wright                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar* restrict r_v,
52*bd882c8aSJames Wright                                   global CeedScalar* restrict d_v) {
53*bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
54*bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
55*bd882c8aSJames Wright 
56*bd882c8aSJames Wright   if (item_id_x < P_1D && elem < num_elem) {
57*bd882c8aSJames Wright     const CeedInt node = item_id_x;
58*bd882c8aSJames Wright     const CeedInt ind  = node * strides_node + elem * strides_elem;
59*bd882c8aSJames Wright     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
60*bd882c8aSJames Wright       d_v[ind + comp * strides_comp] = r_v[comp];
61*bd882c8aSJames Wright     }
62*bd882c8aSJames Wright   }
63*bd882c8aSJames Wright }
64*bd882c8aSJames Wright 
65*bd882c8aSJames Wright //------------------------------------------------------------------------------
66*bd882c8aSJames Wright // 2D
67*bd882c8aSJames Wright //------------------------------------------------------------------------------
68*bd882c8aSJames Wright 
69*bd882c8aSJames Wright //------------------------------------------------------------------------------
70*bd882c8aSJames Wright // E-vector -> single element
71*bd882c8aSJames Wright //------------------------------------------------------------------------------
72*bd882c8aSJames Wright inline void ReadElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
73*bd882c8aSJames Wright                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar* restrict d_u,
74*bd882c8aSJames Wright                                  private CeedScalar* restrict r_u) {
75*bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
76*bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
77*bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
78*bd882c8aSJames Wright 
79*bd882c8aSJames Wright   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
80*bd882c8aSJames Wright     const CeedInt node = item_id_x + item_id_y * P_1D;
81*bd882c8aSJames Wright     const CeedInt ind  = node * strides_node + elem * strides_elem;
82*bd882c8aSJames Wright     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
83*bd882c8aSJames Wright       r_u[comp] = d_u[ind + comp * strides_comp];
84*bd882c8aSJames Wright     }
85*bd882c8aSJames Wright   }
86*bd882c8aSJames Wright }
87*bd882c8aSJames Wright 
88*bd882c8aSJames Wright //------------------------------------------------------------------------------
89*bd882c8aSJames Wright // Single element -> E-vector
90*bd882c8aSJames Wright //------------------------------------------------------------------------------
91*bd882c8aSJames Wright inline void WriteElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
92*bd882c8aSJames Wright                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar* restrict r_v,
93*bd882c8aSJames Wright                                   global CeedScalar* restrict d_v) {
94*bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
95*bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
96*bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
97*bd882c8aSJames Wright 
98*bd882c8aSJames Wright   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
99*bd882c8aSJames Wright     const CeedInt node = item_id_x + item_id_y * P_1D;
100*bd882c8aSJames Wright     const CeedInt ind  = node * strides_node + elem * strides_elem;
101*bd882c8aSJames Wright     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
102*bd882c8aSJames Wright       d_v[ind + comp * strides_comp] = r_v[comp];
103*bd882c8aSJames Wright     }
104*bd882c8aSJames Wright   }
105*bd882c8aSJames Wright }
106*bd882c8aSJames Wright 
107*bd882c8aSJames Wright //------------------------------------------------------------------------------
108*bd882c8aSJames Wright // 3D
109*bd882c8aSJames Wright //------------------------------------------------------------------------------
110*bd882c8aSJames Wright 
111*bd882c8aSJames Wright //------------------------------------------------------------------------------
112*bd882c8aSJames Wright // E-vector -> single element
113*bd882c8aSJames Wright //------------------------------------------------------------------------------
114*bd882c8aSJames Wright inline void ReadElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
115*bd882c8aSJames Wright                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar* restrict d_u,
116*bd882c8aSJames Wright                                  private CeedScalar* restrict r_u) {
117*bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
118*bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
119*bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
120*bd882c8aSJames Wright 
121*bd882c8aSJames Wright   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
122*bd882c8aSJames Wright     for (CeedInt z = 0; z < P_1D; z++) {
123*bd882c8aSJames Wright       const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
124*bd882c8aSJames Wright       const CeedInt ind  = node * strides_node + elem * strides_elem;
125*bd882c8aSJames Wright       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
126*bd882c8aSJames Wright         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
127*bd882c8aSJames Wright       }
128*bd882c8aSJames Wright     }
129*bd882c8aSJames Wright   }
130*bd882c8aSJames Wright }
131*bd882c8aSJames Wright 
132*bd882c8aSJames Wright //------------------------------------------------------------------------------
133*bd882c8aSJames Wright // Single element -> E-vector
134*bd882c8aSJames Wright //------------------------------------------------------------------------------
135*bd882c8aSJames Wright inline void WriteElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
136*bd882c8aSJames Wright                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar* restrict r_v,
137*bd882c8aSJames Wright                                   global CeedScalar* restrict d_v) {
138*bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
139*bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
140*bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
141*bd882c8aSJames Wright 
142*bd882c8aSJames Wright   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
143*bd882c8aSJames Wright     for (CeedInt z = 0; z < P_1D; z++) {
144*bd882c8aSJames Wright       const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
145*bd882c8aSJames Wright       const CeedInt ind  = node * strides_node + elem * strides_elem;
146*bd882c8aSJames Wright       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
147*bd882c8aSJames Wright         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
148*bd882c8aSJames Wright       }
149*bd882c8aSJames Wright     }
150*bd882c8aSJames Wright   }
151*bd882c8aSJames Wright }
152*bd882c8aSJames Wright 
153*bd882c8aSJames Wright //------------------------------------------------------------------------------
154*bd882c8aSJames Wright 
155*bd882c8aSJames Wright #endif
156