xref: /libCEED/include/ceed/jit-source/magma/magma-common-nontensor.h (revision 3c1e2aff6d111c93ca8797996aaf987f66b08927)
1f80f4a74SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3f80f4a74SSebastian Grimberg //
4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5f80f4a74SSebastian Grimberg //
6f80f4a74SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7f80f4a74SSebastian Grimberg 
8*3c1e2affSSebastian Grimberg /// @file
9*3c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common non-tensor basis definitions
10f80f4a74SSebastian Grimberg #ifndef CEED_MAGMA_COMMON_NONTENSOR_H
11f80f4a74SSebastian Grimberg #define CEED_MAGMA_COMMON_NONTENSOR_H
12f80f4a74SSebastian Grimberg 
13*3c1e2affSSebastian Grimberg #include "magma-common-defs.h"
14f80f4a74SSebastian Grimberg 
15f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
16f80f4a74SSebastian Grimberg // read A (no-trans) from global to reg.
17*3c1e2affSSebastian Grimberg // A is (P x Q)
18*3c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads
19f80f4a74SSebastian Grimberg // no sync at the end of the function
20*3c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
21*3c1e2affSSebastian Grimberg static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const T *dA, int ldda, T *sA, int slda, T rA[Q]) {
22f80f4a74SSebastian Grimberg #pragma unroll
23*3c1e2affSSebastian Grimberg   for (int j = 0; j < Q; j++) {
24*3c1e2affSSebastian Grimberg     rA[j] = dA[j * ldda + tx];
25f80f4a74SSebastian Grimberg   }
26f80f4a74SSebastian Grimberg }
27f80f4a74SSebastian Grimberg 
28f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
29*3c1e2affSSebastian Grimberg // read A (trans) from global to reg.
30*3c1e2affSSebastian Grimberg // A is (P x Q)
31*3c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads
32f80f4a74SSebastian Grimberg // no sync at the end of the function
33*3c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
34*3c1e2affSSebastian Grimberg static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, int ldda, T *sA, int slda, T rA[Q]) {
35*3c1e2affSSebastian Grimberg   const int nTH = MAGMA_BASIS_BOUNDS(P, MAGMA_MAXTHREADS_1D);
36f80f4a74SSebastian Grimberg   const int tid = ty * blockDim.x + tx;
37*3c1e2affSSebastian Grimberg   int       i;
38f80f4a74SSebastian Grimberg 
39f80f4a74SSebastian Grimberg #pragma unroll
40*3c1e2affSSebastian Grimberg   for (i = 0; i < (Q * P) - nTH; i += nTH) {
41*3c1e2affSSebastian Grimberg     sA[i + tid] = dA[i + tid];
42f80f4a74SSebastian Grimberg   }
43*3c1e2affSSebastian Grimberg   if (tid < ((Q * P) - i)) {
44*3c1e2affSSebastian Grimberg     sA[i + tid] = dA[i + tid];
45f80f4a74SSebastian Grimberg   }
46f80f4a74SSebastian Grimberg   __syncthreads();
47f80f4a74SSebastian Grimberg 
48f80f4a74SSebastian Grimberg #pragma unroll
49*3c1e2affSSebastian Grimberg   for (int j = 0; j < Q; j++) {
50f80f4a74SSebastian Grimberg     rA[j] = sA[tx * slda + j];
51f80f4a74SSebastian Grimberg   }
52f80f4a74SSebastian Grimberg }
53f80f4a74SSebastian Grimberg 
54f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
55f80f4a74SSebastian Grimberg // read B from global to shared
56*3c1e2affSSebastian Grimberg // B is (Q x NB)
57*3c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads
58f80f4a74SSebastian Grimberg // no sync at the end of the function
59*3c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
60*3c1e2affSSebastian Grimberg static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, int lddb, T *sB, int sldb) {
61*3c1e2affSSebastian Grimberg   if (n != NB) {
62*3c1e2affSSebastian Grimberg     for (int i = 0; i < (Q * n) - P; i += P) {
63f80f4a74SSebastian Grimberg       sB[i + tx] = dB[i + tx];
64f80f4a74SSebastian Grimberg     }
65f80f4a74SSebastian Grimberg   } else {
66f80f4a74SSebastian Grimberg #pragma unroll
67*3c1e2affSSebastian Grimberg     for (int i = 0; i < (Q * NB) - P; i += P) {
68f80f4a74SSebastian Grimberg       sB[i + tx] = dB[i + tx];
69f80f4a74SSebastian Grimberg     }
70f80f4a74SSebastian Grimberg   }
71f80f4a74SSebastian Grimberg 
72f80f4a74SSebastian Grimberg   // cleanup for B
73*3c1e2affSSebastian Grimberg   const int stride = MAGMA_ROUNDUP(Q * n - P, P);
74*3c1e2affSSebastian Grimberg   if (tx < (Q * n) - stride) {
75f80f4a74SSebastian Grimberg     sB[stride + tx] = dB[stride + tx];
76f80f4a74SSebastian Grimberg   }
77f80f4a74SSebastian Grimberg }
78f80f4a74SSebastian Grimberg 
79f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
80*3c1e2affSSebastian Grimberg // write C from reg. to global
81*3c1e2affSSebastian Grimberg // C is (P x NB)
82*3c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads
83*3c1e2affSSebastian Grimberg // no sync at the end of the function
84*3c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
85*3c1e2affSSebastian Grimberg static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC, int lddc) {
86*3c1e2affSSebastian Grimberg   if (n != NB) {
87*3c1e2affSSebastian Grimberg #pragma unroll
88*3c1e2affSSebastian Grimberg     for (int j = 0; j < NB; j++) {
89*3c1e2affSSebastian Grimberg       if (j < n) {
90*3c1e2affSSebastian Grimberg         dC[j * lddc + tx] = rC[j];
91*3c1e2affSSebastian Grimberg       }
92*3c1e2affSSebastian Grimberg     }
93*3c1e2affSSebastian Grimberg   } else {
94*3c1e2affSSebastian Grimberg #pragma unroll
95*3c1e2affSSebastian Grimberg     for (int j = 0; j < NB; j++) {
96*3c1e2affSSebastian Grimberg       dC[j * lddc + tx] = rC[j];
97*3c1e2affSSebastian Grimberg     }
98*3c1e2affSSebastian Grimberg   }
99*3c1e2affSSebastian Grimberg }
100*3c1e2affSSebastian Grimberg 
101*3c1e2affSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
102*3c1e2affSSebastian Grimberg // multiply C = A x B using 1D threads in P x 1 config
103*3c1e2affSSebastian Grimberg // A (P x Q)  in reg., one row per thread
104*3c1e2affSSebastian Grimberg // B (Q x NB) in shared memory
105f80f4a74SSebastian Grimberg // C in registers -- one row per thread
106f80f4a74SSebastian Grimberg // no sync at the end of the function
107*3c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
108*3c1e2affSSebastian Grimberg static __device__ __inline__ void mul_rAsBrC_1D_nosync(const int tx, T rA[Q], T *sB, int sldb, T rC[NB]) {
109*3c1e2affSSebastian Grimberg   T rB[Q];
110f80f4a74SSebastian Grimberg #pragma unroll
111*3c1e2affSSebastian Grimberg   for (int i = 0; i < NB; i++) {
112f80f4a74SSebastian Grimberg #pragma unroll
113*3c1e2affSSebastian Grimberg     for (int k = 0; k < Q; k++) {
114f80f4a74SSebastian Grimberg       rB[k] = sB[i * sldb + k];
115f80f4a74SSebastian Grimberg     }
116*3c1e2affSSebastian Grimberg     rC[i] = 0.0;
117f80f4a74SSebastian Grimberg #pragma unroll
118*3c1e2affSSebastian Grimberg     for (int k = 0; k < Q; k++) {
119*3c1e2affSSebastian Grimberg       rC[i] += rA[k] * rB[k];
120f80f4a74SSebastian Grimberg     }
121f80f4a74SSebastian Grimberg   }
122f80f4a74SSebastian Grimberg }
123f80f4a74SSebastian Grimberg 
124*3c1e2affSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
125*3c1e2affSSebastian Grimberg // multiply C += A x B using 1D threads in P x 1 config
126*3c1e2affSSebastian Grimberg // A (P x Q)  in reg., one row per thread
127*3c1e2affSSebastian Grimberg // B (Q x NB) in shared memory
128*3c1e2affSSebastian Grimberg // C in registers -- one row per thread
129*3c1e2affSSebastian Grimberg // no sync at the end of the function
130*3c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
131*3c1e2affSSebastian Grimberg static __device__ __inline__ void addmul_rAsBrC_1D_nosync(const int tx, T rA[Q], T *sB, int sldb, T rC[NB]) {
132*3c1e2affSSebastian Grimberg   T rB[Q];
133*3c1e2affSSebastian Grimberg #pragma unroll
134*3c1e2affSSebastian Grimberg   for (int i = 0; i < NB; i++) {
135*3c1e2affSSebastian Grimberg #pragma unroll
136*3c1e2affSSebastian Grimberg     for (int k = 0; k < Q; k++) {
137*3c1e2affSSebastian Grimberg       rB[k] = sB[i * sldb + k];
138*3c1e2affSSebastian Grimberg     }
139*3c1e2affSSebastian Grimberg #pragma unroll
140*3c1e2affSSebastian Grimberg     for (int k = 0; k < Q; k++) {
141*3c1e2affSSebastian Grimberg       rC[i] += rA[k] * rB[k];
142*3c1e2affSSebastian Grimberg     }
143*3c1e2affSSebastian Grimberg   }
144*3c1e2affSSebastian Grimberg }
145f80f4a74SSebastian Grimberg 
146f80f4a74SSebastian Grimberg #endif  // CEED_MAGMA_COMMON_NONTENSOR_H
147