xref: /libCEED/rust/libceed-sys/c-src/backends/xsmm/ceed-xsmm-tensor.c (revision 3d0fd664d5609ed64f2fd4536a0c4cacff290615)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <math.h>
18 #include <string.h>
19 #include "ceed-xsmm.h"
20 
21 // Utility functions for index in pointer array
22 int CeedGetXsmmInd_Tensor(CeedInt nelem, CeedInt add, CeedTransposeMode tmode,
23                           CeedInt B, CeedInt C, CeedInt J, CeedInt currdim,
24                           CeedInt dim) {
25   return (nelem == 8 ? 1:0)*4*2*dim + (add ? 1:0)*4*dim +
26          (tmode ? 1:0)*2*dim + (B == J ? 1:0)*dim + currdim;
27 }
28 
29 int CeedGetXsmmInd_NonTensor(CeedInt add, CeedInt P, CeedInt Q, CeedInt B,
30                              CeedInt C, CeedInt J) {
31   return (C == 8 ? 1:0)*4*2 + (add ? 1:0)*4 +
32          (B == P ? (J == Q ? 0:1) : (B == Q ? 2:3));
33 }
34 
35 // Default Tensor Contact
36 static int CeedTensorContract_Xsmm_Default(CeedTensorContract contract,
37     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
38     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
39     CeedScalar *restrict v) {
40   CeedScalar alpha = 1.0, beta = 1.0;
41   char transu = 'N', transt = 'N';
42   if ((tmode == CEED_TRANSPOSE && C != 1)
43       || (tmode == CEED_NOTRANSPOSE && C == 1))
44     transt = 'T';
45 
46   if (!Add)
47     beta = 0.0;
48 
49   if (C != 1)
50     for (CeedInt a=0; a<A; a++)
51       // libXSMM GEMM
52       libxsmm_dgemm(&transu, &transt, &C, &J, &B,
53                     &alpha, &u[a*B*C], NULL, &t[0], NULL,
54                     &beta, &v[a*J*C], NULL);
55   else
56     // libXSMM GEMM
57     libxsmm_dgemm(&transt, &transu, &J, &A, &B,
58                   &alpha, &t[0], NULL, &u[0], NULL,
59                   &beta, &v[0], NULL);
60 
61   return 0;
62 }
63 
64 // Switch for Tensor Contract
65 static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A,
66                                         CeedInt B, CeedInt C, CeedInt J,
67                                         const CeedScalar *restrict t,
68                                         CeedTransposeMode tmode,
69                                         const CeedInt add,
70                                         const CeedScalar *restrict u,
71                                         CeedScalar *restrict v) {
72   int ierr;
73   CeedInt blksize = 8, ind, nelem;
74   CeedTensorContract_Xsmm *impl;
75   ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr);
76 
77   // Get nelem and current dim
78   CeedScalar currdim = log(C/blksize) / log(J);
79   if (!(C % blksize) && currdim - (int)currdim < 1e-15)
80     nelem = blksize;
81   else {
82     nelem = 1;
83     currdim = log(C) / log(J);
84   }
85 
86   // Get kernel index
87   if (impl->tensorbasis)
88     ind = CeedGetXsmmInd_Tensor(nelem, add, tmode==CEED_TRANSPOSE?1:0, B, C,
89                                 J, (CeedInt)currdim, impl->dim);
90   else
91     ind = CeedGetXsmmInd_NonTensor(add, impl->P, impl->Q, B, C, J);
92 
93   // Run kernel or fallback to default implementation
94   if (C != 1 && impl->kernels[ind])
95     for (CeedInt a=0; a<A; a++)
96       impl->kernels[ind](&u[a*B*C], &t[0], &v[a*J*C], NULL, NULL, NULL);
97   else
98     CeedTensorContract_Xsmm_Default(contract, A, B, C, J, t, tmode, add, u, v);
99 
100   return 0;
101 }
102 
103 static int CeedTensorContractDestroy_Xsmm(CeedTensorContract contract) {
104   int ierr;
105   CeedTensorContract_Xsmm *impl;
106   ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr);
107   ierr = CeedFree(&impl->kernels); CeedChk(ierr);
108   ierr = CeedFree(&impl); CeedChk(ierr);
109 
110   return 0;
111 }
112 
113 int CeedTensorContractCreate_Xsmm(CeedBasis basis,
114                                   CeedTensorContract contract) {
115   int ierr;
116   Ceed ceed;
117   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr);
118   CeedTensorContract_Xsmm *impl;
119   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
120 
121   // Set up pointers to kernels
122   ierr = CeedBasisGetTensorStatus(basis, &impl->tensorbasis); CeedChk(ierr);
123   if (impl->tensorbasis) {
124     ierr = CeedBasisGetNumNodes1D(basis, &impl->P); CeedChk(ierr);
125     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &impl->Q); CeedChk(ierr);
126     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr);
127     // Set up kernel pointer array
128     impl->numkernels = 2*2*4*impl->dim;
129     ierr = CeedCalloc(impl->numkernels, &impl->kernels); CeedChk(ierr);
130     for (CeedInt nelem = 1; nelem <= 8; nelem+=7) {
131       for (CeedInt add = 0; add <= 1; add++) {
132         for (CeedInt tmode = 0; tmode <= 1; tmode++) {
133           for (CeedInt grad = 0; grad <=1; grad++) {
134             for (CeedInt dim = 0; dim < impl->dim; dim++) {
135               const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N');
136               CeedInt B = grad ? impl->Q : (tmode ? impl->Q : impl->P),
137                       J = grad ? impl->Q : (tmode ? impl->P : impl->Q),
138                       C = nelem*CeedIntPow(J, dim);
139               int ind = CeedGetXsmmInd_Tensor(nelem, add, tmode, B, C, J, dim,
140                                               impl->dim);
141               CeedScalar alpha = 1.0, beta = 1.0;
142               if (!add) beta = 0.0;
143               impl->kernels[ind] = libxsmm_dmmdispatch(C, J, B,
144                                    NULL, NULL, NULL, &alpha,
145                                    &beta, &flags, NULL);
146             }
147           }
148         }
149       }
150     }
151   } else {
152     ierr = CeedBasisGetNumNodes(basis, &impl->P); CeedChk(ierr);
153     ierr = CeedBasisGetNumQuadraturePoints(basis, &impl->Q); CeedChk(ierr);
154     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr);
155     // Set up kernel pointer array
156     impl->numkernels = 4*2*2;
157     ierr = CeedCalloc(impl->numkernels, &impl->kernels); CeedChk(ierr);
158     for (CeedInt nelem = 1; nelem <= 8; nelem+=7) {
159       for (CeedInt add = 0; add <= 1; add++) {
160         for (CeedInt tmode = 0; tmode <= 1; tmode++) {
161           for (CeedInt grad = 1; grad <= impl->dim; grad+=impl->dim-1) {
162             const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N');
163             CeedInt B = tmode ? grad*impl->Q : impl->P,
164                     J = tmode ? impl->P : grad*impl->Q;
165             int ind = CeedGetXsmmInd_NonTensor(add, impl->P, impl->Q, B, nelem,
166                                                J);
167             CeedScalar alpha = 1.0, beta = 1.0;
168             if (!add) beta = 0.0;
169             impl->kernels[ind] = libxsmm_dmmdispatch(nelem, J, B,
170                                  NULL, NULL, NULL, &alpha,
171                                  &beta, &flags, NULL);
172           }
173         }
174       }
175     }
176   }
177   ierr = CeedTensorContractSetData(contract, (void *)&impl); CeedChk(ierr);
178 
179   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
180                                 CeedTensorContractApply_Xsmm); CeedChk(ierr);
181   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
182                                 CeedTensorContractDestroy_Xsmm); CeedChk(ierr);
183 
184   return 0;
185 }
186