xref: /libCEED/include/ceed/jit-source/magma/magma-basis-interp-deriv-nontensor.h (revision 86ad04ccfbb7dad1a7254c17df6ad1938164f618)
19d15e85bSSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
29d15e85bSSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
39d15e85bSSebastian Grimberg //
49d15e85bSSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
59d15e85bSSebastian Grimberg //
69d15e85bSSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
79d15e85bSSebastian Grimberg 
89d15e85bSSebastian Grimberg /// @file
99d15e85bSSebastian Grimberg /// Internal header for MAGMA non-tensor basis interpolation
109d15e85bSSebastian Grimberg #ifndef CEED_MAGMA_BASIS_INTERP_DERIV_NONTENSOR_H
119d15e85bSSebastian Grimberg #define CEED_MAGMA_BASIS_INTERP_DERIV_NONTENSOR_H
129d15e85bSSebastian Grimberg 
139d15e85bSSebastian Grimberg #include "magma-common-nontensor.h"
149d15e85bSSebastian Grimberg 
159d15e85bSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
169d15e85bSSebastian Grimberg template <typename T, int Q_COMP, int P, int Q, int NB>
179d15e85bSSebastian Grimberg static __device__ __inline__ void magma_basis_nontensor_device_n(const int n, CeedScalar const *dA, CeedScalar const *dB, CeedScalar *dC,
189d15e85bSSebastian Grimberg                                                                  CeedScalar *shared_data) {
199d15e85bSSebastian Grimberg   const int tx      = threadIdx.x;
209d15e85bSSebastian Grimberg   const int ty      = threadIdx.y;
219d15e85bSSebastian Grimberg   const int id      = blockIdx.x * blockDim.y + ty;
229d15e85bSSebastian Grimberg   const int nblocks = (n + NB - 1) / NB;
239d15e85bSSebastian Grimberg   const int myn     = min(NB, n - id * NB);
249d15e85bSSebastian Grimberg 
259d15e85bSSebastian Grimberg   dB += id * P * NB;
269d15e85bSSebastian Grimberg   dC += id * Q * NB;
279d15e85bSSebastian Grimberg 
289d15e85bSSebastian Grimberg   // A is P x Q
299d15e85bSSebastian Grimberg   CeedScalar *sB = shared_data + ty * P * NB;
309d15e85bSSebastian Grimberg   CeedScalar *sA = shared_data + blockDim.y * P * NB;
319d15e85bSSebastian Grimberg 
329d15e85bSSebastian Grimberg   // read B once for all C's
339d15e85bSSebastian Grimberg   if (id < nblocks) {
349d15e85bSSebastian Grimberg     read_B_g2s_1D_nosync<CeedScalar, Q, P, NB>(tx, myn, dB, sB);
359d15e85bSSebastian Grimberg   }
369d15e85bSSebastian Grimberg 
379d15e85bSSebastian Grimberg   // unrolling this loop yields dramatic performance drop using hipcc, so let the compiler decide (no pragma unroll)
389d15e85bSSebastian Grimberg   for (int d = 0; d < Q_COMP; d++) {
399d15e85bSSebastian Grimberg     // read A using all threads
40*86ad04ccSSebastian Grimberg     CeedScalar rA[P];
419d15e85bSSebastian Grimberg     read_A_trans_g2r_1D_nosync<CeedScalar, Q, P, MAGMA_BASIS_NTCOL(Q, MAGMA_MAXTHREADS_1D)>(tx, ty, dA, sA, rA);
429d15e85bSSebastian Grimberg 
43*86ad04ccSSebastian Grimberg     CeedScalar rC[NB];
449d15e85bSSebastian Grimberg     mul_rAsBrC_1D_nosync<CeedScalar, Q, P, NB>(rA, sB, rC);
459d15e85bSSebastian Grimberg 
469d15e85bSSebastian Grimberg     // write C
47833aa127SSebastian Grimberg     if (id < nblocks) {
489d15e85bSSebastian Grimberg       write_C_r2g_1D_nosync<CeedScalar, Q, P, NB>(tx, myn, rC, dC);
499d15e85bSSebastian Grimberg     }
509d15e85bSSebastian Grimberg 
519d15e85bSSebastian Grimberg     dA += Q * P;
529d15e85bSSebastian Grimberg     dC += Q * n;
539d15e85bSSebastian Grimberg 
549d15e85bSSebastian Grimberg     __syncthreads();
559d15e85bSSebastian Grimberg   }
569d15e85bSSebastian Grimberg }
579d15e85bSSebastian Grimberg 
589d15e85bSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
599d15e85bSSebastian Grimberg template <typename T, int Q_COMP, int P, int Q, int NB>
609d15e85bSSebastian Grimberg static __device__ __inline__ void magma_basis_nontensor_device_t(const int n, CeedScalar const *dA, CeedScalar const *dB, CeedScalar *dC,
619d15e85bSSebastian Grimberg                                                                  CeedScalar *shared_data) {
629d15e85bSSebastian Grimberg   const int tx      = threadIdx.x;
639d15e85bSSebastian Grimberg   const int ty      = threadIdx.y;
649d15e85bSSebastian Grimberg   const int id      = blockIdx.x * blockDim.y + ty;
659d15e85bSSebastian Grimberg   const int nblocks = (n + NB - 1) / NB;
669d15e85bSSebastian Grimberg   const int myn     = min(NB, n - id * NB);
679d15e85bSSebastian Grimberg 
689d15e85bSSebastian Grimberg   dB += id * Q * NB;
699d15e85bSSebastian Grimberg   dC += id * P * NB;
709d15e85bSSebastian Grimberg 
719d15e85bSSebastian Grimberg   // A is P x Q
72833aa127SSebastian Grimberg   CeedScalar *sA = shared_data;
739d15e85bSSebastian Grimberg   CeedScalar *sB = shared_data + ty * Q * NB;
749d15e85bSSebastian Grimberg 
759d15e85bSSebastian Grimberg   CeedScalar rC[NB] = {0.0};
769d15e85bSSebastian Grimberg 
779d15e85bSSebastian Grimberg   // unrolling this loop yields dramatic performance drop using hipcc, so let the compiler decide (no pragma unroll)
789d15e85bSSebastian Grimberg   for (int d = 0; d < Q_COMP; d++) {
79833aa127SSebastian Grimberg     // read A using all threads
80*86ad04ccSSebastian Grimberg     CeedScalar rA[Q];
81833aa127SSebastian Grimberg     read_A_notrans_g2r_1D_nosync<CeedScalar, P, Q, MAGMA_BASIS_NTCOL(P, MAGMA_MAXTHREADS_1D)>(tx, ty, dA, sA, rA);
82833aa127SSebastian Grimberg     __syncthreads();
839d15e85bSSebastian Grimberg 
849d15e85bSSebastian Grimberg     // read B
85833aa127SSebastian Grimberg     if (id < nblocks) {
869d15e85bSSebastian Grimberg       read_B_g2s_1D_nosync<CeedScalar, P, Q, NB>(tx, myn, dB, sB);
87833aa127SSebastian Grimberg     }
889d15e85bSSebastian Grimberg     __syncthreads();
899d15e85bSSebastian Grimberg 
909d15e85bSSebastian Grimberg     addmul_rAsBrC_1D_nosync<CeedScalar, P, Q, NB>(rA, sB, rC);
919d15e85bSSebastian Grimberg 
929d15e85bSSebastian Grimberg     dA += P * Q;
939d15e85bSSebastian Grimberg     dB += Q * n;
949d15e85bSSebastian Grimberg 
959d15e85bSSebastian Grimberg     __syncthreads();
969d15e85bSSebastian Grimberg   }
979d15e85bSSebastian Grimberg 
989d15e85bSSebastian Grimberg   // write C
99833aa127SSebastian Grimberg   if (id < nblocks) {
1009d15e85bSSebastian Grimberg     write_C_r2g_1D_nosync<CeedScalar, P, Q, NB>(tx, myn, rC, dC);
1019d15e85bSSebastian Grimberg   }
102833aa127SSebastian Grimberg }
1039d15e85bSSebastian Grimberg 
1049d15e85bSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
105*86ad04ccSSebastian Grimberg template <typename T, int P, int Q, int NB>
106*86ad04ccSSebastian Grimberg static __device__ __inline__ void magma_basis_nontensor_device_n1(const int n, CeedScalar const *dA, CeedScalar const *dB, CeedScalar *dC,
107*86ad04ccSSebastian Grimberg                                                                   CeedScalar *shared_data) {
108*86ad04ccSSebastian Grimberg   const int tx      = threadIdx.x;
109*86ad04ccSSebastian Grimberg   const int ty      = threadIdx.y;
110*86ad04ccSSebastian Grimberg   const int id      = blockIdx.x * blockDim.y + ty;
111*86ad04ccSSebastian Grimberg   const int nblocks = (n + NB - 1) / NB;
112*86ad04ccSSebastian Grimberg   const int myn     = min(NB, n - id * NB);
113*86ad04ccSSebastian Grimberg 
114*86ad04ccSSebastian Grimberg   dB += id * P * NB;
115*86ad04ccSSebastian Grimberg   dC += id * Q * NB;
116*86ad04ccSSebastian Grimberg 
117*86ad04ccSSebastian Grimberg   // A is P x Q
118*86ad04ccSSebastian Grimberg   CeedScalar *sA = shared_data;
119*86ad04ccSSebastian Grimberg   CeedScalar *sB = shared_data + ty * P * NB;
120*86ad04ccSSebastian Grimberg 
121*86ad04ccSSebastian Grimberg   // read A using all threads
122*86ad04ccSSebastian Grimberg   CeedScalar rA[P];
123*86ad04ccSSebastian Grimberg   read_A_trans_g2r_1D_nosync<CeedScalar, Q, P, MAGMA_BASIS_NTCOL(Q, MAGMA_MAXTHREADS_1D)>(tx, ty, dA, sA, rA);
124*86ad04ccSSebastian Grimberg   __syncthreads();
125*86ad04ccSSebastian Grimberg 
126*86ad04ccSSebastian Grimberg   // terminate threads with no work
127*86ad04ccSSebastian Grimberg   if (id >= nblocks) return;
128*86ad04ccSSebastian Grimberg 
129*86ad04ccSSebastian Grimberg   // read B
130*86ad04ccSSebastian Grimberg   read_B_g2s_1D_nosync<CeedScalar, Q, P, NB>(tx, myn, dB, sB);
131*86ad04ccSSebastian Grimberg   __syncthreads();
132*86ad04ccSSebastian Grimberg 
133*86ad04ccSSebastian Grimberg   CeedScalar rC[NB];
134*86ad04ccSSebastian Grimberg   mul_rAsBrC_1D_nosync<CeedScalar, Q, P, NB>(rA, sB, rC);
135*86ad04ccSSebastian Grimberg 
136*86ad04ccSSebastian Grimberg   // write C
137*86ad04ccSSebastian Grimberg   write_C_r2g_1D_nosync<CeedScalar, Q, P, NB>(tx, myn, rC, dC);
138*86ad04ccSSebastian Grimberg }
139*86ad04ccSSebastian Grimberg 
140*86ad04ccSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
141*86ad04ccSSebastian Grimberg template <typename T, int P, int Q, int NB>
142*86ad04ccSSebastian Grimberg static __device__ __inline__ void magma_basis_nontensor_device_t1(const int n, CeedScalar const *dA, CeedScalar const *dB, CeedScalar *dC,
143*86ad04ccSSebastian Grimberg                                                                   CeedScalar *shared_data) {
144*86ad04ccSSebastian Grimberg   const int tx      = threadIdx.x;
145*86ad04ccSSebastian Grimberg   const int ty      = threadIdx.y;
146*86ad04ccSSebastian Grimberg   const int id      = blockIdx.x * blockDim.y + ty;
147*86ad04ccSSebastian Grimberg   const int nblocks = (n + NB - 1) / NB;
148*86ad04ccSSebastian Grimberg   const int myn     = min(NB, n - id * NB);
149*86ad04ccSSebastian Grimberg 
150*86ad04ccSSebastian Grimberg   dB += id * Q * NB;
151*86ad04ccSSebastian Grimberg   dC += id * P * NB;
152*86ad04ccSSebastian Grimberg 
153*86ad04ccSSebastian Grimberg   // A is P x Q
154*86ad04ccSSebastian Grimberg   CeedScalar *sA = shared_data;
155*86ad04ccSSebastian Grimberg   CeedScalar *sB = shared_data + ty * Q * NB;
156*86ad04ccSSebastian Grimberg 
157*86ad04ccSSebastian Grimberg   // read A using all threads
158*86ad04ccSSebastian Grimberg   CeedScalar rA[Q];
159*86ad04ccSSebastian Grimberg   read_A_notrans_g2r_1D_nosync<CeedScalar, P, Q, MAGMA_BASIS_NTCOL(P, MAGMA_MAXTHREADS_1D)>(tx, ty, dA, sA, rA);
160*86ad04ccSSebastian Grimberg   __syncthreads();
161*86ad04ccSSebastian Grimberg 
162*86ad04ccSSebastian Grimberg   // terminate threads with no work
163*86ad04ccSSebastian Grimberg   if (id >= nblocks) return;
164*86ad04ccSSebastian Grimberg 
165*86ad04ccSSebastian Grimberg   // read B
166*86ad04ccSSebastian Grimberg   read_B_g2s_1D_nosync<CeedScalar, P, Q, NB>(tx, myn, dB, sB);
167*86ad04ccSSebastian Grimberg   __syncthreads();
168*86ad04ccSSebastian Grimberg 
169*86ad04ccSSebastian Grimberg   CeedScalar rC[NB];
170*86ad04ccSSebastian Grimberg   mul_rAsBrC_1D_nosync<CeedScalar, P, Q, NB>(rA, sB, rC);
171*86ad04ccSSebastian Grimberg 
172*86ad04ccSSebastian Grimberg   // write C
173*86ad04ccSSebastian Grimberg   write_C_r2g_1D_nosync<CeedScalar, P, Q, NB>(tx, myn, rC, dC);
174*86ad04ccSSebastian Grimberg }
175*86ad04ccSSebastian Grimberg 
176*86ad04ccSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
1779d15e85bSSebastian Grimberg extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_Q, MAGMA_MAXTHREADS_1D)) __global__
1789d15e85bSSebastian Grimberg     void magma_interp_nontensor_n(const int n, CeedScalar const *__restrict__ dA, CeedScalar const *__restrict__ dB, CeedScalar *__restrict__ dC) {
1799d15e85bSSebastian Grimberg   MAGMA_DEVICE_SHARED(CeedScalar, shared_data);
1809d15e85bSSebastian Grimberg 
181*86ad04ccSSebastian Grimberg #if BASIS_Q_COMP_INTERP == 1
1829d15e85bSSebastian Grimberg   magma_basis_nontensor_device_n1<CeedScalar, BASIS_P, BASIS_Q, BASIS_NB_INTERP_N>(n, dA, dB, dC, (CeedScalar *)shared_data);
183*86ad04ccSSebastian Grimberg #else
1849d15e85bSSebastian Grimberg   magma_basis_nontensor_device_n<CeedScalar, BASIS_Q_COMP_INTERP, BASIS_P, BASIS_Q, BASIS_NB_INTERP_N>(n, dA, dB, dC, (CeedScalar *)shared_data);
185*86ad04ccSSebastian Grimberg #endif
1869d15e85bSSebastian Grimberg }
1879d15e85bSSebastian Grimberg 
1889d15e85bSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
1899d15e85bSSebastian Grimberg extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_P, MAGMA_MAXTHREADS_1D)) __global__
1909d15e85bSSebastian Grimberg     void magma_interp_nontensor_t(const int n, CeedScalar const *__restrict__ dA, CeedScalar const *__restrict__ dB, CeedScalar *__restrict__ dC) {
1919d15e85bSSebastian Grimberg   MAGMA_DEVICE_SHARED(CeedScalar, shared_data);
1929d15e85bSSebastian Grimberg 
193*86ad04ccSSebastian Grimberg #if BASIS_Q_COMP_INTERP == 1
194*86ad04ccSSebastian Grimberg   magma_basis_nontensor_device_t1<CeedScalar, BASIS_P, BASIS_Q, BASIS_NB_INTERP_T>(n, dA, dB, dC, (CeedScalar *)shared_data);
195*86ad04ccSSebastian Grimberg #else
1969d15e85bSSebastian Grimberg   magma_basis_nontensor_device_t<CeedScalar, BASIS_Q_COMP_INTERP, BASIS_P, BASIS_Q, BASIS_NB_INTERP_T>(n, dA, dB, dC, (CeedScalar *)shared_data);
197*86ad04ccSSebastian Grimberg #endif
1989d15e85bSSebastian Grimberg }
1999d15e85bSSebastian Grimberg 
2009d15e85bSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
2019d15e85bSSebastian Grimberg extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_Q, MAGMA_MAXTHREADS_1D)) __global__
2029d15e85bSSebastian Grimberg     void magma_deriv_nontensor_n(const int n, CeedScalar const *__restrict__ dA, CeedScalar const *__restrict__ dB, CeedScalar *__restrict__ dC) {
2039d15e85bSSebastian Grimberg   MAGMA_DEVICE_SHARED(CeedScalar, shared_data);
2049d15e85bSSebastian Grimberg 
205*86ad04ccSSebastian Grimberg #if BASIS_Q_COMP_DERIV == 1
2069d15e85bSSebastian Grimberg   magma_basis_nontensor_device_n1<CeedScalar, BASIS_P, BASIS_Q, BASIS_NB_DERIV_N>(n, dA, dB, dC, (CeedScalar *)shared_data);
207*86ad04ccSSebastian Grimberg #else
2089d15e85bSSebastian Grimberg   magma_basis_nontensor_device_n<CeedScalar, BASIS_Q_COMP_DERIV, BASIS_P, BASIS_Q, BASIS_NB_DERIV_N>(n, dA, dB, dC, (CeedScalar *)shared_data);
209*86ad04ccSSebastian Grimberg #endif
2109d15e85bSSebastian Grimberg }
2119d15e85bSSebastian Grimberg 
2129d15e85bSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
2139d15e85bSSebastian Grimberg extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_P, MAGMA_MAXTHREADS_1D)) __global__
2149d15e85bSSebastian Grimberg     void magma_deriv_nontensor_t(const int n, CeedScalar const *__restrict__ dA, CeedScalar const *__restrict__ dB, CeedScalar *__restrict__ dC) {
2159d15e85bSSebastian Grimberg   MAGMA_DEVICE_SHARED(CeedScalar, shared_data);
2169d15e85bSSebastian Grimberg 
217*86ad04ccSSebastian Grimberg #if BASIS_Q_COMP_DERIV == 1
218*86ad04ccSSebastian Grimberg   magma_basis_nontensor_device_t1<CeedScalar, BASIS_P, BASIS_Q, BASIS_NB_DERIV_T>(n, dA, dB, dC, (CeedScalar *)shared_data);
219*86ad04ccSSebastian Grimberg #else
2209d15e85bSSebastian Grimberg   magma_basis_nontensor_device_t<CeedScalar, BASIS_Q_COMP_DERIV, BASIS_P, BASIS_Q, BASIS_NB_DERIV_T>(n, dA, dB, dC, (CeedScalar *)shared_data);
221*86ad04ccSSebastian Grimberg #endif
2229d15e85bSSebastian Grimberg }
2239d15e85bSSebastian Grimberg 
2249d15e85bSSebastian Grimberg #endif  // CEED_MAGMA_BASIS_INTERP_DERIV_NONTENSOR_H
225