xref: /libCEED/include/ceed/jit-source/magma/magma-common-nontensor.h (revision f80f4a748154eed4bc661c135f695b92b1bc45b9)
1*f80f4a74SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*f80f4a74SSebastian Grimberg //
4*f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5*f80f4a74SSebastian Grimberg //
6*f80f4a74SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7*f80f4a74SSebastian Grimberg 
8*f80f4a74SSebastian Grimberg #ifndef CEED_MAGMA_COMMON_NONTENSOR_H
9*f80f4a74SSebastian Grimberg #define CEED_MAGMA_COMMON_NONTENSOR_H
10*f80f4a74SSebastian Grimberg 
11*f80f4a74SSebastian Grimberg #define NONTENSOR_MAX_THREADS (128)
12*f80f4a74SSebastian Grimberg 
13*f80f4a74SSebastian Grimberg #ifndef MAGMA_DEVICE_SHARED
14*f80f4a74SSebastian Grimberg #define MAGMA_DEVICE_SHARED
15*f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP
16*f80f4a74SSebastian Grimberg #define MAGMA_DEVICE_SHARED(type, name) HIP_DYNAMIC_SHARED(type, name)
17*f80f4a74SSebastian Grimberg #else
18*f80f4a74SSebastian Grimberg #define MAGMA_DEVICE_SHARED(type, name) extern __shared__ type name[];
19*f80f4a74SSebastian Grimberg #endif  // CEED_MAGMA_USE_HIP
20*f80f4a74SSebastian Grimberg #endif  // MAGMA_DEVICE_SHARED
21*f80f4a74SSebastian Grimberg 
22*f80f4a74SSebastian Grimberg #define MAGMA_NONTENSOR_BASIS_NTCOL(N) (MAGMA_MAX(1, (NONTENSOR_MAX_THREADS / (N))))
23*f80f4a74SSebastian Grimberg 
24*f80f4a74SSebastian Grimberg #define dA(i, j) dA[(j)*ldda + (i)]
25*f80f4a74SSebastian Grimberg #define sA(i, j) sA[(j)*slda + (i)]
26*f80f4a74SSebastian Grimberg #define dB(i, j) dB[(j)*lddb + (i)]
27*f80f4a74SSebastian Grimberg #define sB(i, j) sB[(j)*sldb + (i)]
28*f80f4a74SSebastian Grimberg 
29*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
30*f80f4a74SSebastian Grimberg // read C from global to reg.
31*f80f4a74SSebastian Grimberg // C is (P_ x NB_)
32*f80f4a74SSebastian Grimberg // 1D thread config. with (Mx1) threads
33*f80f4a74SSebastian Grimberg // no sync at the end of the function
34*f80f4a74SSebastian Grimberg template <typename T, int P_, int NB_, int Q_>
35*f80f4a74SSebastian Grimberg static __device__ __inline__ void read_C_g2r_1D_nosync(const int tx, const int n, T *dC, int lddc, const T &beta, T rC[NB_]) {
36*f80f4a74SSebastian Grimberg   if (n != NB_) {
37*f80f4a74SSebastian Grimberg #pragma unroll
38*f80f4a74SSebastian Grimberg     for (int j = 0; j < NB_; j++) {
39*f80f4a74SSebastian Grimberg       rC[j] = (j < n) ? beta * dC[j * lddc + tx] : 0;
40*f80f4a74SSebastian Grimberg     }
41*f80f4a74SSebastian Grimberg   } else {
42*f80f4a74SSebastian Grimberg #pragma unroll
43*f80f4a74SSebastian Grimberg     for (int j = 0; j < NB_; j++) {
44*f80f4a74SSebastian Grimberg       rC[j] = beta * dC[j * lddc + tx];
45*f80f4a74SSebastian Grimberg     }
46*f80f4a74SSebastian Grimberg   }
47*f80f4a74SSebastian Grimberg }
48*f80f4a74SSebastian Grimberg 
49*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
50*f80f4a74SSebastian Grimberg // write C from reg. to global
51*f80f4a74SSebastian Grimberg // C is (P_ x NB_)
52*f80f4a74SSebastian Grimberg // 1D thread config. with (Mx1) threads
53*f80f4a74SSebastian Grimberg // no sync at the end of the function
54*f80f4a74SSebastian Grimberg template <typename T, int P_, int NB_, int Q_>
55*f80f4a74SSebastian Grimberg static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB_], T *dC, int lddc) {
56*f80f4a74SSebastian Grimberg   if (n != NB_) {
57*f80f4a74SSebastian Grimberg #pragma unroll
58*f80f4a74SSebastian Grimberg     for (int j = 0; j < NB_; j++) {
59*f80f4a74SSebastian Grimberg       if (j < n) {
60*f80f4a74SSebastian Grimberg         dC[j * lddc + tx] = rC[j];
61*f80f4a74SSebastian Grimberg       }
62*f80f4a74SSebastian Grimberg     }
63*f80f4a74SSebastian Grimberg   } else {
64*f80f4a74SSebastian Grimberg #pragma unroll
65*f80f4a74SSebastian Grimberg     for (int j = 0; j < NB_; j++) {
66*f80f4a74SSebastian Grimberg       dC[j * lddc + tx] = rC[j];
67*f80f4a74SSebastian Grimberg     }
68*f80f4a74SSebastian Grimberg   }
69*f80f4a74SSebastian Grimberg }
70*f80f4a74SSebastian Grimberg 
71*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
72*f80f4a74SSebastian Grimberg // read A (no-trans) from global to reg.
73*f80f4a74SSebastian Grimberg // A is (P_ x Q_)
74*f80f4a74SSebastian Grimberg // 1D thread config. with (Mx1) threads
75*f80f4a74SSebastian Grimberg // no sync at the end of the function
76*f80f4a74SSebastian Grimberg template <typename T, int P_, int NB_, int Q_>
77*f80f4a74SSebastian Grimberg static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const T *dA, int ldda, T *sA, int slda, T rA[Q_]) {
78*f80f4a74SSebastian Grimberg #pragma unroll
79*f80f4a74SSebastian Grimberg   for (int j = 0; j < Q_; j++) {
80*f80f4a74SSebastian Grimberg     rA[j] = dA(tx, j);
81*f80f4a74SSebastian Grimberg   }
82*f80f4a74SSebastian Grimberg }
83*f80f4a74SSebastian Grimberg 
84*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
85*f80f4a74SSebastian Grimberg // read A (no-trans) from global to reg.
86*f80f4a74SSebastian Grimberg // A is (P_ x Q_)
87*f80f4a74SSebastian Grimberg // 1D thread config. with (Mx1) threads
88*f80f4a74SSebastian Grimberg // no sync at the end of the function
89*f80f4a74SSebastian Grimberg template <typename T, int P_, int NB_, int Q_>
90*f80f4a74SSebastian Grimberg static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, int ldda, T *sA, int slda, T rA[Q_]) {
91*f80f4a74SSebastian Grimberg   int       ix  = 0;
92*f80f4a74SSebastian Grimberg   const int nTH = P_ * MAGMA_NONTENSOR_BASIS_NTCOL(P_);
93*f80f4a74SSebastian Grimberg   const int tid = ty * blockDim.x + tx;
94*f80f4a74SSebastian Grimberg 
95*f80f4a74SSebastian Grimberg #pragma unroll
96*f80f4a74SSebastian Grimberg   for (ix = 0; ix < (Q_ * P_) - nTH; ix += nTH) {
97*f80f4a74SSebastian Grimberg     sA[ix + tid] = dA[ix + tid];
98*f80f4a74SSebastian Grimberg   }
99*f80f4a74SSebastian Grimberg 
100*f80f4a74SSebastian Grimberg   if (tid < ((Q_ * P_) - ix)) {
101*f80f4a74SSebastian Grimberg     sA[ix + tid] = dA[ix + tid];
102*f80f4a74SSebastian Grimberg   }
103*f80f4a74SSebastian Grimberg   __syncthreads();
104*f80f4a74SSebastian Grimberg 
105*f80f4a74SSebastian Grimberg #pragma unroll
106*f80f4a74SSebastian Grimberg   for (int j = 0; j < Q_; j++) {
107*f80f4a74SSebastian Grimberg     rA[j] = sA[tx * slda + j];
108*f80f4a74SSebastian Grimberg   }
109*f80f4a74SSebastian Grimberg }
110*f80f4a74SSebastian Grimberg 
111*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
112*f80f4a74SSebastian Grimberg // read B from global to shared
113*f80f4a74SSebastian Grimberg // B is (Q_ x NB_)
114*f80f4a74SSebastian Grimberg // 1D thread config. with (Mx1) threads
115*f80f4a74SSebastian Grimberg // no sync at the end of the function
116*f80f4a74SSebastian Grimberg template <typename T, int P_, int NB_, int Q_>
117*f80f4a74SSebastian Grimberg static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, int n, const T *dB, int lddb, T *sB, int sldb) {
118*f80f4a74SSebastian Grimberg   if (n != NB_) {
119*f80f4a74SSebastian Grimberg     for (int i = 0; i < (Q_ * n) - P_; i += P_) {
120*f80f4a74SSebastian Grimberg       sB[i + tx] = dB[i + tx];
121*f80f4a74SSebastian Grimberg     }
122*f80f4a74SSebastian Grimberg   } else {
123*f80f4a74SSebastian Grimberg #pragma unroll
124*f80f4a74SSebastian Grimberg     for (int i = 0; i < (Q_ * NB_) - P_; i += P_) {
125*f80f4a74SSebastian Grimberg       sB[i + tx] = dB[i + tx];
126*f80f4a74SSebastian Grimberg     }
127*f80f4a74SSebastian Grimberg   }
128*f80f4a74SSebastian Grimberg 
129*f80f4a74SSebastian Grimberg   // cleanup for B
130*f80f4a74SSebastian Grimberg   const int stride = MAGMA_ROUNDUP(Q_ * n - P_, P_);
131*f80f4a74SSebastian Grimberg   if (tx < (Q_ * n) - stride) {
132*f80f4a74SSebastian Grimberg     sB[stride + tx] = dB[stride + tx];
133*f80f4a74SSebastian Grimberg   }
134*f80f4a74SSebastian Grimberg }
135*f80f4a74SSebastian Grimberg 
136*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
137*f80f4a74SSebastian Grimberg // multiply C = AxB using 1D threads in Mx1 config
138*f80f4a74SSebastian Grimberg // A (MxK)  in reg., one row per thread
139*f80f4a74SSebastian Grimberg // B (KxNB) in shared memory
140*f80f4a74SSebastian Grimberg // C in registers -- one row per thread
141*f80f4a74SSebastian Grimberg // no sync at the end of the function
142*f80f4a74SSebastian Grimberg template <typename T, int P_, int NB_, int Q_>
143*f80f4a74SSebastian Grimberg static __device__ __inline__ void mul_rAsBrC_1D_nosync(const int tx, const T &alpha, T rA[Q_], T *sB, int sldb, T rC[NB_]) {
144*f80f4a74SSebastian Grimberg   T rB[Q_] = {0};
145*f80f4a74SSebastian Grimberg #pragma unroll
146*f80f4a74SSebastian Grimberg   for (int i = 0; i < NB_; i++) {
147*f80f4a74SSebastian Grimberg #pragma unroll
148*f80f4a74SSebastian Grimberg     for (int k = 0; k < Q_; k++) {
149*f80f4a74SSebastian Grimberg       rB[k] = sB[i * sldb + k];
150*f80f4a74SSebastian Grimberg     }
151*f80f4a74SSebastian Grimberg 
152*f80f4a74SSebastian Grimberg     T rTmp = 0;
153*f80f4a74SSebastian Grimberg #pragma unroll
154*f80f4a74SSebastian Grimberg     for (int k = 0; k < Q_; k++) {
155*f80f4a74SSebastian Grimberg       rTmp += rA[k] * rB[k];
156*f80f4a74SSebastian Grimberg     }
157*f80f4a74SSebastian Grimberg     rC[i] += alpha * rTmp;
158*f80f4a74SSebastian Grimberg   }
159*f80f4a74SSebastian Grimberg }
160*f80f4a74SSebastian Grimberg 
161*f80f4a74SSebastian Grimberg #undef dA
162*f80f4a74SSebastian Grimberg #undef sA
163*f80f4a74SSebastian Grimberg #undef dB
164*f80f4a74SSebastian Grimberg #undef sB
165*f80f4a74SSebastian Grimberg 
166*f80f4a74SSebastian Grimberg #endif  // CEED_MAGMA_COMMON_NONTENSOR_H
167