xref: /libCEED/rust/libceed-sys/c-src/backends/xsmm/ceed-xsmm-tensor.c (revision 65e7b5e84b2ea3ca1f657b4d4f5436f51e94c9a5)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-xsmm.h"
18 
19 //------------------------------------------------------------------------------
20 // Get Kernel Index
21 //------------------------------------------------------------------------------
22 int CeedGetXsmmInd_Tensor(CeedInt nelem, CeedInt add, CeedTransposeMode tmode,
23                           CeedInt B, CeedInt C, CeedInt J, CeedInt currdim,
24                           CeedInt dim) {
25   return (nelem == 8 ? 1:0)*4*2*dim + (add ? 1:0)*4*dim +
26          (tmode ? 1:0)*2*dim + (B == J ? 1:0)*dim + currdim;
27 }
28 
29 int CeedGetXsmmInd_NonTensor(CeedInt add, CeedInt P, CeedInt Q, CeedInt B,
30                              CeedInt C, CeedInt J) {
31   return (C == 8 ? 1:0)*4*2 + (add ? 1:0)*4 +
32          (B == P ? (J == Q ? 0:1) : (B == Q ? 2:3));
33 }
34 
35 //------------------------------------------------------------------------------
36 // Tensor Contract C=1
37 //------------------------------------------------------------------------------
38 static int CeedTensorContract_Xsmm_C1(CeedTensorContract contract,
39                                       CeedInt A, CeedInt B, CeedInt C,
40                                       CeedInt J, const CeedScalar *restrict t,
41                                       CeedTransposeMode tmode,
42                                       const CeedInt Add,
43                                       const CeedScalar *restrict u,
44                                       CeedScalar *restrict v) {
45   CeedScalar alpha = 1.0, beta = 1.0;
46   char transu = 'N', transt = 'N';
47   if ((tmode == CEED_TRANSPOSE && C != 1)
48       || (tmode == CEED_NOTRANSPOSE && C == 1))
49     transt = 'T';
50 
51   if (!Add)
52     beta = 0.0;
53 
54   // libXSMM GEMM
55   libxsmm_dgemm(&transt, &transu, &J, &A, &B,
56                 &alpha, &t[0], NULL, &u[0], NULL,
57                 &beta, &v[0], NULL);
58 
59   return 0;
60 }
61 
62 //------------------------------------------------------------------------------
63 // Tensor Contract Apply
64 //------------------------------------------------------------------------------
65 static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A,
66                                         CeedInt B, CeedInt C, CeedInt J,
67                                         const CeedScalar *restrict t,
68                                         CeedTransposeMode tmode,
69                                         const CeedInt add,
70                                         const CeedScalar *restrict u,
71                                         CeedScalar *restrict v) {
72   int ierr;
73   CeedInt blksize = 8, ind, nelem;
74   CeedTensorContract_Xsmm *impl;
75   ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr);
76 
77   // Get nelem and current dim
78   CeedScalar currdim = log(C/blksize) / log(J);
79   if (!(C % blksize) && currdim - (int)currdim < 1e-15) {
80     nelem = blksize;
81   } else {
82     nelem = 1;
83     currdim = log(C) / log(J);
84   }
85 
86   // Get kernel index
87   if (impl->tensorbasis)
88     ind = CeedGetXsmmInd_Tensor(nelem, add, tmode==CEED_TRANSPOSE?1:0, B, C, J,
89                                 (CeedInt)currdim, impl->dim);
90   else
91     ind = CeedGetXsmmInd_NonTensor(add, impl->P, impl->Q, B, C, J);
92 
93   // Run kernel or fallback to default implementation
94   if (C != 1)
95     for (CeedInt a=0; a<A; a++)
96       impl->kernels[ind](&u[a*B*C], &t[0], &v[a*J*C], NULL, NULL, NULL);
97   else
98     CeedTensorContract_Xsmm_C1(contract, A, B, C, J, t, tmode, add, u, v);
99 
100   return 0;
101 }
102 
103 //------------------------------------------------------------------------------
104 // Tensor Contract Destroy
105 //------------------------------------------------------------------------------
106 static int CeedTensorContractDestroy_Xsmm(CeedTensorContract contract) {
107   int ierr;
108   CeedTensorContract_Xsmm *impl;
109   ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr);
110   ierr = CeedFree(&impl->kernels); CeedChk(ierr);
111   ierr = CeedFree(&impl); CeedChk(ierr);
112 
113   return 0;
114 }
115 
116 //------------------------------------------------------------------------------
117 // Tensor Contract Create
118 //------------------------------------------------------------------------------
119 int CeedTensorContractCreate_Xsmm(CeedBasis basis,
120                                   CeedTensorContract contract) {
121   int ierr;
122   Ceed ceed;
123   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr);
124   CeedTensorContract_Xsmm *impl;
125   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
126 
127   // Set up pointers to kernels
128   ierr = CeedBasisGetTensorStatus(basis, &impl->tensorbasis); CeedChk(ierr);
129   if (impl->tensorbasis) {
130     ierr = CeedBasisGetNumNodes1D(basis, &impl->P); CeedChk(ierr);
131     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &impl->Q); CeedChk(ierr);
132     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr);
133     // Set up kernel pointer array
134     impl->numkernels = 2*2*4*impl->dim;
135     ierr = CeedCalloc(impl->numkernels, &impl->kernels); CeedChk(ierr);
136     for (CeedInt nelem = 1; nelem <= 8; nelem+=7)
137       for (CeedInt add = 0; add <= 1; add++)
138         for (CeedInt tmode = 0; tmode <= 1; tmode++)
139           for (CeedInt grad = 0; grad <=1; grad++)
140             for (CeedInt dim = 0; dim < impl->dim; dim++) {
141               const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N');
142               CeedInt B = grad ? impl->Q : (tmode ? impl->Q : impl->P),
143                       J = grad ? impl->Q : (tmode ? impl->P : impl->Q),
144                       C = nelem*CeedIntPow(J, dim);
145               int ind = CeedGetXsmmInd_Tensor(nelem, add, tmode, B, C, J, dim,
146                                               impl->dim);
147               CeedScalar alpha = 1.0, beta = 1.0;
148               if (!add) beta = 0.0;
149               impl->kernels[ind] = libxsmm_dmmdispatch(C, J, B,
150                                    NULL, NULL, NULL, &alpha,
151                                    &beta, &flags, NULL);
152               if (!impl->kernels[ind])
153                 // LCOV_EXCL_START
154                 return CeedError(ceed, 1, "LIBXSMM kernel failed to build.");
155               // LCOV_EXCL_STOP
156             }
157   } else {
158     ierr = CeedBasisGetNumNodes(basis, &impl->P); CeedChk(ierr);
159     ierr = CeedBasisGetNumQuadraturePoints(basis, &impl->Q); CeedChk(ierr);
160     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr);
161     // Set up kernel pointer array
162     impl->numkernels = 4*2*2;
163     ierr = CeedCalloc(impl->numkernels, &impl->kernels); CeedChk(ierr);
164     for (CeedInt nelem = 1; nelem <= 8; nelem+=7)
165       for (CeedInt add = 0; add <= 1; add++)
166         for (CeedInt tmode = 0; tmode <= 1; tmode++)
167           for (CeedInt grad = 1; grad <= impl->dim; grad+=impl->dim-1) {
168             const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N');
169             CeedInt B = tmode ? grad*impl->Q : impl->P,
170                     J = tmode ? impl->P : grad*impl->Q;
171             int ind = CeedGetXsmmInd_NonTensor(add, impl->P, impl->Q, B, nelem,
172                                                J);
173             CeedScalar alpha = 1.0, beta = 1.0;
174             if (!add) beta = 0.0;
175             impl->kernels[ind] = libxsmm_dmmdispatch(nelem, J, B,
176                                  NULL, NULL, NULL, &alpha,
177                                  &beta, &flags, NULL);
178             if (!impl->kernels[ind])
179               // LCOV_EXCL_START
180               return CeedError(ceed, 1, "LIBXSMM kernel failed to build.");
181             // LCOV_EXCL_STOP
182           }
183   }
184   ierr = CeedTensorContractSetData(contract, (void *)&impl); CeedChk(ierr);
185 
186   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
187                                 CeedTensorContractApply_Xsmm); CeedChk(ierr);
188   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
189                                 CeedTensorContractDestroy_Xsmm); CeedChk(ierr);
190 
191   return 0;
192 }
193 //------------------------------------------------------------------------------
194