1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*bd882c8aSJames Wright // 4*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5*bd882c8aSJames Wright // 6*bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7*bd882c8aSJames Wright 8*bd882c8aSJames Wright /// @file 9*bd882c8aSJames Wright /// Internal header for SYCL shared memory basis read/write templates 10*bd882c8aSJames Wright #ifndef _ceed_sycl_shared_basis_read_write_templates_h 11*bd882c8aSJames Wright #define _ceed_sycl_shared_basis_read_write_templates_h 12*bd882c8aSJames Wright 13*bd882c8aSJames Wright #include <ceed.h> 14*bd882c8aSJames Wright #include "sycl-types.h" 15*bd882c8aSJames Wright 16*bd882c8aSJames Wright //------------------------------------------------------------------------------ 17*bd882c8aSJames Wright // Helper function: load matrices for basis actions 18*bd882c8aSJames Wright //------------------------------------------------------------------------------ 19*bd882c8aSJames Wright inline void loadMatrix(const CeedInt N, const CeedScalar* restrict d_B, CeedScalar* restrict B) { 20*bd882c8aSJames Wright const CeedInt item_id = get_local_linear_id(); 21*bd882c8aSJames Wright const CeedInt group_size = get_local_size(0) * get_local_size(1) * get_local_size(2); 22*bd882c8aSJames Wright for (CeedInt i = item_id; i < N; i += group_size) B[i] = d_B[i]; 23*bd882c8aSJames Wright } 24*bd882c8aSJames Wright 25*bd882c8aSJames Wright //------------------------------------------------------------------------------ 26*bd882c8aSJames Wright // 1D 27*bd882c8aSJames Wright //------------------------------------------------------------------------------ 28*bd882c8aSJames Wright 29*bd882c8aSJames Wright //------------------------------------------------------------------------------ 30*bd882c8aSJames Wright // E-vector -> single element 31*bd882c8aSJames Wright //------------------------------------------------------------------------------ 32*bd882c8aSJames Wright inline void ReadElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 33*bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar* restrict d_u, 34*bd882c8aSJames Wright private CeedScalar* restrict r_u) { 35*bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 36*bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 37*bd882c8aSJames Wright 38*bd882c8aSJames Wright if (item_id_x < P_1D && elem < num_elem) { 39*bd882c8aSJames Wright const CeedInt node = item_id_x; 40*bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 41*bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 42*bd882c8aSJames Wright r_u[comp] = d_u[ind + comp * strides_comp]; 43*bd882c8aSJames Wright } 44*bd882c8aSJames Wright } 45*bd882c8aSJames Wright } 46*bd882c8aSJames Wright 47*bd882c8aSJames Wright //------------------------------------------------------------------------------ 48*bd882c8aSJames Wright // Single element -> E-vector 49*bd882c8aSJames Wright //------------------------------------------------------------------------------ 50*bd882c8aSJames Wright inline void WriteElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 51*bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar* restrict r_v, 52*bd882c8aSJames Wright global CeedScalar* restrict d_v) { 53*bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 54*bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 55*bd882c8aSJames Wright 56*bd882c8aSJames Wright if (item_id_x < P_1D && elem < num_elem) { 57*bd882c8aSJames Wright const CeedInt node = item_id_x; 58*bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 59*bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 60*bd882c8aSJames Wright d_v[ind + comp * strides_comp] = r_v[comp]; 61*bd882c8aSJames Wright } 62*bd882c8aSJames Wright } 63*bd882c8aSJames Wright } 64*bd882c8aSJames Wright 65*bd882c8aSJames Wright //------------------------------------------------------------------------------ 66*bd882c8aSJames Wright // 2D 67*bd882c8aSJames Wright //------------------------------------------------------------------------------ 68*bd882c8aSJames Wright 69*bd882c8aSJames Wright //------------------------------------------------------------------------------ 70*bd882c8aSJames Wright // E-vector -> single element 71*bd882c8aSJames Wright //------------------------------------------------------------------------------ 72*bd882c8aSJames Wright inline void ReadElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 73*bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar* restrict d_u, 74*bd882c8aSJames Wright private CeedScalar* restrict r_u) { 75*bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 76*bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1); 77*bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 78*bd882c8aSJames Wright 79*bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 80*bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D; 81*bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 82*bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 83*bd882c8aSJames Wright r_u[comp] = d_u[ind + comp * strides_comp]; 84*bd882c8aSJames Wright } 85*bd882c8aSJames Wright } 86*bd882c8aSJames Wright } 87*bd882c8aSJames Wright 88*bd882c8aSJames Wright //------------------------------------------------------------------------------ 89*bd882c8aSJames Wright // Single element -> E-vector 90*bd882c8aSJames Wright //------------------------------------------------------------------------------ 91*bd882c8aSJames Wright inline void WriteElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 92*bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar* restrict r_v, 93*bd882c8aSJames Wright global CeedScalar* restrict d_v) { 94*bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 95*bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1); 96*bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 97*bd882c8aSJames Wright 98*bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 99*bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D; 100*bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 101*bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 102*bd882c8aSJames Wright d_v[ind + comp * strides_comp] = r_v[comp]; 103*bd882c8aSJames Wright } 104*bd882c8aSJames Wright } 105*bd882c8aSJames Wright } 106*bd882c8aSJames Wright 107*bd882c8aSJames Wright //------------------------------------------------------------------------------ 108*bd882c8aSJames Wright // 3D 109*bd882c8aSJames Wright //------------------------------------------------------------------------------ 110*bd882c8aSJames Wright 111*bd882c8aSJames Wright //------------------------------------------------------------------------------ 112*bd882c8aSJames Wright // E-vector -> single element 113*bd882c8aSJames Wright //------------------------------------------------------------------------------ 114*bd882c8aSJames Wright inline void ReadElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 115*bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar* restrict d_u, 116*bd882c8aSJames Wright private CeedScalar* restrict r_u) { 117*bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 118*bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1); 119*bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 120*bd882c8aSJames Wright 121*bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 122*bd882c8aSJames Wright for (CeedInt z = 0; z < P_1D; z++) { 123*bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D; 124*bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 125*bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 126*bd882c8aSJames Wright r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp]; 127*bd882c8aSJames Wright } 128*bd882c8aSJames Wright } 129*bd882c8aSJames Wright } 130*bd882c8aSJames Wright } 131*bd882c8aSJames Wright 132*bd882c8aSJames Wright //------------------------------------------------------------------------------ 133*bd882c8aSJames Wright // Single element -> E-vector 134*bd882c8aSJames Wright //------------------------------------------------------------------------------ 135*bd882c8aSJames Wright inline void WriteElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 136*bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar* restrict r_v, 137*bd882c8aSJames Wright global CeedScalar* restrict d_v) { 138*bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 139*bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1); 140*bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 141*bd882c8aSJames Wright 142*bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 143*bd882c8aSJames Wright for (CeedInt z = 0; z < P_1D; z++) { 144*bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D; 145*bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 146*bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 147*bd882c8aSJames Wright d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D]; 148*bd882c8aSJames Wright } 149*bd882c8aSJames Wright } 150*bd882c8aSJames Wright } 151*bd882c8aSJames Wright } 152*bd882c8aSJames Wright 153*bd882c8aSJames Wright //------------------------------------------------------------------------------ 154*bd882c8aSJames Wright 155*bd882c8aSJames Wright #endif 156