1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors. 2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3bd882c8aSJames Wright // 4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5bd882c8aSJames Wright // 6bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7bd882c8aSJames Wright 8bd882c8aSJames Wright /// @file 9bd882c8aSJames Wright /// Internal header for SYCL shared memory tensor product basis 10c0b5abf0SJeremy L Thompson #include <ceed/types.h> 11bd882c8aSJames Wright 12bd882c8aSJames Wright #include "sycl-shared-basis-read-write-templates.h" 13bd882c8aSJames Wright #include "sycl-shared-basis-tensor-templates.h" 14bd882c8aSJames Wright 15bd882c8aSJames Wright // 16bd882c8aSJames Wright // BASIS_NUM_NODES = CeedIntPow(BASIS_P_1D,DIM) 17bd882c8aSJames Wright // BASIS_NUM_QPTS = CeedIntPow(BASIS_Q_1D,DIM) 18bd882c8aSJames Wright 19bd882c8aSJames Wright //------------------------------------------------------------------------------ 20bd882c8aSJames Wright // Interp kernel by dim 21bd882c8aSJames Wright //------------------------------------------------------------------------------ 22bd882c8aSJames Wright kernel void Interp(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_U, 23bd882c8aSJames Wright global CeedScalar *restrict d_V) { 24bd882c8aSJames Wright local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; 25bd882c8aSJames Wright private 26bd882c8aSJames Wright CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 27bd882c8aSJames Wright private 28bd882c8aSJames Wright CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 29bd882c8aSJames Wright 30bd882c8aSJames Wright local CeedScalar scratch[BASIS_INTERP_SCRATCH_SIZE]; 31bd882c8aSJames Wright local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 32bd882c8aSJames Wright 33bd882c8aSJames Wright loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 34bd882c8aSJames Wright work_group_barrier(CLK_LOCAL_MEM_FENCE); 35bd882c8aSJames Wright 36bd882c8aSJames Wright if (BASIS_DIM == 1) { 37bd882c8aSJames Wright ReadElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 38bd882c8aSJames Wright Interp1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 39bd882c8aSJames Wright WriteElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 40bd882c8aSJames Wright 41bd882c8aSJames Wright } else if (BASIS_DIM == 2) { 42bd882c8aSJames Wright ReadElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 43bd882c8aSJames Wright InterpTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 44bd882c8aSJames Wright WriteElementStrided2d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 45bd882c8aSJames Wright 46bd882c8aSJames Wright } else if (BASIS_DIM == 3) { 47bd882c8aSJames Wright ReadElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 48bd882c8aSJames Wright InterpTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 49bd882c8aSJames Wright WriteElementStrided3d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 50bd882c8aSJames Wright } 51bd882c8aSJames Wright } 52bd882c8aSJames Wright 53bd882c8aSJames Wright kernel void InterpTranspose(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_U, 54bd882c8aSJames Wright global CeedScalar *restrict d_V) { 55bd882c8aSJames Wright // local size: 56bd882c8aSJames Wright // 1d: elems_per_block * T_1d 57bd882c8aSJames Wright // 2d,3d: elems_per_block * T_1d * T_1d 58bd882c8aSJames Wright local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; 59bd882c8aSJames Wright private 60bd882c8aSJames Wright CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 61bd882c8aSJames Wright private 62bd882c8aSJames Wright CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 63bd882c8aSJames Wright 64bd882c8aSJames Wright local CeedScalar scratch[BASIS_INTERP_SCRATCH_SIZE]; 65bd882c8aSJames Wright local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 66bd882c8aSJames Wright 67bd882c8aSJames Wright loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 68bd882c8aSJames Wright work_group_barrier(CLK_LOCAL_MEM_FENCE); 69bd882c8aSJames Wright 70bd882c8aSJames Wright if (BASIS_DIM == 1) { 71bd882c8aSJames Wright ReadElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 72bd882c8aSJames Wright InterpTranspose1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 73bd882c8aSJames Wright WriteElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 74bd882c8aSJames Wright 75bd882c8aSJames Wright } else if (BASIS_DIM == 2) { 76bd882c8aSJames Wright ReadElementStrided2d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 77bd882c8aSJames Wright InterpTransposeTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 78bd882c8aSJames Wright WriteElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 79bd882c8aSJames Wright 80bd882c8aSJames Wright } else if (BASIS_DIM == 3) { 81bd882c8aSJames Wright ReadElementStrided3d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 82bd882c8aSJames Wright InterpTransposeTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 83bd882c8aSJames Wright WriteElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 84bd882c8aSJames Wright } 85bd882c8aSJames Wright } 86bd882c8aSJames Wright 87bd882c8aSJames Wright //------------------------------------------------------------------------------ 88bd882c8aSJames Wright // Grad kernel by dim 89bd882c8aSJames Wright //------------------------------------------------------------------------------ 90bd882c8aSJames Wright kernel void Grad(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_grad_1d, 91bd882c8aSJames Wright global const CeedScalar *restrict d_U, global CeedScalar *restrict d_V) { 92bd882c8aSJames Wright local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; // Todo, don't allocate s_B for dimension 1 93bd882c8aSJames Wright local CeedScalar s_G[BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)]; 94bd882c8aSJames Wright 95bd882c8aSJames Wright private 96bd882c8aSJames Wright CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 97bd882c8aSJames Wright private 98bd882c8aSJames Wright CeedScalar r_V[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 99bd882c8aSJames Wright 100bd882c8aSJames Wright local CeedScalar scratch[BASIS_GRAD_SCRATCH_SIZE]; 101bd882c8aSJames Wright local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 102bd882c8aSJames Wright 103bd882c8aSJames Wright loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 104bd882c8aSJames Wright loadMatrix(BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D), d_grad_1d, s_G); 105bd882c8aSJames Wright work_group_barrier(CLK_LOCAL_MEM_FENCE); 106bd882c8aSJames Wright 107bd882c8aSJames Wright if (BASIS_DIM == 1) { 108bd882c8aSJames Wright ReadElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 109bd882c8aSJames Wright Grad1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_G, r_V, elem_scratch); 110bd882c8aSJames Wright WriteElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 111bd882c8aSJames Wright 112bd882c8aSJames Wright } else if (BASIS_DIM == 2) { 113bd882c8aSJames Wright ReadElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 114bd882c8aSJames Wright GradTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 115bd882c8aSJames Wright WriteElementStrided2d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 116bd882c8aSJames Wright 117bd882c8aSJames Wright } else if (BASIS_DIM == 3) { 118bd882c8aSJames Wright ReadElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 119bd882c8aSJames Wright if (BASIS_HAS_COLLOCATED_GRAD) GradTensorCollocated3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 120bd882c8aSJames Wright else GradTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 121bd882c8aSJames Wright WriteElementStrided3d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 122bd882c8aSJames Wright } 123bd882c8aSJames Wright } 124bd882c8aSJames Wright 125bd882c8aSJames Wright kernel void GradTranspose(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_grad_1d, 126bd882c8aSJames Wright global const CeedScalar *restrict d_U, global CeedScalar *restrict d_V) { 127bd882c8aSJames Wright local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; // Todo, don't allocate s_B for dimension 1 128bd882c8aSJames Wright local CeedScalar s_G[BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)]; 129bd882c8aSJames Wright 130bd882c8aSJames Wright private 131bd882c8aSJames Wright CeedScalar r_U[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 132bd882c8aSJames Wright private 133bd882c8aSJames Wright CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 134bd882c8aSJames Wright 135bd882c8aSJames Wright local CeedScalar scratch[BASIS_GRAD_SCRATCH_SIZE]; 136bd882c8aSJames Wright local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 137bd882c8aSJames Wright 138bd882c8aSJames Wright loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 139bd882c8aSJames Wright loadMatrix(BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D), d_grad_1d, s_G); 140bd882c8aSJames Wright work_group_barrier(CLK_LOCAL_MEM_FENCE); 141bd882c8aSJames Wright 142bd882c8aSJames Wright if (BASIS_DIM == 1) { 143bd882c8aSJames Wright ReadElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 144bd882c8aSJames Wright GradTranspose1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_G, r_V, elem_scratch); 145bd882c8aSJames Wright WriteElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 146bd882c8aSJames Wright 147bd882c8aSJames Wright } else if (BASIS_DIM == 2) { 148bd882c8aSJames Wright ReadElementStrided2d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 149bd882c8aSJames Wright GradTransposeTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 150bd882c8aSJames Wright WriteElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 151bd882c8aSJames Wright 152bd882c8aSJames Wright } else if (BASIS_DIM == 3) { 153bd882c8aSJames Wright ReadElementStrided3d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 154bd882c8aSJames Wright if (BASIS_HAS_COLLOCATED_GRAD) GradTransposeTensorCollocated3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 155bd882c8aSJames Wright else GradTransposeTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 156bd882c8aSJames Wright WriteElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 157bd882c8aSJames Wright } 158bd882c8aSJames Wright } 159bd882c8aSJames Wright 160bd882c8aSJames Wright //------------------------------------------------------------------------------ 161bd882c8aSJames Wright // Weight kernels by dim 162bd882c8aSJames Wright //------------------------------------------------------------------------------ 163bd882c8aSJames Wright kernel void Weight(const CeedInt num_elem, global const CeedScalar *restrict q_weight_1d, global CeedScalar *restrict d_W) { 164bd882c8aSJames Wright private 165bd882c8aSJames Wright CeedScalar r_W[BASIS_DIM > 2 ? BASIS_Q_1D : 1]; 166bd882c8aSJames Wright 167bd882c8aSJames Wright // void prefetch(q_weight_1d,BASIS_Q_1D); 168bd882c8aSJames Wright 169bd882c8aSJames Wright if (BASIS_DIM == 1) { 170bd882c8aSJames Wright Weight1d(BASIS_Q_1D, q_weight_1d, r_W); 171bd882c8aSJames Wright WriteElementStrided1d(1, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_W, d_W); 172bd882c8aSJames Wright 173bd882c8aSJames Wright } else if (BASIS_DIM == 2) { 174bd882c8aSJames Wright WeightTensor2d(BASIS_Q_1D, q_weight_1d, r_W); 175bd882c8aSJames Wright WriteElementStrided2d(1, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_W, d_W); 176bd882c8aSJames Wright 177bd882c8aSJames Wright } else if (BASIS_DIM == 3) { 178bd882c8aSJames Wright WeightTensor3d(BASIS_Q_1D, q_weight_1d, r_W); 179bd882c8aSJames Wright WriteElementStrided3d(1, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_W, d_W); 180bd882c8aSJames Wright } 181bd882c8aSJames Wright } 182