15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3f80f4a74SSebastian Grimberg // 4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5f80f4a74SSebastian Grimberg // 6f80f4a74SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7f80f4a74SSebastian Grimberg 83c1e2affSSebastian Grimberg /// @file 93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common non-tensor basis definitions 10*509d4af6SJeremy L Thompson #pragma once 11f80f4a74SSebastian Grimberg 123c1e2affSSebastian Grimberg #include "magma-common-defs.h" 13f80f4a74SSebastian Grimberg 14f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 15f80f4a74SSebastian Grimberg // read A (no-trans) from global to reg. 163c1e2affSSebastian Grimberg // A is (P x Q) 17833aa127SSebastian Grimberg // 2D thread config. with (P x BY) threads 18f80f4a74SSebastian Grimberg // no sync at the end of the function 19833aa127SSebastian Grimberg template <typename T, int P, int Q, int BY> 20833aa127SSebastian Grimberg static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) { 21833aa127SSebastian Grimberg const int tid = ty * P + tx; 22833aa127SSebastian Grimberg int i; 23833aa127SSebastian Grimberg 24f80f4a74SSebastian Grimberg #pragma unroll 25833aa127SSebastian Grimberg for (i = 0; i < P * Q - P * BY; i += P * BY) { 26833aa127SSebastian Grimberg sA[i + tid] = dA[i + tid]; 27833aa127SSebastian Grimberg } 28833aa127SSebastian Grimberg if (i + tid < P * Q) { 29833aa127SSebastian Grimberg sA[i + tid] = dA[i + tid]; 30833aa127SSebastian Grimberg } 31833aa127SSebastian Grimberg __syncthreads(); 32833aa127SSebastian Grimberg 33833aa127SSebastian Grimberg #pragma unroll 34833aa127SSebastian Grimberg for (int j = 0; j < Q; j++) { 351a0eda08SSebastian Grimberg rA[j] = sA[j * P + tx]; 36f80f4a74SSebastian Grimberg } 37f80f4a74SSebastian Grimberg } 38f80f4a74SSebastian Grimberg 39f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 403c1e2affSSebastian Grimberg // read A (trans) from global to reg. 413c1e2affSSebastian Grimberg // A is (P x Q) 429d15e85bSSebastian Grimberg // 2D thread config. with (P x BY) threads 43f80f4a74SSebastian Grimberg // no sync at the end of the function 449d15e85bSSebastian Grimberg template <typename T, int P, int Q, int BY> 459d15e85bSSebastian Grimberg static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) { 46833aa127SSebastian Grimberg const int tid = ty * P + tx; 473c1e2affSSebastian Grimberg int i; 48f80f4a74SSebastian Grimberg 49f80f4a74SSebastian Grimberg #pragma unroll 509d15e85bSSebastian Grimberg for (i = 0; i < P * Q - P * BY; i += P * BY) { 513c1e2affSSebastian Grimberg sA[i + tid] = dA[i + tid]; 52f80f4a74SSebastian Grimberg } 539d15e85bSSebastian Grimberg if (i + tid < P * Q) { 543c1e2affSSebastian Grimberg sA[i + tid] = dA[i + tid]; 55f80f4a74SSebastian Grimberg } 56f80f4a74SSebastian Grimberg __syncthreads(); 57f80f4a74SSebastian Grimberg 58f80f4a74SSebastian Grimberg #pragma unroll 593c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 609d15e85bSSebastian Grimberg rA[j] = sA[tx * Q + j]; 61f80f4a74SSebastian Grimberg } 62f80f4a74SSebastian Grimberg } 63f80f4a74SSebastian Grimberg 64f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 65f80f4a74SSebastian Grimberg // read B from global to shared 663c1e2affSSebastian Grimberg // B is (Q x NB) 673c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads 68f80f4a74SSebastian Grimberg // no sync at the end of the function 693c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 709d15e85bSSebastian Grimberg static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, T *sB) { 719d15e85bSSebastian Grimberg int i; 729d15e85bSSebastian Grimberg 733c1e2affSSebastian Grimberg if (n != NB) { 749d15e85bSSebastian Grimberg for (i = 0; i < Q * n - P; i += P) { 75f80f4a74SSebastian Grimberg sB[i + tx] = dB[i + tx]; 76f80f4a74SSebastian Grimberg } 77f80f4a74SSebastian Grimberg } else { 78f80f4a74SSebastian Grimberg #pragma unroll 799d15e85bSSebastian Grimberg for (i = 0; i < Q * NB - P; i += P) { 80f80f4a74SSebastian Grimberg sB[i + tx] = dB[i + tx]; 81f80f4a74SSebastian Grimberg } 82f80f4a74SSebastian Grimberg } 839d15e85bSSebastian Grimberg if (i + tx < Q * n) { 849d15e85bSSebastian Grimberg sB[i + tx] = dB[i + tx]; 85f80f4a74SSebastian Grimberg } 86f80f4a74SSebastian Grimberg } 87f80f4a74SSebastian Grimberg 88f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 893c1e2affSSebastian Grimberg // write C from reg. to global 903c1e2affSSebastian Grimberg // C is (P x NB) 913c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads 923c1e2affSSebastian Grimberg // no sync at the end of the function 933c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 949d15e85bSSebastian Grimberg static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) { 953c1e2affSSebastian Grimberg if (n != NB) { 969d15e85bSSebastian Grimberg for (int i = 0; i < n; i++) { 979d15e85bSSebastian Grimberg dC[i * P + tx] = rC[i]; 983c1e2affSSebastian Grimberg } 993c1e2affSSebastian Grimberg } else { 1003c1e2affSSebastian Grimberg #pragma unroll 1019d15e85bSSebastian Grimberg for (int i = 0; i < NB; i++) { 1029d15e85bSSebastian Grimberg dC[i * P + tx] = rC[i]; 1033c1e2affSSebastian Grimberg } 1043c1e2affSSebastian Grimberg } 1053c1e2affSSebastian Grimberg } 1063c1e2affSSebastian Grimberg 1073c1e2affSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 1083c1e2affSSebastian Grimberg // multiply C = A x B using 1D threads in P x 1 config 1093c1e2affSSebastian Grimberg // A (P x Q) in reg., one row per thread 1103c1e2affSSebastian Grimberg // B (Q x NB) in shared memory 111f80f4a74SSebastian Grimberg // C in registers -- one row per thread 112f80f4a74SSebastian Grimberg // no sync at the end of the function 1133c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 1149d15e85bSSebastian Grimberg static __device__ __inline__ void mul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) { 1153c1e2affSSebastian Grimberg T rB[Q]; 1169d15e85bSSebastian Grimberg 117f80f4a74SSebastian Grimberg #pragma unroll 1183c1e2affSSebastian Grimberg for (int i = 0; i < NB; i++) { 119f80f4a74SSebastian Grimberg #pragma unroll 1209d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 1219d15e85bSSebastian Grimberg rB[j] = sB[i * Q + j]; 122f80f4a74SSebastian Grimberg } 1233c1e2affSSebastian Grimberg rC[i] = 0.0; 124f80f4a74SSebastian Grimberg #pragma unroll 1259d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 1269d15e85bSSebastian Grimberg rC[i] += rA[j] * rB[j]; 127f80f4a74SSebastian Grimberg } 128f80f4a74SSebastian Grimberg } 129f80f4a74SSebastian Grimberg } 130f80f4a74SSebastian Grimberg 1313c1e2affSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 1323c1e2affSSebastian Grimberg // multiply C += A x B using 1D threads in P x 1 config 1333c1e2affSSebastian Grimberg // A (P x Q) in reg., one row per thread 1343c1e2affSSebastian Grimberg // B (Q x NB) in shared memory 1353c1e2affSSebastian Grimberg // C in registers -- one row per thread 1363c1e2affSSebastian Grimberg // no sync at the end of the function 1373c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 1389d15e85bSSebastian Grimberg static __device__ __inline__ void addmul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) { 1393c1e2affSSebastian Grimberg T rB[Q]; 1409d15e85bSSebastian Grimberg 1413c1e2affSSebastian Grimberg #pragma unroll 1423c1e2affSSebastian Grimberg for (int i = 0; i < NB; i++) { 1433c1e2affSSebastian Grimberg #pragma unroll 1449d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 1459d15e85bSSebastian Grimberg rB[j] = sB[i * Q + j]; 1463c1e2affSSebastian Grimberg } 1473c1e2affSSebastian Grimberg #pragma unroll 1489d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 1499d15e85bSSebastian Grimberg rC[i] += rA[j] * rB[j]; 1503c1e2affSSebastian Grimberg } 1513c1e2affSSebastian Grimberg } 1523c1e2affSSebastian Grimberg } 153