xref: /libCEED/include/ceed/jit-source/magma/magma-common-nontensor.h (revision 9d15e85b4f78ffb2d2860753c87a3b1789cc3bb6)
1f80f4a74SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3f80f4a74SSebastian Grimberg //
4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5f80f4a74SSebastian Grimberg //
6f80f4a74SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7f80f4a74SSebastian Grimberg 
83c1e2affSSebastian Grimberg /// @file
93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common non-tensor basis definitions
10f80f4a74SSebastian Grimberg #ifndef CEED_MAGMA_COMMON_NONTENSOR_H
11f80f4a74SSebastian Grimberg #define CEED_MAGMA_COMMON_NONTENSOR_H
12f80f4a74SSebastian Grimberg 
133c1e2affSSebastian Grimberg #include "magma-common-defs.h"
14f80f4a74SSebastian Grimberg 
15f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
16f80f4a74SSebastian Grimberg // read A (no-trans) from global to reg.
173c1e2affSSebastian Grimberg // A is (P x Q)
183c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads
19f80f4a74SSebastian Grimberg // no sync at the end of the function
203c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
21*9d15e85bSSebastian Grimberg static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const T *dA, T rA[Q]) {
22f80f4a74SSebastian Grimberg #pragma unroll
23*9d15e85bSSebastian Grimberg   for (int i = 0; i < Q; i++) {
24*9d15e85bSSebastian Grimberg     rA[i] = dA[i * P + tx];
25f80f4a74SSebastian Grimberg   }
26f80f4a74SSebastian Grimberg }
27f80f4a74SSebastian Grimberg 
28f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
293c1e2affSSebastian Grimberg // read A (trans) from global to reg.
303c1e2affSSebastian Grimberg // A is (P x Q)
31*9d15e85bSSebastian Grimberg // 2D thread config. with (P x BY) threads
32f80f4a74SSebastian Grimberg // no sync at the end of the function
33*9d15e85bSSebastian Grimberg template <typename T, int P, int Q, int BY>
34*9d15e85bSSebastian Grimberg static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
35f80f4a74SSebastian Grimberg   const int tid = ty * blockDim.x + tx;
363c1e2affSSebastian Grimberg   int       i;
37f80f4a74SSebastian Grimberg 
38f80f4a74SSebastian Grimberg #pragma unroll
39*9d15e85bSSebastian Grimberg   for (i = 0; i < P * Q - P * BY; i += P * BY) {
403c1e2affSSebastian Grimberg     sA[i + tid] = dA[i + tid];
41f80f4a74SSebastian Grimberg   }
42*9d15e85bSSebastian Grimberg   if (i + tid < P * Q) {
433c1e2affSSebastian Grimberg     sA[i + tid] = dA[i + tid];
44f80f4a74SSebastian Grimberg   }
45f80f4a74SSebastian Grimberg   __syncthreads();
46f80f4a74SSebastian Grimberg 
47f80f4a74SSebastian Grimberg #pragma unroll
483c1e2affSSebastian Grimberg   for (int j = 0; j < Q; j++) {
49*9d15e85bSSebastian Grimberg     rA[j] = sA[tx * Q + j];
50f80f4a74SSebastian Grimberg   }
51f80f4a74SSebastian Grimberg }
52f80f4a74SSebastian Grimberg 
53f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
54f80f4a74SSebastian Grimberg // read B from global to shared
553c1e2affSSebastian Grimberg // B is (Q x NB)
563c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads
57f80f4a74SSebastian Grimberg // no sync at the end of the function
583c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
59*9d15e85bSSebastian Grimberg static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, T *sB) {
60*9d15e85bSSebastian Grimberg   int i;
61*9d15e85bSSebastian Grimberg 
623c1e2affSSebastian Grimberg   if (n != NB) {
63*9d15e85bSSebastian Grimberg     for (i = 0; i < Q * n - P; i += P) {
64f80f4a74SSebastian Grimberg       sB[i + tx] = dB[i + tx];
65f80f4a74SSebastian Grimberg     }
66f80f4a74SSebastian Grimberg   } else {
67f80f4a74SSebastian Grimberg #pragma unroll
68*9d15e85bSSebastian Grimberg     for (i = 0; i < Q * NB - P; i += P) {
69f80f4a74SSebastian Grimberg       sB[i + tx] = dB[i + tx];
70f80f4a74SSebastian Grimberg     }
71f80f4a74SSebastian Grimberg   }
72*9d15e85bSSebastian Grimberg   if (i + tx < Q * n) {
73*9d15e85bSSebastian Grimberg     sB[i + tx] = dB[i + tx];
74f80f4a74SSebastian Grimberg   }
75f80f4a74SSebastian Grimberg }
76f80f4a74SSebastian Grimberg 
77f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
783c1e2affSSebastian Grimberg // write C from reg. to global
793c1e2affSSebastian Grimberg // C is (P x NB)
803c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads
813c1e2affSSebastian Grimberg // no sync at the end of the function
823c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
83*9d15e85bSSebastian Grimberg static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) {
843c1e2affSSebastian Grimberg   if (n != NB) {
85*9d15e85bSSebastian Grimberg     for (int i = 0; i < n; i++) {
86*9d15e85bSSebastian Grimberg       dC[i * P + tx] = rC[i];
873c1e2affSSebastian Grimberg     }
883c1e2affSSebastian Grimberg   } else {
893c1e2affSSebastian Grimberg #pragma unroll
90*9d15e85bSSebastian Grimberg     for (int i = 0; i < NB; i++) {
91*9d15e85bSSebastian Grimberg       dC[i * P + tx] = rC[i];
923c1e2affSSebastian Grimberg     }
933c1e2affSSebastian Grimberg   }
943c1e2affSSebastian Grimberg }
953c1e2affSSebastian Grimberg 
963c1e2affSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
973c1e2affSSebastian Grimberg // multiply C = A x B using 1D threads in P x 1 config
983c1e2affSSebastian Grimberg // A (P x Q)  in reg., one row per thread
993c1e2affSSebastian Grimberg // B (Q x NB) in shared memory
100f80f4a74SSebastian Grimberg // C in registers -- one row per thread
101f80f4a74SSebastian Grimberg // no sync at the end of the function
1023c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
103*9d15e85bSSebastian Grimberg static __device__ __inline__ void mul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
1043c1e2affSSebastian Grimberg   T rB[Q];
105*9d15e85bSSebastian Grimberg 
106f80f4a74SSebastian Grimberg #pragma unroll
1073c1e2affSSebastian Grimberg   for (int i = 0; i < NB; i++) {
108f80f4a74SSebastian Grimberg #pragma unroll
109*9d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
110*9d15e85bSSebastian Grimberg       rB[j] = sB[i * Q + j];
111f80f4a74SSebastian Grimberg     }
1123c1e2affSSebastian Grimberg     rC[i] = 0.0;
113f80f4a74SSebastian Grimberg #pragma unroll
114*9d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
115*9d15e85bSSebastian Grimberg       rC[i] += rA[j] * rB[j];
116f80f4a74SSebastian Grimberg     }
117f80f4a74SSebastian Grimberg   }
118f80f4a74SSebastian Grimberg }
119f80f4a74SSebastian Grimberg 
1203c1e2affSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
1213c1e2affSSebastian Grimberg // multiply C += A x B using 1D threads in P x 1 config
1223c1e2affSSebastian Grimberg // A (P x Q)  in reg., one row per thread
1233c1e2affSSebastian Grimberg // B (Q x NB) in shared memory
1243c1e2affSSebastian Grimberg // C in registers -- one row per thread
1253c1e2affSSebastian Grimberg // no sync at the end of the function
1263c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
127*9d15e85bSSebastian Grimberg static __device__ __inline__ void addmul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
1283c1e2affSSebastian Grimberg   T rB[Q];
129*9d15e85bSSebastian Grimberg 
1303c1e2affSSebastian Grimberg #pragma unroll
1313c1e2affSSebastian Grimberg   for (int i = 0; i < NB; i++) {
1323c1e2affSSebastian Grimberg #pragma unroll
133*9d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
134*9d15e85bSSebastian Grimberg       rB[j] = sB[i * Q + j];
1353c1e2affSSebastian Grimberg     }
1363c1e2affSSebastian Grimberg #pragma unroll
137*9d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
138*9d15e85bSSebastian Grimberg       rC[i] += rA[j] * rB[j];
1393c1e2affSSebastian Grimberg     }
1403c1e2affSSebastian Grimberg   }
1413c1e2affSSebastian Grimberg }
142f80f4a74SSebastian Grimberg 
143f80f4a74SSebastian Grimberg #endif  // CEED_MAGMA_COMMON_NONTENSOR_H
144