xref: /libCEED/rust/libceed-sys/c-src/include/ceed/jit-source/hip/hip-shared-basis-read-write-templates.h (revision 9e201c85545dd39529c090846df629a32c15659b)
1*9e201c85SYohann // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*9e201c85SYohann // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*9e201c85SYohann //
4*9e201c85SYohann // SPDX-License-Identifier: BSD-2-Clause
5*9e201c85SYohann //
6*9e201c85SYohann // This file is part of CEED:  http://github.com/ceed
7*9e201c85SYohann 
8*9e201c85SYohann /// @file
9*9e201c85SYohann /// Internal header for HIP shared memory basis read/write templates
10*9e201c85SYohann #ifndef _ceed_hip_shared_basis_read_write_templates_h
11*9e201c85SYohann #define _ceed_hip_shared_basis_read_write_templates_h
12*9e201c85SYohann 
13*9e201c85SYohann #include <ceed.h>
14*9e201c85SYohann 
15*9e201c85SYohann 
16*9e201c85SYohann //------------------------------------------------------------------------------
17*9e201c85SYohann // Helper function: load matrices for basis actions
18*9e201c85SYohann //------------------------------------------------------------------------------
19*9e201c85SYohann template <int SIZE>
20*9e201c85SYohann inline __device__ void loadMatrix(const CeedScalar* d_B, CeedScalar* B) {
21*9e201c85SYohann   CeedInt tid = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x;
22*9e201c85SYohann   for (CeedInt i = tid; i < SIZE; i += blockDim.x*blockDim.y*blockDim.z)
23*9e201c85SYohann     B[i] = d_B[i];
24*9e201c85SYohann }
25*9e201c85SYohann 
26*9e201c85SYohann //------------------------------------------------------------------------------
27*9e201c85SYohann // 1D
28*9e201c85SYohann //------------------------------------------------------------------------------
29*9e201c85SYohann 
30*9e201c85SYohann //------------------------------------------------------------------------------
31*9e201c85SYohann // E-vector -> single element
32*9e201c85SYohann //------------------------------------------------------------------------------
33*9e201c85SYohann template <int NUM_COMP, int P_1D>
34*9e201c85SYohann inline __device__ void ReadElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
35*9e201c85SYohann   if (data.t_id_x < P_1D) {
36*9e201c85SYohann     const CeedInt node = data.t_id_x;
37*9e201c85SYohann     const CeedInt ind = node * strides_node + elem * strides_elem;
38*9e201c85SYohann     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
39*9e201c85SYohann       r_u[comp] = d_u[ind + comp * strides_comp];
40*9e201c85SYohann     }
41*9e201c85SYohann   }
42*9e201c85SYohann }
43*9e201c85SYohann 
44*9e201c85SYohann //------------------------------------------------------------------------------
45*9e201c85SYohann // Single element -> E-vector
46*9e201c85SYohann //------------------------------------------------------------------------------
47*9e201c85SYohann template <int NUM_COMP, int P_1D>
48*9e201c85SYohann inline __device__ void WriteElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
49*9e201c85SYohann   if (data.t_id_x < P_1D) {
50*9e201c85SYohann     const CeedInt node = data.t_id_x;
51*9e201c85SYohann     const CeedInt ind = node * strides_node + elem * strides_elem;
52*9e201c85SYohann     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
53*9e201c85SYohann       d_v[ind + comp * strides_comp] = r_v[comp];
54*9e201c85SYohann     }
55*9e201c85SYohann   }
56*9e201c85SYohann }
57*9e201c85SYohann 
58*9e201c85SYohann //------------------------------------------------------------------------------
59*9e201c85SYohann // 2D
60*9e201c85SYohann //------------------------------------------------------------------------------
61*9e201c85SYohann 
62*9e201c85SYohann //------------------------------------------------------------------------------
63*9e201c85SYohann // E-vector -> single element
64*9e201c85SYohann //------------------------------------------------------------------------------
65*9e201c85SYohann template <int NUM_COMP, int P_1D>
66*9e201c85SYohann inline __device__ void ReadElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
67*9e201c85SYohann   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
68*9e201c85SYohann     const CeedInt node = data.t_id_x + data.t_id_y*P_1D;
69*9e201c85SYohann     const CeedInt ind = node * strides_node + elem * strides_elem;
70*9e201c85SYohann     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
71*9e201c85SYohann       r_u[comp] = d_u[ind + comp * strides_comp];
72*9e201c85SYohann     }
73*9e201c85SYohann   }
74*9e201c85SYohann }
75*9e201c85SYohann 
76*9e201c85SYohann //------------------------------------------------------------------------------
77*9e201c85SYohann // Single element -> E-vector
78*9e201c85SYohann //------------------------------------------------------------------------------
79*9e201c85SYohann template <int NUM_COMP, int P_1D>
80*9e201c85SYohann inline __device__ void WriteElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
81*9e201c85SYohann   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
82*9e201c85SYohann     const CeedInt node = data.t_id_x + data.t_id_y*P_1D;
83*9e201c85SYohann     const CeedInt ind = node * strides_node + elem * strides_elem;
84*9e201c85SYohann     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
85*9e201c85SYohann       d_v[ind + comp * strides_comp] = r_v[comp];
86*9e201c85SYohann     }
87*9e201c85SYohann   }
88*9e201c85SYohann }
89*9e201c85SYohann 
90*9e201c85SYohann //------------------------------------------------------------------------------
91*9e201c85SYohann // 3D
92*9e201c85SYohann //------------------------------------------------------------------------------
93*9e201c85SYohann 
94*9e201c85SYohann //------------------------------------------------------------------------------
95*9e201c85SYohann // E-vector -> single element
96*9e201c85SYohann //------------------------------------------------------------------------------
97*9e201c85SYohann template <int NUM_COMP, int P_1D>
98*9e201c85SYohann inline __device__ void ReadElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
99*9e201c85SYohann   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
100*9e201c85SYohann     for (CeedInt z = 0; z < P_1D; z++) {
101*9e201c85SYohann       const CeedInt node = data.t_id_x + data.t_id_y*P_1D + z*P_1D*P_1D;
102*9e201c85SYohann       const CeedInt ind = node * strides_node + elem * strides_elem;
103*9e201c85SYohann       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
104*9e201c85SYohann         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
105*9e201c85SYohann       }
106*9e201c85SYohann     }
107*9e201c85SYohann   }
108*9e201c85SYohann }
109*9e201c85SYohann 
110*9e201c85SYohann //------------------------------------------------------------------------------
111*9e201c85SYohann // Single element -> E-vector
112*9e201c85SYohann //------------------------------------------------------------------------------
113*9e201c85SYohann template <int NUM_COMP, int P_1D>
114*9e201c85SYohann inline __device__ void WriteElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
115*9e201c85SYohann   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
116*9e201c85SYohann     for (CeedInt z = 0; z < P_1D; z++) {
117*9e201c85SYohann       const CeedInt node = data.t_id_x + data.t_id_y*P_1D + z*P_1D*P_1D;
118*9e201c85SYohann       const CeedInt ind = node * strides_node + elem * strides_elem;
119*9e201c85SYohann       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
120*9e201c85SYohann         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
121*9e201c85SYohann       }
122*9e201c85SYohann     }
123*9e201c85SYohann   }
124*9e201c85SYohann }
125*9e201c85SYohann 
126*9e201c85SYohann //------------------------------------------------------------------------------
127*9e201c85SYohann 
128*9e201c85SYohann #endif
129