xref: /libCEED/include/ceed/jit-source/magma/magma-common-nontensor.h (revision 509d4af65d23546c690c9766d8b29e47dc3b3afb)
15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3f80f4a74SSebastian Grimberg //
4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5f80f4a74SSebastian Grimberg //
6f80f4a74SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7f80f4a74SSebastian Grimberg 
83c1e2affSSebastian Grimberg /// @file
93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common non-tensor basis definitions
10*509d4af6SJeremy L Thompson #pragma once
11f80f4a74SSebastian Grimberg 
123c1e2affSSebastian Grimberg #include "magma-common-defs.h"
13f80f4a74SSebastian Grimberg 
14f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
15f80f4a74SSebastian Grimberg // read A (no-trans) from global to reg.
163c1e2affSSebastian Grimberg // A is (P x Q)
17833aa127SSebastian Grimberg // 2D thread config. with (P x BY) threads
18f80f4a74SSebastian Grimberg // no sync at the end of the function
19833aa127SSebastian Grimberg template <typename T, int P, int Q, int BY>
20833aa127SSebastian Grimberg static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
21833aa127SSebastian Grimberg   const int tid = ty * P + tx;
22833aa127SSebastian Grimberg   int       i;
23833aa127SSebastian Grimberg 
24f80f4a74SSebastian Grimberg #pragma unroll
25833aa127SSebastian Grimberg   for (i = 0; i < P * Q - P * BY; i += P * BY) {
26833aa127SSebastian Grimberg     sA[i + tid] = dA[i + tid];
27833aa127SSebastian Grimberg   }
28833aa127SSebastian Grimberg   if (i + tid < P * Q) {
29833aa127SSebastian Grimberg     sA[i + tid] = dA[i + tid];
30833aa127SSebastian Grimberg   }
31833aa127SSebastian Grimberg   __syncthreads();
32833aa127SSebastian Grimberg 
33833aa127SSebastian Grimberg #pragma unroll
34833aa127SSebastian Grimberg   for (int j = 0; j < Q; j++) {
351a0eda08SSebastian Grimberg     rA[j] = sA[j * P + tx];
36f80f4a74SSebastian Grimberg   }
37f80f4a74SSebastian Grimberg }
38f80f4a74SSebastian Grimberg 
39f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
403c1e2affSSebastian Grimberg // read A (trans) from global to reg.
413c1e2affSSebastian Grimberg // A is (P x Q)
429d15e85bSSebastian Grimberg // 2D thread config. with (P x BY) threads
43f80f4a74SSebastian Grimberg // no sync at the end of the function
449d15e85bSSebastian Grimberg template <typename T, int P, int Q, int BY>
459d15e85bSSebastian Grimberg static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
46833aa127SSebastian Grimberg   const int tid = ty * P + tx;
473c1e2affSSebastian Grimberg   int       i;
48f80f4a74SSebastian Grimberg 
49f80f4a74SSebastian Grimberg #pragma unroll
509d15e85bSSebastian Grimberg   for (i = 0; i < P * Q - P * BY; i += P * BY) {
513c1e2affSSebastian Grimberg     sA[i + tid] = dA[i + tid];
52f80f4a74SSebastian Grimberg   }
539d15e85bSSebastian Grimberg   if (i + tid < P * Q) {
543c1e2affSSebastian Grimberg     sA[i + tid] = dA[i + tid];
55f80f4a74SSebastian Grimberg   }
56f80f4a74SSebastian Grimberg   __syncthreads();
57f80f4a74SSebastian Grimberg 
58f80f4a74SSebastian Grimberg #pragma unroll
593c1e2affSSebastian Grimberg   for (int j = 0; j < Q; j++) {
609d15e85bSSebastian Grimberg     rA[j] = sA[tx * Q + j];
61f80f4a74SSebastian Grimberg   }
62f80f4a74SSebastian Grimberg }
63f80f4a74SSebastian Grimberg 
64f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
65f80f4a74SSebastian Grimberg // read B from global to shared
663c1e2affSSebastian Grimberg // B is (Q x NB)
673c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads
68f80f4a74SSebastian Grimberg // no sync at the end of the function
693c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
709d15e85bSSebastian Grimberg static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, T *sB) {
719d15e85bSSebastian Grimberg   int i;
729d15e85bSSebastian Grimberg 
733c1e2affSSebastian Grimberg   if (n != NB) {
749d15e85bSSebastian Grimberg     for (i = 0; i < Q * n - P; i += P) {
75f80f4a74SSebastian Grimberg       sB[i + tx] = dB[i + tx];
76f80f4a74SSebastian Grimberg     }
77f80f4a74SSebastian Grimberg   } else {
78f80f4a74SSebastian Grimberg #pragma unroll
799d15e85bSSebastian Grimberg     for (i = 0; i < Q * NB - P; i += P) {
80f80f4a74SSebastian Grimberg       sB[i + tx] = dB[i + tx];
81f80f4a74SSebastian Grimberg     }
82f80f4a74SSebastian Grimberg   }
839d15e85bSSebastian Grimberg   if (i + tx < Q * n) {
849d15e85bSSebastian Grimberg     sB[i + tx] = dB[i + tx];
85f80f4a74SSebastian Grimberg   }
86f80f4a74SSebastian Grimberg }
87f80f4a74SSebastian Grimberg 
88f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
893c1e2affSSebastian Grimberg // write C from reg. to global
903c1e2affSSebastian Grimberg // C is (P x NB)
913c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads
923c1e2affSSebastian Grimberg // no sync at the end of the function
933c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
949d15e85bSSebastian Grimberg static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) {
953c1e2affSSebastian Grimberg   if (n != NB) {
969d15e85bSSebastian Grimberg     for (int i = 0; i < n; i++) {
979d15e85bSSebastian Grimberg       dC[i * P + tx] = rC[i];
983c1e2affSSebastian Grimberg     }
993c1e2affSSebastian Grimberg   } else {
1003c1e2affSSebastian Grimberg #pragma unroll
1019d15e85bSSebastian Grimberg     for (int i = 0; i < NB; i++) {
1029d15e85bSSebastian Grimberg       dC[i * P + tx] = rC[i];
1033c1e2affSSebastian Grimberg     }
1043c1e2affSSebastian Grimberg   }
1053c1e2affSSebastian Grimberg }
1063c1e2affSSebastian Grimberg 
1073c1e2affSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
1083c1e2affSSebastian Grimberg // multiply C = A x B using 1D threads in P x 1 config
1093c1e2affSSebastian Grimberg // A (P x Q)  in reg., one row per thread
1103c1e2affSSebastian Grimberg // B (Q x NB) in shared memory
111f80f4a74SSebastian Grimberg // C in registers -- one row per thread
112f80f4a74SSebastian Grimberg // no sync at the end of the function
1133c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
1149d15e85bSSebastian Grimberg static __device__ __inline__ void mul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
1153c1e2affSSebastian Grimberg   T rB[Q];
1169d15e85bSSebastian Grimberg 
117f80f4a74SSebastian Grimberg #pragma unroll
1183c1e2affSSebastian Grimberg   for (int i = 0; i < NB; i++) {
119f80f4a74SSebastian Grimberg #pragma unroll
1209d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
1219d15e85bSSebastian Grimberg       rB[j] = sB[i * Q + j];
122f80f4a74SSebastian Grimberg     }
1233c1e2affSSebastian Grimberg     rC[i] = 0.0;
124f80f4a74SSebastian Grimberg #pragma unroll
1259d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
1269d15e85bSSebastian Grimberg       rC[i] += rA[j] * rB[j];
127f80f4a74SSebastian Grimberg     }
128f80f4a74SSebastian Grimberg   }
129f80f4a74SSebastian Grimberg }
130f80f4a74SSebastian Grimberg 
1313c1e2affSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
1323c1e2affSSebastian Grimberg // multiply C += A x B using 1D threads in P x 1 config
1333c1e2affSSebastian Grimberg // A (P x Q)  in reg., one row per thread
1343c1e2affSSebastian Grimberg // B (Q x NB) in shared memory
1353c1e2affSSebastian Grimberg // C in registers -- one row per thread
1363c1e2affSSebastian Grimberg // no sync at the end of the function
1373c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
1389d15e85bSSebastian Grimberg static __device__ __inline__ void addmul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
1393c1e2affSSebastian Grimberg   T rB[Q];
1409d15e85bSSebastian Grimberg 
1413c1e2affSSebastian Grimberg #pragma unroll
1423c1e2affSSebastian Grimberg   for (int i = 0; i < NB; i++) {
1433c1e2affSSebastian Grimberg #pragma unroll
1449d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
1459d15e85bSSebastian Grimberg       rB[j] = sB[i * Q + j];
1463c1e2affSSebastian Grimberg     }
1473c1e2affSSebastian Grimberg #pragma unroll
1489d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
1499d15e85bSSebastian Grimberg       rC[i] += rA[j] * rB[j];
1503c1e2affSSebastian Grimberg     }
1513c1e2affSSebastian Grimberg   }
1523c1e2affSSebastian Grimberg }
153