xref: /libCEED/rust/libceed-sys/c-src/backends/blocked/ceed-blocked-operator.c (revision a7652942538d3752d3104ef0802e524662ebe038)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-blocked.h"
18 #include "../ref/ceed-ref.h"
19 
20 static int CeedOperatorDestroy_Blocked(CeedOperator op) {
21   int ierr;
22   CeedOperator_Blocked *impl;
23   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
24 
25   for (CeedInt i=0; i<impl->numein+impl->numeout; i++) {
26     ierr = CeedElemRestrictionDestroy(&impl->blkrestr[i]); CeedChk(ierr);
27     ierr = CeedVectorDestroy(&impl->evecs[i]); CeedChk(ierr);
28   }
29   ierr = CeedFree(&impl->blkrestr); CeedChk(ierr);
30   ierr = CeedFree(&impl->evecs); CeedChk(ierr);
31   ierr = CeedFree(&impl->edata); CeedChk(ierr);
32   ierr = CeedFree(&impl->inputstate); CeedChk(ierr);
33 
34   for (CeedInt i=0; i<impl->numein; i++) {
35     ierr = CeedVectorDestroy(&impl->evecsin[i]); CeedChk(ierr);
36     ierr = CeedVectorDestroy(&impl->qvecsin[i]); CeedChk(ierr);
37   }
38   ierr = CeedFree(&impl->evecsin); CeedChk(ierr);
39   ierr = CeedFree(&impl->qvecsin); CeedChk(ierr);
40 
41   for (CeedInt i=0; i<impl->numeout; i++) {
42     ierr = CeedVectorDestroy(&impl->evecsout[i]); CeedChk(ierr);
43     ierr = CeedVectorDestroy(&impl->qvecsout[i]); CeedChk(ierr);
44   }
45   ierr = CeedFree(&impl->evecsout); CeedChk(ierr);
46   ierr = CeedFree(&impl->qvecsout); CeedChk(ierr);
47 
48   ierr = CeedFree(&impl); CeedChk(ierr);
49   return 0;
50 }
51 
52 /*
53   Setup infields or outfields
54  */
55 static int CeedOperatorSetupFields_Blocked(CeedQFunction qf, CeedOperator op,
56     bool inOrOut,
57     CeedElemRestriction *blkrestr,
58     CeedVector *fullevecs, CeedVector *evecs,
59     CeedVector *qvecs, CeedInt starte,
60     CeedInt numfields, CeedInt Q) {
61   CeedInt dim, ierr, ncomp, P;
62   Ceed ceed;
63   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
64   CeedBasis basis;
65   CeedElemRestriction r;
66   CeedOperatorField *opfields;
67   CeedQFunctionField *qffields;
68   if (inOrOut) {
69     ierr = CeedOperatorGetFields(op, NULL, &opfields);
70     CeedChk(ierr);
71     ierr = CeedQFunctionGetFields(qf, NULL, &qffields);
72     CeedChk(ierr);
73   } else {
74     ierr = CeedOperatorGetFields(op, &opfields, NULL);
75     CeedChk(ierr);
76     ierr = CeedQFunctionGetFields(qf, &qffields, NULL);
77     CeedChk(ierr);
78   }
79   const CeedInt blksize = 8;
80 
81   // Loop over fields
82   for (CeedInt i=0; i<numfields; i++) {
83     CeedEvalMode emode;
84     ierr = CeedQFunctionFieldGetEvalMode(qffields[i], &emode); CeedChk(ierr);
85 
86     if (emode != CEED_EVAL_WEIGHT) {
87       ierr = CeedOperatorFieldGetElemRestriction(opfields[i], &r);
88       CeedChk(ierr);
89       CeedElemRestriction_Ref *data;
90       ierr = CeedElemRestrictionGetData(r, (void *)&data); CeedChk(ierr);
91       Ceed ceed;
92       ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChk(ierr);
93       CeedInt nelem, elemsize, ndof;
94       ierr = CeedElemRestrictionGetNumElements(r, &nelem); CeedChk(ierr);
95       ierr = CeedElemRestrictionGetElementSize(r, &elemsize); CeedChk(ierr);
96       ierr = CeedElemRestrictionGetNumDoF(r, &ndof); CeedChk(ierr);
97       ierr = CeedElemRestrictionGetNumComponents(r, &ncomp); CeedChk(ierr);
98       ierr = CeedElemRestrictionCreateBlocked(ceed, nelem, elemsize,
99                                               blksize, ndof, ncomp,
100                                               CEED_MEM_HOST, CEED_COPY_VALUES,
101                                               data->indices, &blkrestr[i+starte]);
102       CeedChk(ierr);
103       ierr = CeedElemRestrictionCreateVector(blkrestr[i+starte], NULL,
104                                              &fullevecs[i+starte]);
105       CeedChk(ierr);
106     }
107 
108     switch(emode) {
109     case CEED_EVAL_NONE:
110       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
111       CeedChk(ierr);
112       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &evecs[i]); CeedChk(ierr);
113       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &qvecs[i]); CeedChk(ierr);
114       break;
115     case CEED_EVAL_INTERP:
116       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
117       CeedChk(ierr);
118       ierr = CeedElemRestrictionGetElementSize(r, &P);
119       CeedChk(ierr);
120       ierr = CeedVectorCreate(ceed, P*ncomp*blksize, &evecs[i]); CeedChk(ierr);
121       ierr = CeedVectorCreate(ceed, Q*ncomp*blksize, &qvecs[i]); CeedChk(ierr);
122       break;
123     case CEED_EVAL_GRAD:
124       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
125       ierr = CeedQFunctionFieldGetNumComponents(qffields[i], &ncomp);
126       CeedChk(ierr);
127       ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
128       ierr = CeedElemRestrictionGetElementSize(r, &P);
129       CeedChk(ierr);
130       ierr = CeedVectorCreate(ceed, P*ncomp*blksize, &evecs[i]); CeedChk(ierr);
131       ierr = CeedVectorCreate(ceed, Q*ncomp*dim*blksize, &qvecs[i]); CeedChk(ierr);
132       break;
133     case CEED_EVAL_WEIGHT: // Only on input fields
134       ierr = CeedOperatorFieldGetBasis(opfields[i], &basis); CeedChk(ierr);
135       ierr = CeedVectorCreate(ceed, Q*blksize, &qvecs[i]); CeedChk(ierr);
136       ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
137                             CEED_EVAL_WEIGHT, NULL, qvecs[i]); CeedChk(ierr);
138 
139       break;
140     case CEED_EVAL_DIV:
141       break; // Not implimented
142     case CEED_EVAL_CURL:
143       break; // Not implimented
144     }
145   }
146   return 0;
147 }
148 
149 /*
150   CeedOperator needs to connect all the named fields (be they active or passive)
151   to the named inputs and outputs of its CeedQFunction.
152  */
153 static int CeedOperatorSetup_Blocked(CeedOperator op) {
154   int ierr;
155   bool setupdone;
156   ierr = CeedOperatorGetSetupStatus(op, &setupdone); CeedChk(ierr);
157   if (setupdone) return 0;
158   Ceed ceed;
159   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
160   CeedOperator_Blocked *impl;
161   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
162   CeedQFunction qf;
163   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
164   CeedInt Q, numinputfields, numoutputfields;
165   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
166   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
167   CeedChk(ierr);
168   CeedOperatorField *opinputfields, *opoutputfields;
169   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
170   CeedChk(ierr);
171   CeedQFunctionField *qfinputfields, *qfoutputfields;
172   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
173   CeedChk(ierr);
174 
175   // Allocate
176   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->blkrestr);
177   CeedChk(ierr);
178   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->evecs);
179   CeedChk(ierr);
180   ierr = CeedCalloc(numinputfields + numoutputfields, &impl->edata);
181   CeedChk(ierr);
182 
183   ierr = CeedCalloc(16, &impl->inputstate); CeedChk(ierr);
184   ierr = CeedCalloc(16, &impl->evecsin); CeedChk(ierr);
185   ierr = CeedCalloc(16, &impl->evecsout); CeedChk(ierr);
186   ierr = CeedCalloc(16, &impl->qvecsin); CeedChk(ierr);
187   ierr = CeedCalloc(16, &impl->qvecsout); CeedChk(ierr);
188 
189   impl->numein = numinputfields; impl->numeout = numoutputfields;
190 
191   // Set up infield and outfield pointer arrays
192   // Infields
193   ierr = CeedOperatorSetupFields_Blocked(qf, op, 0, impl->blkrestr,
194                                          impl->evecs, impl->evecsin,
195                                          impl->qvecsin, 0,
196                                          numinputfields, Q);
197   CeedChk(ierr);
198   // Outfields
199   ierr = CeedOperatorSetupFields_Blocked(qf, op, 1, impl->blkrestr,
200                                          impl->evecs, impl->evecsout,
201                                          impl->qvecsout, numinputfields,
202                                          numoutputfields, Q);
203   CeedChk(ierr);
204 
205   ierr = CeedOperatorSetSetupDone(op); CeedChk(ierr);
206 
207   return 0;
208 }
209 
210 static int CeedOperatorApply_Blocked(CeedOperator op, CeedVector invec,
211                                      CeedVector outvec, CeedRequest *request) {
212   int ierr;
213   CeedOperator_Blocked *impl;
214   ierr = CeedOperatorGetData(op, (void *)&impl); CeedChk(ierr);
215   const CeedInt blksize = 8;
216   CeedInt Q, elemsize, numinputfields, numoutputfields, numelements, ncomp;
217   ierr = CeedOperatorGetNumElements(op, &numelements); CeedChk(ierr);
218   ierr = CeedOperatorGetNumQuadraturePoints(op, &Q); CeedChk(ierr);
219   CeedInt nblks = (numelements/blksize) + !!(numelements%blksize);
220   CeedQFunction qf;
221   ierr = CeedOperatorGetQFunction(op, &qf); CeedChk(ierr);
222   ierr= CeedQFunctionGetNumArgs(qf, &numinputfields, &numoutputfields);
223   CeedChk(ierr);
224   CeedTransposeMode lmode;
225   CeedOperatorField *opinputfields, *opoutputfields;
226   ierr = CeedOperatorGetFields(op, &opinputfields, &opoutputfields);
227   CeedChk(ierr);
228   CeedQFunctionField *qfinputfields, *qfoutputfields;
229   ierr = CeedQFunctionGetFields(qf, &qfinputfields, &qfoutputfields);
230   CeedChk(ierr);
231   CeedEvalMode emode;
232   CeedVector vec;
233   CeedBasis basis;
234   CeedElemRestriction Erestrict;
235   uint64_t state;
236 
237   // Setup
238   ierr = CeedOperatorSetup_Blocked(op); CeedChk(ierr);
239 
240   // Input Evecs and Restriction
241   for (CeedInt i=0; i<numinputfields; i++) {
242     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
243     CeedChk(ierr);
244     if (emode == CEED_EVAL_WEIGHT) { // Skip
245     } else {
246       // Get input vector
247       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
248       if (vec != CEED_VECTOR_ACTIVE) {
249         // Restrict
250         ierr = CeedVectorGetState(vec, &state); CeedChk(ierr);
251         if (state != impl->inputstate[i]) {
252           ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode);
253           CeedChk(ierr);
254           ierr = CeedElemRestrictionApply(impl->blkrestr[i], CEED_NOTRANSPOSE,
255                                           lmode, vec, impl->evecs[i], request);
256           CeedChk(ierr);
257           impl->inputstate[i] = state;
258         }
259       } else {
260         // Set Qvec for CEED_EVAL_NONE
261         if (emode == CEED_EVAL_NONE) {
262           ierr = CeedVectorGetArray(impl->evecsin[i], CEED_MEM_HOST,
263                                     &impl->edata[i]); CeedChk(ierr);
264           ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
265                                     CEED_USE_POINTER,
266                                     impl->edata[i]); CeedChk(ierr);
267           ierr = CeedVectorRestoreArray(impl->evecsin[i],
268                                         &impl->edata[i]); CeedChk(ierr);
269         }
270       }
271       // Get evec
272       ierr = CeedVectorGetArrayRead(impl->evecs[i], CEED_MEM_HOST,
273                                     (const CeedScalar **) &impl->edata[i]);
274       CeedChk(ierr);
275     }
276   }
277 
278   // Output Lvecs, Evecs, and Qvecs
279   for (CeedInt i=0; i<numoutputfields; i++) {
280     // Zero Lvecs
281     ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
282     if (vec == CEED_VECTOR_ACTIVE) {
283       if (!impl->add) {
284         vec = outvec;
285         ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
286       }
287     } else {
288       ierr = CeedVectorSetValue(vec, 0.0); CeedChk(ierr);
289     }
290     // Set Qvec if needed
291     ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
292     CeedChk(ierr);
293     if (emode == CEED_EVAL_NONE) {
294       // Set qvec to single block evec
295       ierr = CeedVectorGetArray(impl->evecsout[i], CEED_MEM_HOST,
296                                 &impl->edata[i + numinputfields]);
297       CeedChk(ierr);
298       ierr = CeedVectorSetArray(impl->qvecsout[i], CEED_MEM_HOST,
299                                 CEED_USE_POINTER,
300                                 impl->edata[i + numinputfields]); CeedChk(ierr);
301       ierr = CeedVectorRestoreArray(impl->evecsout[i],
302                                     &impl->edata[i + numinputfields]);
303       CeedChk(ierr);
304     }
305   }
306   impl->add = false;
307 
308   // Loop through elements
309   for (CeedInt e=0; e<nblks*blksize; e+=blksize) {
310     // Input basis apply if needed
311     for (CeedInt i=0; i<numinputfields; i++) {
312       CeedInt activein = 0;
313       // Get elemsize, emode, ncomp
314       ierr = CeedOperatorFieldGetElemRestriction(opinputfields[i], &Erestrict);
315       CeedChk(ierr);
316       ierr = CeedElemRestrictionGetElementSize(Erestrict, &elemsize);
317       CeedChk(ierr);
318       ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
319       CeedChk(ierr);
320       ierr = CeedQFunctionFieldGetNumComponents(qfinputfields[i], &ncomp);
321       CeedChk(ierr);
322       // Restrict block active input
323       ierr = CeedOperatorFieldGetVector(opinputfields[i], &vec); CeedChk(ierr);
324       if (vec == CEED_VECTOR_ACTIVE) {
325         ierr = CeedOperatorFieldGetLMode(opinputfields[i], &lmode);
326         CeedChk(ierr);
327         ierr = CeedElemRestrictionApplyBlock(impl->blkrestr[i], e/blksize,
328                                              CEED_NOTRANSPOSE, lmode, invec,
329                                              impl->evecsin[i], request);
330         CeedChk(ierr);
331         activein = 1;
332       }
333       // Basis action
334       switch(emode) {
335       case CEED_EVAL_NONE:
336         if (!activein) {
337           ierr = CeedVectorSetArray(impl->qvecsin[i], CEED_MEM_HOST,
338                                     CEED_USE_POINTER,
339                                     &impl->edata[i][e*Q*ncomp]); CeedChk(ierr);
340         }
341         break;
342       case CEED_EVAL_INTERP:
343         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis);
344         CeedChk(ierr);
345         if (!activein) {
346           ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
347                                     CEED_USE_POINTER,
348                                     &impl->edata[i][e*elemsize*ncomp]);
349           CeedChk(ierr);
350         }
351         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
352                               CEED_EVAL_INTERP, impl->evecsin[i],
353                               impl->qvecsin[i]); CeedChk(ierr);
354         break;
355       case CEED_EVAL_GRAD:
356         ierr = CeedOperatorFieldGetBasis(opinputfields[i], &basis);
357         CeedChk(ierr);
358          if (!activein) {
359           ierr = CeedVectorSetArray(impl->evecsin[i], CEED_MEM_HOST,
360                                     CEED_USE_POINTER,
361                                     &impl->edata[i][e*elemsize*ncomp]);
362           CeedChk(ierr);
363         }
364         ierr = CeedBasisApply(basis, blksize, CEED_NOTRANSPOSE,
365                               CEED_EVAL_GRAD, impl->evecsin[i],
366                               impl->qvecsin[i]); CeedChk(ierr);
367         break;
368       case CEED_EVAL_WEIGHT:
369         break;  // No action
370       case CEED_EVAL_DIV:
371         break; // Not implimented
372       case CEED_EVAL_CURL:
373         break; // Not implimented
374       }
375     }
376 
377     // Q function
378     ierr = CeedQFunctionApply(qf, Q*blksize, impl->qvecsin, impl->qvecsout);
379     CeedChk(ierr);
380 
381     // Output basis apply and restrict
382     for (CeedInt i=0; i<numoutputfields; i++) {
383       // Get elemsize, emode, ncomp
384       ierr = CeedOperatorFieldGetElemRestriction(opoutputfields[i], &Erestrict);
385       CeedChk(ierr);
386       ierr = CeedQFunctionFieldGetEvalMode(qfoutputfields[i], &emode);
387       CeedChk(ierr);
388       // Basis action
389       switch(emode) {
390       case CEED_EVAL_NONE:
391         break; // No action
392       case CEED_EVAL_INTERP:
393         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
394         CeedChk(ierr);
395         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
396                               CEED_EVAL_INTERP, impl->qvecsout[i],
397                               impl->evecsout[i]); CeedChk(ierr);
398         break;
399       case CEED_EVAL_GRAD:
400         ierr = CeedOperatorFieldGetBasis(opoutputfields[i], &basis);
401         CeedChk(ierr);
402         ierr = CeedBasisApply(basis, blksize, CEED_TRANSPOSE,
403                               CEED_EVAL_GRAD, impl->qvecsout[i],
404                               impl->evecsout[i]); CeedChk(ierr);
405         break;
406       case CEED_EVAL_WEIGHT: {
407         Ceed ceed;
408         ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
409         return CeedError(ceed, 1,
410                          "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
411         break; // Should not occur
412       }
413       case CEED_EVAL_DIV:
414         break; // Not implimented
415       case CEED_EVAL_CURL:
416         break; // Not implimented
417       }
418       // Restrict output block
419       // Get output vector
420       ierr = CeedOperatorFieldGetVector(opoutputfields[i], &vec); CeedChk(ierr);
421       if (vec == CEED_VECTOR_ACTIVE)
422         vec = outvec;
423       // Restrict
424       ierr = CeedOperatorFieldGetLMode(opoutputfields[i], &lmode);
425       CeedChk(ierr);
426       ierr = CeedElemRestrictionApplyBlock(impl->blkrestr[i+impl->numein],
427                                            e/blksize, CEED_TRANSPOSE,
428                                            lmode, impl->evecsout[i],
429                                            vec, request); CeedChk(ierr);
430     }
431   }
432 
433   // Restore input arrays
434   for (CeedInt i=0; i<numinputfields; i++) {
435     ierr = CeedQFunctionFieldGetEvalMode(qfinputfields[i], &emode);
436     CeedChk(ierr);
437     if (emode == CEED_EVAL_WEIGHT) { // Skip
438     } else {
439       ierr = CeedVectorRestoreArrayRead(impl->evecs[i],
440                                         (const CeedScalar **) &impl->edata[i]);
441       CeedChk(ierr);
442     }
443   }
444 
445   return 0;
446 }
447 
448 int CeedOperatorCreate_Blocked(CeedOperator op) {
449   int ierr;
450   Ceed ceed;
451   ierr = CeedOperatorGetCeed(op, &ceed); CeedChk(ierr);
452   CeedOperator_Blocked *impl;
453 
454   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
455   ierr = CeedOperatorSetData(op, (void *)&impl); CeedChk(ierr);
456 
457   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Apply",
458                                 CeedOperatorApply_Blocked); CeedChk(ierr);
459   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
460                                 CeedOperatorDestroy_Blocked); CeedChk(ierr);
461   return 0;
462 }
463