xref: /libCEED/examples/fluids/src/mat-ceed.c (revision c8564c30f55def119ddd4dd977f1131cd364f231)
1*c8564c30SJames Wright /// @file
2*c8564c30SJames Wright /// MatCeed and it's related operators
3*c8564c30SJames Wright 
4*c8564c30SJames Wright #include <ceed.h>
5*c8564c30SJames Wright #include <ceed/backend.h>
6*c8564c30SJames Wright #include <mat-ceed-impl.h>
7*c8564c30SJames Wright #include <mat-ceed.h>
8*c8564c30SJames Wright #include <petscdmplex.h>
9*c8564c30SJames Wright #include <stdlib.h>
10*c8564c30SJames Wright #include <string.h>
11*c8564c30SJames Wright 
12*c8564c30SJames Wright PetscClassId  MATCEED_CLASSID;
13*c8564c30SJames Wright PetscLogEvent MATCEED_MULT, MATCEED_MULT_TRANSPOSE;
14*c8564c30SJames Wright 
15*c8564c30SJames Wright /**
16*c8564c30SJames Wright   @brief Register MATCEED log events.
17*c8564c30SJames Wright 
18*c8564c30SJames Wright   Not collective across MPI processes.
19*c8564c30SJames Wright 
20*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
21*c8564c30SJames Wright **/
22*c8564c30SJames Wright static PetscErrorCode MatCeedRegisterLogEvents() {
23*c8564c30SJames Wright   static bool registered = false;
24*c8564c30SJames Wright 
25*c8564c30SJames Wright   PetscFunctionBeginUser;
26*c8564c30SJames Wright   if (registered) PetscFunctionReturn(PETSC_SUCCESS);
27*c8564c30SJames Wright   PetscCall(PetscClassIdRegister("MATCEED", &MATCEED_CLASSID));
28*c8564c30SJames Wright   PetscCall(PetscLogEventRegister("MATCEED Mult", MATCEED_CLASSID, &MATCEED_MULT));
29*c8564c30SJames Wright   PetscCall(PetscLogEventRegister("MATCEED Mult Transpose", MATCEED_CLASSID, &MATCEED_MULT_TRANSPOSE));
30*c8564c30SJames Wright   registered = true;
31*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
32*c8564c30SJames Wright }
33*c8564c30SJames Wright 
34*c8564c30SJames Wright /**
35*c8564c30SJames Wright   @brief Translate PetscMemType to CeedMemType
36*c8564c30SJames Wright 
37*c8564c30SJames Wright   @param[in]  mem_type  PetscMemType
38*c8564c30SJames Wright 
39*c8564c30SJames Wright   @return Equivalent CeedMemType
40*c8564c30SJames Wright **/
41*c8564c30SJames Wright static inline CeedMemType MemTypeP2C(PetscMemType mem_type) { return PetscMemTypeDevice(mem_type) ? CEED_MEM_DEVICE : CEED_MEM_HOST; }
42*c8564c30SJames Wright 
43*c8564c30SJames Wright /**
44*c8564c30SJames Wright   @brief Translate array of `CeedInt` to `PetscInt`.
45*c8564c30SJames Wright     If the types differ, `array_ceed` is freed with `free()` and `array_petsc` is allocated with `malloc()`.
46*c8564c30SJames Wright     Caller is responsible for freeing `array_petsc` with `free()`.
47*c8564c30SJames Wright 
48*c8564c30SJames Wright   Not collective across MPI processes.
49*c8564c30SJames Wright 
50*c8564c30SJames Wright   @param[in]      num_entries  Number of array entries
51*c8564c30SJames Wright   @param[in,out]  array_ceed   Array of `CeedInt`
52*c8564c30SJames Wright   @param[out]     array_petsc  Array of `PetscInt`
53*c8564c30SJames Wright 
54*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
55*c8564c30SJames Wright **/
56*c8564c30SJames Wright static inline PetscErrorCode IntArrayC2P(PetscInt num_entries, CeedInt **array_ceed, PetscInt **array_petsc) {
57*c8564c30SJames Wright   const CeedInt  int_c = 0;
58*c8564c30SJames Wright   const PetscInt int_p = 0;
59*c8564c30SJames Wright 
60*c8564c30SJames Wright   PetscFunctionBeginUser;
61*c8564c30SJames Wright   if (sizeof(int_c) == sizeof(int_p)) {
62*c8564c30SJames Wright     *array_petsc = (PetscInt *)*array_ceed;
63*c8564c30SJames Wright   } else {
64*c8564c30SJames Wright     *array_petsc = malloc(num_entries * sizeof(PetscInt));
65*c8564c30SJames Wright     for (PetscInt i = 0; i < num_entries; i++) (*array_petsc)[i] = (*array_ceed)[i];
66*c8564c30SJames Wright     free(*array_ceed);
67*c8564c30SJames Wright   }
68*c8564c30SJames Wright   *array_ceed = NULL;
69*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
70*c8564c30SJames Wright }
71*c8564c30SJames Wright 
72*c8564c30SJames Wright /**
73*c8564c30SJames Wright   @brief Transfer array from PETSc `Vec` to `CeedVector`.
74*c8564c30SJames Wright 
75*c8564c30SJames Wright   Collective across MPI processes.
76*c8564c30SJames Wright 
77*c8564c30SJames Wright   @param[in]   ceed      libCEED context
78*c8564c30SJames Wright   @param[in]   X_petsc   PETSc `Vec`
79*c8564c30SJames Wright   @param[out]  mem_type  PETSc `MemType`
80*c8564c30SJames Wright   @param[out]  x_ceed    `CeedVector`
81*c8564c30SJames Wright 
82*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
83*c8564c30SJames Wright **/
84*c8564c30SJames Wright static inline PetscErrorCode VecP2C(Ceed ceed, Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) {
85*c8564c30SJames Wright   PetscScalar *x;
86*c8564c30SJames Wright 
87*c8564c30SJames Wright   PetscFunctionBeginUser;
88*c8564c30SJames Wright   PetscCall(VecGetArrayAndMemType(X_petsc, &x, mem_type));
89*c8564c30SJames Wright   PetscCeedCall(ceed, CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x));
90*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
91*c8564c30SJames Wright }
92*c8564c30SJames Wright 
93*c8564c30SJames Wright /**
94*c8564c30SJames Wright   @brief Transfer array from `CeedVector` to PETSc `Vec`.
95*c8564c30SJames Wright 
96*c8564c30SJames Wright   Collective across MPI processes.
97*c8564c30SJames Wright 
98*c8564c30SJames Wright   @param[in]   ceed      libCEED context
99*c8564c30SJames Wright   @param[in]   x_ceed    `CeedVector`
100*c8564c30SJames Wright   @param[in]   mem_type  PETSc `MemType`
101*c8564c30SJames Wright   @param[out]  X_petsc   PETSc `Vec`
102*c8564c30SJames Wright 
103*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
104*c8564c30SJames Wright **/
105*c8564c30SJames Wright static inline PetscErrorCode VecC2P(Ceed ceed, CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) {
106*c8564c30SJames Wright   PetscScalar *x;
107*c8564c30SJames Wright 
108*c8564c30SJames Wright   PetscFunctionBeginUser;
109*c8564c30SJames Wright   PetscCeedCall(ceed, CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x));
110*c8564c30SJames Wright   PetscCall(VecRestoreArrayAndMemType(X_petsc, &x));
111*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
112*c8564c30SJames Wright }
113*c8564c30SJames Wright 
114*c8564c30SJames Wright /**
115*c8564c30SJames Wright   @brief Transfer read only array from PETSc `Vec` to `CeedVector`.
116*c8564c30SJames Wright 
117*c8564c30SJames Wright   Collective across MPI processes.
118*c8564c30SJames Wright 
119*c8564c30SJames Wright   @param[in]   ceed      libCEED context
120*c8564c30SJames Wright   @param[in]   X_petsc   PETSc `Vec`
121*c8564c30SJames Wright   @param[out]  mem_type  PETSc `MemType`
122*c8564c30SJames Wright   @param[out]  x_ceed    `CeedVector`
123*c8564c30SJames Wright 
124*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
125*c8564c30SJames Wright **/
126*c8564c30SJames Wright static inline PetscErrorCode VecReadP2C(Ceed ceed, Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) {
127*c8564c30SJames Wright   PetscScalar *x;
128*c8564c30SJames Wright 
129*c8564c30SJames Wright   PetscFunctionBeginUser;
130*c8564c30SJames Wright   PetscCall(VecGetArrayReadAndMemType(X_petsc, (const PetscScalar **)&x, mem_type));
131*c8564c30SJames Wright   PetscCeedCall(ceed, CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x));
132*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
133*c8564c30SJames Wright }
134*c8564c30SJames Wright 
135*c8564c30SJames Wright /**
136*c8564c30SJames Wright   @brief Transfer read only array from `CeedVector` to PETSc `Vec`.
137*c8564c30SJames Wright 
138*c8564c30SJames Wright   Collective across MPI processes.
139*c8564c30SJames Wright 
140*c8564c30SJames Wright   @param[in]   ceed      libCEED context
141*c8564c30SJames Wright   @param[in]   x_ceed    `CeedVector`
142*c8564c30SJames Wright   @param[in]   mem_type  PETSc `MemType`
143*c8564c30SJames Wright   @param[out]  X_petsc   PETSc `Vec`
144*c8564c30SJames Wright 
145*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
146*c8564c30SJames Wright **/
147*c8564c30SJames Wright static inline PetscErrorCode VecReadC2P(Ceed ceed, CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) {
148*c8564c30SJames Wright   PetscScalar *x;
149*c8564c30SJames Wright 
150*c8564c30SJames Wright   PetscFunctionBeginUser;
151*c8564c30SJames Wright   PetscCeedCall(ceed, CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x));
152*c8564c30SJames Wright   PetscCall(VecRestoreArrayReadAndMemType(X_petsc, (const PetscScalar **)&x));
153*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
154*c8564c30SJames Wright }
155*c8564c30SJames Wright 
156*c8564c30SJames Wright /**
157*c8564c30SJames Wright   @brief Setup inner `Mat` for `PC` operations not directly supported by libCEED.
158*c8564c30SJames Wright 
159*c8564c30SJames Wright   Collective across MPI processes.
160*c8564c30SJames Wright 
161*c8564c30SJames Wright   @param[in]   mat_ceed   `MATCEED` to setup
162*c8564c30SJames Wright   @param[out]  mat_inner  Inner `Mat`
163*c8564c30SJames Wright 
164*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
165*c8564c30SJames Wright **/
166*c8564c30SJames Wright static PetscErrorCode MatCeedSetupInnerMat(Mat mat_ceed, Mat *mat_inner) {
167*c8564c30SJames Wright   MatCeedContext ctx;
168*c8564c30SJames Wright 
169*c8564c30SJames Wright   PetscFunctionBeginUser;
170*c8564c30SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
171*c8564c30SJames Wright 
172*c8564c30SJames Wright   PetscCheck(ctx->dm_x == ctx->dm_y, PetscObjectComm((PetscObject)mat_ceed), PETSC_ERR_SUP, "PC only supported for MATCEED on a single DM");
173*c8564c30SJames Wright 
174*c8564c30SJames Wright   // Check cl mat type
175*c8564c30SJames Wright   {
176*c8564c30SJames Wright     PetscBool is_internal_mat_type_cl = PETSC_FALSE;
177*c8564c30SJames Wright     char      internal_mat_type_cl[64];
178*c8564c30SJames Wright 
179*c8564c30SJames Wright     // Check for specific CL inner mat type for this Mat
180*c8564c30SJames Wright     {
181*c8564c30SJames Wright       const char *mat_ceed_prefix = NULL;
182*c8564c30SJames Wright 
183*c8564c30SJames Wright       PetscCall(MatGetOptionsPrefix(mat_ceed, &mat_ceed_prefix));
184*c8564c30SJames Wright       PetscOptionsBegin(PetscObjectComm((PetscObject)mat_ceed), mat_ceed_prefix, "", NULL);
185*c8564c30SJames Wright       PetscCall(PetscOptionsFList("-ceed_inner_mat_type", "MATCEED inner assembled MatType for PC support", NULL, MatList, internal_mat_type_cl,
186*c8564c30SJames Wright                                   internal_mat_type_cl, sizeof(internal_mat_type_cl), &is_internal_mat_type_cl));
187*c8564c30SJames Wright       PetscOptionsEnd();
188*c8564c30SJames Wright       if (is_internal_mat_type_cl) {
189*c8564c30SJames Wright         PetscCall(PetscFree(ctx->internal_mat_type));
190*c8564c30SJames Wright         PetscCall(PetscStrallocpy(internal_mat_type_cl, &ctx->internal_mat_type));
191*c8564c30SJames Wright       }
192*c8564c30SJames Wright     }
193*c8564c30SJames Wright   }
194*c8564c30SJames Wright 
195*c8564c30SJames Wright   // Create sparse matrix
196*c8564c30SJames Wright   {
197*c8564c30SJames Wright     MatType dm_mat_type, dm_mat_type_copy;
198*c8564c30SJames Wright 
199*c8564c30SJames Wright     PetscCall(DMGetMatType(ctx->dm_x, &dm_mat_type));
200*c8564c30SJames Wright     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
201*c8564c30SJames Wright     PetscCall(DMSetMatType(ctx->dm_x, ctx->internal_mat_type));
202*c8564c30SJames Wright     PetscCall(DMCreateMatrix(ctx->dm_x, mat_inner));
203*c8564c30SJames Wright     PetscCall(DMSetMatType(ctx->dm_x, dm_mat_type_copy));
204*c8564c30SJames Wright     PetscCall(PetscFree(dm_mat_type_copy));
205*c8564c30SJames Wright   }
206*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
207*c8564c30SJames Wright }
208*c8564c30SJames Wright 
209*c8564c30SJames Wright /**
210*c8564c30SJames Wright   @brief Assemble the point block diagonal of a `MATCEED` into a `MATAIJ` or similar.
211*c8564c30SJames Wright          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
212*c8564c30SJames Wright          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
213*c8564c30SJames Wright 
214*c8564c30SJames Wright   Collective across MPI processes.
215*c8564c30SJames Wright 
216*c8564c30SJames Wright   @param[in]      mat_ceed  `MATCEED` to assemble
217*c8564c30SJames Wright   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
218*c8564c30SJames Wright 
219*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
220*c8564c30SJames Wright **/
221*c8564c30SJames Wright static PetscErrorCode MatCeedAssemblePointBlockDiagonalCOO(Mat mat_ceed, Mat mat_coo) {
222*c8564c30SJames Wright   MatCeedContext ctx;
223*c8564c30SJames Wright 
224*c8564c30SJames Wright   PetscFunctionBeginUser;
225*c8564c30SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
226*c8564c30SJames Wright 
227*c8564c30SJames Wright   // Check if COO pattern set
228*c8564c30SJames Wright   {
229*c8564c30SJames Wright     PetscInt index = -1;
230*c8564c30SJames Wright 
231*c8564c30SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
232*c8564c30SJames Wright       if (ctx->mats_assembled_pbd[i] == mat_coo) index = i;
233*c8564c30SJames Wright     }
234*c8564c30SJames Wright     if (index == -1) {
235*c8564c30SJames Wright       PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
236*c8564c30SJames Wright       CeedInt      *rows_ceed, *cols_ceed;
237*c8564c30SJames Wright       PetscCount    num_entries;
238*c8564c30SJames Wright       PetscLogStage stage_amg_setup;
239*c8564c30SJames Wright 
240*c8564c30SJames Wright       // -- Assemble sparsity pattern if mat hasn't been assembled before
241*c8564c30SJames Wright       PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
242*c8564c30SJames Wright       if (stage_amg_setup == -1) {
243*c8564c30SJames Wright         PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
244*c8564c30SJames Wright       }
245*c8564c30SJames Wright       PetscCall(PetscLogStagePush(stage_amg_setup));
246*c8564c30SJames Wright       PetscCeedCall(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonalSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
247*c8564c30SJames Wright       PetscCall(IntArrayC2P(num_entries, &rows_ceed, &rows_petsc));
248*c8564c30SJames Wright       PetscCall(IntArrayC2P(num_entries, &cols_ceed, &cols_petsc));
249*c8564c30SJames Wright       PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
250*c8564c30SJames Wright       free(rows_petsc);
251*c8564c30SJames Wright       free(cols_petsc);
252*c8564c30SJames Wright       if (!ctx->coo_values_pbd) PetscCeedCall(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_pbd));
253*c8564c30SJames Wright       PetscCall(PetscRealloc(++ctx->num_mats_assembled_pbd * sizeof(Mat), &ctx->mats_assembled_pbd));
254*c8564c30SJames Wright       ctx->mats_assembled_pbd[ctx->num_mats_assembled_pbd - 1] = mat_coo;
255*c8564c30SJames Wright       PetscCall(PetscLogStagePop());
256*c8564c30SJames Wright     }
257*c8564c30SJames Wright   }
258*c8564c30SJames Wright 
259*c8564c30SJames Wright   // Assemble mat_ceed
260*c8564c30SJames Wright   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
261*c8564c30SJames Wright   {
262*c8564c30SJames Wright     const CeedScalar *values;
263*c8564c30SJames Wright     MatType           mat_type;
264*c8564c30SJames Wright     CeedMemType       mem_type = CEED_MEM_HOST;
265*c8564c30SJames Wright     PetscBool         is_spd, is_spd_known;
266*c8564c30SJames Wright 
267*c8564c30SJames Wright     PetscCall(MatGetType(mat_coo, &mat_type));
268*c8564c30SJames Wright     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
269*c8564c30SJames Wright     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
270*c8564c30SJames Wright     else mem_type = CEED_MEM_HOST;
271*c8564c30SJames Wright 
272*c8564c30SJames Wright     PetscCeedCall(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonal(ctx->op_mult, ctx->coo_values_pbd, CEED_REQUEST_IMMEDIATE));
273*c8564c30SJames Wright     PetscCeedCall(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_pbd, mem_type, &values));
274*c8564c30SJames Wright     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
275*c8564c30SJames Wright     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
276*c8564c30SJames Wright     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
277*c8564c30SJames Wright     PetscCeedCall(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_pbd, &values));
278*c8564c30SJames Wright   }
279*c8564c30SJames Wright   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
280*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
281*c8564c30SJames Wright }
282*c8564c30SJames Wright 
283*c8564c30SJames Wright /**
284*c8564c30SJames Wright   @brief Assemble inner `Mat` for diagonal `PC` operations
285*c8564c30SJames Wright 
286*c8564c30SJames Wright   Collective across MPI processes.
287*c8564c30SJames Wright 
288*c8564c30SJames Wright   @param[in]   mat_ceed      `MATCEED` to invert
289*c8564c30SJames Wright   @param[in]   use_ceed_pbd  Boolean flag to use libCEED PBD assembly
290*c8564c30SJames Wright   @param[out]  mat_inner     Inner `Mat` for diagonal operations
291*c8564c30SJames Wright 
292*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
293*c8564c30SJames Wright **/
294*c8564c30SJames Wright static PetscErrorCode MatCeedAssembleInnerBlockDiagonalMat(Mat mat_ceed, PetscBool use_ceed_pbd, Mat *mat_inner) {
295*c8564c30SJames Wright   MatCeedContext ctx;
296*c8564c30SJames Wright 
297*c8564c30SJames Wright   PetscFunctionBeginUser;
298*c8564c30SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
299*c8564c30SJames Wright   if (use_ceed_pbd) {
300*c8564c30SJames Wright     // Check if COO pattern set
301*c8564c30SJames Wright     if (!ctx->mat_assembled_pbd_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_pbd_internal));
302*c8564c30SJames Wright 
303*c8564c30SJames Wright     // Assemble mat_assembled_full_internal
304*c8564c30SJames Wright     PetscCall(MatCeedAssemblePointBlockDiagonalCOO(mat_ceed, ctx->mat_assembled_pbd_internal));
305*c8564c30SJames Wright     if (mat_inner) *mat_inner = ctx->mat_assembled_pbd_internal;
306*c8564c30SJames Wright   } else {
307*c8564c30SJames Wright     // Check if COO pattern set
308*c8564c30SJames Wright     if (!ctx->mat_assembled_full_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_full_internal));
309*c8564c30SJames Wright 
310*c8564c30SJames Wright     // Assemble mat_assembled_full_internal
311*c8564c30SJames Wright     PetscCall(MatCeedAssembleCOO(mat_ceed, ctx->mat_assembled_full_internal));
312*c8564c30SJames Wright     if (mat_inner) *mat_inner = ctx->mat_assembled_full_internal;
313*c8564c30SJames Wright   }
314*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
315*c8564c30SJames Wright }
316*c8564c30SJames Wright 
317*c8564c30SJames Wright /**
318*c8564c30SJames Wright   @brief Get `MATCEED` diagonal block for Jacobi.
319*c8564c30SJames Wright 
320*c8564c30SJames Wright   Collective across MPI processes.
321*c8564c30SJames Wright 
322*c8564c30SJames Wright   @param[in]   mat_ceed   `MATCEED` to invert
323*c8564c30SJames Wright   @param[out]  mat_block  The diagonal block matrix
324*c8564c30SJames Wright 
325*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
326*c8564c30SJames Wright **/
327*c8564c30SJames Wright static PetscErrorCode MatGetDiagonalBlock_Ceed(Mat mat_ceed, Mat *mat_block) {
328*c8564c30SJames Wright   Mat            mat_inner = NULL;
329*c8564c30SJames Wright   MatCeedContext ctx;
330*c8564c30SJames Wright 
331*c8564c30SJames Wright   PetscFunctionBeginUser;
332*c8564c30SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
333*c8564c30SJames Wright 
334*c8564c30SJames Wright   // Assemble inner mat if needed
335*c8564c30SJames Wright   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
336*c8564c30SJames Wright 
337*c8564c30SJames Wright   // Get block diagonal
338*c8564c30SJames Wright   PetscCall(MatGetDiagonalBlock(mat_inner, mat_block));
339*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
340*c8564c30SJames Wright }
341*c8564c30SJames Wright 
342*c8564c30SJames Wright /**
343*c8564c30SJames Wright   @brief Invert `MATCEED` diagonal block for Jacobi.
344*c8564c30SJames Wright 
345*c8564c30SJames Wright   Collective across MPI processes.
346*c8564c30SJames Wright 
347*c8564c30SJames Wright   @param[in]   mat_ceed  `MATCEED` to invert
348*c8564c30SJames Wright   @param[out]  values    The block inverses in column major order
349*c8564c30SJames Wright 
350*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
351*c8564c30SJames Wright **/
352*c8564c30SJames Wright static PetscErrorCode MatInvertBlockDiagonal_Ceed(Mat mat_ceed, const PetscScalar **values) {
353*c8564c30SJames Wright   Mat            mat_inner = NULL;
354*c8564c30SJames Wright   MatCeedContext ctx;
355*c8564c30SJames Wright 
356*c8564c30SJames Wright   PetscFunctionBeginUser;
357*c8564c30SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
358*c8564c30SJames Wright 
359*c8564c30SJames Wright   // Assemble inner mat if needed
360*c8564c30SJames Wright   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
361*c8564c30SJames Wright 
362*c8564c30SJames Wright   // Invert PB diagonal
363*c8564c30SJames Wright   PetscCall(MatInvertBlockDiagonal(mat_inner, values));
364*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
365*c8564c30SJames Wright }
366*c8564c30SJames Wright 
367*c8564c30SJames Wright /**
368*c8564c30SJames Wright   @brief Invert `MATCEED` variable diagonal block for Jacobi.
369*c8564c30SJames Wright 
370*c8564c30SJames Wright   Collective across MPI processes.
371*c8564c30SJames Wright 
372*c8564c30SJames Wright   @param[in]   mat_ceed     `MATCEED` to invert
373*c8564c30SJames Wright   @param[in]   num_blocks   The number of blocks on the process
374*c8564c30SJames Wright   @param[in]   block_sizes  The size of each block on the process
375*c8564c30SJames Wright   @param[out]  values       The block inverses in column major order
376*c8564c30SJames Wright 
377*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
378*c8564c30SJames Wright **/
379*c8564c30SJames Wright static PetscErrorCode MatInvertVariableBlockDiagonal_Ceed(Mat mat_ceed, PetscInt num_blocks, const PetscInt *block_sizes, PetscScalar *values) {
380*c8564c30SJames Wright   Mat            mat_inner = NULL;
381*c8564c30SJames Wright   MatCeedContext ctx;
382*c8564c30SJames Wright 
383*c8564c30SJames Wright   PetscFunctionBeginUser;
384*c8564c30SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
385*c8564c30SJames Wright 
386*c8564c30SJames Wright   // Assemble inner mat if needed
387*c8564c30SJames Wright   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_vpbd_valid, &mat_inner));
388*c8564c30SJames Wright 
389*c8564c30SJames Wright   // Invert PB diagonal
390*c8564c30SJames Wright   PetscCall(MatInvertVariableBlockDiagonal(mat_inner, num_blocks, block_sizes, values));
391*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
392*c8564c30SJames Wright }
393*c8564c30SJames Wright 
394*c8564c30SJames Wright // -----------------------------------------------------------------------------
395*c8564c30SJames Wright // MatCeed
396*c8564c30SJames Wright // -----------------------------------------------------------------------------
397*c8564c30SJames Wright 
398*c8564c30SJames Wright /**
399*c8564c30SJames Wright   @brief Create PETSc `Mat` from libCEED operators.
400*c8564c30SJames Wright 
401*c8564c30SJames Wright   Collective across MPI processes.
402*c8564c30SJames Wright 
403*c8564c30SJames Wright   @param[in]   dm_x                      Input `DM`
404*c8564c30SJames Wright   @param[in]   dm_y                      Output `DM`
405*c8564c30SJames Wright   @param[in]   op_mult                   `CeedOperator` for forward evaluation
406*c8564c30SJames Wright   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
407*c8564c30SJames Wright   @param[out]  mat                        New MatCeed
408*c8564c30SJames Wright 
409*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
410*c8564c30SJames Wright **/
411*c8564c30SJames Wright PetscErrorCode MatCeedCreate(DM dm_x, DM dm_y, CeedOperator op_mult, CeedOperator op_mult_transpose, Mat *mat) {
412*c8564c30SJames Wright   PetscInt       X_l_size, X_g_size, Y_l_size, Y_g_size;
413*c8564c30SJames Wright   VecType        vec_type;
414*c8564c30SJames Wright   MatCeedContext ctx;
415*c8564c30SJames Wright 
416*c8564c30SJames Wright   PetscFunctionBeginUser;
417*c8564c30SJames Wright   PetscCall(MatCeedRegisterLogEvents());
418*c8564c30SJames Wright 
419*c8564c30SJames Wright   // Collect context data
420*c8564c30SJames Wright   PetscCall(DMGetVecType(dm_x, &vec_type));
421*c8564c30SJames Wright   {
422*c8564c30SJames Wright     Vec X;
423*c8564c30SJames Wright 
424*c8564c30SJames Wright     PetscCall(DMGetGlobalVector(dm_x, &X));
425*c8564c30SJames Wright     PetscCall(VecGetSize(X, &X_g_size));
426*c8564c30SJames Wright     PetscCall(VecGetLocalSize(X, &X_l_size));
427*c8564c30SJames Wright     PetscCall(DMRestoreGlobalVector(dm_x, &X));
428*c8564c30SJames Wright   }
429*c8564c30SJames Wright   if (dm_y) {
430*c8564c30SJames Wright     Vec Y;
431*c8564c30SJames Wright 
432*c8564c30SJames Wright     PetscCall(DMGetGlobalVector(dm_y, &Y));
433*c8564c30SJames Wright     PetscCall(VecGetSize(Y, &Y_g_size));
434*c8564c30SJames Wright     PetscCall(VecGetLocalSize(Y, &Y_l_size));
435*c8564c30SJames Wright     PetscCall(DMRestoreGlobalVector(dm_y, &Y));
436*c8564c30SJames Wright   } else {
437*c8564c30SJames Wright     dm_y     = dm_x;
438*c8564c30SJames Wright     Y_g_size = X_g_size;
439*c8564c30SJames Wright     Y_l_size = X_l_size;
440*c8564c30SJames Wright   }
441*c8564c30SJames Wright   // Create context
442*c8564c30SJames Wright   {
443*c8564c30SJames Wright     Vec X_loc, Y_loc_transpose = NULL;
444*c8564c30SJames Wright 
445*c8564c30SJames Wright     PetscCall(DMCreateLocalVector(dm_x, &X_loc));
446*c8564c30SJames Wright     PetscCall(VecZeroEntries(X_loc));
447*c8564c30SJames Wright     if (op_mult_transpose) {
448*c8564c30SJames Wright       PetscCall(DMCreateLocalVector(dm_y, &Y_loc_transpose));
449*c8564c30SJames Wright       PetscCall(VecZeroEntries(Y_loc_transpose));
450*c8564c30SJames Wright     }
451*c8564c30SJames Wright     PetscCall(MatCeedContextCreate(dm_x, dm_y, X_loc, Y_loc_transpose, op_mult, op_mult_transpose, MATCEED_MULT, MATCEED_MULT_TRANSPOSE, &ctx));
452*c8564c30SJames Wright     PetscCall(VecDestroy(&X_loc));
453*c8564c30SJames Wright     PetscCall(VecDestroy(&Y_loc_transpose));
454*c8564c30SJames Wright   }
455*c8564c30SJames Wright 
456*c8564c30SJames Wright   // Create mat
457*c8564c30SJames Wright   PetscCall(MatCreateShell(PetscObjectComm((PetscObject)dm_x), Y_l_size, X_l_size, Y_g_size, X_g_size, ctx, mat));
458*c8564c30SJames Wright   PetscCall(PetscObjectChangeTypeName((PetscObject)*mat, MATCEED));
459*c8564c30SJames Wright   // -- Set block and variable block sizes
460*c8564c30SJames Wright   if (dm_x == dm_y) {
461*c8564c30SJames Wright     MatType dm_mat_type, dm_mat_type_copy;
462*c8564c30SJames Wright     Mat     temp_mat;
463*c8564c30SJames Wright 
464*c8564c30SJames Wright     PetscCall(DMGetMatType(dm_x, &dm_mat_type));
465*c8564c30SJames Wright     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
466*c8564c30SJames Wright     PetscCall(DMSetMatType(dm_x, MATAIJ));
467*c8564c30SJames Wright     PetscCall(DMCreateMatrix(dm_x, &temp_mat));
468*c8564c30SJames Wright     PetscCall(DMSetMatType(dm_x, dm_mat_type_copy));
469*c8564c30SJames Wright     PetscCall(PetscFree(dm_mat_type_copy));
470*c8564c30SJames Wright 
471*c8564c30SJames Wright     {
472*c8564c30SJames Wright       PetscInt        block_size, num_blocks, max_vblock_size = PETSC_INT_MAX;
473*c8564c30SJames Wright       const PetscInt *vblock_sizes;
474*c8564c30SJames Wright 
475*c8564c30SJames Wright       // -- Get block sizes
476*c8564c30SJames Wright       PetscCall(MatGetBlockSize(temp_mat, &block_size));
477*c8564c30SJames Wright       PetscCall(MatGetVariableBlockSizes(temp_mat, &num_blocks, &vblock_sizes));
478*c8564c30SJames Wright       {
479*c8564c30SJames Wright         PetscInt local_min_max[2] = {0}, global_min_max[2] = {0, PETSC_INT_MAX};
480*c8564c30SJames Wright 
481*c8564c30SJames Wright         for (PetscInt i = 0; i < num_blocks; i++) local_min_max[1] = PetscMax(local_min_max[1], vblock_sizes[i]);
482*c8564c30SJames Wright         PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_min_max, global_min_max));
483*c8564c30SJames Wright         max_vblock_size = global_min_max[1];
484*c8564c30SJames Wright       }
485*c8564c30SJames Wright 
486*c8564c30SJames Wright       // -- Copy block sizes
487*c8564c30SJames Wright       if (block_size > 1) PetscCall(MatSetBlockSize(*mat, block_size));
488*c8564c30SJames Wright       if (num_blocks) PetscCall(MatSetVariableBlockSizes(*mat, num_blocks, (PetscInt *)vblock_sizes));
489*c8564c30SJames Wright 
490*c8564c30SJames Wright       // -- Check libCEED compatibility
491*c8564c30SJames Wright       {
492*c8564c30SJames Wright         bool is_composite;
493*c8564c30SJames Wright 
494*c8564c30SJames Wright         ctx->is_ceed_pbd_valid  = PETSC_TRUE;
495*c8564c30SJames Wright         ctx->is_ceed_vpbd_valid = PETSC_TRUE;
496*c8564c30SJames Wright         PetscCeedCall(ctx->ceed, CeedOperatorIsComposite(op_mult, &is_composite));
497*c8564c30SJames Wright         if (is_composite) {
498*c8564c30SJames Wright           CeedInt       num_sub_operators;
499*c8564c30SJames Wright           CeedOperator *sub_operators;
500*c8564c30SJames Wright 
501*c8564c30SJames Wright           PetscCeedCall(ctx->ceed, CeedCompositeOperatorGetNumSub(op_mult, &num_sub_operators));
502*c8564c30SJames Wright           PetscCeedCall(ctx->ceed, CeedCompositeOperatorGetSubList(op_mult, &sub_operators));
503*c8564c30SJames Wright           for (CeedInt i = 0; i < num_sub_operators; i++) {
504*c8564c30SJames Wright             CeedInt                  num_bases, num_comp;
505*c8564c30SJames Wright             CeedBasis               *active_bases;
506*c8564c30SJames Wright             CeedOperatorAssemblyData assembly_data;
507*c8564c30SJames Wright 
508*c8564c30SJames Wright             PetscCeedCall(ctx->ceed, CeedOperatorGetOperatorAssemblyData(sub_operators[i], &assembly_data));
509*c8564c30SJames Wright             PetscCeedCall(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
510*c8564c30SJames Wright             PetscCeedCall(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
511*c8564c30SJames Wright             if (num_bases > 1) {
512*c8564c30SJames Wright               ctx->is_ceed_pbd_valid  = PETSC_FALSE;
513*c8564c30SJames Wright               ctx->is_ceed_vpbd_valid = PETSC_FALSE;
514*c8564c30SJames Wright             }
515*c8564c30SJames Wright             if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
516*c8564c30SJames Wright             if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
517*c8564c30SJames Wright           }
518*c8564c30SJames Wright         } else {
519*c8564c30SJames Wright           // LCOV_EXCL_START
520*c8564c30SJames Wright           CeedInt                  num_bases, num_comp;
521*c8564c30SJames Wright           CeedBasis               *active_bases;
522*c8564c30SJames Wright           CeedOperatorAssemblyData assembly_data;
523*c8564c30SJames Wright 
524*c8564c30SJames Wright           PetscCeedCall(ctx->ceed, CeedOperatorGetOperatorAssemblyData(op_mult, &assembly_data));
525*c8564c30SJames Wright           PetscCeedCall(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
526*c8564c30SJames Wright           PetscCeedCall(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
527*c8564c30SJames Wright           if (num_bases > 1) {
528*c8564c30SJames Wright             ctx->is_ceed_pbd_valid  = PETSC_FALSE;
529*c8564c30SJames Wright             ctx->is_ceed_vpbd_valid = PETSC_FALSE;
530*c8564c30SJames Wright           }
531*c8564c30SJames Wright           if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
532*c8564c30SJames Wright           if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
533*c8564c30SJames Wright           // LCOV_EXCL_STOP
534*c8564c30SJames Wright         }
535*c8564c30SJames Wright         {
536*c8564c30SJames Wright           PetscInt local_is_valid[2], global_is_valid[2];
537*c8564c30SJames Wright 
538*c8564c30SJames Wright           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_pbd_valid;
539*c8564c30SJames Wright           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
540*c8564c30SJames Wright           ctx->is_ceed_pbd_valid = global_is_valid[0];
541*c8564c30SJames Wright           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_vpbd_valid;
542*c8564c30SJames Wright           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
543*c8564c30SJames Wright           ctx->is_ceed_vpbd_valid = global_is_valid[0];
544*c8564c30SJames Wright         }
545*c8564c30SJames Wright       }
546*c8564c30SJames Wright     }
547*c8564c30SJames Wright     PetscCall(MatDestroy(&temp_mat));
548*c8564c30SJames Wright   }
549*c8564c30SJames Wright   // -- Set internal mat type
550*c8564c30SJames Wright   {
551*c8564c30SJames Wright     VecType vec_type;
552*c8564c30SJames Wright     MatType internal_mat_type = MATAIJ;
553*c8564c30SJames Wright 
554*c8564c30SJames Wright     PetscCall(VecGetType(ctx->X_loc, &vec_type));
555*c8564c30SJames Wright     if (strstr(vec_type, VECCUDA)) internal_mat_type = MATAIJCUSPARSE;
556*c8564c30SJames Wright     else if (strstr(vec_type, VECKOKKOS)) internal_mat_type = MATAIJKOKKOS;
557*c8564c30SJames Wright     else internal_mat_type = MATAIJ;
558*c8564c30SJames Wright     PetscCall(PetscStrallocpy(internal_mat_type, &ctx->internal_mat_type));
559*c8564c30SJames Wright   }
560*c8564c30SJames Wright   // -- Set mat operations
561*c8564c30SJames Wright   PetscCall(MatShellSetContextDestroy(*mat, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
562*c8564c30SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_MULT, (void (*)(void))MatMult_Ceed));
563*c8564c30SJames Wright   if (op_mult_transpose) PetscCall(MatShellSetOperation(*mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
564*c8564c30SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
565*c8564c30SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
566*c8564c30SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
567*c8564c30SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
568*c8564c30SJames Wright   PetscCall(MatShellSetVecType(*mat, vec_type));
569*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
570*c8564c30SJames Wright }
571*c8564c30SJames Wright 
572*c8564c30SJames Wright /**
573*c8564c30SJames Wright   @brief Copy `MATCEED` into a compatible `Mat` with type `MatShell` or `MATCEED`.
574*c8564c30SJames Wright 
575*c8564c30SJames Wright   Collective across MPI processes.
576*c8564c30SJames Wright 
577*c8564c30SJames Wright   @param[in]   mat_ceed   `MATCEED` to copy from
578*c8564c30SJames Wright   @param[out]  mat_other  `MatShell` or `MATCEED` to copy into
579*c8564c30SJames Wright 
580*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
581*c8564c30SJames Wright **/
582*c8564c30SJames Wright PetscErrorCode MatCeedCopy(Mat mat_ceed, Mat mat_other) {
583*c8564c30SJames Wright   PetscFunctionBeginUser;
584*c8564c30SJames Wright   PetscCall(MatCeedRegisterLogEvents());
585*c8564c30SJames Wright 
586*c8564c30SJames Wright   // Check type compatibility
587*c8564c30SJames Wright   {
588*c8564c30SJames Wright     MatType mat_type_ceed, mat_type_other;
589*c8564c30SJames Wright 
590*c8564c30SJames Wright     PetscCall(MatGetType(mat_ceed, &mat_type_ceed));
591*c8564c30SJames Wright     PetscCheck(!strcmp(mat_type_ceed, MATCEED), PETSC_COMM_SELF, PETSC_ERR_LIB, "mat_ceed must have type " MATCEED);
592*c8564c30SJames Wright     PetscCall(MatGetType(mat_ceed, &mat_type_other));
593*c8564c30SJames Wright     PetscCheck(!strcmp(mat_type_other, MATCEED) || !strcmp(mat_type_other, MATSHELL), PETSC_COMM_SELF, PETSC_ERR_LIB,
594*c8564c30SJames Wright                "mat_other must have type " MATCEED " or " MATSHELL);
595*c8564c30SJames Wright   }
596*c8564c30SJames Wright 
597*c8564c30SJames Wright   // Check dimension compatibility
598*c8564c30SJames Wright   {
599*c8564c30SJames Wright     PetscInt X_l_ceed_size, X_g_ceed_size, Y_l_ceed_size, Y_g_ceed_size, X_l_other_size, X_g_other_size, Y_l_other_size, Y_g_other_size;
600*c8564c30SJames Wright 
601*c8564c30SJames Wright     PetscCall(MatGetSize(mat_ceed, &Y_g_ceed_size, &X_g_ceed_size));
602*c8564c30SJames Wright     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_ceed_size, &X_l_ceed_size));
603*c8564c30SJames Wright     PetscCall(MatGetSize(mat_ceed, &Y_g_other_size, &X_g_other_size));
604*c8564c30SJames Wright     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_other_size, &X_l_other_size));
605*c8564c30SJames Wright     PetscCheck((Y_g_ceed_size == Y_g_other_size) && (X_g_ceed_size == X_g_other_size) && (Y_l_ceed_size == Y_l_other_size) &&
606*c8564c30SJames Wright                    (X_l_ceed_size == X_l_other_size),
607*c8564c30SJames Wright                PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ,
608*c8564c30SJames Wright                "mat_ceed and mat_other must have compatible sizes; found mat_ceed (Global: %" PetscInt_FMT ", %" PetscInt_FMT
609*c8564c30SJames Wright                "; Local: %" PetscInt_FMT ", %" PetscInt_FMT ") mat_other (Global: %" PetscInt_FMT ", %" PetscInt_FMT "; Local: %" PetscInt_FMT
610*c8564c30SJames Wright                ", %" PetscInt_FMT ")",
611*c8564c30SJames Wright                Y_g_ceed_size, X_g_ceed_size, Y_l_ceed_size, X_l_ceed_size, Y_g_other_size, X_g_other_size, Y_l_other_size, X_l_other_size);
612*c8564c30SJames Wright   }
613*c8564c30SJames Wright 
614*c8564c30SJames Wright   // Convert
615*c8564c30SJames Wright   {
616*c8564c30SJames Wright     VecType        vec_type;
617*c8564c30SJames Wright     MatCeedContext ctx;
618*c8564c30SJames Wright 
619*c8564c30SJames Wright     PetscCall(PetscObjectChangeTypeName((PetscObject)mat_other, MATCEED));
620*c8564c30SJames Wright     PetscCall(MatShellGetContext(mat_ceed, &ctx));
621*c8564c30SJames Wright     PetscCall(MatCeedContextReference(ctx));
622*c8564c30SJames Wright     PetscCall(MatShellSetContext(mat_other, ctx));
623*c8564c30SJames Wright     PetscCall(MatShellSetContextDestroy(mat_other, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
624*c8564c30SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_MULT, (void (*)(void))MatMult_Ceed));
625*c8564c30SJames Wright     if (ctx->op_mult_transpose) PetscCall(MatShellSetOperation(mat_other, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
626*c8564c30SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
627*c8564c30SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
628*c8564c30SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
629*c8564c30SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
630*c8564c30SJames Wright     {
631*c8564c30SJames Wright       PetscInt block_size;
632*c8564c30SJames Wright 
633*c8564c30SJames Wright       PetscCall(MatGetBlockSize(mat_ceed, &block_size));
634*c8564c30SJames Wright       if (block_size > 1) PetscCall(MatSetBlockSize(mat_other, block_size));
635*c8564c30SJames Wright     }
636*c8564c30SJames Wright     {
637*c8564c30SJames Wright       PetscInt        num_blocks;
638*c8564c30SJames Wright       const PetscInt *block_sizes;
639*c8564c30SJames Wright 
640*c8564c30SJames Wright       PetscCall(MatGetVariableBlockSizes(mat_ceed, &num_blocks, &block_sizes));
641*c8564c30SJames Wright       if (num_blocks) PetscCall(MatSetVariableBlockSizes(mat_other, num_blocks, (PetscInt *)block_sizes));
642*c8564c30SJames Wright     }
643*c8564c30SJames Wright     PetscCall(DMGetVecType(ctx->dm_x, &vec_type));
644*c8564c30SJames Wright     PetscCall(MatShellSetVecType(mat_other, vec_type));
645*c8564c30SJames Wright   }
646*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
647*c8564c30SJames Wright }
648*c8564c30SJames Wright 
649*c8564c30SJames Wright /**
650*c8564c30SJames Wright   @brief Assemble a `MATCEED` into a `MATAIJ` or similar.
651*c8564c30SJames Wright          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
652*c8564c30SJames Wright          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
653*c8564c30SJames Wright 
654*c8564c30SJames Wright   Collective across MPI processes.
655*c8564c30SJames Wright 
656*c8564c30SJames Wright   @param[in]      mat_ceed  `MATCEED` to assemble
657*c8564c30SJames Wright   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
658*c8564c30SJames Wright 
659*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
660*c8564c30SJames Wright **/
661*c8564c30SJames Wright PetscErrorCode MatCeedAssembleCOO(Mat mat_ceed, Mat mat_coo) {
662*c8564c30SJames Wright   MatCeedContext ctx;
663*c8564c30SJames Wright 
664*c8564c30SJames Wright   PetscFunctionBeginUser;
665*c8564c30SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
666*c8564c30SJames Wright 
667*c8564c30SJames Wright   // Check if COO pattern set
668*c8564c30SJames Wright   {
669*c8564c30SJames Wright     PetscInt index = -1;
670*c8564c30SJames Wright 
671*c8564c30SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
672*c8564c30SJames Wright       if (ctx->mats_assembled_full[i] == mat_coo) index = i;
673*c8564c30SJames Wright     }
674*c8564c30SJames Wright     if (index == -1) {
675*c8564c30SJames Wright       PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
676*c8564c30SJames Wright       CeedInt      *rows_ceed, *cols_ceed;
677*c8564c30SJames Wright       PetscCount    num_entries;
678*c8564c30SJames Wright       PetscLogStage stage_amg_setup;
679*c8564c30SJames Wright 
680*c8564c30SJames Wright       // -- Assemble sparsity pattern if mat hasn't been assembled before
681*c8564c30SJames Wright       PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
682*c8564c30SJames Wright       if (stage_amg_setup == -1) {
683*c8564c30SJames Wright         PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
684*c8564c30SJames Wright       }
685*c8564c30SJames Wright       PetscCall(PetscLogStagePush(stage_amg_setup));
686*c8564c30SJames Wright       PetscCeedCall(ctx->ceed, CeedOperatorLinearAssembleSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
687*c8564c30SJames Wright       PetscCall(IntArrayC2P(num_entries, &rows_ceed, &rows_petsc));
688*c8564c30SJames Wright       PetscCall(IntArrayC2P(num_entries, &cols_ceed, &cols_petsc));
689*c8564c30SJames Wright       PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
690*c8564c30SJames Wright       free(rows_petsc);
691*c8564c30SJames Wright       free(cols_petsc);
692*c8564c30SJames Wright       if (!ctx->coo_values_full) PetscCeedCall(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_full));
693*c8564c30SJames Wright       PetscCall(PetscRealloc(++ctx->num_mats_assembled_full * sizeof(Mat), &ctx->mats_assembled_full));
694*c8564c30SJames Wright       ctx->mats_assembled_full[ctx->num_mats_assembled_full - 1] = mat_coo;
695*c8564c30SJames Wright       PetscCall(PetscLogStagePop());
696*c8564c30SJames Wright     }
697*c8564c30SJames Wright   }
698*c8564c30SJames Wright 
699*c8564c30SJames Wright   // Assemble mat_ceed
700*c8564c30SJames Wright   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
701*c8564c30SJames Wright   {
702*c8564c30SJames Wright     const CeedScalar *values;
703*c8564c30SJames Wright     MatType           mat_type;
704*c8564c30SJames Wright     CeedMemType       mem_type = CEED_MEM_HOST;
705*c8564c30SJames Wright     PetscBool         is_spd, is_spd_known;
706*c8564c30SJames Wright 
707*c8564c30SJames Wright     PetscCall(MatGetType(mat_coo, &mat_type));
708*c8564c30SJames Wright     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
709*c8564c30SJames Wright     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
710*c8564c30SJames Wright     else mem_type = CEED_MEM_HOST;
711*c8564c30SJames Wright 
712*c8564c30SJames Wright     PetscCeedCall(ctx->ceed, CeedOperatorLinearAssemble(ctx->op_mult, ctx->coo_values_full));
713*c8564c30SJames Wright     PetscCeedCall(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_full, mem_type, &values));
714*c8564c30SJames Wright     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
715*c8564c30SJames Wright     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
716*c8564c30SJames Wright     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
717*c8564c30SJames Wright     PetscCeedCall(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_full, &values));
718*c8564c30SJames Wright   }
719*c8564c30SJames Wright   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
720*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
721*c8564c30SJames Wright }
722*c8564c30SJames Wright 
723*c8564c30SJames Wright /**
724*c8564c30SJames Wright   @brief Set user context for a `MATCEED`.
725*c8564c30SJames Wright 
726*c8564c30SJames Wright   Collective across MPI processes.
727*c8564c30SJames Wright 
728*c8564c30SJames Wright   @param[in,out]  mat  `MATCEED`
729*c8564c30SJames Wright   @param[in]      f    The context destroy function, or NULL
730*c8564c30SJames Wright   @param[in]      ctx  User context, or NULL to unset
731*c8564c30SJames Wright 
732*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
733*c8564c30SJames Wright **/
734*c8564c30SJames Wright PetscErrorCode MatCeedSetContext(Mat mat, PetscErrorCode (*f)(void *), void *ctx) {
735*c8564c30SJames Wright   PetscContainer user_ctx = NULL;
736*c8564c30SJames Wright 
737*c8564c30SJames Wright   PetscFunctionBeginUser;
738*c8564c30SJames Wright   if (ctx) {
739*c8564c30SJames Wright     PetscCall(PetscContainerCreate(PetscObjectComm((PetscObject)mat), &user_ctx));
740*c8564c30SJames Wright     PetscCall(PetscContainerSetPointer(user_ctx, ctx));
741*c8564c30SJames Wright     PetscCall(PetscContainerSetUserDestroy(user_ctx, f));
742*c8564c30SJames Wright   }
743*c8564c30SJames Wright   PetscCall(PetscObjectCompose((PetscObject)mat, "MatCeed user context", (PetscObject)user_ctx));
744*c8564c30SJames Wright   PetscCall(PetscContainerDestroy(&user_ctx));
745*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
746*c8564c30SJames Wright }
747*c8564c30SJames Wright 
748*c8564c30SJames Wright /**
749*c8564c30SJames Wright   @brief Retrieve the user context for a `MATCEED`.
750*c8564c30SJames Wright 
751*c8564c30SJames Wright   Collective across MPI processes.
752*c8564c30SJames Wright 
753*c8564c30SJames Wright   @param[in,out]  mat  `MATCEED`
754*c8564c30SJames Wright   @param[in]      ctx  User context
755*c8564c30SJames Wright 
756*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
757*c8564c30SJames Wright **/
758*c8564c30SJames Wright PetscErrorCode MatCeedGetContext(Mat mat, void *ctx) {
759*c8564c30SJames Wright   PetscContainer user_ctx;
760*c8564c30SJames Wright 
761*c8564c30SJames Wright   PetscFunctionBeginUser;
762*c8564c30SJames Wright   PetscCall(PetscObjectQuery((PetscObject)mat, "MatCeed user context", (PetscObject *)&user_ctx));
763*c8564c30SJames Wright   if (user_ctx) PetscCall(PetscContainerGetPointer(user_ctx, (void **)ctx));
764*c8564c30SJames Wright   else *(void **)ctx = NULL;
765*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
766*c8564c30SJames Wright }
767*c8564c30SJames Wright 
768*c8564c30SJames Wright /**
769*c8564c30SJames Wright   @brief Sets the inner matrix type as a string from the `MATCEED`.
770*c8564c30SJames Wright 
771*c8564c30SJames Wright   Collective across MPI processes.
772*c8564c30SJames Wright 
773*c8564c30SJames Wright   @param[in,out]  mat   `MATCEED`
774*c8564c30SJames Wright   @param[in]      type  Inner `MatType` to set
775*c8564c30SJames Wright 
776*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
777*c8564c30SJames Wright **/
778*c8564c30SJames Wright PetscErrorCode MatCeedSetInnerMatType(Mat mat, MatType type) {
779*c8564c30SJames Wright   MatCeedContext ctx;
780*c8564c30SJames Wright 
781*c8564c30SJames Wright   PetscFunctionBeginUser;
782*c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
783*c8564c30SJames Wright   // Check if same
784*c8564c30SJames Wright   {
785*c8564c30SJames Wright     size_t    len_old, len_new;
786*c8564c30SJames Wright     PetscBool is_same = PETSC_FALSE;
787*c8564c30SJames Wright 
788*c8564c30SJames Wright     PetscCall(PetscStrlen(ctx->internal_mat_type, &len_old));
789*c8564c30SJames Wright     PetscCall(PetscStrlen(type, &len_new));
790*c8564c30SJames Wright     if (len_old == len_new) PetscCall(PetscStrncmp(ctx->internal_mat_type, type, len_old, &is_same));
791*c8564c30SJames Wright     if (is_same) PetscFunctionReturn(PETSC_SUCCESS);
792*c8564c30SJames Wright   }
793*c8564c30SJames Wright   // Clean up old mats in different format
794*c8564c30SJames Wright   // LCOV_EXCL_START
795*c8564c30SJames Wright   if (ctx->mat_assembled_full_internal) {
796*c8564c30SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
797*c8564c30SJames Wright       if (ctx->mats_assembled_full[i] == ctx->mat_assembled_full_internal) {
798*c8564c30SJames Wright         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_full; j++) {
799*c8564c30SJames Wright           ctx->mats_assembled_full[j - 1] = ctx->mats_assembled_full[j];
800*c8564c30SJames Wright         }
801*c8564c30SJames Wright         ctx->num_mats_assembled_full--;
802*c8564c30SJames Wright         // Note: we'll realloc this array again, so no need to shrink the allocation
803*c8564c30SJames Wright         PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
804*c8564c30SJames Wright       }
805*c8564c30SJames Wright     }
806*c8564c30SJames Wright   }
807*c8564c30SJames Wright   if (ctx->mat_assembled_pbd_internal) {
808*c8564c30SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
809*c8564c30SJames Wright       if (ctx->mats_assembled_pbd[i] == ctx->mat_assembled_pbd_internal) {
810*c8564c30SJames Wright         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_pbd; j++) {
811*c8564c30SJames Wright           ctx->mats_assembled_pbd[j - 1] = ctx->mats_assembled_pbd[j];
812*c8564c30SJames Wright         }
813*c8564c30SJames Wright         // Note: we'll realloc this array again, so no need to shrink the allocation
814*c8564c30SJames Wright         ctx->num_mats_assembled_pbd--;
815*c8564c30SJames Wright         PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
816*c8564c30SJames Wright       }
817*c8564c30SJames Wright     }
818*c8564c30SJames Wright   }
819*c8564c30SJames Wright   PetscCall(PetscFree(ctx->internal_mat_type));
820*c8564c30SJames Wright   PetscCall(PetscStrallocpy(type, &ctx->internal_mat_type));
821*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
822*c8564c30SJames Wright   // LCOV_EXCL_STOP
823*c8564c30SJames Wright }
824*c8564c30SJames Wright 
825*c8564c30SJames Wright /**
826*c8564c30SJames Wright   @brief Gets the inner matrix type as a string from the `MATCEED`.
827*c8564c30SJames Wright 
828*c8564c30SJames Wright   Collective across MPI processes.
829*c8564c30SJames Wright 
830*c8564c30SJames Wright   @param[in,out]  mat   `MATCEED`
831*c8564c30SJames Wright   @param[in]      type  Inner `MatType`
832*c8564c30SJames Wright 
833*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
834*c8564c30SJames Wright **/
835*c8564c30SJames Wright PetscErrorCode MatCeedGetInnerMatType(Mat mat, MatType *type) {
836*c8564c30SJames Wright   MatCeedContext ctx;
837*c8564c30SJames Wright 
838*c8564c30SJames Wright   PetscFunctionBeginUser;
839*c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
840*c8564c30SJames Wright   *type = ctx->internal_mat_type;
841*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
842*c8564c30SJames Wright }
843*c8564c30SJames Wright 
844*c8564c30SJames Wright /**
845*c8564c30SJames Wright   @brief Set a user defined matrix operation for a `MATCEED` matrix.
846*c8564c30SJames Wright 
847*c8564c30SJames Wright   Within each user-defined routine, the user should call `MatCeedGetContext()` to obtain the user-defined context that was set by
848*c8564c30SJames Wright `MatCeedSetContext()`.
849*c8564c30SJames Wright 
850*c8564c30SJames Wright   Collective across MPI processes.
851*c8564c30SJames Wright 
852*c8564c30SJames Wright   @param[in,out]  mat  `MATCEED`
853*c8564c30SJames Wright   @param[in]      op   Name of the `MatOperation`
854*c8564c30SJames Wright   @param[in]      g    Function that provides the operation
855*c8564c30SJames Wright 
856*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
857*c8564c30SJames Wright **/
858*c8564c30SJames Wright PetscErrorCode MatCeedSetOperation(Mat mat, MatOperation op, void (*g)(void)) {
859*c8564c30SJames Wright   PetscFunctionBeginUser;
860*c8564c30SJames Wright   PetscCall(MatShellSetOperation(mat, op, g));
861*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
862*c8564c30SJames Wright }
863*c8564c30SJames Wright 
864*c8564c30SJames Wright /**
865*c8564c30SJames Wright   @brief Set input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
866*c8564c30SJames Wright 
867*c8564c30SJames Wright   Not collective across MPI processes.
868*c8564c30SJames Wright 
869*c8564c30SJames Wright   @param[in,out]  mat              `MATCEED`
870*c8564c30SJames Wright   @param[in]      X_loc            Input PETSc local vector, or NULL
871*c8564c30SJames Wright   @param[in]      Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
872*c8564c30SJames Wright 
873*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
874*c8564c30SJames Wright **/
875*c8564c30SJames Wright PetscErrorCode MatCeedSetLocalVectors(Mat mat, Vec X_loc, Vec Y_loc_transpose) {
876*c8564c30SJames Wright   MatCeedContext ctx;
877*c8564c30SJames Wright 
878*c8564c30SJames Wright   PetscFunctionBeginUser;
879*c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
880*c8564c30SJames Wright   if (X_loc) {
881*c8564c30SJames Wright     PetscInt len_old, len_new;
882*c8564c30SJames Wright 
883*c8564c30SJames Wright     PetscCall(VecGetSize(ctx->X_loc, &len_old));
884*c8564c30SJames Wright     PetscCall(VecGetSize(X_loc, &len_new));
885*c8564c30SJames Wright     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, "new X_loc length %" PetscInt_FMT " should match old X_loc length %" PetscInt_FMT,
886*c8564c30SJames Wright                len_new, len_old);
887*c8564c30SJames Wright     PetscCall(VecDestroy(&ctx->X_loc));
888*c8564c30SJames Wright     ctx->X_loc = X_loc;
889*c8564c30SJames Wright     PetscCall(PetscObjectReference((PetscObject)X_loc));
890*c8564c30SJames Wright   }
891*c8564c30SJames Wright   if (Y_loc_transpose) {
892*c8564c30SJames Wright     PetscInt len_old, len_new;
893*c8564c30SJames Wright 
894*c8564c30SJames Wright     PetscCall(VecGetSize(ctx->Y_loc_transpose, &len_old));
895*c8564c30SJames Wright     PetscCall(VecGetSize(Y_loc_transpose, &len_new));
896*c8564c30SJames Wright     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB,
897*c8564c30SJames Wright                "new Y_loc_transpose length %" PetscInt_FMT " should match old Y_loc_transpose length %" PetscInt_FMT, len_new, len_old);
898*c8564c30SJames Wright     PetscCall(VecDestroy(&ctx->Y_loc_transpose));
899*c8564c30SJames Wright     ctx->Y_loc_transpose = Y_loc_transpose;
900*c8564c30SJames Wright     PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose));
901*c8564c30SJames Wright   }
902*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
903*c8564c30SJames Wright }
904*c8564c30SJames Wright 
905*c8564c30SJames Wright /**
906*c8564c30SJames Wright   @brief Get input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
907*c8564c30SJames Wright 
908*c8564c30SJames Wright   Not collective across MPI processes.
909*c8564c30SJames Wright 
910*c8564c30SJames Wright   @param[in,out]  mat              `MATCEED`
911*c8564c30SJames Wright   @param[out]     X_loc            Input PETSc local vector, or NULL
912*c8564c30SJames Wright   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
913*c8564c30SJames Wright 
914*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
915*c8564c30SJames Wright **/
916*c8564c30SJames Wright PetscErrorCode MatCeedGetLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
917*c8564c30SJames Wright   MatCeedContext ctx;
918*c8564c30SJames Wright 
919*c8564c30SJames Wright   PetscFunctionBeginUser;
920*c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
921*c8564c30SJames Wright   if (X_loc) {
922*c8564c30SJames Wright     *X_loc = ctx->X_loc;
923*c8564c30SJames Wright     PetscCall(PetscObjectReference((PetscObject)*X_loc));
924*c8564c30SJames Wright   }
925*c8564c30SJames Wright   if (Y_loc_transpose) {
926*c8564c30SJames Wright     *Y_loc_transpose = ctx->Y_loc_transpose;
927*c8564c30SJames Wright     PetscCall(PetscObjectReference((PetscObject)*Y_loc_transpose));
928*c8564c30SJames Wright   }
929*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
930*c8564c30SJames Wright }
931*c8564c30SJames Wright 
932*c8564c30SJames Wright /**
933*c8564c30SJames Wright   @brief Restore input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
934*c8564c30SJames Wright 
935*c8564c30SJames Wright   Not collective across MPI processes.
936*c8564c30SJames Wright 
937*c8564c30SJames Wright   @param[in,out]  mat              MatCeed
938*c8564c30SJames Wright   @param[out]     X_loc            Input PETSc local vector, or NULL
939*c8564c30SJames Wright   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
940*c8564c30SJames Wright 
941*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
942*c8564c30SJames Wright **/
943*c8564c30SJames Wright PetscErrorCode MatCeedRestoreLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
944*c8564c30SJames Wright   PetscFunctionBeginUser;
945*c8564c30SJames Wright   if (X_loc) PetscCall(VecDestroy(X_loc));
946*c8564c30SJames Wright   if (Y_loc_transpose) PetscCall(VecDestroy(Y_loc_transpose));
947*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
948*c8564c30SJames Wright }
949*c8564c30SJames Wright 
950*c8564c30SJames Wright /**
951*c8564c30SJames Wright   @brief Get libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
952*c8564c30SJames Wright 
953*c8564c30SJames Wright   Not collective across MPI processes.
954*c8564c30SJames Wright 
955*c8564c30SJames Wright   @param[in,out]  mat                MatCeed
956*c8564c30SJames Wright   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
957*c8564c30SJames Wright   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
958*c8564c30SJames Wright 
959*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
960*c8564c30SJames Wright **/
961*c8564c30SJames Wright PetscErrorCode MatCeedGetCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
962*c8564c30SJames Wright   MatCeedContext ctx;
963*c8564c30SJames Wright 
964*c8564c30SJames Wright   PetscFunctionBeginUser;
965*c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
966*c8564c30SJames Wright   if (op_mult) {
967*c8564c30SJames Wright     *op_mult = NULL;
968*c8564c30SJames Wright     PetscCeedCall(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult, op_mult));
969*c8564c30SJames Wright   }
970*c8564c30SJames Wright   if (op_mult_transpose) {
971*c8564c30SJames Wright     *op_mult_transpose = NULL;
972*c8564c30SJames Wright     PetscCeedCall(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult_transpose, op_mult_transpose));
973*c8564c30SJames Wright   }
974*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
975*c8564c30SJames Wright }
976*c8564c30SJames Wright 
977*c8564c30SJames Wright /**
978*c8564c30SJames Wright   @brief Restore libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
979*c8564c30SJames Wright 
980*c8564c30SJames Wright   Not collective across MPI processes.
981*c8564c30SJames Wright 
982*c8564c30SJames Wright   @param[in,out]  mat                MatCeed
983*c8564c30SJames Wright   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
984*c8564c30SJames Wright   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
985*c8564c30SJames Wright 
986*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
987*c8564c30SJames Wright **/
988*c8564c30SJames Wright PetscErrorCode MatCeedRestoreCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
989*c8564c30SJames Wright   MatCeedContext ctx;
990*c8564c30SJames Wright 
991*c8564c30SJames Wright   PetscFunctionBeginUser;
992*c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
993*c8564c30SJames Wright   if (op_mult) PetscCeedCall(ctx->ceed, CeedOperatorDestroy(op_mult));
994*c8564c30SJames Wright   if (op_mult_transpose) PetscCeedCall(ctx->ceed, CeedOperatorDestroy(op_mult_transpose));
995*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
996*c8564c30SJames Wright }
997*c8564c30SJames Wright 
998*c8564c30SJames Wright /**
999*c8564c30SJames Wright   @brief Set `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
1000*c8564c30SJames Wright 
1001*c8564c30SJames Wright   Not collective across MPI processes.
1002*c8564c30SJames Wright 
1003*c8564c30SJames Wright   @param[in,out]  mat                       MatCeed
1004*c8564c30SJames Wright   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
1005*c8564c30SJames Wright   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
1006*c8564c30SJames Wright 
1007*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1008*c8564c30SJames Wright **/
1009*c8564c30SJames Wright PetscErrorCode MatCeedSetLogEvents(Mat mat, PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose) {
1010*c8564c30SJames Wright   MatCeedContext ctx;
1011*c8564c30SJames Wright 
1012*c8564c30SJames Wright   PetscFunctionBeginUser;
1013*c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
1014*c8564c30SJames Wright   if (log_event_mult) ctx->log_event_mult = log_event_mult;
1015*c8564c30SJames Wright   if (log_event_mult_transpose) ctx->log_event_mult_transpose = log_event_mult_transpose;
1016*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1017*c8564c30SJames Wright }
1018*c8564c30SJames Wright 
1019*c8564c30SJames Wright /**
1020*c8564c30SJames Wright   @brief Get `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
1021*c8564c30SJames Wright 
1022*c8564c30SJames Wright   Not collective across MPI processes.
1023*c8564c30SJames Wright 
1024*c8564c30SJames Wright   @param[in,out]  mat                       MatCeed
1025*c8564c30SJames Wright   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
1026*c8564c30SJames Wright   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
1027*c8564c30SJames Wright 
1028*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1029*c8564c30SJames Wright **/
1030*c8564c30SJames Wright PetscErrorCode MatCeedGetLogEvents(Mat mat, PetscLogEvent *log_event_mult, PetscLogEvent *log_event_mult_transpose) {
1031*c8564c30SJames Wright   MatCeedContext ctx;
1032*c8564c30SJames Wright 
1033*c8564c30SJames Wright   PetscFunctionBeginUser;
1034*c8564c30SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
1035*c8564c30SJames Wright   if (log_event_mult) *log_event_mult = ctx->log_event_mult;
1036*c8564c30SJames Wright   if (log_event_mult_transpose) *log_event_mult_transpose = ctx->log_event_mult_transpose;
1037*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1038*c8564c30SJames Wright }
1039*c8564c30SJames Wright 
1040*c8564c30SJames Wright // -----------------------------------------------------------------------------
1041*c8564c30SJames Wright // Operator context data
1042*c8564c30SJames Wright // -----------------------------------------------------------------------------
1043*c8564c30SJames Wright 
1044*c8564c30SJames Wright /**
1045*c8564c30SJames Wright   @brief Setup context data for operator application.
1046*c8564c30SJames Wright 
1047*c8564c30SJames Wright   Collective across MPI processes.
1048*c8564c30SJames Wright 
1049*c8564c30SJames Wright   @param[in]   dm_x                      Input `DM`
1050*c8564c30SJames Wright   @param[in]   dm_y                      Output `DM`
1051*c8564c30SJames Wright   @param[in]   X_loc                     Input PETSc local vector, or NULL
1052*c8564c30SJames Wright   @param[in]   Y_loc_transpose           Input PETSc local vector for transpose operation, or NULL
1053*c8564c30SJames Wright   @param[in]   op_mult                   `CeedOperator` for forward evaluation
1054*c8564c30SJames Wright   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
1055*c8564c30SJames Wright   @param[in]   log_event_mult            `PetscLogEvent` for forward evaluation
1056*c8564c30SJames Wright   @param[in]   log_event_mult_transpose  `PetscLogEvent` for transpose evaluation
1057*c8564c30SJames Wright   @param[out]  ctx                       Context data for operator evaluation
1058*c8564c30SJames Wright 
1059*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1060*c8564c30SJames Wright **/
1061*c8564c30SJames Wright PetscErrorCode MatCeedContextCreate(DM dm_x, DM dm_y, Vec X_loc, Vec Y_loc_transpose, CeedOperator op_mult, CeedOperator op_mult_transpose,
1062*c8564c30SJames Wright                                     PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose, MatCeedContext *ctx) {
1063*c8564c30SJames Wright   CeedSize x_loc_len, y_loc_len;
1064*c8564c30SJames Wright 
1065*c8564c30SJames Wright   PetscFunctionBeginUser;
1066*c8564c30SJames Wright 
1067*c8564c30SJames Wright   // Allocate
1068*c8564c30SJames Wright   PetscCall(PetscNew(ctx));
1069*c8564c30SJames Wright   (*ctx)->ref_count = 1;
1070*c8564c30SJames Wright 
1071*c8564c30SJames Wright   // Logging
1072*c8564c30SJames Wright   (*ctx)->log_event_mult           = log_event_mult;
1073*c8564c30SJames Wright   (*ctx)->log_event_mult_transpose = log_event_mult_transpose;
1074*c8564c30SJames Wright 
1075*c8564c30SJames Wright   // PETSc objects
1076*c8564c30SJames Wright   PetscCall(PetscObjectReference((PetscObject)dm_x));
1077*c8564c30SJames Wright   (*ctx)->dm_x = dm_x;
1078*c8564c30SJames Wright   PetscCall(PetscObjectReference((PetscObject)dm_y));
1079*c8564c30SJames Wright   (*ctx)->dm_y = dm_y;
1080*c8564c30SJames Wright   if (X_loc) PetscCall(PetscObjectReference((PetscObject)X_loc));
1081*c8564c30SJames Wright   (*ctx)->X_loc = X_loc;
1082*c8564c30SJames Wright   if (Y_loc_transpose) PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose));
1083*c8564c30SJames Wright   (*ctx)->Y_loc_transpose = Y_loc_transpose;
1084*c8564c30SJames Wright 
1085*c8564c30SJames Wright   // Memtype
1086*c8564c30SJames Wright   {
1087*c8564c30SJames Wright     const PetscScalar *x;
1088*c8564c30SJames Wright     Vec                X;
1089*c8564c30SJames Wright 
1090*c8564c30SJames Wright     PetscCall(DMGetLocalVector(dm_x, &X));
1091*c8564c30SJames Wright     PetscCall(VecGetArrayReadAndMemType(X, &x, &(*ctx)->mem_type));
1092*c8564c30SJames Wright     PetscCall(VecRestoreArrayReadAndMemType(X, &x));
1093*c8564c30SJames Wright     PetscCall(DMRestoreLocalVector(dm_x, &X));
1094*c8564c30SJames Wright   }
1095*c8564c30SJames Wright 
1096*c8564c30SJames Wright   // libCEED objects
1097*c8564c30SJames Wright   PetscCheck(CeedOperatorGetCeed(op_mult, &(*ctx)->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB,
1098*c8564c30SJames Wright              "retrieving Ceed context object failed");
1099*c8564c30SJames Wright   PetscCeedCall((*ctx)->ceed, CeedReference((*ctx)->ceed));
1100*c8564c30SJames Wright   PetscCeedCall((*ctx)->ceed, CeedOperatorGetActiveVectorLengths(op_mult, &x_loc_len, &y_loc_len));
1101*c8564c30SJames Wright   PetscCeedCall((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult, &(*ctx)->op_mult));
1102*c8564c30SJames Wright   if (op_mult_transpose) PetscCeedCall((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult_transpose, &(*ctx)->op_mult_transpose));
1103*c8564c30SJames Wright   PetscCeedCall((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, x_loc_len, &(*ctx)->x_loc));
1104*c8564c30SJames Wright   PetscCeedCall((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, y_loc_len, &(*ctx)->y_loc));
1105*c8564c30SJames Wright 
1106*c8564c30SJames Wright   // Flop counting
1107*c8564c30SJames Wright   {
1108*c8564c30SJames Wright     CeedSize ceed_flops_estimate = 0;
1109*c8564c30SJames Wright 
1110*c8564c30SJames Wright     PetscCeedCall((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult, &ceed_flops_estimate));
1111*c8564c30SJames Wright     (*ctx)->flops_mult = ceed_flops_estimate;
1112*c8564c30SJames Wright     if (op_mult_transpose) {
1113*c8564c30SJames Wright       PetscCeedCall((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult_transpose, &ceed_flops_estimate));
1114*c8564c30SJames Wright       (*ctx)->flops_mult_transpose = ceed_flops_estimate;
1115*c8564c30SJames Wright     }
1116*c8564c30SJames Wright   }
1117*c8564c30SJames Wright 
1118*c8564c30SJames Wright   // Check sizes
1119*c8564c30SJames Wright   if (x_loc_len > 0 || y_loc_len > 0) {
1120*c8564c30SJames Wright     CeedSize ctx_x_loc_len, ctx_y_loc_len;
1121*c8564c30SJames Wright     PetscInt X_loc_len, dm_x_loc_len, Y_loc_len, dm_y_loc_len;
1122*c8564c30SJames Wright     Vec      dm_X_loc, dm_Y_loc;
1123*c8564c30SJames Wright 
1124*c8564c30SJames Wright     // -- Input
1125*c8564c30SJames Wright     PetscCall(DMGetLocalVector(dm_x, &dm_X_loc));
1126*c8564c30SJames Wright     PetscCall(VecGetLocalSize(dm_X_loc, &dm_x_loc_len));
1127*c8564c30SJames Wright     PetscCall(DMRestoreLocalVector(dm_x, &dm_X_loc));
1128*c8564c30SJames Wright     if (X_loc) PetscCall(VecGetLocalSize(X_loc, &X_loc_len));
1129*c8564c30SJames Wright     PetscCeedCall((*ctx)->ceed, CeedVectorGetLength((*ctx)->x_loc, &ctx_x_loc_len));
1130*c8564c30SJames Wright     if (X_loc) PetscCheck(X_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "X_loc must match dm_x dimensions");
1131*c8564c30SJames Wright     PetscCheck(x_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_x dimensions");
1132*c8564c30SJames Wright     PetscCheck(x_loc_len == ctx_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "x_loc must match op dimensions");
1133*c8564c30SJames Wright 
1134*c8564c30SJames Wright     // -- Output
1135*c8564c30SJames Wright     PetscCall(DMGetLocalVector(dm_y, &dm_Y_loc));
1136*c8564c30SJames Wright     PetscCall(VecGetLocalSize(dm_Y_loc, &dm_y_loc_len));
1137*c8564c30SJames Wright     PetscCall(DMRestoreLocalVector(dm_y, &dm_Y_loc));
1138*c8564c30SJames Wright     PetscCeedCall((*ctx)->ceed, CeedVectorGetLength((*ctx)->y_loc, &ctx_y_loc_len));
1139*c8564c30SJames Wright     PetscCheck(ctx_y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_y dimensions");
1140*c8564c30SJames Wright 
1141*c8564c30SJames Wright     // -- Transpose
1142*c8564c30SJames Wright     if (Y_loc_transpose) {
1143*c8564c30SJames Wright       PetscCall(VecGetLocalSize(Y_loc_transpose, &Y_loc_len));
1144*c8564c30SJames Wright       PetscCheck(Y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "Y_loc_transpose must match dm_y dimensions");
1145*c8564c30SJames Wright     }
1146*c8564c30SJames Wright   }
1147*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1148*c8564c30SJames Wright }
1149*c8564c30SJames Wright 
1150*c8564c30SJames Wright /**
1151*c8564c30SJames Wright   @brief Increment reference counter for `MATCEED` context.
1152*c8564c30SJames Wright 
1153*c8564c30SJames Wright   Not collective across MPI processes.
1154*c8564c30SJames Wright 
1155*c8564c30SJames Wright   @param[in,out]  ctx  Context data
1156*c8564c30SJames Wright 
1157*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1158*c8564c30SJames Wright **/
1159*c8564c30SJames Wright PetscErrorCode MatCeedContextReference(MatCeedContext ctx) {
1160*c8564c30SJames Wright   PetscFunctionBeginUser;
1161*c8564c30SJames Wright   ctx->ref_count++;
1162*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1163*c8564c30SJames Wright }
1164*c8564c30SJames Wright 
1165*c8564c30SJames Wright /**
1166*c8564c30SJames Wright   @brief Copy reference for `MATCEED`.
1167*c8564c30SJames Wright          Note: If `ctx_copy` is non-null, it is assumed to be a valid pointer to a `MatCeedContext`.
1168*c8564c30SJames Wright 
1169*c8564c30SJames Wright   Not collective across MPI processes.
1170*c8564c30SJames Wright 
1171*c8564c30SJames Wright   @param[in]   ctx       Context data
1172*c8564c30SJames Wright   @param[out]  ctx_copy  Copy of pointer to context data
1173*c8564c30SJames Wright 
1174*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1175*c8564c30SJames Wright **/
1176*c8564c30SJames Wright PetscErrorCode MatCeedContextReferenceCopy(MatCeedContext ctx, MatCeedContext *ctx_copy) {
1177*c8564c30SJames Wright   PetscFunctionBeginUser;
1178*c8564c30SJames Wright   PetscCall(MatCeedContextReference(ctx));
1179*c8564c30SJames Wright   PetscCall(MatCeedContextDestroy(*ctx_copy));
1180*c8564c30SJames Wright   *ctx_copy = ctx;
1181*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1182*c8564c30SJames Wright }
1183*c8564c30SJames Wright 
1184*c8564c30SJames Wright /**
1185*c8564c30SJames Wright   @brief Destroy context data for operator application.
1186*c8564c30SJames Wright 
1187*c8564c30SJames Wright   Collective across MPI processes.
1188*c8564c30SJames Wright 
1189*c8564c30SJames Wright   @param[in,out]  ctx  Context data for operator evaluation
1190*c8564c30SJames Wright 
1191*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1192*c8564c30SJames Wright **/
1193*c8564c30SJames Wright PetscErrorCode MatCeedContextDestroy(MatCeedContext ctx) {
1194*c8564c30SJames Wright   PetscFunctionBeginUser;
1195*c8564c30SJames Wright   if (!ctx || --ctx->ref_count > 0) PetscFunctionReturn(PETSC_SUCCESS);
1196*c8564c30SJames Wright 
1197*c8564c30SJames Wright   // PETSc objects
1198*c8564c30SJames Wright   PetscCall(DMDestroy(&ctx->dm_x));
1199*c8564c30SJames Wright   PetscCall(DMDestroy(&ctx->dm_y));
1200*c8564c30SJames Wright   PetscCall(VecDestroy(&ctx->X_loc));
1201*c8564c30SJames Wright   PetscCall(VecDestroy(&ctx->Y_loc_transpose));
1202*c8564c30SJames Wright   PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
1203*c8564c30SJames Wright   PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
1204*c8564c30SJames Wright   PetscCall(PetscFree(ctx->internal_mat_type));
1205*c8564c30SJames Wright   PetscCall(PetscFree(ctx->mats_assembled_full));
1206*c8564c30SJames Wright   PetscCall(PetscFree(ctx->mats_assembled_pbd));
1207*c8564c30SJames Wright 
1208*c8564c30SJames Wright   // libCEED objects
1209*c8564c30SJames Wright   PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->x_loc));
1210*c8564c30SJames Wright   PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->y_loc));
1211*c8564c30SJames Wright   PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_full));
1212*c8564c30SJames Wright   PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_pbd));
1213*c8564c30SJames Wright   PetscCeedCall(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult));
1214*c8564c30SJames Wright   PetscCeedCall(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult_transpose));
1215*c8564c30SJames Wright   PetscCheck(CeedDestroy(&ctx->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "destroying libCEED context object failed");
1216*c8564c30SJames Wright 
1217*c8564c30SJames Wright   // Deallocate
1218*c8564c30SJames Wright   ctx->is_destroyed = PETSC_TRUE;  // Flag as destroyed in case someone has stale ref
1219*c8564c30SJames Wright   PetscCall(PetscFree(ctx));
1220*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1221*c8564c30SJames Wright }
1222*c8564c30SJames Wright 
1223*c8564c30SJames Wright /**
1224*c8564c30SJames Wright   @brief Compute the diagonal of an operator via libCEED.
1225*c8564c30SJames Wright 
1226*c8564c30SJames Wright   Collective across MPI processes.
1227*c8564c30SJames Wright 
1228*c8564c30SJames Wright   @param[in]   A  `MATCEED`
1229*c8564c30SJames Wright   @param[out]  D  Vector holding operator diagonal
1230*c8564c30SJames Wright 
1231*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1232*c8564c30SJames Wright **/
1233*c8564c30SJames Wright PetscErrorCode MatGetDiagonal_Ceed(Mat A, Vec D) {
1234*c8564c30SJames Wright   PetscMemType   mem_type;
1235*c8564c30SJames Wright   Vec            D_loc;
1236*c8564c30SJames Wright   MatCeedContext ctx;
1237*c8564c30SJames Wright 
1238*c8564c30SJames Wright   PetscFunctionBeginUser;
1239*c8564c30SJames Wright   PetscCall(MatShellGetContext(A, &ctx));
1240*c8564c30SJames Wright 
1241*c8564c30SJames Wright   // Place PETSc vector in libCEED vector
1242*c8564c30SJames Wright   PetscCall(DMGetLocalVector(ctx->dm_x, &D_loc));
1243*c8564c30SJames Wright   PetscCall(VecP2C(ctx->ceed, D_loc, &mem_type, ctx->x_loc));
1244*c8564c30SJames Wright 
1245*c8564c30SJames Wright   // Compute Diagonal
1246*c8564c30SJames Wright   PetscCeedCall(ctx->ceed, CeedOperatorLinearAssembleDiagonal(ctx->op_mult, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
1247*c8564c30SJames Wright 
1248*c8564c30SJames Wright   // Restore PETSc vector
1249*c8564c30SJames Wright   PetscCall(VecC2P(ctx->ceed, ctx->x_loc, mem_type, D_loc));
1250*c8564c30SJames Wright 
1251*c8564c30SJames Wright   // Local-to-Global
1252*c8564c30SJames Wright   PetscCall(VecZeroEntries(D));
1253*c8564c30SJames Wright   PetscCall(DMLocalToGlobal(ctx->dm_x, D_loc, ADD_VALUES, D));
1254*c8564c30SJames Wright   PetscCall(DMRestoreLocalVector(ctx->dm_x, &D_loc));
1255*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1256*c8564c30SJames Wright }
1257*c8564c30SJames Wright 
1258*c8564c30SJames Wright /**
1259*c8564c30SJames Wright   @brief Compute `A X = Y` for a `MATCEED`.
1260*c8564c30SJames Wright 
1261*c8564c30SJames Wright   Collective across MPI processes.
1262*c8564c30SJames Wright 
1263*c8564c30SJames Wright   @param[in]   A  `MATCEED`
1264*c8564c30SJames Wright   @param[in]   X  Input PETSc vector
1265*c8564c30SJames Wright   @param[out]  Y  Output PETSc vector
1266*c8564c30SJames Wright 
1267*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1268*c8564c30SJames Wright **/
1269*c8564c30SJames Wright PetscErrorCode MatMult_Ceed(Mat A, Vec X, Vec Y) {
1270*c8564c30SJames Wright   MatCeedContext ctx;
1271*c8564c30SJames Wright 
1272*c8564c30SJames Wright   PetscFunctionBeginUser;
1273*c8564c30SJames Wright   PetscCall(MatShellGetContext(A, &ctx));
1274*c8564c30SJames Wright   PetscCall(PetscLogEventBegin(ctx->log_event_mult, A, X, Y, 0));
1275*c8564c30SJames Wright 
1276*c8564c30SJames Wright   {
1277*c8564c30SJames Wright     PetscMemType x_mem_type, y_mem_type;
1278*c8564c30SJames Wright     Vec          X_loc = ctx->X_loc, Y_loc;
1279*c8564c30SJames Wright 
1280*c8564c30SJames Wright     // Get local vectors
1281*c8564c30SJames Wright     if (!ctx->X_loc) PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
1282*c8564c30SJames Wright     PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
1283*c8564c30SJames Wright 
1284*c8564c30SJames Wright     // Global-to-local
1285*c8564c30SJames Wright     PetscCall(DMGlobalToLocal(ctx->dm_x, X, INSERT_VALUES, X_loc));
1286*c8564c30SJames Wright 
1287*c8564c30SJames Wright     // Setup libCEED vectors
1288*c8564c30SJames Wright     PetscCall(VecReadP2C(ctx->ceed, X_loc, &x_mem_type, ctx->x_loc));
1289*c8564c30SJames Wright     PetscCall(VecZeroEntries(Y_loc));
1290*c8564c30SJames Wright     PetscCall(VecP2C(ctx->ceed, Y_loc, &y_mem_type, ctx->y_loc));
1291*c8564c30SJames Wright 
1292*c8564c30SJames Wright     // Apply libCEED operator
1293*c8564c30SJames Wright     PetscCall(PetscLogGpuTimeBegin());
1294*c8564c30SJames Wright     PetscCeedCall(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult, ctx->x_loc, ctx->y_loc, CEED_REQUEST_IMMEDIATE));
1295*c8564c30SJames Wright     PetscCall(PetscLogGpuTimeEnd());
1296*c8564c30SJames Wright 
1297*c8564c30SJames Wright     // Restore PETSc vectors
1298*c8564c30SJames Wright     PetscCall(VecReadC2P(ctx->ceed, ctx->x_loc, x_mem_type, X_loc));
1299*c8564c30SJames Wright     PetscCall(VecC2P(ctx->ceed, ctx->y_loc, y_mem_type, Y_loc));
1300*c8564c30SJames Wright 
1301*c8564c30SJames Wright     // Local-to-global
1302*c8564c30SJames Wright     PetscCall(VecZeroEntries(Y));
1303*c8564c30SJames Wright     PetscCall(DMLocalToGlobal(ctx->dm_y, Y_loc, ADD_VALUES, Y));
1304*c8564c30SJames Wright 
1305*c8564c30SJames Wright     // Restore local vectors, as needed
1306*c8564c30SJames Wright     if (!ctx->X_loc) PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
1307*c8564c30SJames Wright     PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
1308*c8564c30SJames Wright   }
1309*c8564c30SJames Wright 
1310*c8564c30SJames Wright   // Log flops
1311*c8564c30SJames Wright   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult));
1312*c8564c30SJames Wright   else PetscCall(PetscLogFlops(ctx->flops_mult));
1313*c8564c30SJames Wright 
1314*c8564c30SJames Wright   PetscCall(PetscLogEventEnd(ctx->log_event_mult, A, X, Y, 0));
1315*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1316*c8564c30SJames Wright }
1317*c8564c30SJames Wright 
1318*c8564c30SJames Wright /**
1319*c8564c30SJames Wright   @brief Compute `A^T Y = X` for a `MATCEED`.
1320*c8564c30SJames Wright 
1321*c8564c30SJames Wright   Collective across MPI processes.
1322*c8564c30SJames Wright 
1323*c8564c30SJames Wright   @param[in]   A  `MATCEED`
1324*c8564c30SJames Wright   @param[in]   Y  Input PETSc vector
1325*c8564c30SJames Wright   @param[out]  X  Output PETSc vector
1326*c8564c30SJames Wright 
1327*c8564c30SJames Wright   @return An error code: 0 - success, otherwise - failure
1328*c8564c30SJames Wright **/
1329*c8564c30SJames Wright PetscErrorCode MatMultTranspose_Ceed(Mat A, Vec Y, Vec X) {
1330*c8564c30SJames Wright   MatCeedContext ctx;
1331*c8564c30SJames Wright 
1332*c8564c30SJames Wright   PetscFunctionBeginUser;
1333*c8564c30SJames Wright   PetscCall(MatShellGetContext(A, &ctx));
1334*c8564c30SJames Wright   PetscCall(PetscLogEventBegin(ctx->log_event_mult_transpose, A, Y, X, 0));
1335*c8564c30SJames Wright 
1336*c8564c30SJames Wright   {
1337*c8564c30SJames Wright     PetscMemType x_mem_type, y_mem_type;
1338*c8564c30SJames Wright     Vec          X_loc, Y_loc = ctx->Y_loc_transpose;
1339*c8564c30SJames Wright 
1340*c8564c30SJames Wright     // Get local vectors
1341*c8564c30SJames Wright     if (!ctx->Y_loc_transpose) PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
1342*c8564c30SJames Wright     PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
1343*c8564c30SJames Wright 
1344*c8564c30SJames Wright     // Global-to-local
1345*c8564c30SJames Wright     PetscCall(DMGlobalToLocal(ctx->dm_y, Y, INSERT_VALUES, Y_loc));
1346*c8564c30SJames Wright 
1347*c8564c30SJames Wright     // Setup libCEED vectors
1348*c8564c30SJames Wright     PetscCall(VecReadP2C(ctx->ceed, Y_loc, &y_mem_type, ctx->y_loc));
1349*c8564c30SJames Wright     PetscCall(VecZeroEntries(X_loc));
1350*c8564c30SJames Wright     PetscCall(VecP2C(ctx->ceed, X_loc, &x_mem_type, ctx->x_loc));
1351*c8564c30SJames Wright 
1352*c8564c30SJames Wright     // Apply libCEED operator
1353*c8564c30SJames Wright     PetscCall(PetscLogGpuTimeBegin());
1354*c8564c30SJames Wright     PetscCeedCall(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult_transpose, ctx->y_loc, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
1355*c8564c30SJames Wright     PetscCall(PetscLogGpuTimeEnd());
1356*c8564c30SJames Wright 
1357*c8564c30SJames Wright     // Restore PETSc vectors
1358*c8564c30SJames Wright     PetscCall(VecReadC2P(ctx->ceed, ctx->y_loc, y_mem_type, Y_loc));
1359*c8564c30SJames Wright     PetscCall(VecC2P(ctx->ceed, ctx->x_loc, x_mem_type, X_loc));
1360*c8564c30SJames Wright 
1361*c8564c30SJames Wright     // Local-to-global
1362*c8564c30SJames Wright     PetscCall(VecZeroEntries(X));
1363*c8564c30SJames Wright     PetscCall(DMLocalToGlobal(ctx->dm_x, X_loc, ADD_VALUES, X));
1364*c8564c30SJames Wright 
1365*c8564c30SJames Wright     // Restore local vectors, as needed
1366*c8564c30SJames Wright     if (!ctx->Y_loc_transpose) PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
1367*c8564c30SJames Wright     PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
1368*c8564c30SJames Wright   }
1369*c8564c30SJames Wright 
1370*c8564c30SJames Wright   // Log flops
1371*c8564c30SJames Wright   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult_transpose));
1372*c8564c30SJames Wright   else PetscCall(PetscLogFlops(ctx->flops_mult_transpose));
1373*c8564c30SJames Wright 
1374*c8564c30SJames Wright   PetscCall(PetscLogEventEnd(ctx->log_event_mult_transpose, A, Y, X, 0));
1375*c8564c30SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
1376*c8564c30SJames Wright }
1377