1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors. 2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3bd882c8aSJames Wright // 4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5bd882c8aSJames Wright // 6bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7bd882c8aSJames Wright 8bd882c8aSJames Wright /// @file 9bd882c8aSJames Wright /// Internal header for SYCL shared memory basis read/write templates 10c0b5abf0SJeremy L Thompson #include <ceed/types.h> 11bd882c8aSJames Wright 12bd882c8aSJames Wright //------------------------------------------------------------------------------ 13bd882c8aSJames Wright // Helper function: load matrices for basis actions 14bd882c8aSJames Wright //------------------------------------------------------------------------------ 15bd882c8aSJames Wright inline void loadMatrix(const CeedInt N, const CeedScalar *restrict d_B, CeedScalar *restrict B) { 16bd882c8aSJames Wright const CeedInt item_id = get_local_linear_id(); 17bd882c8aSJames Wright const CeedInt group_size = get_local_size(0) * get_local_size(1) * get_local_size(2); 18bd882c8aSJames Wright for (CeedInt i = item_id; i < N; i += group_size) B[i] = d_B[i]; 19bd882c8aSJames Wright } 20bd882c8aSJames Wright 21bd882c8aSJames Wright //------------------------------------------------------------------------------ 22bd882c8aSJames Wright // 1D 23bd882c8aSJames Wright //------------------------------------------------------------------------------ 24bd882c8aSJames Wright 25bd882c8aSJames Wright //------------------------------------------------------------------------------ 26bd882c8aSJames Wright // E-vector -> single element 27bd882c8aSJames Wright //------------------------------------------------------------------------------ 28bd882c8aSJames Wright inline void ReadElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 29bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u, 30bd882c8aSJames Wright private CeedScalar *restrict r_u) { 31bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 32bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 33bd882c8aSJames Wright 34bd882c8aSJames Wright if (item_id_x < P_1D && elem < num_elem) { 35bd882c8aSJames Wright const CeedInt node = item_id_x; 36bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 37bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 38bd882c8aSJames Wright r_u[comp] = d_u[ind + comp * strides_comp]; 39bd882c8aSJames Wright } 40bd882c8aSJames Wright } 41bd882c8aSJames Wright } 42bd882c8aSJames Wright 43bd882c8aSJames Wright //------------------------------------------------------------------------------ 44bd882c8aSJames Wright // Single element -> E-vector 45bd882c8aSJames Wright //------------------------------------------------------------------------------ 46bd882c8aSJames Wright inline void WriteElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 47bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v, 48bd882c8aSJames Wright global CeedScalar *restrict d_v) { 49bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 50bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 51bd882c8aSJames Wright 52bd882c8aSJames Wright if (item_id_x < P_1D && elem < num_elem) { 53bd882c8aSJames Wright const CeedInt node = item_id_x; 54bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 55bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 56bd882c8aSJames Wright d_v[ind + comp * strides_comp] = r_v[comp]; 57bd882c8aSJames Wright } 58bd882c8aSJames Wright } 59bd882c8aSJames Wright } 60bd882c8aSJames Wright 61bd882c8aSJames Wright //------------------------------------------------------------------------------ 62bd882c8aSJames Wright // 2D 63bd882c8aSJames Wright //------------------------------------------------------------------------------ 64bd882c8aSJames Wright 65bd882c8aSJames Wright //------------------------------------------------------------------------------ 66bd882c8aSJames Wright // E-vector -> single element 67bd882c8aSJames Wright //------------------------------------------------------------------------------ 68bd882c8aSJames Wright inline void ReadElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 69bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u, 70bd882c8aSJames Wright private CeedScalar *restrict r_u) { 71bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 72bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1); 73bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 74bd882c8aSJames Wright 75bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 76bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D; 77bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 78bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 79bd882c8aSJames Wright r_u[comp] = d_u[ind + comp * strides_comp]; 80bd882c8aSJames Wright } 81bd882c8aSJames Wright } 82bd882c8aSJames Wright } 83bd882c8aSJames Wright 84bd882c8aSJames Wright //------------------------------------------------------------------------------ 85bd882c8aSJames Wright // Single element -> E-vector 86bd882c8aSJames Wright //------------------------------------------------------------------------------ 87bd882c8aSJames Wright inline void WriteElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 88bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v, 89bd882c8aSJames Wright global CeedScalar *restrict d_v) { 90bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 91bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1); 92bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 93bd882c8aSJames Wright 94bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 95bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D; 96bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 97bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 98bd882c8aSJames Wright d_v[ind + comp * strides_comp] = r_v[comp]; 99bd882c8aSJames Wright } 100bd882c8aSJames Wright } 101bd882c8aSJames Wright } 102bd882c8aSJames Wright 103bd882c8aSJames Wright //------------------------------------------------------------------------------ 104bd882c8aSJames Wright // 3D 105bd882c8aSJames Wright //------------------------------------------------------------------------------ 106bd882c8aSJames Wright 107bd882c8aSJames Wright //------------------------------------------------------------------------------ 108bd882c8aSJames Wright // E-vector -> single element 109bd882c8aSJames Wright //------------------------------------------------------------------------------ 110bd882c8aSJames Wright inline void ReadElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 111bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u, 112bd882c8aSJames Wright private CeedScalar *restrict r_u) { 113bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 114bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1); 115bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 116bd882c8aSJames Wright 117bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 118bd882c8aSJames Wright for (CeedInt z = 0; z < P_1D; z++) { 119bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D; 120bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 121bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 122bd882c8aSJames Wright r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp]; 123bd882c8aSJames Wright } 124bd882c8aSJames Wright } 125bd882c8aSJames Wright } 126bd882c8aSJames Wright } 127bd882c8aSJames Wright 128bd882c8aSJames Wright //------------------------------------------------------------------------------ 129bd882c8aSJames Wright // Single element -> E-vector 130bd882c8aSJames Wright //------------------------------------------------------------------------------ 131bd882c8aSJames Wright inline void WriteElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 132bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v, 133bd882c8aSJames Wright global CeedScalar *restrict d_v) { 134bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 135bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1); 136bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 137bd882c8aSJames Wright 138bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 139bd882c8aSJames Wright for (CeedInt z = 0; z < P_1D; z++) { 140bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D; 141bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 142bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 143bd882c8aSJames Wright d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D]; 144bd882c8aSJames Wright } 145bd882c8aSJames Wright } 146bd882c8aSJames Wright } 147bd882c8aSJames Wright } 148