1*9e201c85SYohann // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*9e201c85SYohann // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*9e201c85SYohann // 4*9e201c85SYohann // SPDX-License-Identifier: BSD-2-Clause 5*9e201c85SYohann // 6*9e201c85SYohann // This file is part of CEED: http://github.com/ceed 7*9e201c85SYohann 8*9e201c85SYohann /// @file 9*9e201c85SYohann /// Internal header for HIP shared memory basis read/write templates 10*9e201c85SYohann #ifndef _ceed_hip_shared_basis_read_write_templates_h 11*9e201c85SYohann #define _ceed_hip_shared_basis_read_write_templates_h 12*9e201c85SYohann 13*9e201c85SYohann #include <ceed.h> 14*9e201c85SYohann 15*9e201c85SYohann 16*9e201c85SYohann //------------------------------------------------------------------------------ 17*9e201c85SYohann // Helper function: load matrices for basis actions 18*9e201c85SYohann //------------------------------------------------------------------------------ 19*9e201c85SYohann template <int SIZE> 20*9e201c85SYohann inline __device__ void loadMatrix(const CeedScalar* d_B, CeedScalar* B) { 21*9e201c85SYohann CeedInt tid = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x; 22*9e201c85SYohann for (CeedInt i = tid; i < SIZE; i += blockDim.x*blockDim.y*blockDim.z) 23*9e201c85SYohann B[i] = d_B[i]; 24*9e201c85SYohann } 25*9e201c85SYohann 26*9e201c85SYohann //------------------------------------------------------------------------------ 27*9e201c85SYohann // 1D 28*9e201c85SYohann //------------------------------------------------------------------------------ 29*9e201c85SYohann 30*9e201c85SYohann //------------------------------------------------------------------------------ 31*9e201c85SYohann // E-vector -> single element 32*9e201c85SYohann //------------------------------------------------------------------------------ 33*9e201c85SYohann template <int NUM_COMP, int P_1D> 34*9e201c85SYohann inline __device__ void ReadElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) { 35*9e201c85SYohann if (data.t_id_x < P_1D) { 36*9e201c85SYohann const CeedInt node = data.t_id_x; 37*9e201c85SYohann const CeedInt ind = node * strides_node + elem * strides_elem; 38*9e201c85SYohann for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 39*9e201c85SYohann r_u[comp] = d_u[ind + comp * strides_comp]; 40*9e201c85SYohann } 41*9e201c85SYohann } 42*9e201c85SYohann } 43*9e201c85SYohann 44*9e201c85SYohann //------------------------------------------------------------------------------ 45*9e201c85SYohann // Single element -> E-vector 46*9e201c85SYohann //------------------------------------------------------------------------------ 47*9e201c85SYohann template <int NUM_COMP, int P_1D> 48*9e201c85SYohann inline __device__ void WriteElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) { 49*9e201c85SYohann if (data.t_id_x < P_1D) { 50*9e201c85SYohann const CeedInt node = data.t_id_x; 51*9e201c85SYohann const CeedInt ind = node * strides_node + elem * strides_elem; 52*9e201c85SYohann for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 53*9e201c85SYohann d_v[ind + comp * strides_comp] = r_v[comp]; 54*9e201c85SYohann } 55*9e201c85SYohann } 56*9e201c85SYohann } 57*9e201c85SYohann 58*9e201c85SYohann //------------------------------------------------------------------------------ 59*9e201c85SYohann // 2D 60*9e201c85SYohann //------------------------------------------------------------------------------ 61*9e201c85SYohann 62*9e201c85SYohann //------------------------------------------------------------------------------ 63*9e201c85SYohann // E-vector -> single element 64*9e201c85SYohann //------------------------------------------------------------------------------ 65*9e201c85SYohann template <int NUM_COMP, int P_1D> 66*9e201c85SYohann inline __device__ void ReadElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) { 67*9e201c85SYohann if (data.t_id_x < P_1D && data.t_id_y < P_1D) { 68*9e201c85SYohann const CeedInt node = data.t_id_x + data.t_id_y*P_1D; 69*9e201c85SYohann const CeedInt ind = node * strides_node + elem * strides_elem; 70*9e201c85SYohann for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 71*9e201c85SYohann r_u[comp] = d_u[ind + comp * strides_comp]; 72*9e201c85SYohann } 73*9e201c85SYohann } 74*9e201c85SYohann } 75*9e201c85SYohann 76*9e201c85SYohann //------------------------------------------------------------------------------ 77*9e201c85SYohann // Single element -> E-vector 78*9e201c85SYohann //------------------------------------------------------------------------------ 79*9e201c85SYohann template <int NUM_COMP, int P_1D> 80*9e201c85SYohann inline __device__ void WriteElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) { 81*9e201c85SYohann if (data.t_id_x < P_1D && data.t_id_y < P_1D) { 82*9e201c85SYohann const CeedInt node = data.t_id_x + data.t_id_y*P_1D; 83*9e201c85SYohann const CeedInt ind = node * strides_node + elem * strides_elem; 84*9e201c85SYohann for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 85*9e201c85SYohann d_v[ind + comp * strides_comp] = r_v[comp]; 86*9e201c85SYohann } 87*9e201c85SYohann } 88*9e201c85SYohann } 89*9e201c85SYohann 90*9e201c85SYohann //------------------------------------------------------------------------------ 91*9e201c85SYohann // 3D 92*9e201c85SYohann //------------------------------------------------------------------------------ 93*9e201c85SYohann 94*9e201c85SYohann //------------------------------------------------------------------------------ 95*9e201c85SYohann // E-vector -> single element 96*9e201c85SYohann //------------------------------------------------------------------------------ 97*9e201c85SYohann template <int NUM_COMP, int P_1D> 98*9e201c85SYohann inline __device__ void ReadElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) { 99*9e201c85SYohann if (data.t_id_x < P_1D && data.t_id_y < P_1D) { 100*9e201c85SYohann for (CeedInt z = 0; z < P_1D; z++) { 101*9e201c85SYohann const CeedInt node = data.t_id_x + data.t_id_y*P_1D + z*P_1D*P_1D; 102*9e201c85SYohann const CeedInt ind = node * strides_node + elem * strides_elem; 103*9e201c85SYohann for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 104*9e201c85SYohann r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp]; 105*9e201c85SYohann } 106*9e201c85SYohann } 107*9e201c85SYohann } 108*9e201c85SYohann } 109*9e201c85SYohann 110*9e201c85SYohann //------------------------------------------------------------------------------ 111*9e201c85SYohann // Single element -> E-vector 112*9e201c85SYohann //------------------------------------------------------------------------------ 113*9e201c85SYohann template <int NUM_COMP, int P_1D> 114*9e201c85SYohann inline __device__ void WriteElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) { 115*9e201c85SYohann if (data.t_id_x < P_1D && data.t_id_y < P_1D) { 116*9e201c85SYohann for (CeedInt z = 0; z < P_1D; z++) { 117*9e201c85SYohann const CeedInt node = data.t_id_x + data.t_id_y*P_1D + z*P_1D*P_1D; 118*9e201c85SYohann const CeedInt ind = node * strides_node + elem * strides_elem; 119*9e201c85SYohann for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 120*9e201c85SYohann d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D]; 121*9e201c85SYohann } 122*9e201c85SYohann } 123*9e201c85SYohann } 124*9e201c85SYohann } 125*9e201c85SYohann 126*9e201c85SYohann //------------------------------------------------------------------------------ 127*9e201c85SYohann 128*9e201c85SYohann #endif 129