1f80f4a74SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3f80f4a74SSebastian Grimberg // 4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5f80f4a74SSebastian Grimberg // 6f80f4a74SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7f80f4a74SSebastian Grimberg 83c1e2affSSebastian Grimberg /// @file 93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common non-tensor basis definitions 10f80f4a74SSebastian Grimberg #ifndef CEED_MAGMA_COMMON_NONTENSOR_H 11f80f4a74SSebastian Grimberg #define CEED_MAGMA_COMMON_NONTENSOR_H 12f80f4a74SSebastian Grimberg 133c1e2affSSebastian Grimberg #include "magma-common-defs.h" 14f80f4a74SSebastian Grimberg 15f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 16f80f4a74SSebastian Grimberg // read A (no-trans) from global to reg. 173c1e2affSSebastian Grimberg // A is (P x Q) 183c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads 19f80f4a74SSebastian Grimberg // no sync at the end of the function 203c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 21*9d15e85bSSebastian Grimberg static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const T *dA, T rA[Q]) { 22f80f4a74SSebastian Grimberg #pragma unroll 23*9d15e85bSSebastian Grimberg for (int i = 0; i < Q; i++) { 24*9d15e85bSSebastian Grimberg rA[i] = dA[i * P + tx]; 25f80f4a74SSebastian Grimberg } 26f80f4a74SSebastian Grimberg } 27f80f4a74SSebastian Grimberg 28f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 293c1e2affSSebastian Grimberg // read A (trans) from global to reg. 303c1e2affSSebastian Grimberg // A is (P x Q) 31*9d15e85bSSebastian Grimberg // 2D thread config. with (P x BY) threads 32f80f4a74SSebastian Grimberg // no sync at the end of the function 33*9d15e85bSSebastian Grimberg template <typename T, int P, int Q, int BY> 34*9d15e85bSSebastian Grimberg static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) { 35f80f4a74SSebastian Grimberg const int tid = ty * blockDim.x + tx; 363c1e2affSSebastian Grimberg int i; 37f80f4a74SSebastian Grimberg 38f80f4a74SSebastian Grimberg #pragma unroll 39*9d15e85bSSebastian Grimberg for (i = 0; i < P * Q - P * BY; i += P * BY) { 403c1e2affSSebastian Grimberg sA[i + tid] = dA[i + tid]; 41f80f4a74SSebastian Grimberg } 42*9d15e85bSSebastian Grimberg if (i + tid < P * Q) { 433c1e2affSSebastian Grimberg sA[i + tid] = dA[i + tid]; 44f80f4a74SSebastian Grimberg } 45f80f4a74SSebastian Grimberg __syncthreads(); 46f80f4a74SSebastian Grimberg 47f80f4a74SSebastian Grimberg #pragma unroll 483c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 49*9d15e85bSSebastian Grimberg rA[j] = sA[tx * Q + j]; 50f80f4a74SSebastian Grimberg } 51f80f4a74SSebastian Grimberg } 52f80f4a74SSebastian Grimberg 53f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 54f80f4a74SSebastian Grimberg // read B from global to shared 553c1e2affSSebastian Grimberg // B is (Q x NB) 563c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads 57f80f4a74SSebastian Grimberg // no sync at the end of the function 583c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 59*9d15e85bSSebastian Grimberg static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, T *sB) { 60*9d15e85bSSebastian Grimberg int i; 61*9d15e85bSSebastian Grimberg 623c1e2affSSebastian Grimberg if (n != NB) { 63*9d15e85bSSebastian Grimberg for (i = 0; i < Q * n - P; i += P) { 64f80f4a74SSebastian Grimberg sB[i + tx] = dB[i + tx]; 65f80f4a74SSebastian Grimberg } 66f80f4a74SSebastian Grimberg } else { 67f80f4a74SSebastian Grimberg #pragma unroll 68*9d15e85bSSebastian Grimberg for (i = 0; i < Q * NB - P; i += P) { 69f80f4a74SSebastian Grimberg sB[i + tx] = dB[i + tx]; 70f80f4a74SSebastian Grimberg } 71f80f4a74SSebastian Grimberg } 72*9d15e85bSSebastian Grimberg if (i + tx < Q * n) { 73*9d15e85bSSebastian Grimberg sB[i + tx] = dB[i + tx]; 74f80f4a74SSebastian Grimberg } 75f80f4a74SSebastian Grimberg } 76f80f4a74SSebastian Grimberg 77f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 783c1e2affSSebastian Grimberg // write C from reg. to global 793c1e2affSSebastian Grimberg // C is (P x NB) 803c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads 813c1e2affSSebastian Grimberg // no sync at the end of the function 823c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 83*9d15e85bSSebastian Grimberg static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) { 843c1e2affSSebastian Grimberg if (n != NB) { 85*9d15e85bSSebastian Grimberg for (int i = 0; i < n; i++) { 86*9d15e85bSSebastian Grimberg dC[i * P + tx] = rC[i]; 873c1e2affSSebastian Grimberg } 883c1e2affSSebastian Grimberg } else { 893c1e2affSSebastian Grimberg #pragma unroll 90*9d15e85bSSebastian Grimberg for (int i = 0; i < NB; i++) { 91*9d15e85bSSebastian Grimberg dC[i * P + tx] = rC[i]; 923c1e2affSSebastian Grimberg } 933c1e2affSSebastian Grimberg } 943c1e2affSSebastian Grimberg } 953c1e2affSSebastian Grimberg 963c1e2affSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 973c1e2affSSebastian Grimberg // multiply C = A x B using 1D threads in P x 1 config 983c1e2affSSebastian Grimberg // A (P x Q) in reg., one row per thread 993c1e2affSSebastian Grimberg // B (Q x NB) in shared memory 100f80f4a74SSebastian Grimberg // C in registers -- one row per thread 101f80f4a74SSebastian Grimberg // no sync at the end of the function 1023c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 103*9d15e85bSSebastian Grimberg static __device__ __inline__ void mul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) { 1043c1e2affSSebastian Grimberg T rB[Q]; 105*9d15e85bSSebastian Grimberg 106f80f4a74SSebastian Grimberg #pragma unroll 1073c1e2affSSebastian Grimberg for (int i = 0; i < NB; i++) { 108f80f4a74SSebastian Grimberg #pragma unroll 109*9d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 110*9d15e85bSSebastian Grimberg rB[j] = sB[i * Q + j]; 111f80f4a74SSebastian Grimberg } 1123c1e2affSSebastian Grimberg rC[i] = 0.0; 113f80f4a74SSebastian Grimberg #pragma unroll 114*9d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 115*9d15e85bSSebastian Grimberg rC[i] += rA[j] * rB[j]; 116f80f4a74SSebastian Grimberg } 117f80f4a74SSebastian Grimberg } 118f80f4a74SSebastian Grimberg } 119f80f4a74SSebastian Grimberg 1203c1e2affSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 1213c1e2affSSebastian Grimberg // multiply C += A x B using 1D threads in P x 1 config 1223c1e2affSSebastian Grimberg // A (P x Q) in reg., one row per thread 1233c1e2affSSebastian Grimberg // B (Q x NB) in shared memory 1243c1e2affSSebastian Grimberg // C in registers -- one row per thread 1253c1e2affSSebastian Grimberg // no sync at the end of the function 1263c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 127*9d15e85bSSebastian Grimberg static __device__ __inline__ void addmul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) { 1283c1e2affSSebastian Grimberg T rB[Q]; 129*9d15e85bSSebastian Grimberg 1303c1e2affSSebastian Grimberg #pragma unroll 1313c1e2affSSebastian Grimberg for (int i = 0; i < NB; i++) { 1323c1e2affSSebastian Grimberg #pragma unroll 133*9d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 134*9d15e85bSSebastian Grimberg rB[j] = sB[i * Q + j]; 1353c1e2affSSebastian Grimberg } 1363c1e2affSSebastian Grimberg #pragma unroll 137*9d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 138*9d15e85bSSebastian Grimberg rC[i] += rA[j] * rB[j]; 1393c1e2affSSebastian Grimberg } 1403c1e2affSSebastian Grimberg } 1413c1e2affSSebastian Grimberg } 142f80f4a74SSebastian Grimberg 143f80f4a74SSebastian Grimberg #endif // CEED_MAGMA_COMMON_NONTENSOR_H 144