xref: /honee/src/mat-ceed.c (revision 50f50432244b41b50d841adc7e498f6ce637289e)
158600ac3SJames Wright /// @file
258600ac3SJames Wright /// MatCeed and it's related operators
358600ac3SJames Wright 
458600ac3SJames Wright #include <ceed.h>
558600ac3SJames Wright #include <ceed/backend.h>
658600ac3SJames Wright #include <mat-ceed-impl.h>
758600ac3SJames Wright #include <mat-ceed.h>
858600ac3SJames Wright #include <petscdmplex.h>
958600ac3SJames Wright #include <stdlib.h>
1058600ac3SJames Wright #include <string.h>
1158600ac3SJames Wright 
1258600ac3SJames Wright PetscClassId  MATCEED_CLASSID;
1358600ac3SJames Wright PetscLogEvent MATCEED_MULT, MATCEED_MULT_TRANSPOSE;
1458600ac3SJames Wright 
1558600ac3SJames Wright /**
1658600ac3SJames Wright   @brief Register MATCEED log events.
1758600ac3SJames Wright 
1858600ac3SJames Wright   Not collective across MPI processes.
1958600ac3SJames Wright 
2058600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
2158600ac3SJames Wright **/
2258600ac3SJames Wright static PetscErrorCode MatCeedRegisterLogEvents() {
2358600ac3SJames Wright   static bool registered = false;
2458600ac3SJames Wright 
2558600ac3SJames Wright   PetscFunctionBeginUser;
2658600ac3SJames Wright   if (registered) PetscFunctionReturn(PETSC_SUCCESS);
2758600ac3SJames Wright   PetscCall(PetscClassIdRegister("MATCEED", &MATCEED_CLASSID));
2858600ac3SJames Wright   PetscCall(PetscLogEventRegister("MATCEED Mult", MATCEED_CLASSID, &MATCEED_MULT));
2958600ac3SJames Wright   PetscCall(PetscLogEventRegister("MATCEED Mult Transpose", MATCEED_CLASSID, &MATCEED_MULT_TRANSPOSE));
3058600ac3SJames Wright   registered = true;
3158600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
3258600ac3SJames Wright }
3358600ac3SJames Wright 
3458600ac3SJames Wright /**
3558600ac3SJames Wright   @brief Translate PetscMemType to CeedMemType
3658600ac3SJames Wright 
3758600ac3SJames Wright   @param[in]  mem_type  PetscMemType
3858600ac3SJames Wright 
3958600ac3SJames Wright   @return Equivalent CeedMemType
4058600ac3SJames Wright **/
4158600ac3SJames Wright static inline CeedMemType MemTypeP2C(PetscMemType mem_type) { return PetscMemTypeDevice(mem_type) ? CEED_MEM_DEVICE : CEED_MEM_HOST; }
4258600ac3SJames Wright 
4358600ac3SJames Wright /**
4458600ac3SJames Wright   @brief Translate array of `CeedInt` to `PetscInt`.
4558600ac3SJames Wright     If the types differ, `array_ceed` is freed with `free()` and `array_petsc` is allocated with `malloc()`.
4658600ac3SJames Wright     Caller is responsible for freeing `array_petsc` with `free()`.
4758600ac3SJames Wright 
4858600ac3SJames Wright   Not collective across MPI processes.
4958600ac3SJames Wright 
5058600ac3SJames Wright   @param[in]      num_entries  Number of array entries
5158600ac3SJames Wright   @param[in,out]  array_ceed   Array of `CeedInt`
5258600ac3SJames Wright   @param[out]     array_petsc  Array of `PetscInt`
5358600ac3SJames Wright 
5458600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
5558600ac3SJames Wright **/
5658600ac3SJames Wright static inline PetscErrorCode IntArrayC2P(PetscInt num_entries, CeedInt **array_ceed, PetscInt **array_petsc) {
5758600ac3SJames Wright   const CeedInt  int_c = 0;
5858600ac3SJames Wright   const PetscInt int_p = 0;
5958600ac3SJames Wright 
6058600ac3SJames Wright   PetscFunctionBeginUser;
6158600ac3SJames Wright   if (sizeof(int_c) == sizeof(int_p)) {
6258600ac3SJames Wright     *array_petsc = (PetscInt *)*array_ceed;
6358600ac3SJames Wright   } else {
6458600ac3SJames Wright     *array_petsc = malloc(num_entries * sizeof(PetscInt));
6558600ac3SJames Wright     for (PetscInt i = 0; i < num_entries; i++) (*array_petsc)[i] = (*array_ceed)[i];
6658600ac3SJames Wright     free(*array_ceed);
6758600ac3SJames Wright   }
6858600ac3SJames Wright   *array_ceed = NULL;
6958600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
7058600ac3SJames Wright }
7158600ac3SJames Wright 
7258600ac3SJames Wright /**
7358600ac3SJames Wright   @brief Transfer array from PETSc `Vec` to `CeedVector`.
7458600ac3SJames Wright 
7558600ac3SJames Wright   Collective across MPI processes.
7658600ac3SJames Wright 
7758600ac3SJames Wright   @param[in]   ceed      libCEED context
7858600ac3SJames Wright   @param[in]   X_petsc   PETSc `Vec`
7958600ac3SJames Wright   @param[out]  mem_type  PETSc `MemType`
8058600ac3SJames Wright   @param[out]  x_ceed    `CeedVector`
8158600ac3SJames Wright 
8258600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
8358600ac3SJames Wright **/
8458600ac3SJames Wright static inline PetscErrorCode VecP2C(Ceed ceed, Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) {
8558600ac3SJames Wright   PetscScalar *x;
8658600ac3SJames Wright 
8758600ac3SJames Wright   PetscFunctionBeginUser;
8858600ac3SJames Wright   PetscCall(VecGetArrayAndMemType(X_petsc, &x, mem_type));
89*50f50432SJames Wright   PetscCallCeed(ceed, CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x));
9058600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
9158600ac3SJames Wright }
9258600ac3SJames Wright 
9358600ac3SJames Wright /**
9458600ac3SJames Wright   @brief Transfer array from `CeedVector` to PETSc `Vec`.
9558600ac3SJames Wright 
9658600ac3SJames Wright   Collective across MPI processes.
9758600ac3SJames Wright 
9858600ac3SJames Wright   @param[in]   ceed      libCEED context
9958600ac3SJames Wright   @param[in]   x_ceed    `CeedVector`
10058600ac3SJames Wright   @param[in]   mem_type  PETSc `MemType`
10158600ac3SJames Wright   @param[out]  X_petsc   PETSc `Vec`
10258600ac3SJames Wright 
10358600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
10458600ac3SJames Wright **/
10558600ac3SJames Wright static inline PetscErrorCode VecC2P(Ceed ceed, CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) {
10658600ac3SJames Wright   PetscScalar *x;
10758600ac3SJames Wright 
10858600ac3SJames Wright   PetscFunctionBeginUser;
109*50f50432SJames Wright   PetscCallCeed(ceed, CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x));
11058600ac3SJames Wright   PetscCall(VecRestoreArrayAndMemType(X_petsc, &x));
11158600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
11258600ac3SJames Wright }
11358600ac3SJames Wright 
11458600ac3SJames Wright /**
11558600ac3SJames Wright   @brief Transfer read only array from PETSc `Vec` to `CeedVector`.
11658600ac3SJames Wright 
11758600ac3SJames Wright   Collective across MPI processes.
11858600ac3SJames Wright 
11958600ac3SJames Wright   @param[in]   ceed      libCEED context
12058600ac3SJames Wright   @param[in]   X_petsc   PETSc `Vec`
12158600ac3SJames Wright   @param[out]  mem_type  PETSc `MemType`
12258600ac3SJames Wright   @param[out]  x_ceed    `CeedVector`
12358600ac3SJames Wright 
12458600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
12558600ac3SJames Wright **/
12658600ac3SJames Wright static inline PetscErrorCode VecReadP2C(Ceed ceed, Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) {
12758600ac3SJames Wright   PetscScalar *x;
12858600ac3SJames Wright 
12958600ac3SJames Wright   PetscFunctionBeginUser;
13058600ac3SJames Wright   PetscCall(VecGetArrayReadAndMemType(X_petsc, (const PetscScalar **)&x, mem_type));
131*50f50432SJames Wright   PetscCallCeed(ceed, CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x));
13258600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
13358600ac3SJames Wright }
13458600ac3SJames Wright 
13558600ac3SJames Wright /**
13658600ac3SJames Wright   @brief Transfer read only array from `CeedVector` to PETSc `Vec`.
13758600ac3SJames Wright 
13858600ac3SJames Wright   Collective across MPI processes.
13958600ac3SJames Wright 
14058600ac3SJames Wright   @param[in]   ceed      libCEED context
14158600ac3SJames Wright   @param[in]   x_ceed    `CeedVector`
14258600ac3SJames Wright   @param[in]   mem_type  PETSc `MemType`
14358600ac3SJames Wright   @param[out]  X_petsc   PETSc `Vec`
14458600ac3SJames Wright 
14558600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
14658600ac3SJames Wright **/
14758600ac3SJames Wright static inline PetscErrorCode VecReadC2P(Ceed ceed, CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) {
14858600ac3SJames Wright   PetscScalar *x;
14958600ac3SJames Wright 
15058600ac3SJames Wright   PetscFunctionBeginUser;
151*50f50432SJames Wright   PetscCallCeed(ceed, CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x));
15258600ac3SJames Wright   PetscCall(VecRestoreArrayReadAndMemType(X_petsc, (const PetscScalar **)&x));
15358600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
15458600ac3SJames Wright }
15558600ac3SJames Wright 
15658600ac3SJames Wright /**
15758600ac3SJames Wright   @brief Setup inner `Mat` for `PC` operations not directly supported by libCEED.
15858600ac3SJames Wright 
15958600ac3SJames Wright   Collective across MPI processes.
16058600ac3SJames Wright 
16158600ac3SJames Wright   @param[in]   mat_ceed   `MATCEED` to setup
16258600ac3SJames Wright   @param[out]  mat_inner  Inner `Mat`
16358600ac3SJames Wright 
16458600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
16558600ac3SJames Wright **/
16658600ac3SJames Wright static PetscErrorCode MatCeedSetupInnerMat(Mat mat_ceed, Mat *mat_inner) {
16758600ac3SJames Wright   MatCeedContext ctx;
16858600ac3SJames Wright 
16958600ac3SJames Wright   PetscFunctionBeginUser;
17058600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
17158600ac3SJames Wright 
17258600ac3SJames Wright   PetscCheck(ctx->dm_x == ctx->dm_y, PetscObjectComm((PetscObject)mat_ceed), PETSC_ERR_SUP, "PC only supported for MATCEED on a single DM");
17358600ac3SJames Wright 
17458600ac3SJames Wright   // Check cl mat type
17558600ac3SJames Wright   {
17658600ac3SJames Wright     PetscBool is_internal_mat_type_cl = PETSC_FALSE;
17758600ac3SJames Wright     char      internal_mat_type_cl[64];
17858600ac3SJames Wright 
17958600ac3SJames Wright     // Check for specific CL inner mat type for this Mat
18058600ac3SJames Wright     {
18158600ac3SJames Wright       const char *mat_ceed_prefix = NULL;
18258600ac3SJames Wright 
18358600ac3SJames Wright       PetscCall(MatGetOptionsPrefix(mat_ceed, &mat_ceed_prefix));
18458600ac3SJames Wright       PetscOptionsBegin(PetscObjectComm((PetscObject)mat_ceed), mat_ceed_prefix, "", NULL);
18558600ac3SJames Wright       PetscCall(PetscOptionsFList("-ceed_inner_mat_type", "MATCEED inner assembled MatType for PC support", NULL, MatList, internal_mat_type_cl,
18658600ac3SJames Wright                                   internal_mat_type_cl, sizeof(internal_mat_type_cl), &is_internal_mat_type_cl));
18758600ac3SJames Wright       PetscOptionsEnd();
18858600ac3SJames Wright       if (is_internal_mat_type_cl) {
18958600ac3SJames Wright         PetscCall(PetscFree(ctx->internal_mat_type));
19058600ac3SJames Wright         PetscCall(PetscStrallocpy(internal_mat_type_cl, &ctx->internal_mat_type));
19158600ac3SJames Wright       }
19258600ac3SJames Wright     }
19358600ac3SJames Wright   }
19458600ac3SJames Wright 
19558600ac3SJames Wright   // Create sparse matrix
19658600ac3SJames Wright   {
19758600ac3SJames Wright     MatType dm_mat_type, dm_mat_type_copy;
19858600ac3SJames Wright 
19958600ac3SJames Wright     PetscCall(DMGetMatType(ctx->dm_x, &dm_mat_type));
20058600ac3SJames Wright     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
20158600ac3SJames Wright     PetscCall(DMSetMatType(ctx->dm_x, ctx->internal_mat_type));
20258600ac3SJames Wright     PetscCall(DMCreateMatrix(ctx->dm_x, mat_inner));
20358600ac3SJames Wright     PetscCall(DMSetMatType(ctx->dm_x, dm_mat_type_copy));
20458600ac3SJames Wright     PetscCall(PetscFree(dm_mat_type_copy));
20558600ac3SJames Wright   }
20658600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
20758600ac3SJames Wright }
20858600ac3SJames Wright 
20958600ac3SJames Wright /**
21058600ac3SJames Wright   @brief Assemble the point block diagonal of a `MATCEED` into a `MATAIJ` or similar.
21158600ac3SJames Wright          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
21258600ac3SJames Wright          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
21358600ac3SJames Wright 
21458600ac3SJames Wright   Collective across MPI processes.
21558600ac3SJames Wright 
21658600ac3SJames Wright   @param[in]      mat_ceed  `MATCEED` to assemble
21758600ac3SJames Wright   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
21858600ac3SJames Wright 
21958600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
22058600ac3SJames Wright **/
22158600ac3SJames Wright static PetscErrorCode MatCeedAssemblePointBlockDiagonalCOO(Mat mat_ceed, Mat mat_coo) {
22258600ac3SJames Wright   MatCeedContext ctx;
22358600ac3SJames Wright 
22458600ac3SJames Wright   PetscFunctionBeginUser;
22558600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
22658600ac3SJames Wright 
22758600ac3SJames Wright   // Check if COO pattern set
22858600ac3SJames Wright   {
22958600ac3SJames Wright     PetscInt index = -1;
23058600ac3SJames Wright 
23158600ac3SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
23258600ac3SJames Wright       if (ctx->mats_assembled_pbd[i] == mat_coo) index = i;
23358600ac3SJames Wright     }
23458600ac3SJames Wright     if (index == -1) {
23558600ac3SJames Wright       PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
23658600ac3SJames Wright       CeedInt      *rows_ceed, *cols_ceed;
23758600ac3SJames Wright       PetscCount    num_entries;
23858600ac3SJames Wright       PetscLogStage stage_amg_setup;
23958600ac3SJames Wright 
24058600ac3SJames Wright       // -- Assemble sparsity pattern if mat hasn't been assembled before
24158600ac3SJames Wright       PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
24258600ac3SJames Wright       if (stage_amg_setup == -1) {
24358600ac3SJames Wright         PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
24458600ac3SJames Wright       }
24558600ac3SJames Wright       PetscCall(PetscLogStagePush(stage_amg_setup));
246*50f50432SJames Wright       PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonalSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
24758600ac3SJames Wright       PetscCall(IntArrayC2P(num_entries, &rows_ceed, &rows_petsc));
24858600ac3SJames Wright       PetscCall(IntArrayC2P(num_entries, &cols_ceed, &cols_petsc));
24958600ac3SJames Wright       PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
25058600ac3SJames Wright       free(rows_petsc);
25158600ac3SJames Wright       free(cols_petsc);
252*50f50432SJames Wright       if (!ctx->coo_values_pbd) PetscCallCeed(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_pbd));
25358600ac3SJames Wright       PetscCall(PetscRealloc(++ctx->num_mats_assembled_pbd * sizeof(Mat), &ctx->mats_assembled_pbd));
25458600ac3SJames Wright       ctx->mats_assembled_pbd[ctx->num_mats_assembled_pbd - 1] = mat_coo;
25558600ac3SJames Wright       PetscCall(PetscLogStagePop());
25658600ac3SJames Wright     }
25758600ac3SJames Wright   }
25858600ac3SJames Wright 
25958600ac3SJames Wright   // Assemble mat_ceed
26058600ac3SJames Wright   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
26158600ac3SJames Wright   {
26258600ac3SJames Wright     const CeedScalar *values;
26358600ac3SJames Wright     MatType           mat_type;
26458600ac3SJames Wright     CeedMemType       mem_type = CEED_MEM_HOST;
26558600ac3SJames Wright     PetscBool         is_spd, is_spd_known;
26658600ac3SJames Wright 
26758600ac3SJames Wright     PetscCall(MatGetType(mat_coo, &mat_type));
26858600ac3SJames Wright     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
26958600ac3SJames Wright     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
27058600ac3SJames Wright     else mem_type = CEED_MEM_HOST;
27158600ac3SJames Wright 
272*50f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonal(ctx->op_mult, ctx->coo_values_pbd, CEED_REQUEST_IMMEDIATE));
273*50f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_pbd, mem_type, &values));
27458600ac3SJames Wright     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
27558600ac3SJames Wright     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
27658600ac3SJames Wright     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
277*50f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_pbd, &values));
27858600ac3SJames Wright   }
27958600ac3SJames Wright   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
28058600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
28158600ac3SJames Wright }
28258600ac3SJames Wright 
28358600ac3SJames Wright /**
28458600ac3SJames Wright   @brief Assemble inner `Mat` for diagonal `PC` operations
28558600ac3SJames Wright 
28658600ac3SJames Wright   Collective across MPI processes.
28758600ac3SJames Wright 
28858600ac3SJames Wright   @param[in]   mat_ceed      `MATCEED` to invert
28958600ac3SJames Wright   @param[in]   use_ceed_pbd  Boolean flag to use libCEED PBD assembly
29058600ac3SJames Wright   @param[out]  mat_inner     Inner `Mat` for diagonal operations
29158600ac3SJames Wright 
29258600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
29358600ac3SJames Wright **/
29458600ac3SJames Wright static PetscErrorCode MatCeedAssembleInnerBlockDiagonalMat(Mat mat_ceed, PetscBool use_ceed_pbd, Mat *mat_inner) {
29558600ac3SJames Wright   MatCeedContext ctx;
29658600ac3SJames Wright 
29758600ac3SJames Wright   PetscFunctionBeginUser;
29858600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
29958600ac3SJames Wright   if (use_ceed_pbd) {
30058600ac3SJames Wright     // Check if COO pattern set
30158600ac3SJames Wright     if (!ctx->mat_assembled_pbd_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_pbd_internal));
30258600ac3SJames Wright 
30358600ac3SJames Wright     // Assemble mat_assembled_full_internal
30458600ac3SJames Wright     PetscCall(MatCeedAssemblePointBlockDiagonalCOO(mat_ceed, ctx->mat_assembled_pbd_internal));
30558600ac3SJames Wright     if (mat_inner) *mat_inner = ctx->mat_assembled_pbd_internal;
30658600ac3SJames Wright   } else {
30758600ac3SJames Wright     // Check if COO pattern set
30858600ac3SJames Wright     if (!ctx->mat_assembled_full_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_full_internal));
30958600ac3SJames Wright 
31058600ac3SJames Wright     // Assemble mat_assembled_full_internal
31158600ac3SJames Wright     PetscCall(MatCeedAssembleCOO(mat_ceed, ctx->mat_assembled_full_internal));
31258600ac3SJames Wright     if (mat_inner) *mat_inner = ctx->mat_assembled_full_internal;
31358600ac3SJames Wright   }
31458600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
31558600ac3SJames Wright }
31658600ac3SJames Wright 
31758600ac3SJames Wright /**
31858600ac3SJames Wright   @brief Get `MATCEED` diagonal block for Jacobi.
31958600ac3SJames Wright 
32058600ac3SJames Wright   Collective across MPI processes.
32158600ac3SJames Wright 
32258600ac3SJames Wright   @param[in]   mat_ceed   `MATCEED` to invert
32358600ac3SJames Wright   @param[out]  mat_block  The diagonal block matrix
32458600ac3SJames Wright 
32558600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
32658600ac3SJames Wright **/
32758600ac3SJames Wright static PetscErrorCode MatGetDiagonalBlock_Ceed(Mat mat_ceed, Mat *mat_block) {
32858600ac3SJames Wright   Mat            mat_inner = NULL;
32958600ac3SJames Wright   MatCeedContext ctx;
33058600ac3SJames Wright 
33158600ac3SJames Wright   PetscFunctionBeginUser;
33258600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
33358600ac3SJames Wright 
33458600ac3SJames Wright   // Assemble inner mat if needed
33558600ac3SJames Wright   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
33658600ac3SJames Wright 
33758600ac3SJames Wright   // Get block diagonal
33858600ac3SJames Wright   PetscCall(MatGetDiagonalBlock(mat_inner, mat_block));
33958600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
34058600ac3SJames Wright }
34158600ac3SJames Wright 
34258600ac3SJames Wright /**
34358600ac3SJames Wright   @brief Invert `MATCEED` diagonal block for Jacobi.
34458600ac3SJames Wright 
34558600ac3SJames Wright   Collective across MPI processes.
34658600ac3SJames Wright 
34758600ac3SJames Wright   @param[in]   mat_ceed  `MATCEED` to invert
34858600ac3SJames Wright   @param[out]  values    The block inverses in column major order
34958600ac3SJames Wright 
35058600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
35158600ac3SJames Wright **/
35258600ac3SJames Wright static PetscErrorCode MatInvertBlockDiagonal_Ceed(Mat mat_ceed, const PetscScalar **values) {
35358600ac3SJames Wright   Mat            mat_inner = NULL;
35458600ac3SJames Wright   MatCeedContext ctx;
35558600ac3SJames Wright 
35658600ac3SJames Wright   PetscFunctionBeginUser;
35758600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
35858600ac3SJames Wright 
35958600ac3SJames Wright   // Assemble inner mat if needed
36058600ac3SJames Wright   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner));
36158600ac3SJames Wright 
36258600ac3SJames Wright   // Invert PB diagonal
36358600ac3SJames Wright   PetscCall(MatInvertBlockDiagonal(mat_inner, values));
36458600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
36558600ac3SJames Wright }
36658600ac3SJames Wright 
36758600ac3SJames Wright /**
36858600ac3SJames Wright   @brief Invert `MATCEED` variable diagonal block for Jacobi.
36958600ac3SJames Wright 
37058600ac3SJames Wright   Collective across MPI processes.
37158600ac3SJames Wright 
37258600ac3SJames Wright   @param[in]   mat_ceed     `MATCEED` to invert
37358600ac3SJames Wright   @param[in]   num_blocks   The number of blocks on the process
37458600ac3SJames Wright   @param[in]   block_sizes  The size of each block on the process
37558600ac3SJames Wright   @param[out]  values       The block inverses in column major order
37658600ac3SJames Wright 
37758600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
37858600ac3SJames Wright **/
37958600ac3SJames Wright static PetscErrorCode MatInvertVariableBlockDiagonal_Ceed(Mat mat_ceed, PetscInt num_blocks, const PetscInt *block_sizes, PetscScalar *values) {
38058600ac3SJames Wright   Mat            mat_inner = NULL;
38158600ac3SJames Wright   MatCeedContext ctx;
38258600ac3SJames Wright 
38358600ac3SJames Wright   PetscFunctionBeginUser;
38458600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
38558600ac3SJames Wright 
38658600ac3SJames Wright   // Assemble inner mat if needed
38758600ac3SJames Wright   PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_vpbd_valid, &mat_inner));
38858600ac3SJames Wright 
38958600ac3SJames Wright   // Invert PB diagonal
39058600ac3SJames Wright   PetscCall(MatInvertVariableBlockDiagonal(mat_inner, num_blocks, block_sizes, values));
39158600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
39258600ac3SJames Wright }
39358600ac3SJames Wright 
39458600ac3SJames Wright // -----------------------------------------------------------------------------
39558600ac3SJames Wright // MatCeed
39658600ac3SJames Wright // -----------------------------------------------------------------------------
39758600ac3SJames Wright 
39858600ac3SJames Wright /**
39958600ac3SJames Wright   @brief Create PETSc `Mat` from libCEED operators.
40058600ac3SJames Wright 
40158600ac3SJames Wright   Collective across MPI processes.
40258600ac3SJames Wright 
40358600ac3SJames Wright   @param[in]   dm_x                      Input `DM`
40458600ac3SJames Wright   @param[in]   dm_y                      Output `DM`
40558600ac3SJames Wright   @param[in]   op_mult                   `CeedOperator` for forward evaluation
40658600ac3SJames Wright   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
40758600ac3SJames Wright   @param[out]  mat                        New MatCeed
40858600ac3SJames Wright 
40958600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
41058600ac3SJames Wright **/
41158600ac3SJames Wright PetscErrorCode MatCeedCreate(DM dm_x, DM dm_y, CeedOperator op_mult, CeedOperator op_mult_transpose, Mat *mat) {
41258600ac3SJames Wright   PetscInt       X_l_size, X_g_size, Y_l_size, Y_g_size;
41358600ac3SJames Wright   VecType        vec_type;
41458600ac3SJames Wright   MatCeedContext ctx;
41558600ac3SJames Wright 
41658600ac3SJames Wright   PetscFunctionBeginUser;
41758600ac3SJames Wright   PetscCall(MatCeedRegisterLogEvents());
41858600ac3SJames Wright 
41958600ac3SJames Wright   // Collect context data
42058600ac3SJames Wright   PetscCall(DMGetVecType(dm_x, &vec_type));
42158600ac3SJames Wright   {
42258600ac3SJames Wright     Vec X;
42358600ac3SJames Wright 
42458600ac3SJames Wright     PetscCall(DMGetGlobalVector(dm_x, &X));
42558600ac3SJames Wright     PetscCall(VecGetSize(X, &X_g_size));
42658600ac3SJames Wright     PetscCall(VecGetLocalSize(X, &X_l_size));
42758600ac3SJames Wright     PetscCall(DMRestoreGlobalVector(dm_x, &X));
42858600ac3SJames Wright   }
42958600ac3SJames Wright   if (dm_y) {
43058600ac3SJames Wright     Vec Y;
43158600ac3SJames Wright 
43258600ac3SJames Wright     PetscCall(DMGetGlobalVector(dm_y, &Y));
43358600ac3SJames Wright     PetscCall(VecGetSize(Y, &Y_g_size));
43458600ac3SJames Wright     PetscCall(VecGetLocalSize(Y, &Y_l_size));
43558600ac3SJames Wright     PetscCall(DMRestoreGlobalVector(dm_y, &Y));
43658600ac3SJames Wright   } else {
43758600ac3SJames Wright     dm_y     = dm_x;
43858600ac3SJames Wright     Y_g_size = X_g_size;
43958600ac3SJames Wright     Y_l_size = X_l_size;
44058600ac3SJames Wright   }
44158600ac3SJames Wright   // Create context
44258600ac3SJames Wright   {
44358600ac3SJames Wright     Vec X_loc, Y_loc_transpose = NULL;
44458600ac3SJames Wright 
44558600ac3SJames Wright     PetscCall(DMCreateLocalVector(dm_x, &X_loc));
44658600ac3SJames Wright     PetscCall(VecZeroEntries(X_loc));
44758600ac3SJames Wright     if (op_mult_transpose) {
44858600ac3SJames Wright       PetscCall(DMCreateLocalVector(dm_y, &Y_loc_transpose));
44958600ac3SJames Wright       PetscCall(VecZeroEntries(Y_loc_transpose));
45058600ac3SJames Wright     }
45158600ac3SJames Wright     PetscCall(MatCeedContextCreate(dm_x, dm_y, X_loc, Y_loc_transpose, op_mult, op_mult_transpose, MATCEED_MULT, MATCEED_MULT_TRANSPOSE, &ctx));
45258600ac3SJames Wright     PetscCall(VecDestroy(&X_loc));
45358600ac3SJames Wright     PetscCall(VecDestroy(&Y_loc_transpose));
45458600ac3SJames Wright   }
45558600ac3SJames Wright 
45658600ac3SJames Wright   // Create mat
45758600ac3SJames Wright   PetscCall(MatCreateShell(PetscObjectComm((PetscObject)dm_x), Y_l_size, X_l_size, Y_g_size, X_g_size, ctx, mat));
45858600ac3SJames Wright   PetscCall(PetscObjectChangeTypeName((PetscObject)*mat, MATCEED));
45958600ac3SJames Wright   // -- Set block and variable block sizes
46058600ac3SJames Wright   if (dm_x == dm_y) {
46158600ac3SJames Wright     MatType dm_mat_type, dm_mat_type_copy;
46258600ac3SJames Wright     Mat     temp_mat;
46358600ac3SJames Wright 
46458600ac3SJames Wright     PetscCall(DMGetMatType(dm_x, &dm_mat_type));
46558600ac3SJames Wright     PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy));
46658600ac3SJames Wright     PetscCall(DMSetMatType(dm_x, MATAIJ));
46758600ac3SJames Wright     PetscCall(DMCreateMatrix(dm_x, &temp_mat));
46858600ac3SJames Wright     PetscCall(DMSetMatType(dm_x, dm_mat_type_copy));
46958600ac3SJames Wright     PetscCall(PetscFree(dm_mat_type_copy));
47058600ac3SJames Wright 
47158600ac3SJames Wright     {
47258600ac3SJames Wright       PetscInt        block_size, num_blocks, max_vblock_size = PETSC_INT_MAX;
47358600ac3SJames Wright       const PetscInt *vblock_sizes;
47458600ac3SJames Wright 
47558600ac3SJames Wright       // -- Get block sizes
47658600ac3SJames Wright       PetscCall(MatGetBlockSize(temp_mat, &block_size));
47758600ac3SJames Wright       PetscCall(MatGetVariableBlockSizes(temp_mat, &num_blocks, &vblock_sizes));
47858600ac3SJames Wright       {
47958600ac3SJames Wright         PetscInt local_min_max[2] = {0}, global_min_max[2] = {0, PETSC_INT_MAX};
48058600ac3SJames Wright 
48158600ac3SJames Wright         for (PetscInt i = 0; i < num_blocks; i++) local_min_max[1] = PetscMax(local_min_max[1], vblock_sizes[i]);
48258600ac3SJames Wright         PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_min_max, global_min_max));
48358600ac3SJames Wright         max_vblock_size = global_min_max[1];
48458600ac3SJames Wright       }
48558600ac3SJames Wright 
48658600ac3SJames Wright       // -- Copy block sizes
48758600ac3SJames Wright       if (block_size > 1) PetscCall(MatSetBlockSize(*mat, block_size));
48858600ac3SJames Wright       if (num_blocks) PetscCall(MatSetVariableBlockSizes(*mat, num_blocks, (PetscInt *)vblock_sizes));
48958600ac3SJames Wright 
49058600ac3SJames Wright       // -- Check libCEED compatibility
49158600ac3SJames Wright       {
49258600ac3SJames Wright         bool is_composite;
49358600ac3SJames Wright 
49458600ac3SJames Wright         ctx->is_ceed_pbd_valid  = PETSC_TRUE;
49558600ac3SJames Wright         ctx->is_ceed_vpbd_valid = PETSC_TRUE;
496*50f50432SJames Wright         PetscCallCeed(ctx->ceed, CeedOperatorIsComposite(op_mult, &is_composite));
49758600ac3SJames Wright         if (is_composite) {
49858600ac3SJames Wright           CeedInt       num_sub_operators;
49958600ac3SJames Wright           CeedOperator *sub_operators;
50058600ac3SJames Wright 
501*50f50432SJames Wright           PetscCallCeed(ctx->ceed, CeedCompositeOperatorGetNumSub(op_mult, &num_sub_operators));
502*50f50432SJames Wright           PetscCallCeed(ctx->ceed, CeedCompositeOperatorGetSubList(op_mult, &sub_operators));
50358600ac3SJames Wright           for (CeedInt i = 0; i < num_sub_operators; i++) {
50458600ac3SJames Wright             CeedInt                  num_bases, num_comp;
50558600ac3SJames Wright             CeedBasis               *active_bases;
50658600ac3SJames Wright             CeedOperatorAssemblyData assembly_data;
50758600ac3SJames Wright 
508*50f50432SJames Wright             PetscCallCeed(ctx->ceed, CeedOperatorGetOperatorAssemblyData(sub_operators[i], &assembly_data));
509*50f50432SJames Wright             PetscCallCeed(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
510*50f50432SJames Wright             PetscCallCeed(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
51158600ac3SJames Wright             if (num_bases > 1) {
51258600ac3SJames Wright               ctx->is_ceed_pbd_valid  = PETSC_FALSE;
51358600ac3SJames Wright               ctx->is_ceed_vpbd_valid = PETSC_FALSE;
51458600ac3SJames Wright             }
51558600ac3SJames Wright             if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
51658600ac3SJames Wright             if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
51758600ac3SJames Wright           }
51858600ac3SJames Wright         } else {
51958600ac3SJames Wright           // LCOV_EXCL_START
52058600ac3SJames Wright           CeedInt                  num_bases, num_comp;
52158600ac3SJames Wright           CeedBasis               *active_bases;
52258600ac3SJames Wright           CeedOperatorAssemblyData assembly_data;
52358600ac3SJames Wright 
524*50f50432SJames Wright           PetscCallCeed(ctx->ceed, CeedOperatorGetOperatorAssemblyData(op_mult, &assembly_data));
525*50f50432SJames Wright           PetscCallCeed(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL));
526*50f50432SJames Wright           PetscCallCeed(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp));
52758600ac3SJames Wright           if (num_bases > 1) {
52858600ac3SJames Wright             ctx->is_ceed_pbd_valid  = PETSC_FALSE;
52958600ac3SJames Wright             ctx->is_ceed_vpbd_valid = PETSC_FALSE;
53058600ac3SJames Wright           }
53158600ac3SJames Wright           if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE;
53258600ac3SJames Wright           if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE;
53358600ac3SJames Wright           // LCOV_EXCL_STOP
53458600ac3SJames Wright         }
53558600ac3SJames Wright         {
53658600ac3SJames Wright           PetscInt local_is_valid[2], global_is_valid[2];
53758600ac3SJames Wright 
53858600ac3SJames Wright           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_pbd_valid;
53958600ac3SJames Wright           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
54058600ac3SJames Wright           ctx->is_ceed_pbd_valid = global_is_valid[0];
54158600ac3SJames Wright           local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_vpbd_valid;
54258600ac3SJames Wright           PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid));
54358600ac3SJames Wright           ctx->is_ceed_vpbd_valid = global_is_valid[0];
54458600ac3SJames Wright         }
54558600ac3SJames Wright       }
54658600ac3SJames Wright     }
54758600ac3SJames Wright     PetscCall(MatDestroy(&temp_mat));
54858600ac3SJames Wright   }
54958600ac3SJames Wright   // -- Set internal mat type
55058600ac3SJames Wright   {
55158600ac3SJames Wright     VecType vec_type;
55258600ac3SJames Wright     MatType internal_mat_type = MATAIJ;
55358600ac3SJames Wright 
55458600ac3SJames Wright     PetscCall(VecGetType(ctx->X_loc, &vec_type));
55558600ac3SJames Wright     if (strstr(vec_type, VECCUDA)) internal_mat_type = MATAIJCUSPARSE;
55658600ac3SJames Wright     else if (strstr(vec_type, VECKOKKOS)) internal_mat_type = MATAIJKOKKOS;
55758600ac3SJames Wright     else internal_mat_type = MATAIJ;
55858600ac3SJames Wright     PetscCall(PetscStrallocpy(internal_mat_type, &ctx->internal_mat_type));
55958600ac3SJames Wright   }
56058600ac3SJames Wright   // -- Set mat operations
56158600ac3SJames Wright   PetscCall(MatShellSetContextDestroy(*mat, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
56258600ac3SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_MULT, (void (*)(void))MatMult_Ceed));
56358600ac3SJames Wright   if (op_mult_transpose) PetscCall(MatShellSetOperation(*mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
56458600ac3SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
56558600ac3SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
56658600ac3SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
56758600ac3SJames Wright   PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
56858600ac3SJames Wright   PetscCall(MatShellSetVecType(*mat, vec_type));
56958600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
57058600ac3SJames Wright }
57158600ac3SJames Wright 
57258600ac3SJames Wright /**
57358600ac3SJames Wright   @brief Copy `MATCEED` into a compatible `Mat` with type `MatShell` or `MATCEED`.
57458600ac3SJames Wright 
57558600ac3SJames Wright   Collective across MPI processes.
57658600ac3SJames Wright 
57758600ac3SJames Wright   @param[in]   mat_ceed   `MATCEED` to copy from
57858600ac3SJames Wright   @param[out]  mat_other  `MatShell` or `MATCEED` to copy into
57958600ac3SJames Wright 
58058600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
58158600ac3SJames Wright **/
58258600ac3SJames Wright PetscErrorCode MatCeedCopy(Mat mat_ceed, Mat mat_other) {
58358600ac3SJames Wright   PetscFunctionBeginUser;
58458600ac3SJames Wright   PetscCall(MatCeedRegisterLogEvents());
58558600ac3SJames Wright 
58658600ac3SJames Wright   // Check type compatibility
58758600ac3SJames Wright   {
58858600ac3SJames Wright     MatType mat_type_ceed, mat_type_other;
58958600ac3SJames Wright 
59058600ac3SJames Wright     PetscCall(MatGetType(mat_ceed, &mat_type_ceed));
59158600ac3SJames Wright     PetscCheck(!strcmp(mat_type_ceed, MATCEED), PETSC_COMM_SELF, PETSC_ERR_LIB, "mat_ceed must have type " MATCEED);
59258600ac3SJames Wright     PetscCall(MatGetType(mat_ceed, &mat_type_other));
59358600ac3SJames Wright     PetscCheck(!strcmp(mat_type_other, MATCEED) || !strcmp(mat_type_other, MATSHELL), PETSC_COMM_SELF, PETSC_ERR_LIB,
59458600ac3SJames Wright                "mat_other must have type " MATCEED " or " MATSHELL);
59558600ac3SJames Wright   }
59658600ac3SJames Wright 
59758600ac3SJames Wright   // Check dimension compatibility
59858600ac3SJames Wright   {
59958600ac3SJames Wright     PetscInt X_l_ceed_size, X_g_ceed_size, Y_l_ceed_size, Y_g_ceed_size, X_l_other_size, X_g_other_size, Y_l_other_size, Y_g_other_size;
60058600ac3SJames Wright 
60158600ac3SJames Wright     PetscCall(MatGetSize(mat_ceed, &Y_g_ceed_size, &X_g_ceed_size));
60258600ac3SJames Wright     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_ceed_size, &X_l_ceed_size));
60358600ac3SJames Wright     PetscCall(MatGetSize(mat_ceed, &Y_g_other_size, &X_g_other_size));
60458600ac3SJames Wright     PetscCall(MatGetLocalSize(mat_ceed, &Y_l_other_size, &X_l_other_size));
60558600ac3SJames Wright     PetscCheck((Y_g_ceed_size == Y_g_other_size) && (X_g_ceed_size == X_g_other_size) && (Y_l_ceed_size == Y_l_other_size) &&
60658600ac3SJames Wright                    (X_l_ceed_size == X_l_other_size),
60758600ac3SJames Wright                PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ,
60858600ac3SJames Wright                "mat_ceed and mat_other must have compatible sizes; found mat_ceed (Global: %" PetscInt_FMT ", %" PetscInt_FMT
60958600ac3SJames Wright                "; Local: %" PetscInt_FMT ", %" PetscInt_FMT ") mat_other (Global: %" PetscInt_FMT ", %" PetscInt_FMT "; Local: %" PetscInt_FMT
61058600ac3SJames Wright                ", %" PetscInt_FMT ")",
61158600ac3SJames Wright                Y_g_ceed_size, X_g_ceed_size, Y_l_ceed_size, X_l_ceed_size, Y_g_other_size, X_g_other_size, Y_l_other_size, X_l_other_size);
61258600ac3SJames Wright   }
61358600ac3SJames Wright 
61458600ac3SJames Wright   // Convert
61558600ac3SJames Wright   {
61658600ac3SJames Wright     VecType        vec_type;
61758600ac3SJames Wright     MatCeedContext ctx;
61858600ac3SJames Wright 
61958600ac3SJames Wright     PetscCall(PetscObjectChangeTypeName((PetscObject)mat_other, MATCEED));
62058600ac3SJames Wright     PetscCall(MatShellGetContext(mat_ceed, &ctx));
62158600ac3SJames Wright     PetscCall(MatCeedContextReference(ctx));
62258600ac3SJames Wright     PetscCall(MatShellSetContext(mat_other, ctx));
62358600ac3SJames Wright     PetscCall(MatShellSetContextDestroy(mat_other, (PetscErrorCode(*)(void *))MatCeedContextDestroy));
62458600ac3SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_MULT, (void (*)(void))MatMult_Ceed));
62558600ac3SJames Wright     if (ctx->op_mult_transpose) PetscCall(MatShellSetOperation(mat_other, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed));
62658600ac3SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed));
62758600ac3SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed));
62858600ac3SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed));
62958600ac3SJames Wright     PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed));
63058600ac3SJames Wright     {
63158600ac3SJames Wright       PetscInt block_size;
63258600ac3SJames Wright 
63358600ac3SJames Wright       PetscCall(MatGetBlockSize(mat_ceed, &block_size));
63458600ac3SJames Wright       if (block_size > 1) PetscCall(MatSetBlockSize(mat_other, block_size));
63558600ac3SJames Wright     }
63658600ac3SJames Wright     {
63758600ac3SJames Wright       PetscInt        num_blocks;
63858600ac3SJames Wright       const PetscInt *block_sizes;
63958600ac3SJames Wright 
64058600ac3SJames Wright       PetscCall(MatGetVariableBlockSizes(mat_ceed, &num_blocks, &block_sizes));
64158600ac3SJames Wright       if (num_blocks) PetscCall(MatSetVariableBlockSizes(mat_other, num_blocks, (PetscInt *)block_sizes));
64258600ac3SJames Wright     }
64358600ac3SJames Wright     PetscCall(DMGetVecType(ctx->dm_x, &vec_type));
64458600ac3SJames Wright     PetscCall(MatShellSetVecType(mat_other, vec_type));
64558600ac3SJames Wright   }
64658600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
64758600ac3SJames Wright }
64858600ac3SJames Wright 
64958600ac3SJames Wright /**
65058600ac3SJames Wright   @brief Assemble a `MATCEED` into a `MATAIJ` or similar.
65158600ac3SJames Wright          The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`.
65258600ac3SJames Wright          The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail.
65358600ac3SJames Wright 
65458600ac3SJames Wright   Collective across MPI processes.
65558600ac3SJames Wright 
65658600ac3SJames Wright   @param[in]      mat_ceed  `MATCEED` to assemble
65758600ac3SJames Wright   @param[in,out]  mat_coo   `MATAIJ` or similar to assemble into
65858600ac3SJames Wright 
65958600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
66058600ac3SJames Wright **/
66158600ac3SJames Wright PetscErrorCode MatCeedAssembleCOO(Mat mat_ceed, Mat mat_coo) {
66258600ac3SJames Wright   MatCeedContext ctx;
66358600ac3SJames Wright 
66458600ac3SJames Wright   PetscFunctionBeginUser;
66558600ac3SJames Wright   PetscCall(MatShellGetContext(mat_ceed, &ctx));
66658600ac3SJames Wright 
66758600ac3SJames Wright   // Check if COO pattern set
66858600ac3SJames Wright   {
66958600ac3SJames Wright     PetscInt index = -1;
67058600ac3SJames Wright 
67158600ac3SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
67258600ac3SJames Wright       if (ctx->mats_assembled_full[i] == mat_coo) index = i;
67358600ac3SJames Wright     }
67458600ac3SJames Wright     if (index == -1) {
67558600ac3SJames Wright       PetscInt     *rows_petsc = NULL, *cols_petsc = NULL;
67658600ac3SJames Wright       CeedInt      *rows_ceed, *cols_ceed;
67758600ac3SJames Wright       PetscCount    num_entries;
67858600ac3SJames Wright       PetscLogStage stage_amg_setup;
67958600ac3SJames Wright 
68058600ac3SJames Wright       // -- Assemble sparsity pattern if mat hasn't been assembled before
68158600ac3SJames Wright       PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup));
68258600ac3SJames Wright       if (stage_amg_setup == -1) {
68358600ac3SJames Wright         PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup));
68458600ac3SJames Wright       }
68558600ac3SJames Wright       PetscCall(PetscLogStagePush(stage_amg_setup));
686*50f50432SJames Wright       PetscCallCeed(ctx->ceed, CeedOperatorLinearAssembleSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed));
68758600ac3SJames Wright       PetscCall(IntArrayC2P(num_entries, &rows_ceed, &rows_petsc));
68858600ac3SJames Wright       PetscCall(IntArrayC2P(num_entries, &cols_ceed, &cols_petsc));
68958600ac3SJames Wright       PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc));
69058600ac3SJames Wright       free(rows_petsc);
69158600ac3SJames Wright       free(cols_petsc);
692*50f50432SJames Wright       if (!ctx->coo_values_full) PetscCallCeed(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_full));
69358600ac3SJames Wright       PetscCall(PetscRealloc(++ctx->num_mats_assembled_full * sizeof(Mat), &ctx->mats_assembled_full));
69458600ac3SJames Wright       ctx->mats_assembled_full[ctx->num_mats_assembled_full - 1] = mat_coo;
69558600ac3SJames Wright       PetscCall(PetscLogStagePop());
69658600ac3SJames Wright     }
69758600ac3SJames Wright   }
69858600ac3SJames Wright 
69958600ac3SJames Wright   // Assemble mat_ceed
70058600ac3SJames Wright   PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY));
70158600ac3SJames Wright   {
70258600ac3SJames Wright     const CeedScalar *values;
70358600ac3SJames Wright     MatType           mat_type;
70458600ac3SJames Wright     CeedMemType       mem_type = CEED_MEM_HOST;
70558600ac3SJames Wright     PetscBool         is_spd, is_spd_known;
70658600ac3SJames Wright 
70758600ac3SJames Wright     PetscCall(MatGetType(mat_coo, &mat_type));
70858600ac3SJames Wright     if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE;
70958600ac3SJames Wright     else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE;
71058600ac3SJames Wright     else mem_type = CEED_MEM_HOST;
71158600ac3SJames Wright 
712*50f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemble(ctx->op_mult, ctx->coo_values_full));
713*50f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_full, mem_type, &values));
71458600ac3SJames Wright     PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES));
71558600ac3SJames Wright     PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd));
71658600ac3SJames Wright     if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd));
717*50f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_full, &values));
71858600ac3SJames Wright   }
71958600ac3SJames Wright   PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY));
72058600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
72158600ac3SJames Wright }
72258600ac3SJames Wright 
72358600ac3SJames Wright /**
72458600ac3SJames Wright   @brief Set user context for a `MATCEED`.
72558600ac3SJames Wright 
72658600ac3SJames Wright   Collective across MPI processes.
72758600ac3SJames Wright 
72858600ac3SJames Wright   @param[in,out]  mat  `MATCEED`
72958600ac3SJames Wright   @param[in]      f    The context destroy function, or NULL
73058600ac3SJames Wright   @param[in]      ctx  User context, or NULL to unset
73158600ac3SJames Wright 
73258600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
73358600ac3SJames Wright **/
73458600ac3SJames Wright PetscErrorCode MatCeedSetContext(Mat mat, PetscErrorCode (*f)(void *), void *ctx) {
73558600ac3SJames Wright   PetscContainer user_ctx = NULL;
73658600ac3SJames Wright 
73758600ac3SJames Wright   PetscFunctionBeginUser;
73858600ac3SJames Wright   if (ctx) {
73958600ac3SJames Wright     PetscCall(PetscContainerCreate(PetscObjectComm((PetscObject)mat), &user_ctx));
74058600ac3SJames Wright     PetscCall(PetscContainerSetPointer(user_ctx, ctx));
74158600ac3SJames Wright     PetscCall(PetscContainerSetUserDestroy(user_ctx, f));
74258600ac3SJames Wright   }
74358600ac3SJames Wright   PetscCall(PetscObjectCompose((PetscObject)mat, "MatCeed user context", (PetscObject)user_ctx));
74458600ac3SJames Wright   PetscCall(PetscContainerDestroy(&user_ctx));
74558600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
74658600ac3SJames Wright }
74758600ac3SJames Wright 
74858600ac3SJames Wright /**
74958600ac3SJames Wright   @brief Retrieve the user context for a `MATCEED`.
75058600ac3SJames Wright 
75158600ac3SJames Wright   Collective across MPI processes.
75258600ac3SJames Wright 
75358600ac3SJames Wright   @param[in,out]  mat  `MATCEED`
75458600ac3SJames Wright   @param[in]      ctx  User context
75558600ac3SJames Wright 
75658600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
75758600ac3SJames Wright **/
75858600ac3SJames Wright PetscErrorCode MatCeedGetContext(Mat mat, void *ctx) {
75958600ac3SJames Wright   PetscContainer user_ctx;
76058600ac3SJames Wright 
76158600ac3SJames Wright   PetscFunctionBeginUser;
76258600ac3SJames Wright   PetscCall(PetscObjectQuery((PetscObject)mat, "MatCeed user context", (PetscObject *)&user_ctx));
76358600ac3SJames Wright   if (user_ctx) PetscCall(PetscContainerGetPointer(user_ctx, (void **)ctx));
76458600ac3SJames Wright   else *(void **)ctx = NULL;
76558600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
76658600ac3SJames Wright }
76758600ac3SJames Wright 
76858600ac3SJames Wright /**
76958600ac3SJames Wright   @brief Sets the inner matrix type as a string from the `MATCEED`.
77058600ac3SJames Wright 
77158600ac3SJames Wright   Collective across MPI processes.
77258600ac3SJames Wright 
77358600ac3SJames Wright   @param[in,out]  mat   `MATCEED`
77458600ac3SJames Wright   @param[in]      type  Inner `MatType` to set
77558600ac3SJames Wright 
77658600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
77758600ac3SJames Wright **/
77858600ac3SJames Wright PetscErrorCode MatCeedSetInnerMatType(Mat mat, MatType type) {
77958600ac3SJames Wright   MatCeedContext ctx;
78058600ac3SJames Wright 
78158600ac3SJames Wright   PetscFunctionBeginUser;
78258600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
78358600ac3SJames Wright   // Check if same
78458600ac3SJames Wright   {
78558600ac3SJames Wright     size_t    len_old, len_new;
78658600ac3SJames Wright     PetscBool is_same = PETSC_FALSE;
78758600ac3SJames Wright 
78858600ac3SJames Wright     PetscCall(PetscStrlen(ctx->internal_mat_type, &len_old));
78958600ac3SJames Wright     PetscCall(PetscStrlen(type, &len_new));
79058600ac3SJames Wright     if (len_old == len_new) PetscCall(PetscStrncmp(ctx->internal_mat_type, type, len_old, &is_same));
79158600ac3SJames Wright     if (is_same) PetscFunctionReturn(PETSC_SUCCESS);
79258600ac3SJames Wright   }
79358600ac3SJames Wright   // Clean up old mats in different format
79458600ac3SJames Wright   // LCOV_EXCL_START
79558600ac3SJames Wright   if (ctx->mat_assembled_full_internal) {
79658600ac3SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) {
79758600ac3SJames Wright       if (ctx->mats_assembled_full[i] == ctx->mat_assembled_full_internal) {
79858600ac3SJames Wright         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_full; j++) {
79958600ac3SJames Wright           ctx->mats_assembled_full[j - 1] = ctx->mats_assembled_full[j];
80058600ac3SJames Wright         }
80158600ac3SJames Wright         ctx->num_mats_assembled_full--;
80258600ac3SJames Wright         // Note: we'll realloc this array again, so no need to shrink the allocation
80358600ac3SJames Wright         PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
80458600ac3SJames Wright       }
80558600ac3SJames Wright     }
80658600ac3SJames Wright   }
80758600ac3SJames Wright   if (ctx->mat_assembled_pbd_internal) {
80858600ac3SJames Wright     for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) {
80958600ac3SJames Wright       if (ctx->mats_assembled_pbd[i] == ctx->mat_assembled_pbd_internal) {
81058600ac3SJames Wright         for (PetscInt j = i + 1; j < ctx->num_mats_assembled_pbd; j++) {
81158600ac3SJames Wright           ctx->mats_assembled_pbd[j - 1] = ctx->mats_assembled_pbd[j];
81258600ac3SJames Wright         }
81358600ac3SJames Wright         // Note: we'll realloc this array again, so no need to shrink the allocation
81458600ac3SJames Wright         ctx->num_mats_assembled_pbd--;
81558600ac3SJames Wright         PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
81658600ac3SJames Wright       }
81758600ac3SJames Wright     }
81858600ac3SJames Wright   }
81958600ac3SJames Wright   PetscCall(PetscFree(ctx->internal_mat_type));
82058600ac3SJames Wright   PetscCall(PetscStrallocpy(type, &ctx->internal_mat_type));
82158600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
82258600ac3SJames Wright   // LCOV_EXCL_STOP
82358600ac3SJames Wright }
82458600ac3SJames Wright 
82558600ac3SJames Wright /**
82658600ac3SJames Wright   @brief Gets the inner matrix type as a string from the `MATCEED`.
82758600ac3SJames Wright 
82858600ac3SJames Wright   Collective across MPI processes.
82958600ac3SJames Wright 
83058600ac3SJames Wright   @param[in,out]  mat   `MATCEED`
83158600ac3SJames Wright   @param[in]      type  Inner `MatType`
83258600ac3SJames Wright 
83358600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
83458600ac3SJames Wright **/
83558600ac3SJames Wright PetscErrorCode MatCeedGetInnerMatType(Mat mat, MatType *type) {
83658600ac3SJames Wright   MatCeedContext ctx;
83758600ac3SJames Wright 
83858600ac3SJames Wright   PetscFunctionBeginUser;
83958600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
84058600ac3SJames Wright   *type = ctx->internal_mat_type;
84158600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
84258600ac3SJames Wright }
84358600ac3SJames Wright 
84458600ac3SJames Wright /**
84558600ac3SJames Wright   @brief Set a user defined matrix operation for a `MATCEED` matrix.
84658600ac3SJames Wright 
84758600ac3SJames Wright   Within each user-defined routine, the user should call `MatCeedGetContext()` to obtain the user-defined context that was set by
84858600ac3SJames Wright `MatCeedSetContext()`.
84958600ac3SJames Wright 
85058600ac3SJames Wright   Collective across MPI processes.
85158600ac3SJames Wright 
85258600ac3SJames Wright   @param[in,out]  mat  `MATCEED`
85358600ac3SJames Wright   @param[in]      op   Name of the `MatOperation`
85458600ac3SJames Wright   @param[in]      g    Function that provides the operation
85558600ac3SJames Wright 
85658600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
85758600ac3SJames Wright **/
85858600ac3SJames Wright PetscErrorCode MatCeedSetOperation(Mat mat, MatOperation op, void (*g)(void)) {
85958600ac3SJames Wright   PetscFunctionBeginUser;
86058600ac3SJames Wright   PetscCall(MatShellSetOperation(mat, op, g));
86158600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
86258600ac3SJames Wright }
86358600ac3SJames Wright 
86458600ac3SJames Wright /**
86558600ac3SJames Wright   @brief Set input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
86658600ac3SJames Wright 
86758600ac3SJames Wright   Not collective across MPI processes.
86858600ac3SJames Wright 
86958600ac3SJames Wright   @param[in,out]  mat              `MATCEED`
87058600ac3SJames Wright   @param[in]      X_loc            Input PETSc local vector, or NULL
87158600ac3SJames Wright   @param[in]      Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
87258600ac3SJames Wright 
87358600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
87458600ac3SJames Wright **/
87558600ac3SJames Wright PetscErrorCode MatCeedSetLocalVectors(Mat mat, Vec X_loc, Vec Y_loc_transpose) {
87658600ac3SJames Wright   MatCeedContext ctx;
87758600ac3SJames Wright 
87858600ac3SJames Wright   PetscFunctionBeginUser;
87958600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
88058600ac3SJames Wright   if (X_loc) {
88158600ac3SJames Wright     PetscInt len_old, len_new;
88258600ac3SJames Wright 
88358600ac3SJames Wright     PetscCall(VecGetSize(ctx->X_loc, &len_old));
88458600ac3SJames Wright     PetscCall(VecGetSize(X_loc, &len_new));
88558600ac3SJames Wright     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, "new X_loc length %" PetscInt_FMT " should match old X_loc length %" PetscInt_FMT,
88658600ac3SJames Wright                len_new, len_old);
88758600ac3SJames Wright     PetscCall(VecDestroy(&ctx->X_loc));
88858600ac3SJames Wright     ctx->X_loc = X_loc;
88958600ac3SJames Wright     PetscCall(PetscObjectReference((PetscObject)X_loc));
89058600ac3SJames Wright   }
89158600ac3SJames Wright   if (Y_loc_transpose) {
89258600ac3SJames Wright     PetscInt len_old, len_new;
89358600ac3SJames Wright 
89458600ac3SJames Wright     PetscCall(VecGetSize(ctx->Y_loc_transpose, &len_old));
89558600ac3SJames Wright     PetscCall(VecGetSize(Y_loc_transpose, &len_new));
89658600ac3SJames Wright     PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB,
89758600ac3SJames Wright                "new Y_loc_transpose length %" PetscInt_FMT " should match old Y_loc_transpose length %" PetscInt_FMT, len_new, len_old);
89858600ac3SJames Wright     PetscCall(VecDestroy(&ctx->Y_loc_transpose));
89958600ac3SJames Wright     ctx->Y_loc_transpose = Y_loc_transpose;
90058600ac3SJames Wright     PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose));
90158600ac3SJames Wright   }
90258600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
90358600ac3SJames Wright }
90458600ac3SJames Wright 
90558600ac3SJames Wright /**
90658600ac3SJames Wright   @brief Get input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
90758600ac3SJames Wright 
90858600ac3SJames Wright   Not collective across MPI processes.
90958600ac3SJames Wright 
91058600ac3SJames Wright   @param[in,out]  mat              `MATCEED`
91158600ac3SJames Wright   @param[out]     X_loc            Input PETSc local vector, or NULL
91258600ac3SJames Wright   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
91358600ac3SJames Wright 
91458600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
91558600ac3SJames Wright **/
91658600ac3SJames Wright PetscErrorCode MatCeedGetLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
91758600ac3SJames Wright   MatCeedContext ctx;
91858600ac3SJames Wright 
91958600ac3SJames Wright   PetscFunctionBeginUser;
92058600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
92158600ac3SJames Wright   if (X_loc) {
92258600ac3SJames Wright     *X_loc = ctx->X_loc;
92358600ac3SJames Wright     PetscCall(PetscObjectReference((PetscObject)*X_loc));
92458600ac3SJames Wright   }
92558600ac3SJames Wright   if (Y_loc_transpose) {
92658600ac3SJames Wright     *Y_loc_transpose = ctx->Y_loc_transpose;
92758600ac3SJames Wright     PetscCall(PetscObjectReference((PetscObject)*Y_loc_transpose));
92858600ac3SJames Wright   }
92958600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
93058600ac3SJames Wright }
93158600ac3SJames Wright 
93258600ac3SJames Wright /**
93358600ac3SJames Wright   @brief Restore input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
93458600ac3SJames Wright 
93558600ac3SJames Wright   Not collective across MPI processes.
93658600ac3SJames Wright 
93758600ac3SJames Wright   @param[in,out]  mat              MatCeed
93858600ac3SJames Wright   @param[out]     X_loc            Input PETSc local vector, or NULL
93958600ac3SJames Wright   @param[out]     Y_loc_transpose  Input PETSc local vector for transpose operation, or NULL
94058600ac3SJames Wright 
94158600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
94258600ac3SJames Wright **/
94358600ac3SJames Wright PetscErrorCode MatCeedRestoreLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) {
94458600ac3SJames Wright   PetscFunctionBeginUser;
94558600ac3SJames Wright   if (X_loc) PetscCall(VecDestroy(X_loc));
94658600ac3SJames Wright   if (Y_loc_transpose) PetscCall(VecDestroy(Y_loc_transpose));
94758600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
94858600ac3SJames Wright }
94958600ac3SJames Wright 
95058600ac3SJames Wright /**
95158600ac3SJames Wright   @brief Get libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
95258600ac3SJames Wright 
95358600ac3SJames Wright   Not collective across MPI processes.
95458600ac3SJames Wright 
95558600ac3SJames Wright   @param[in,out]  mat                MatCeed
95658600ac3SJames Wright   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
95758600ac3SJames Wright   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
95858600ac3SJames Wright 
95958600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
96058600ac3SJames Wright **/
96158600ac3SJames Wright PetscErrorCode MatCeedGetCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
96258600ac3SJames Wright   MatCeedContext ctx;
96358600ac3SJames Wright 
96458600ac3SJames Wright   PetscFunctionBeginUser;
96558600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
96658600ac3SJames Wright   if (op_mult) {
96758600ac3SJames Wright     *op_mult = NULL;
968*50f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult, op_mult));
96958600ac3SJames Wright   }
97058600ac3SJames Wright   if (op_mult_transpose) {
97158600ac3SJames Wright     *op_mult_transpose = NULL;
972*50f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult_transpose, op_mult_transpose));
97358600ac3SJames Wright   }
97458600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
97558600ac3SJames Wright }
97658600ac3SJames Wright 
97758600ac3SJames Wright /**
97858600ac3SJames Wright   @brief Restore libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations.
97958600ac3SJames Wright 
98058600ac3SJames Wright   Not collective across MPI processes.
98158600ac3SJames Wright 
98258600ac3SJames Wright   @param[in,out]  mat                MatCeed
98358600ac3SJames Wright   @param[out]     op_mult            libCEED `CeedOperator` for `MatMult()`, or NULL
98458600ac3SJames Wright   @param[out]     op_mult_transpose  libCEED `CeedOperator` for `MatMultTranspose()`, or NULL
98558600ac3SJames Wright 
98658600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
98758600ac3SJames Wright **/
98858600ac3SJames Wright PetscErrorCode MatCeedRestoreCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) {
98958600ac3SJames Wright   MatCeedContext ctx;
99058600ac3SJames Wright 
99158600ac3SJames Wright   PetscFunctionBeginUser;
99258600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
993*50f50432SJames Wright   if (op_mult) PetscCallCeed(ctx->ceed, CeedOperatorDestroy(op_mult));
994*50f50432SJames Wright   if (op_mult_transpose) PetscCallCeed(ctx->ceed, CeedOperatorDestroy(op_mult_transpose));
99558600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
99658600ac3SJames Wright }
99758600ac3SJames Wright 
99858600ac3SJames Wright /**
99958600ac3SJames Wright   @brief Set `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
100058600ac3SJames Wright 
100158600ac3SJames Wright   Not collective across MPI processes.
100258600ac3SJames Wright 
100358600ac3SJames Wright   @param[in,out]  mat                       MatCeed
100458600ac3SJames Wright   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
100558600ac3SJames Wright   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
100658600ac3SJames Wright 
100758600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
100858600ac3SJames Wright **/
100958600ac3SJames Wright PetscErrorCode MatCeedSetLogEvents(Mat mat, PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose) {
101058600ac3SJames Wright   MatCeedContext ctx;
101158600ac3SJames Wright 
101258600ac3SJames Wright   PetscFunctionBeginUser;
101358600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
101458600ac3SJames Wright   if (log_event_mult) ctx->log_event_mult = log_event_mult;
101558600ac3SJames Wright   if (log_event_mult_transpose) ctx->log_event_mult_transpose = log_event_mult_transpose;
101658600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
101758600ac3SJames Wright }
101858600ac3SJames Wright 
101958600ac3SJames Wright /**
102058600ac3SJames Wright   @brief Get `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators.
102158600ac3SJames Wright 
102258600ac3SJames Wright   Not collective across MPI processes.
102358600ac3SJames Wright 
102458600ac3SJames Wright   @param[in,out]  mat                       MatCeed
102558600ac3SJames Wright   @param[out]     log_event_mult            `PetscLogEvent` for forward evaluation, or NULL
102658600ac3SJames Wright   @param[out]     log_event_mult_transpose  `PetscLogEvent` for transpose evaluation, or NULL
102758600ac3SJames Wright 
102858600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
102958600ac3SJames Wright **/
103058600ac3SJames Wright PetscErrorCode MatCeedGetLogEvents(Mat mat, PetscLogEvent *log_event_mult, PetscLogEvent *log_event_mult_transpose) {
103158600ac3SJames Wright   MatCeedContext ctx;
103258600ac3SJames Wright 
103358600ac3SJames Wright   PetscFunctionBeginUser;
103458600ac3SJames Wright   PetscCall(MatShellGetContext(mat, &ctx));
103558600ac3SJames Wright   if (log_event_mult) *log_event_mult = ctx->log_event_mult;
103658600ac3SJames Wright   if (log_event_mult_transpose) *log_event_mult_transpose = ctx->log_event_mult_transpose;
103758600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
103858600ac3SJames Wright }
103958600ac3SJames Wright 
104058600ac3SJames Wright // -----------------------------------------------------------------------------
104158600ac3SJames Wright // Operator context data
104258600ac3SJames Wright // -----------------------------------------------------------------------------
104358600ac3SJames Wright 
104458600ac3SJames Wright /**
104558600ac3SJames Wright   @brief Setup context data for operator application.
104658600ac3SJames Wright 
104758600ac3SJames Wright   Collective across MPI processes.
104858600ac3SJames Wright 
104958600ac3SJames Wright   @param[in]   dm_x                      Input `DM`
105058600ac3SJames Wright   @param[in]   dm_y                      Output `DM`
105158600ac3SJames Wright   @param[in]   X_loc                     Input PETSc local vector, or NULL
105258600ac3SJames Wright   @param[in]   Y_loc_transpose           Input PETSc local vector for transpose operation, or NULL
105358600ac3SJames Wright   @param[in]   op_mult                   `CeedOperator` for forward evaluation
105458600ac3SJames Wright   @param[in]   op_mult_transpose         `CeedOperator` for transpose evaluation
105558600ac3SJames Wright   @param[in]   log_event_mult            `PetscLogEvent` for forward evaluation
105658600ac3SJames Wright   @param[in]   log_event_mult_transpose  `PetscLogEvent` for transpose evaluation
105758600ac3SJames Wright   @param[out]  ctx                       Context data for operator evaluation
105858600ac3SJames Wright 
105958600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
106058600ac3SJames Wright **/
106158600ac3SJames Wright PetscErrorCode MatCeedContextCreate(DM dm_x, DM dm_y, Vec X_loc, Vec Y_loc_transpose, CeedOperator op_mult, CeedOperator op_mult_transpose,
106258600ac3SJames Wright                                     PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose, MatCeedContext *ctx) {
106358600ac3SJames Wright   CeedSize x_loc_len, y_loc_len;
106458600ac3SJames Wright 
106558600ac3SJames Wright   PetscFunctionBeginUser;
106658600ac3SJames Wright 
106758600ac3SJames Wright   // Allocate
106858600ac3SJames Wright   PetscCall(PetscNew(ctx));
106958600ac3SJames Wright   (*ctx)->ref_count = 1;
107058600ac3SJames Wright 
107158600ac3SJames Wright   // Logging
107258600ac3SJames Wright   (*ctx)->log_event_mult           = log_event_mult;
107358600ac3SJames Wright   (*ctx)->log_event_mult_transpose = log_event_mult_transpose;
107458600ac3SJames Wright 
107558600ac3SJames Wright   // PETSc objects
107658600ac3SJames Wright   PetscCall(PetscObjectReference((PetscObject)dm_x));
107758600ac3SJames Wright   (*ctx)->dm_x = dm_x;
107858600ac3SJames Wright   PetscCall(PetscObjectReference((PetscObject)dm_y));
107958600ac3SJames Wright   (*ctx)->dm_y = dm_y;
108058600ac3SJames Wright   if (X_loc) PetscCall(PetscObjectReference((PetscObject)X_loc));
108158600ac3SJames Wright   (*ctx)->X_loc = X_loc;
108258600ac3SJames Wright   if (Y_loc_transpose) PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose));
108358600ac3SJames Wright   (*ctx)->Y_loc_transpose = Y_loc_transpose;
108458600ac3SJames Wright 
108558600ac3SJames Wright   // Memtype
108658600ac3SJames Wright   {
108758600ac3SJames Wright     const PetscScalar *x;
108858600ac3SJames Wright     Vec                X;
108958600ac3SJames Wright 
109058600ac3SJames Wright     PetscCall(DMGetLocalVector(dm_x, &X));
109158600ac3SJames Wright     PetscCall(VecGetArrayReadAndMemType(X, &x, &(*ctx)->mem_type));
109258600ac3SJames Wright     PetscCall(VecRestoreArrayReadAndMemType(X, &x));
109358600ac3SJames Wright     PetscCall(DMRestoreLocalVector(dm_x, &X));
109458600ac3SJames Wright   }
109558600ac3SJames Wright 
109658600ac3SJames Wright   // libCEED objects
109758600ac3SJames Wright   PetscCheck(CeedOperatorGetCeed(op_mult, &(*ctx)->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB,
109858600ac3SJames Wright              "retrieving Ceed context object failed");
1099*50f50432SJames Wright   PetscCallCeed((*ctx)->ceed, CeedReference((*ctx)->ceed));
1100*50f50432SJames Wright   PetscCallCeed((*ctx)->ceed, CeedOperatorGetActiveVectorLengths(op_mult, &x_loc_len, &y_loc_len));
1101*50f50432SJames Wright   PetscCallCeed((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult, &(*ctx)->op_mult));
1102*50f50432SJames Wright   if (op_mult_transpose) PetscCallCeed((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult_transpose, &(*ctx)->op_mult_transpose));
1103*50f50432SJames Wright   PetscCallCeed((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, x_loc_len, &(*ctx)->x_loc));
1104*50f50432SJames Wright   PetscCallCeed((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, y_loc_len, &(*ctx)->y_loc));
110558600ac3SJames Wright 
110658600ac3SJames Wright   // Flop counting
110758600ac3SJames Wright   {
110858600ac3SJames Wright     CeedSize ceed_flops_estimate = 0;
110958600ac3SJames Wright 
1110*50f50432SJames Wright     PetscCallCeed((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult, &ceed_flops_estimate));
111158600ac3SJames Wright     (*ctx)->flops_mult = ceed_flops_estimate;
111258600ac3SJames Wright     if (op_mult_transpose) {
1113*50f50432SJames Wright       PetscCallCeed((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult_transpose, &ceed_flops_estimate));
111458600ac3SJames Wright       (*ctx)->flops_mult_transpose = ceed_flops_estimate;
111558600ac3SJames Wright     }
111658600ac3SJames Wright   }
111758600ac3SJames Wright 
111858600ac3SJames Wright   // Check sizes
111958600ac3SJames Wright   if (x_loc_len > 0 || y_loc_len > 0) {
112058600ac3SJames Wright     CeedSize ctx_x_loc_len, ctx_y_loc_len;
112158600ac3SJames Wright     PetscInt X_loc_len, dm_x_loc_len, Y_loc_len, dm_y_loc_len;
112258600ac3SJames Wright     Vec      dm_X_loc, dm_Y_loc;
112358600ac3SJames Wright 
112458600ac3SJames Wright     // -- Input
112558600ac3SJames Wright     PetscCall(DMGetLocalVector(dm_x, &dm_X_loc));
112658600ac3SJames Wright     PetscCall(VecGetLocalSize(dm_X_loc, &dm_x_loc_len));
112758600ac3SJames Wright     PetscCall(DMRestoreLocalVector(dm_x, &dm_X_loc));
112858600ac3SJames Wright     if (X_loc) PetscCall(VecGetLocalSize(X_loc, &X_loc_len));
1129*50f50432SJames Wright     PetscCallCeed((*ctx)->ceed, CeedVectorGetLength((*ctx)->x_loc, &ctx_x_loc_len));
113058600ac3SJames Wright     if (X_loc) PetscCheck(X_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "X_loc must match dm_x dimensions");
113158600ac3SJames Wright     PetscCheck(x_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_x dimensions");
113258600ac3SJames Wright     PetscCheck(x_loc_len == ctx_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "x_loc must match op dimensions");
113358600ac3SJames Wright 
113458600ac3SJames Wright     // -- Output
113558600ac3SJames Wright     PetscCall(DMGetLocalVector(dm_y, &dm_Y_loc));
113658600ac3SJames Wright     PetscCall(VecGetLocalSize(dm_Y_loc, &dm_y_loc_len));
113758600ac3SJames Wright     PetscCall(DMRestoreLocalVector(dm_y, &dm_Y_loc));
1138*50f50432SJames Wright     PetscCallCeed((*ctx)->ceed, CeedVectorGetLength((*ctx)->y_loc, &ctx_y_loc_len));
113958600ac3SJames Wright     PetscCheck(ctx_y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_y dimensions");
114058600ac3SJames Wright 
114158600ac3SJames Wright     // -- Transpose
114258600ac3SJames Wright     if (Y_loc_transpose) {
114358600ac3SJames Wright       PetscCall(VecGetLocalSize(Y_loc_transpose, &Y_loc_len));
114458600ac3SJames Wright       PetscCheck(Y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "Y_loc_transpose must match dm_y dimensions");
114558600ac3SJames Wright     }
114658600ac3SJames Wright   }
114758600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
114858600ac3SJames Wright }
114958600ac3SJames Wright 
115058600ac3SJames Wright /**
115158600ac3SJames Wright   @brief Increment reference counter for `MATCEED` context.
115258600ac3SJames Wright 
115358600ac3SJames Wright   Not collective across MPI processes.
115458600ac3SJames Wright 
115558600ac3SJames Wright   @param[in,out]  ctx  Context data
115658600ac3SJames Wright 
115758600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
115858600ac3SJames Wright **/
115958600ac3SJames Wright PetscErrorCode MatCeedContextReference(MatCeedContext ctx) {
116058600ac3SJames Wright   PetscFunctionBeginUser;
116158600ac3SJames Wright   ctx->ref_count++;
116258600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
116358600ac3SJames Wright }
116458600ac3SJames Wright 
116558600ac3SJames Wright /**
116658600ac3SJames Wright   @brief Copy reference for `MATCEED`.
116758600ac3SJames Wright          Note: If `ctx_copy` is non-null, it is assumed to be a valid pointer to a `MatCeedContext`.
116858600ac3SJames Wright 
116958600ac3SJames Wright   Not collective across MPI processes.
117058600ac3SJames Wright 
117158600ac3SJames Wright   @param[in]   ctx       Context data
117258600ac3SJames Wright   @param[out]  ctx_copy  Copy of pointer to context data
117358600ac3SJames Wright 
117458600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
117558600ac3SJames Wright **/
117658600ac3SJames Wright PetscErrorCode MatCeedContextReferenceCopy(MatCeedContext ctx, MatCeedContext *ctx_copy) {
117758600ac3SJames Wright   PetscFunctionBeginUser;
117858600ac3SJames Wright   PetscCall(MatCeedContextReference(ctx));
117958600ac3SJames Wright   PetscCall(MatCeedContextDestroy(*ctx_copy));
118058600ac3SJames Wright   *ctx_copy = ctx;
118158600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
118258600ac3SJames Wright }
118358600ac3SJames Wright 
118458600ac3SJames Wright /**
118558600ac3SJames Wright   @brief Destroy context data for operator application.
118658600ac3SJames Wright 
118758600ac3SJames Wright   Collective across MPI processes.
118858600ac3SJames Wright 
118958600ac3SJames Wright   @param[in,out]  ctx  Context data for operator evaluation
119058600ac3SJames Wright 
119158600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
119258600ac3SJames Wright **/
119358600ac3SJames Wright PetscErrorCode MatCeedContextDestroy(MatCeedContext ctx) {
119458600ac3SJames Wright   PetscFunctionBeginUser;
119558600ac3SJames Wright   if (!ctx || --ctx->ref_count > 0) PetscFunctionReturn(PETSC_SUCCESS);
119658600ac3SJames Wright 
119758600ac3SJames Wright   // PETSc objects
119858600ac3SJames Wright   PetscCall(DMDestroy(&ctx->dm_x));
119958600ac3SJames Wright   PetscCall(DMDestroy(&ctx->dm_y));
120058600ac3SJames Wright   PetscCall(VecDestroy(&ctx->X_loc));
120158600ac3SJames Wright   PetscCall(VecDestroy(&ctx->Y_loc_transpose));
120258600ac3SJames Wright   PetscCall(MatDestroy(&ctx->mat_assembled_full_internal));
120358600ac3SJames Wright   PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal));
120458600ac3SJames Wright   PetscCall(PetscFree(ctx->internal_mat_type));
120558600ac3SJames Wright   PetscCall(PetscFree(ctx->mats_assembled_full));
120658600ac3SJames Wright   PetscCall(PetscFree(ctx->mats_assembled_pbd));
120758600ac3SJames Wright 
120858600ac3SJames Wright   // libCEED objects
1209*50f50432SJames Wright   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->x_loc));
1210*50f50432SJames Wright   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->y_loc));
1211*50f50432SJames Wright   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_full));
1212*50f50432SJames Wright   PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_pbd));
1213*50f50432SJames Wright   PetscCallCeed(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult));
1214*50f50432SJames Wright   PetscCallCeed(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult_transpose));
121558600ac3SJames Wright   PetscCheck(CeedDestroy(&ctx->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "destroying libCEED context object failed");
121658600ac3SJames Wright 
121758600ac3SJames Wright   // Deallocate
121858600ac3SJames Wright   ctx->is_destroyed = PETSC_TRUE;  // Flag as destroyed in case someone has stale ref
121958600ac3SJames Wright   PetscCall(PetscFree(ctx));
122058600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
122158600ac3SJames Wright }
122258600ac3SJames Wright 
122358600ac3SJames Wright /**
122458600ac3SJames Wright   @brief Compute the diagonal of an operator via libCEED.
122558600ac3SJames Wright 
122658600ac3SJames Wright   Collective across MPI processes.
122758600ac3SJames Wright 
122858600ac3SJames Wright   @param[in]   A  `MATCEED`
122958600ac3SJames Wright   @param[out]  D  Vector holding operator diagonal
123058600ac3SJames Wright 
123158600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
123258600ac3SJames Wright **/
123358600ac3SJames Wright PetscErrorCode MatGetDiagonal_Ceed(Mat A, Vec D) {
123458600ac3SJames Wright   PetscMemType   mem_type;
123558600ac3SJames Wright   Vec            D_loc;
123658600ac3SJames Wright   MatCeedContext ctx;
123758600ac3SJames Wright 
123858600ac3SJames Wright   PetscFunctionBeginUser;
123958600ac3SJames Wright   PetscCall(MatShellGetContext(A, &ctx));
124058600ac3SJames Wright 
124158600ac3SJames Wright   // Place PETSc vector in libCEED vector
124258600ac3SJames Wright   PetscCall(DMGetLocalVector(ctx->dm_x, &D_loc));
124358600ac3SJames Wright   PetscCall(VecP2C(ctx->ceed, D_loc, &mem_type, ctx->x_loc));
124458600ac3SJames Wright 
124558600ac3SJames Wright   // Compute Diagonal
1246*50f50432SJames Wright   PetscCallCeed(ctx->ceed, CeedOperatorLinearAssembleDiagonal(ctx->op_mult, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
124758600ac3SJames Wright 
124858600ac3SJames Wright   // Restore PETSc vector
124958600ac3SJames Wright   PetscCall(VecC2P(ctx->ceed, ctx->x_loc, mem_type, D_loc));
125058600ac3SJames Wright 
125158600ac3SJames Wright   // Local-to-Global
125258600ac3SJames Wright   PetscCall(VecZeroEntries(D));
125358600ac3SJames Wright   PetscCall(DMLocalToGlobal(ctx->dm_x, D_loc, ADD_VALUES, D));
125458600ac3SJames Wright   PetscCall(DMRestoreLocalVector(ctx->dm_x, &D_loc));
125558600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
125658600ac3SJames Wright }
125758600ac3SJames Wright 
125858600ac3SJames Wright /**
125958600ac3SJames Wright   @brief Compute `A X = Y` for a `MATCEED`.
126058600ac3SJames Wright 
126158600ac3SJames Wright   Collective across MPI processes.
126258600ac3SJames Wright 
126358600ac3SJames Wright   @param[in]   A  `MATCEED`
126458600ac3SJames Wright   @param[in]   X  Input PETSc vector
126558600ac3SJames Wright   @param[out]  Y  Output PETSc vector
126658600ac3SJames Wright 
126758600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
126858600ac3SJames Wright **/
126958600ac3SJames Wright PetscErrorCode MatMult_Ceed(Mat A, Vec X, Vec Y) {
127058600ac3SJames Wright   MatCeedContext ctx;
127158600ac3SJames Wright 
127258600ac3SJames Wright   PetscFunctionBeginUser;
127358600ac3SJames Wright   PetscCall(MatShellGetContext(A, &ctx));
127458600ac3SJames Wright   PetscCall(PetscLogEventBegin(ctx->log_event_mult, A, X, Y, 0));
127558600ac3SJames Wright 
127658600ac3SJames Wright   {
127758600ac3SJames Wright     PetscMemType x_mem_type, y_mem_type;
127858600ac3SJames Wright     Vec          X_loc = ctx->X_loc, Y_loc;
127958600ac3SJames Wright 
128058600ac3SJames Wright     // Get local vectors
128158600ac3SJames Wright     if (!ctx->X_loc) PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
128258600ac3SJames Wright     PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
128358600ac3SJames Wright 
128458600ac3SJames Wright     // Global-to-local
128558600ac3SJames Wright     PetscCall(DMGlobalToLocal(ctx->dm_x, X, INSERT_VALUES, X_loc));
128658600ac3SJames Wright 
128758600ac3SJames Wright     // Setup libCEED vectors
128858600ac3SJames Wright     PetscCall(VecReadP2C(ctx->ceed, X_loc, &x_mem_type, ctx->x_loc));
128958600ac3SJames Wright     PetscCall(VecZeroEntries(Y_loc));
129058600ac3SJames Wright     PetscCall(VecP2C(ctx->ceed, Y_loc, &y_mem_type, ctx->y_loc));
129158600ac3SJames Wright 
129258600ac3SJames Wright     // Apply libCEED operator
129358600ac3SJames Wright     PetscCall(PetscLogGpuTimeBegin());
1294*50f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult, ctx->x_loc, ctx->y_loc, CEED_REQUEST_IMMEDIATE));
129558600ac3SJames Wright     PetscCall(PetscLogGpuTimeEnd());
129658600ac3SJames Wright 
129758600ac3SJames Wright     // Restore PETSc vectors
129858600ac3SJames Wright     PetscCall(VecReadC2P(ctx->ceed, ctx->x_loc, x_mem_type, X_loc));
129958600ac3SJames Wright     PetscCall(VecC2P(ctx->ceed, ctx->y_loc, y_mem_type, Y_loc));
130058600ac3SJames Wright 
130158600ac3SJames Wright     // Local-to-global
130258600ac3SJames Wright     PetscCall(VecZeroEntries(Y));
130358600ac3SJames Wright     PetscCall(DMLocalToGlobal(ctx->dm_y, Y_loc, ADD_VALUES, Y));
130458600ac3SJames Wright 
130558600ac3SJames Wright     // Restore local vectors, as needed
130658600ac3SJames Wright     if (!ctx->X_loc) PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
130758600ac3SJames Wright     PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
130858600ac3SJames Wright   }
130958600ac3SJames Wright 
131058600ac3SJames Wright   // Log flops
131158600ac3SJames Wright   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult));
131258600ac3SJames Wright   else PetscCall(PetscLogFlops(ctx->flops_mult));
131358600ac3SJames Wright 
131458600ac3SJames Wright   PetscCall(PetscLogEventEnd(ctx->log_event_mult, A, X, Y, 0));
131558600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
131658600ac3SJames Wright }
131758600ac3SJames Wright 
131858600ac3SJames Wright /**
131958600ac3SJames Wright   @brief Compute `A^T Y = X` for a `MATCEED`.
132058600ac3SJames Wright 
132158600ac3SJames Wright   Collective across MPI processes.
132258600ac3SJames Wright 
132358600ac3SJames Wright   @param[in]   A  `MATCEED`
132458600ac3SJames Wright   @param[in]   Y  Input PETSc vector
132558600ac3SJames Wright   @param[out]  X  Output PETSc vector
132658600ac3SJames Wright 
132758600ac3SJames Wright   @return An error code: 0 - success, otherwise - failure
132858600ac3SJames Wright **/
132958600ac3SJames Wright PetscErrorCode MatMultTranspose_Ceed(Mat A, Vec Y, Vec X) {
133058600ac3SJames Wright   MatCeedContext ctx;
133158600ac3SJames Wright 
133258600ac3SJames Wright   PetscFunctionBeginUser;
133358600ac3SJames Wright   PetscCall(MatShellGetContext(A, &ctx));
133458600ac3SJames Wright   PetscCall(PetscLogEventBegin(ctx->log_event_mult_transpose, A, Y, X, 0));
133558600ac3SJames Wright 
133658600ac3SJames Wright   {
133758600ac3SJames Wright     PetscMemType x_mem_type, y_mem_type;
133858600ac3SJames Wright     Vec          X_loc, Y_loc = ctx->Y_loc_transpose;
133958600ac3SJames Wright 
134058600ac3SJames Wright     // Get local vectors
134158600ac3SJames Wright     if (!ctx->Y_loc_transpose) PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc));
134258600ac3SJames Wright     PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc));
134358600ac3SJames Wright 
134458600ac3SJames Wright     // Global-to-local
134558600ac3SJames Wright     PetscCall(DMGlobalToLocal(ctx->dm_y, Y, INSERT_VALUES, Y_loc));
134658600ac3SJames Wright 
134758600ac3SJames Wright     // Setup libCEED vectors
134858600ac3SJames Wright     PetscCall(VecReadP2C(ctx->ceed, Y_loc, &y_mem_type, ctx->y_loc));
134958600ac3SJames Wright     PetscCall(VecZeroEntries(X_loc));
135058600ac3SJames Wright     PetscCall(VecP2C(ctx->ceed, X_loc, &x_mem_type, ctx->x_loc));
135158600ac3SJames Wright 
135258600ac3SJames Wright     // Apply libCEED operator
135358600ac3SJames Wright     PetscCall(PetscLogGpuTimeBegin());
1354*50f50432SJames Wright     PetscCallCeed(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult_transpose, ctx->y_loc, ctx->x_loc, CEED_REQUEST_IMMEDIATE));
135558600ac3SJames Wright     PetscCall(PetscLogGpuTimeEnd());
135658600ac3SJames Wright 
135758600ac3SJames Wright     // Restore PETSc vectors
135858600ac3SJames Wright     PetscCall(VecReadC2P(ctx->ceed, ctx->y_loc, y_mem_type, Y_loc));
135958600ac3SJames Wright     PetscCall(VecC2P(ctx->ceed, ctx->x_loc, x_mem_type, X_loc));
136058600ac3SJames Wright 
136158600ac3SJames Wright     // Local-to-global
136258600ac3SJames Wright     PetscCall(VecZeroEntries(X));
136358600ac3SJames Wright     PetscCall(DMLocalToGlobal(ctx->dm_x, X_loc, ADD_VALUES, X));
136458600ac3SJames Wright 
136558600ac3SJames Wright     // Restore local vectors, as needed
136658600ac3SJames Wright     if (!ctx->Y_loc_transpose) PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc));
136758600ac3SJames Wright     PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc));
136858600ac3SJames Wright   }
136958600ac3SJames Wright 
137058600ac3SJames Wright   // Log flops
137158600ac3SJames Wright   if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult_transpose));
137258600ac3SJames Wright   else PetscCall(PetscLogFlops(ctx->flops_mult_transpose));
137358600ac3SJames Wright 
137458600ac3SJames Wright   PetscCall(PetscLogEventEnd(ctx->log_event_mult_transpose, A, Y, X, 0));
137558600ac3SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
137658600ac3SJames Wright }
1377