1*c8564c30SJames Wright /// @file 2*c8564c30SJames Wright /// MatCeed and it's related operators 3*c8564c30SJames Wright 4*c8564c30SJames Wright #include <ceed.h> 5*c8564c30SJames Wright #include <ceed/backend.h> 6*c8564c30SJames Wright #include <mat-ceed-impl.h> 7*c8564c30SJames Wright #include <mat-ceed.h> 8*c8564c30SJames Wright #include <petscdmplex.h> 9*c8564c30SJames Wright #include <stdlib.h> 10*c8564c30SJames Wright #include <string.h> 11*c8564c30SJames Wright 12*c8564c30SJames Wright PetscClassId MATCEED_CLASSID; 13*c8564c30SJames Wright PetscLogEvent MATCEED_MULT, MATCEED_MULT_TRANSPOSE; 14*c8564c30SJames Wright 15*c8564c30SJames Wright /** 16*c8564c30SJames Wright @brief Register MATCEED log events. 17*c8564c30SJames Wright 18*c8564c30SJames Wright Not collective across MPI processes. 19*c8564c30SJames Wright 20*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 21*c8564c30SJames Wright **/ 22*c8564c30SJames Wright static PetscErrorCode MatCeedRegisterLogEvents() { 23*c8564c30SJames Wright static bool registered = false; 24*c8564c30SJames Wright 25*c8564c30SJames Wright PetscFunctionBeginUser; 26*c8564c30SJames Wright if (registered) PetscFunctionReturn(PETSC_SUCCESS); 27*c8564c30SJames Wright PetscCall(PetscClassIdRegister("MATCEED", &MATCEED_CLASSID)); 28*c8564c30SJames Wright PetscCall(PetscLogEventRegister("MATCEED Mult", MATCEED_CLASSID, &MATCEED_MULT)); 29*c8564c30SJames Wright PetscCall(PetscLogEventRegister("MATCEED Mult Transpose", MATCEED_CLASSID, &MATCEED_MULT_TRANSPOSE)); 30*c8564c30SJames Wright registered = true; 31*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 32*c8564c30SJames Wright } 33*c8564c30SJames Wright 34*c8564c30SJames Wright /** 35*c8564c30SJames Wright @brief Translate PetscMemType to CeedMemType 36*c8564c30SJames Wright 37*c8564c30SJames Wright @param[in] mem_type PetscMemType 38*c8564c30SJames Wright 39*c8564c30SJames Wright @return Equivalent CeedMemType 40*c8564c30SJames Wright **/ 41*c8564c30SJames Wright static inline CeedMemType MemTypeP2C(PetscMemType mem_type) { return PetscMemTypeDevice(mem_type) ? CEED_MEM_DEVICE : CEED_MEM_HOST; } 42*c8564c30SJames Wright 43*c8564c30SJames Wright /** 44*c8564c30SJames Wright @brief Translate array of `CeedInt` to `PetscInt`. 45*c8564c30SJames Wright If the types differ, `array_ceed` is freed with `free()` and `array_petsc` is allocated with `malloc()`. 46*c8564c30SJames Wright Caller is responsible for freeing `array_petsc` with `free()`. 47*c8564c30SJames Wright 48*c8564c30SJames Wright Not collective across MPI processes. 49*c8564c30SJames Wright 50*c8564c30SJames Wright @param[in] num_entries Number of array entries 51*c8564c30SJames Wright @param[in,out] array_ceed Array of `CeedInt` 52*c8564c30SJames Wright @param[out] array_petsc Array of `PetscInt` 53*c8564c30SJames Wright 54*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 55*c8564c30SJames Wright **/ 56*c8564c30SJames Wright static inline PetscErrorCode IntArrayC2P(PetscInt num_entries, CeedInt **array_ceed, PetscInt **array_petsc) { 57*c8564c30SJames Wright const CeedInt int_c = 0; 58*c8564c30SJames Wright const PetscInt int_p = 0; 59*c8564c30SJames Wright 60*c8564c30SJames Wright PetscFunctionBeginUser; 61*c8564c30SJames Wright if (sizeof(int_c) == sizeof(int_p)) { 62*c8564c30SJames Wright *array_petsc = (PetscInt *)*array_ceed; 63*c8564c30SJames Wright } else { 64*c8564c30SJames Wright *array_petsc = malloc(num_entries * sizeof(PetscInt)); 65*c8564c30SJames Wright for (PetscInt i = 0; i < num_entries; i++) (*array_petsc)[i] = (*array_ceed)[i]; 66*c8564c30SJames Wright free(*array_ceed); 67*c8564c30SJames Wright } 68*c8564c30SJames Wright *array_ceed = NULL; 69*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 70*c8564c30SJames Wright } 71*c8564c30SJames Wright 72*c8564c30SJames Wright /** 73*c8564c30SJames Wright @brief Transfer array from PETSc `Vec` to `CeedVector`. 74*c8564c30SJames Wright 75*c8564c30SJames Wright Collective across MPI processes. 76*c8564c30SJames Wright 77*c8564c30SJames Wright @param[in] ceed libCEED context 78*c8564c30SJames Wright @param[in] X_petsc PETSc `Vec` 79*c8564c30SJames Wright @param[out] mem_type PETSc `MemType` 80*c8564c30SJames Wright @param[out] x_ceed `CeedVector` 81*c8564c30SJames Wright 82*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 83*c8564c30SJames Wright **/ 84*c8564c30SJames Wright static inline PetscErrorCode VecP2C(Ceed ceed, Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) { 85*c8564c30SJames Wright PetscScalar *x; 86*c8564c30SJames Wright 87*c8564c30SJames Wright PetscFunctionBeginUser; 88*c8564c30SJames Wright PetscCall(VecGetArrayAndMemType(X_petsc, &x, mem_type)); 89*c8564c30SJames Wright PetscCeedCall(ceed, CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x)); 90*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 91*c8564c30SJames Wright } 92*c8564c30SJames Wright 93*c8564c30SJames Wright /** 94*c8564c30SJames Wright @brief Transfer array from `CeedVector` to PETSc `Vec`. 95*c8564c30SJames Wright 96*c8564c30SJames Wright Collective across MPI processes. 97*c8564c30SJames Wright 98*c8564c30SJames Wright @param[in] ceed libCEED context 99*c8564c30SJames Wright @param[in] x_ceed `CeedVector` 100*c8564c30SJames Wright @param[in] mem_type PETSc `MemType` 101*c8564c30SJames Wright @param[out] X_petsc PETSc `Vec` 102*c8564c30SJames Wright 103*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 104*c8564c30SJames Wright **/ 105*c8564c30SJames Wright static inline PetscErrorCode VecC2P(Ceed ceed, CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) { 106*c8564c30SJames Wright PetscScalar *x; 107*c8564c30SJames Wright 108*c8564c30SJames Wright PetscFunctionBeginUser; 109*c8564c30SJames Wright PetscCeedCall(ceed, CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x)); 110*c8564c30SJames Wright PetscCall(VecRestoreArrayAndMemType(X_petsc, &x)); 111*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 112*c8564c30SJames Wright } 113*c8564c30SJames Wright 114*c8564c30SJames Wright /** 115*c8564c30SJames Wright @brief Transfer read only array from PETSc `Vec` to `CeedVector`. 116*c8564c30SJames Wright 117*c8564c30SJames Wright Collective across MPI processes. 118*c8564c30SJames Wright 119*c8564c30SJames Wright @param[in] ceed libCEED context 120*c8564c30SJames Wright @param[in] X_petsc PETSc `Vec` 121*c8564c30SJames Wright @param[out] mem_type PETSc `MemType` 122*c8564c30SJames Wright @param[out] x_ceed `CeedVector` 123*c8564c30SJames Wright 124*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 125*c8564c30SJames Wright **/ 126*c8564c30SJames Wright static inline PetscErrorCode VecReadP2C(Ceed ceed, Vec X_petsc, PetscMemType *mem_type, CeedVector x_ceed) { 127*c8564c30SJames Wright PetscScalar *x; 128*c8564c30SJames Wright 129*c8564c30SJames Wright PetscFunctionBeginUser; 130*c8564c30SJames Wright PetscCall(VecGetArrayReadAndMemType(X_petsc, (const PetscScalar **)&x, mem_type)); 131*c8564c30SJames Wright PetscCeedCall(ceed, CeedVectorSetArray(x_ceed, MemTypeP2C(*mem_type), CEED_USE_POINTER, x)); 132*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 133*c8564c30SJames Wright } 134*c8564c30SJames Wright 135*c8564c30SJames Wright /** 136*c8564c30SJames Wright @brief Transfer read only array from `CeedVector` to PETSc `Vec`. 137*c8564c30SJames Wright 138*c8564c30SJames Wright Collective across MPI processes. 139*c8564c30SJames Wright 140*c8564c30SJames Wright @param[in] ceed libCEED context 141*c8564c30SJames Wright @param[in] x_ceed `CeedVector` 142*c8564c30SJames Wright @param[in] mem_type PETSc `MemType` 143*c8564c30SJames Wright @param[out] X_petsc PETSc `Vec` 144*c8564c30SJames Wright 145*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 146*c8564c30SJames Wright **/ 147*c8564c30SJames Wright static inline PetscErrorCode VecReadC2P(Ceed ceed, CeedVector x_ceed, PetscMemType mem_type, Vec X_petsc) { 148*c8564c30SJames Wright PetscScalar *x; 149*c8564c30SJames Wright 150*c8564c30SJames Wright PetscFunctionBeginUser; 151*c8564c30SJames Wright PetscCeedCall(ceed, CeedVectorTakeArray(x_ceed, MemTypeP2C(mem_type), &x)); 152*c8564c30SJames Wright PetscCall(VecRestoreArrayReadAndMemType(X_petsc, (const PetscScalar **)&x)); 153*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 154*c8564c30SJames Wright } 155*c8564c30SJames Wright 156*c8564c30SJames Wright /** 157*c8564c30SJames Wright @brief Setup inner `Mat` for `PC` operations not directly supported by libCEED. 158*c8564c30SJames Wright 159*c8564c30SJames Wright Collective across MPI processes. 160*c8564c30SJames Wright 161*c8564c30SJames Wright @param[in] mat_ceed `MATCEED` to setup 162*c8564c30SJames Wright @param[out] mat_inner Inner `Mat` 163*c8564c30SJames Wright 164*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 165*c8564c30SJames Wright **/ 166*c8564c30SJames Wright static PetscErrorCode MatCeedSetupInnerMat(Mat mat_ceed, Mat *mat_inner) { 167*c8564c30SJames Wright MatCeedContext ctx; 168*c8564c30SJames Wright 169*c8564c30SJames Wright PetscFunctionBeginUser; 170*c8564c30SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 171*c8564c30SJames Wright 172*c8564c30SJames Wright PetscCheck(ctx->dm_x == ctx->dm_y, PetscObjectComm((PetscObject)mat_ceed), PETSC_ERR_SUP, "PC only supported for MATCEED on a single DM"); 173*c8564c30SJames Wright 174*c8564c30SJames Wright // Check cl mat type 175*c8564c30SJames Wright { 176*c8564c30SJames Wright PetscBool is_internal_mat_type_cl = PETSC_FALSE; 177*c8564c30SJames Wright char internal_mat_type_cl[64]; 178*c8564c30SJames Wright 179*c8564c30SJames Wright // Check for specific CL inner mat type for this Mat 180*c8564c30SJames Wright { 181*c8564c30SJames Wright const char *mat_ceed_prefix = NULL; 182*c8564c30SJames Wright 183*c8564c30SJames Wright PetscCall(MatGetOptionsPrefix(mat_ceed, &mat_ceed_prefix)); 184*c8564c30SJames Wright PetscOptionsBegin(PetscObjectComm((PetscObject)mat_ceed), mat_ceed_prefix, "", NULL); 185*c8564c30SJames Wright PetscCall(PetscOptionsFList("-ceed_inner_mat_type", "MATCEED inner assembled MatType for PC support", NULL, MatList, internal_mat_type_cl, 186*c8564c30SJames Wright internal_mat_type_cl, sizeof(internal_mat_type_cl), &is_internal_mat_type_cl)); 187*c8564c30SJames Wright PetscOptionsEnd(); 188*c8564c30SJames Wright if (is_internal_mat_type_cl) { 189*c8564c30SJames Wright PetscCall(PetscFree(ctx->internal_mat_type)); 190*c8564c30SJames Wright PetscCall(PetscStrallocpy(internal_mat_type_cl, &ctx->internal_mat_type)); 191*c8564c30SJames Wright } 192*c8564c30SJames Wright } 193*c8564c30SJames Wright } 194*c8564c30SJames Wright 195*c8564c30SJames Wright // Create sparse matrix 196*c8564c30SJames Wright { 197*c8564c30SJames Wright MatType dm_mat_type, dm_mat_type_copy; 198*c8564c30SJames Wright 199*c8564c30SJames Wright PetscCall(DMGetMatType(ctx->dm_x, &dm_mat_type)); 200*c8564c30SJames Wright PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy)); 201*c8564c30SJames Wright PetscCall(DMSetMatType(ctx->dm_x, ctx->internal_mat_type)); 202*c8564c30SJames Wright PetscCall(DMCreateMatrix(ctx->dm_x, mat_inner)); 203*c8564c30SJames Wright PetscCall(DMSetMatType(ctx->dm_x, dm_mat_type_copy)); 204*c8564c30SJames Wright PetscCall(PetscFree(dm_mat_type_copy)); 205*c8564c30SJames Wright } 206*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 207*c8564c30SJames Wright } 208*c8564c30SJames Wright 209*c8564c30SJames Wright /** 210*c8564c30SJames Wright @brief Assemble the point block diagonal of a `MATCEED` into a `MATAIJ` or similar. 211*c8564c30SJames Wright The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`. 212*c8564c30SJames Wright The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail. 213*c8564c30SJames Wright 214*c8564c30SJames Wright Collective across MPI processes. 215*c8564c30SJames Wright 216*c8564c30SJames Wright @param[in] mat_ceed `MATCEED` to assemble 217*c8564c30SJames Wright @param[in,out] mat_coo `MATAIJ` or similar to assemble into 218*c8564c30SJames Wright 219*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 220*c8564c30SJames Wright **/ 221*c8564c30SJames Wright static PetscErrorCode MatCeedAssemblePointBlockDiagonalCOO(Mat mat_ceed, Mat mat_coo) { 222*c8564c30SJames Wright MatCeedContext ctx; 223*c8564c30SJames Wright 224*c8564c30SJames Wright PetscFunctionBeginUser; 225*c8564c30SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 226*c8564c30SJames Wright 227*c8564c30SJames Wright // Check if COO pattern set 228*c8564c30SJames Wright { 229*c8564c30SJames Wright PetscInt index = -1; 230*c8564c30SJames Wright 231*c8564c30SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) { 232*c8564c30SJames Wright if (ctx->mats_assembled_pbd[i] == mat_coo) index = i; 233*c8564c30SJames Wright } 234*c8564c30SJames Wright if (index == -1) { 235*c8564c30SJames Wright PetscInt *rows_petsc = NULL, *cols_petsc = NULL; 236*c8564c30SJames Wright CeedInt *rows_ceed, *cols_ceed; 237*c8564c30SJames Wright PetscCount num_entries; 238*c8564c30SJames Wright PetscLogStage stage_amg_setup; 239*c8564c30SJames Wright 240*c8564c30SJames Wright // -- Assemble sparsity pattern if mat hasn't been assembled before 241*c8564c30SJames Wright PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup)); 242*c8564c30SJames Wright if (stage_amg_setup == -1) { 243*c8564c30SJames Wright PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup)); 244*c8564c30SJames Wright } 245*c8564c30SJames Wright PetscCall(PetscLogStagePush(stage_amg_setup)); 246*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonalSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed)); 247*c8564c30SJames Wright PetscCall(IntArrayC2P(num_entries, &rows_ceed, &rows_petsc)); 248*c8564c30SJames Wright PetscCall(IntArrayC2P(num_entries, &cols_ceed, &cols_petsc)); 249*c8564c30SJames Wright PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc)); 250*c8564c30SJames Wright free(rows_petsc); 251*c8564c30SJames Wright free(cols_petsc); 252*c8564c30SJames Wright if (!ctx->coo_values_pbd) PetscCeedCall(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_pbd)); 253*c8564c30SJames Wright PetscCall(PetscRealloc(++ctx->num_mats_assembled_pbd * sizeof(Mat), &ctx->mats_assembled_pbd)); 254*c8564c30SJames Wright ctx->mats_assembled_pbd[ctx->num_mats_assembled_pbd - 1] = mat_coo; 255*c8564c30SJames Wright PetscCall(PetscLogStagePop()); 256*c8564c30SJames Wright } 257*c8564c30SJames Wright } 258*c8564c30SJames Wright 259*c8564c30SJames Wright // Assemble mat_ceed 260*c8564c30SJames Wright PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY)); 261*c8564c30SJames Wright { 262*c8564c30SJames Wright const CeedScalar *values; 263*c8564c30SJames Wright MatType mat_type; 264*c8564c30SJames Wright CeedMemType mem_type = CEED_MEM_HOST; 265*c8564c30SJames Wright PetscBool is_spd, is_spd_known; 266*c8564c30SJames Wright 267*c8564c30SJames Wright PetscCall(MatGetType(mat_coo, &mat_type)); 268*c8564c30SJames Wright if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE; 269*c8564c30SJames Wright else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE; 270*c8564c30SJames Wright else mem_type = CEED_MEM_HOST; 271*c8564c30SJames Wright 272*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonal(ctx->op_mult, ctx->coo_values_pbd, CEED_REQUEST_IMMEDIATE)); 273*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_pbd, mem_type, &values)); 274*c8564c30SJames Wright PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES)); 275*c8564c30SJames Wright PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd)); 276*c8564c30SJames Wright if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd)); 277*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_pbd, &values)); 278*c8564c30SJames Wright } 279*c8564c30SJames Wright PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY)); 280*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 281*c8564c30SJames Wright } 282*c8564c30SJames Wright 283*c8564c30SJames Wright /** 284*c8564c30SJames Wright @brief Assemble inner `Mat` for diagonal `PC` operations 285*c8564c30SJames Wright 286*c8564c30SJames Wright Collective across MPI processes. 287*c8564c30SJames Wright 288*c8564c30SJames Wright @param[in] mat_ceed `MATCEED` to invert 289*c8564c30SJames Wright @param[in] use_ceed_pbd Boolean flag to use libCEED PBD assembly 290*c8564c30SJames Wright @param[out] mat_inner Inner `Mat` for diagonal operations 291*c8564c30SJames Wright 292*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 293*c8564c30SJames Wright **/ 294*c8564c30SJames Wright static PetscErrorCode MatCeedAssembleInnerBlockDiagonalMat(Mat mat_ceed, PetscBool use_ceed_pbd, Mat *mat_inner) { 295*c8564c30SJames Wright MatCeedContext ctx; 296*c8564c30SJames Wright 297*c8564c30SJames Wright PetscFunctionBeginUser; 298*c8564c30SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 299*c8564c30SJames Wright if (use_ceed_pbd) { 300*c8564c30SJames Wright // Check if COO pattern set 301*c8564c30SJames Wright if (!ctx->mat_assembled_pbd_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_pbd_internal)); 302*c8564c30SJames Wright 303*c8564c30SJames Wright // Assemble mat_assembled_full_internal 304*c8564c30SJames Wright PetscCall(MatCeedAssemblePointBlockDiagonalCOO(mat_ceed, ctx->mat_assembled_pbd_internal)); 305*c8564c30SJames Wright if (mat_inner) *mat_inner = ctx->mat_assembled_pbd_internal; 306*c8564c30SJames Wright } else { 307*c8564c30SJames Wright // Check if COO pattern set 308*c8564c30SJames Wright if (!ctx->mat_assembled_full_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_full_internal)); 309*c8564c30SJames Wright 310*c8564c30SJames Wright // Assemble mat_assembled_full_internal 311*c8564c30SJames Wright PetscCall(MatCeedAssembleCOO(mat_ceed, ctx->mat_assembled_full_internal)); 312*c8564c30SJames Wright if (mat_inner) *mat_inner = ctx->mat_assembled_full_internal; 313*c8564c30SJames Wright } 314*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 315*c8564c30SJames Wright } 316*c8564c30SJames Wright 317*c8564c30SJames Wright /** 318*c8564c30SJames Wright @brief Get `MATCEED` diagonal block for Jacobi. 319*c8564c30SJames Wright 320*c8564c30SJames Wright Collective across MPI processes. 321*c8564c30SJames Wright 322*c8564c30SJames Wright @param[in] mat_ceed `MATCEED` to invert 323*c8564c30SJames Wright @param[out] mat_block The diagonal block matrix 324*c8564c30SJames Wright 325*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 326*c8564c30SJames Wright **/ 327*c8564c30SJames Wright static PetscErrorCode MatGetDiagonalBlock_Ceed(Mat mat_ceed, Mat *mat_block) { 328*c8564c30SJames Wright Mat mat_inner = NULL; 329*c8564c30SJames Wright MatCeedContext ctx; 330*c8564c30SJames Wright 331*c8564c30SJames Wright PetscFunctionBeginUser; 332*c8564c30SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 333*c8564c30SJames Wright 334*c8564c30SJames Wright // Assemble inner mat if needed 335*c8564c30SJames Wright PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner)); 336*c8564c30SJames Wright 337*c8564c30SJames Wright // Get block diagonal 338*c8564c30SJames Wright PetscCall(MatGetDiagonalBlock(mat_inner, mat_block)); 339*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 340*c8564c30SJames Wright } 341*c8564c30SJames Wright 342*c8564c30SJames Wright /** 343*c8564c30SJames Wright @brief Invert `MATCEED` diagonal block for Jacobi. 344*c8564c30SJames Wright 345*c8564c30SJames Wright Collective across MPI processes. 346*c8564c30SJames Wright 347*c8564c30SJames Wright @param[in] mat_ceed `MATCEED` to invert 348*c8564c30SJames Wright @param[out] values The block inverses in column major order 349*c8564c30SJames Wright 350*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 351*c8564c30SJames Wright **/ 352*c8564c30SJames Wright static PetscErrorCode MatInvertBlockDiagonal_Ceed(Mat mat_ceed, const PetscScalar **values) { 353*c8564c30SJames Wright Mat mat_inner = NULL; 354*c8564c30SJames Wright MatCeedContext ctx; 355*c8564c30SJames Wright 356*c8564c30SJames Wright PetscFunctionBeginUser; 357*c8564c30SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 358*c8564c30SJames Wright 359*c8564c30SJames Wright // Assemble inner mat if needed 360*c8564c30SJames Wright PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner)); 361*c8564c30SJames Wright 362*c8564c30SJames Wright // Invert PB diagonal 363*c8564c30SJames Wright PetscCall(MatInvertBlockDiagonal(mat_inner, values)); 364*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 365*c8564c30SJames Wright } 366*c8564c30SJames Wright 367*c8564c30SJames Wright /** 368*c8564c30SJames Wright @brief Invert `MATCEED` variable diagonal block for Jacobi. 369*c8564c30SJames Wright 370*c8564c30SJames Wright Collective across MPI processes. 371*c8564c30SJames Wright 372*c8564c30SJames Wright @param[in] mat_ceed `MATCEED` to invert 373*c8564c30SJames Wright @param[in] num_blocks The number of blocks on the process 374*c8564c30SJames Wright @param[in] block_sizes The size of each block on the process 375*c8564c30SJames Wright @param[out] values The block inverses in column major order 376*c8564c30SJames Wright 377*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 378*c8564c30SJames Wright **/ 379*c8564c30SJames Wright static PetscErrorCode MatInvertVariableBlockDiagonal_Ceed(Mat mat_ceed, PetscInt num_blocks, const PetscInt *block_sizes, PetscScalar *values) { 380*c8564c30SJames Wright Mat mat_inner = NULL; 381*c8564c30SJames Wright MatCeedContext ctx; 382*c8564c30SJames Wright 383*c8564c30SJames Wright PetscFunctionBeginUser; 384*c8564c30SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 385*c8564c30SJames Wright 386*c8564c30SJames Wright // Assemble inner mat if needed 387*c8564c30SJames Wright PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_vpbd_valid, &mat_inner)); 388*c8564c30SJames Wright 389*c8564c30SJames Wright // Invert PB diagonal 390*c8564c30SJames Wright PetscCall(MatInvertVariableBlockDiagonal(mat_inner, num_blocks, block_sizes, values)); 391*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 392*c8564c30SJames Wright } 393*c8564c30SJames Wright 394*c8564c30SJames Wright // ----------------------------------------------------------------------------- 395*c8564c30SJames Wright // MatCeed 396*c8564c30SJames Wright // ----------------------------------------------------------------------------- 397*c8564c30SJames Wright 398*c8564c30SJames Wright /** 399*c8564c30SJames Wright @brief Create PETSc `Mat` from libCEED operators. 400*c8564c30SJames Wright 401*c8564c30SJames Wright Collective across MPI processes. 402*c8564c30SJames Wright 403*c8564c30SJames Wright @param[in] dm_x Input `DM` 404*c8564c30SJames Wright @param[in] dm_y Output `DM` 405*c8564c30SJames Wright @param[in] op_mult `CeedOperator` for forward evaluation 406*c8564c30SJames Wright @param[in] op_mult_transpose `CeedOperator` for transpose evaluation 407*c8564c30SJames Wright @param[out] mat New MatCeed 408*c8564c30SJames Wright 409*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 410*c8564c30SJames Wright **/ 411*c8564c30SJames Wright PetscErrorCode MatCeedCreate(DM dm_x, DM dm_y, CeedOperator op_mult, CeedOperator op_mult_transpose, Mat *mat) { 412*c8564c30SJames Wright PetscInt X_l_size, X_g_size, Y_l_size, Y_g_size; 413*c8564c30SJames Wright VecType vec_type; 414*c8564c30SJames Wright MatCeedContext ctx; 415*c8564c30SJames Wright 416*c8564c30SJames Wright PetscFunctionBeginUser; 417*c8564c30SJames Wright PetscCall(MatCeedRegisterLogEvents()); 418*c8564c30SJames Wright 419*c8564c30SJames Wright // Collect context data 420*c8564c30SJames Wright PetscCall(DMGetVecType(dm_x, &vec_type)); 421*c8564c30SJames Wright { 422*c8564c30SJames Wright Vec X; 423*c8564c30SJames Wright 424*c8564c30SJames Wright PetscCall(DMGetGlobalVector(dm_x, &X)); 425*c8564c30SJames Wright PetscCall(VecGetSize(X, &X_g_size)); 426*c8564c30SJames Wright PetscCall(VecGetLocalSize(X, &X_l_size)); 427*c8564c30SJames Wright PetscCall(DMRestoreGlobalVector(dm_x, &X)); 428*c8564c30SJames Wright } 429*c8564c30SJames Wright if (dm_y) { 430*c8564c30SJames Wright Vec Y; 431*c8564c30SJames Wright 432*c8564c30SJames Wright PetscCall(DMGetGlobalVector(dm_y, &Y)); 433*c8564c30SJames Wright PetscCall(VecGetSize(Y, &Y_g_size)); 434*c8564c30SJames Wright PetscCall(VecGetLocalSize(Y, &Y_l_size)); 435*c8564c30SJames Wright PetscCall(DMRestoreGlobalVector(dm_y, &Y)); 436*c8564c30SJames Wright } else { 437*c8564c30SJames Wright dm_y = dm_x; 438*c8564c30SJames Wright Y_g_size = X_g_size; 439*c8564c30SJames Wright Y_l_size = X_l_size; 440*c8564c30SJames Wright } 441*c8564c30SJames Wright // Create context 442*c8564c30SJames Wright { 443*c8564c30SJames Wright Vec X_loc, Y_loc_transpose = NULL; 444*c8564c30SJames Wright 445*c8564c30SJames Wright PetscCall(DMCreateLocalVector(dm_x, &X_loc)); 446*c8564c30SJames Wright PetscCall(VecZeroEntries(X_loc)); 447*c8564c30SJames Wright if (op_mult_transpose) { 448*c8564c30SJames Wright PetscCall(DMCreateLocalVector(dm_y, &Y_loc_transpose)); 449*c8564c30SJames Wright PetscCall(VecZeroEntries(Y_loc_transpose)); 450*c8564c30SJames Wright } 451*c8564c30SJames Wright PetscCall(MatCeedContextCreate(dm_x, dm_y, X_loc, Y_loc_transpose, op_mult, op_mult_transpose, MATCEED_MULT, MATCEED_MULT_TRANSPOSE, &ctx)); 452*c8564c30SJames Wright PetscCall(VecDestroy(&X_loc)); 453*c8564c30SJames Wright PetscCall(VecDestroy(&Y_loc_transpose)); 454*c8564c30SJames Wright } 455*c8564c30SJames Wright 456*c8564c30SJames Wright // Create mat 457*c8564c30SJames Wright PetscCall(MatCreateShell(PetscObjectComm((PetscObject)dm_x), Y_l_size, X_l_size, Y_g_size, X_g_size, ctx, mat)); 458*c8564c30SJames Wright PetscCall(PetscObjectChangeTypeName((PetscObject)*mat, MATCEED)); 459*c8564c30SJames Wright // -- Set block and variable block sizes 460*c8564c30SJames Wright if (dm_x == dm_y) { 461*c8564c30SJames Wright MatType dm_mat_type, dm_mat_type_copy; 462*c8564c30SJames Wright Mat temp_mat; 463*c8564c30SJames Wright 464*c8564c30SJames Wright PetscCall(DMGetMatType(dm_x, &dm_mat_type)); 465*c8564c30SJames Wright PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy)); 466*c8564c30SJames Wright PetscCall(DMSetMatType(dm_x, MATAIJ)); 467*c8564c30SJames Wright PetscCall(DMCreateMatrix(dm_x, &temp_mat)); 468*c8564c30SJames Wright PetscCall(DMSetMatType(dm_x, dm_mat_type_copy)); 469*c8564c30SJames Wright PetscCall(PetscFree(dm_mat_type_copy)); 470*c8564c30SJames Wright 471*c8564c30SJames Wright { 472*c8564c30SJames Wright PetscInt block_size, num_blocks, max_vblock_size = PETSC_INT_MAX; 473*c8564c30SJames Wright const PetscInt *vblock_sizes; 474*c8564c30SJames Wright 475*c8564c30SJames Wright // -- Get block sizes 476*c8564c30SJames Wright PetscCall(MatGetBlockSize(temp_mat, &block_size)); 477*c8564c30SJames Wright PetscCall(MatGetVariableBlockSizes(temp_mat, &num_blocks, &vblock_sizes)); 478*c8564c30SJames Wright { 479*c8564c30SJames Wright PetscInt local_min_max[2] = {0}, global_min_max[2] = {0, PETSC_INT_MAX}; 480*c8564c30SJames Wright 481*c8564c30SJames Wright for (PetscInt i = 0; i < num_blocks; i++) local_min_max[1] = PetscMax(local_min_max[1], vblock_sizes[i]); 482*c8564c30SJames Wright PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_min_max, global_min_max)); 483*c8564c30SJames Wright max_vblock_size = global_min_max[1]; 484*c8564c30SJames Wright } 485*c8564c30SJames Wright 486*c8564c30SJames Wright // -- Copy block sizes 487*c8564c30SJames Wright if (block_size > 1) PetscCall(MatSetBlockSize(*mat, block_size)); 488*c8564c30SJames Wright if (num_blocks) PetscCall(MatSetVariableBlockSizes(*mat, num_blocks, (PetscInt *)vblock_sizes)); 489*c8564c30SJames Wright 490*c8564c30SJames Wright // -- Check libCEED compatibility 491*c8564c30SJames Wright { 492*c8564c30SJames Wright bool is_composite; 493*c8564c30SJames Wright 494*c8564c30SJames Wright ctx->is_ceed_pbd_valid = PETSC_TRUE; 495*c8564c30SJames Wright ctx->is_ceed_vpbd_valid = PETSC_TRUE; 496*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorIsComposite(op_mult, &is_composite)); 497*c8564c30SJames Wright if (is_composite) { 498*c8564c30SJames Wright CeedInt num_sub_operators; 499*c8564c30SJames Wright CeedOperator *sub_operators; 500*c8564c30SJames Wright 501*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedCompositeOperatorGetNumSub(op_mult, &num_sub_operators)); 502*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedCompositeOperatorGetSubList(op_mult, &sub_operators)); 503*c8564c30SJames Wright for (CeedInt i = 0; i < num_sub_operators; i++) { 504*c8564c30SJames Wright CeedInt num_bases, num_comp; 505*c8564c30SJames Wright CeedBasis *active_bases; 506*c8564c30SJames Wright CeedOperatorAssemblyData assembly_data; 507*c8564c30SJames Wright 508*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorGetOperatorAssemblyData(sub_operators[i], &assembly_data)); 509*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL)); 510*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp)); 511*c8564c30SJames Wright if (num_bases > 1) { 512*c8564c30SJames Wright ctx->is_ceed_pbd_valid = PETSC_FALSE; 513*c8564c30SJames Wright ctx->is_ceed_vpbd_valid = PETSC_FALSE; 514*c8564c30SJames Wright } 515*c8564c30SJames Wright if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE; 516*c8564c30SJames Wright if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE; 517*c8564c30SJames Wright } 518*c8564c30SJames Wright } else { 519*c8564c30SJames Wright // LCOV_EXCL_START 520*c8564c30SJames Wright CeedInt num_bases, num_comp; 521*c8564c30SJames Wright CeedBasis *active_bases; 522*c8564c30SJames Wright CeedOperatorAssemblyData assembly_data; 523*c8564c30SJames Wright 524*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorGetOperatorAssemblyData(op_mult, &assembly_data)); 525*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL)); 526*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp)); 527*c8564c30SJames Wright if (num_bases > 1) { 528*c8564c30SJames Wright ctx->is_ceed_pbd_valid = PETSC_FALSE; 529*c8564c30SJames Wright ctx->is_ceed_vpbd_valid = PETSC_FALSE; 530*c8564c30SJames Wright } 531*c8564c30SJames Wright if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE; 532*c8564c30SJames Wright if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE; 533*c8564c30SJames Wright // LCOV_EXCL_STOP 534*c8564c30SJames Wright } 535*c8564c30SJames Wright { 536*c8564c30SJames Wright PetscInt local_is_valid[2], global_is_valid[2]; 537*c8564c30SJames Wright 538*c8564c30SJames Wright local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_pbd_valid; 539*c8564c30SJames Wright PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid)); 540*c8564c30SJames Wright ctx->is_ceed_pbd_valid = global_is_valid[0]; 541*c8564c30SJames Wright local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_vpbd_valid; 542*c8564c30SJames Wright PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid)); 543*c8564c30SJames Wright ctx->is_ceed_vpbd_valid = global_is_valid[0]; 544*c8564c30SJames Wright } 545*c8564c30SJames Wright } 546*c8564c30SJames Wright } 547*c8564c30SJames Wright PetscCall(MatDestroy(&temp_mat)); 548*c8564c30SJames Wright } 549*c8564c30SJames Wright // -- Set internal mat type 550*c8564c30SJames Wright { 551*c8564c30SJames Wright VecType vec_type; 552*c8564c30SJames Wright MatType internal_mat_type = MATAIJ; 553*c8564c30SJames Wright 554*c8564c30SJames Wright PetscCall(VecGetType(ctx->X_loc, &vec_type)); 555*c8564c30SJames Wright if (strstr(vec_type, VECCUDA)) internal_mat_type = MATAIJCUSPARSE; 556*c8564c30SJames Wright else if (strstr(vec_type, VECKOKKOS)) internal_mat_type = MATAIJKOKKOS; 557*c8564c30SJames Wright else internal_mat_type = MATAIJ; 558*c8564c30SJames Wright PetscCall(PetscStrallocpy(internal_mat_type, &ctx->internal_mat_type)); 559*c8564c30SJames Wright } 560*c8564c30SJames Wright // -- Set mat operations 561*c8564c30SJames Wright PetscCall(MatShellSetContextDestroy(*mat, (PetscErrorCode(*)(void *))MatCeedContextDestroy)); 562*c8564c30SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_MULT, (void (*)(void))MatMult_Ceed)); 563*c8564c30SJames Wright if (op_mult_transpose) PetscCall(MatShellSetOperation(*mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed)); 564*c8564c30SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed)); 565*c8564c30SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed)); 566*c8564c30SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed)); 567*c8564c30SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed)); 568*c8564c30SJames Wright PetscCall(MatShellSetVecType(*mat, vec_type)); 569*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 570*c8564c30SJames Wright } 571*c8564c30SJames Wright 572*c8564c30SJames Wright /** 573*c8564c30SJames Wright @brief Copy `MATCEED` into a compatible `Mat` with type `MatShell` or `MATCEED`. 574*c8564c30SJames Wright 575*c8564c30SJames Wright Collective across MPI processes. 576*c8564c30SJames Wright 577*c8564c30SJames Wright @param[in] mat_ceed `MATCEED` to copy from 578*c8564c30SJames Wright @param[out] mat_other `MatShell` or `MATCEED` to copy into 579*c8564c30SJames Wright 580*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 581*c8564c30SJames Wright **/ 582*c8564c30SJames Wright PetscErrorCode MatCeedCopy(Mat mat_ceed, Mat mat_other) { 583*c8564c30SJames Wright PetscFunctionBeginUser; 584*c8564c30SJames Wright PetscCall(MatCeedRegisterLogEvents()); 585*c8564c30SJames Wright 586*c8564c30SJames Wright // Check type compatibility 587*c8564c30SJames Wright { 588*c8564c30SJames Wright MatType mat_type_ceed, mat_type_other; 589*c8564c30SJames Wright 590*c8564c30SJames Wright PetscCall(MatGetType(mat_ceed, &mat_type_ceed)); 591*c8564c30SJames Wright PetscCheck(!strcmp(mat_type_ceed, MATCEED), PETSC_COMM_SELF, PETSC_ERR_LIB, "mat_ceed must have type " MATCEED); 592*c8564c30SJames Wright PetscCall(MatGetType(mat_ceed, &mat_type_other)); 593*c8564c30SJames Wright PetscCheck(!strcmp(mat_type_other, MATCEED) || !strcmp(mat_type_other, MATSHELL), PETSC_COMM_SELF, PETSC_ERR_LIB, 594*c8564c30SJames Wright "mat_other must have type " MATCEED " or " MATSHELL); 595*c8564c30SJames Wright } 596*c8564c30SJames Wright 597*c8564c30SJames Wright // Check dimension compatibility 598*c8564c30SJames Wright { 599*c8564c30SJames Wright PetscInt X_l_ceed_size, X_g_ceed_size, Y_l_ceed_size, Y_g_ceed_size, X_l_other_size, X_g_other_size, Y_l_other_size, Y_g_other_size; 600*c8564c30SJames Wright 601*c8564c30SJames Wright PetscCall(MatGetSize(mat_ceed, &Y_g_ceed_size, &X_g_ceed_size)); 602*c8564c30SJames Wright PetscCall(MatGetLocalSize(mat_ceed, &Y_l_ceed_size, &X_l_ceed_size)); 603*c8564c30SJames Wright PetscCall(MatGetSize(mat_ceed, &Y_g_other_size, &X_g_other_size)); 604*c8564c30SJames Wright PetscCall(MatGetLocalSize(mat_ceed, &Y_l_other_size, &X_l_other_size)); 605*c8564c30SJames Wright PetscCheck((Y_g_ceed_size == Y_g_other_size) && (X_g_ceed_size == X_g_other_size) && (Y_l_ceed_size == Y_l_other_size) && 606*c8564c30SJames Wright (X_l_ceed_size == X_l_other_size), 607*c8564c30SJames Wright PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, 608*c8564c30SJames Wright "mat_ceed and mat_other must have compatible sizes; found mat_ceed (Global: %" PetscInt_FMT ", %" PetscInt_FMT 609*c8564c30SJames Wright "; Local: %" PetscInt_FMT ", %" PetscInt_FMT ") mat_other (Global: %" PetscInt_FMT ", %" PetscInt_FMT "; Local: %" PetscInt_FMT 610*c8564c30SJames Wright ", %" PetscInt_FMT ")", 611*c8564c30SJames Wright Y_g_ceed_size, X_g_ceed_size, Y_l_ceed_size, X_l_ceed_size, Y_g_other_size, X_g_other_size, Y_l_other_size, X_l_other_size); 612*c8564c30SJames Wright } 613*c8564c30SJames Wright 614*c8564c30SJames Wright // Convert 615*c8564c30SJames Wright { 616*c8564c30SJames Wright VecType vec_type; 617*c8564c30SJames Wright MatCeedContext ctx; 618*c8564c30SJames Wright 619*c8564c30SJames Wright PetscCall(PetscObjectChangeTypeName((PetscObject)mat_other, MATCEED)); 620*c8564c30SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 621*c8564c30SJames Wright PetscCall(MatCeedContextReference(ctx)); 622*c8564c30SJames Wright PetscCall(MatShellSetContext(mat_other, ctx)); 623*c8564c30SJames Wright PetscCall(MatShellSetContextDestroy(mat_other, (PetscErrorCode(*)(void *))MatCeedContextDestroy)); 624*c8564c30SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_MULT, (void (*)(void))MatMult_Ceed)); 625*c8564c30SJames Wright if (ctx->op_mult_transpose) PetscCall(MatShellSetOperation(mat_other, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed)); 626*c8564c30SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed)); 627*c8564c30SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed)); 628*c8564c30SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed)); 629*c8564c30SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed)); 630*c8564c30SJames Wright { 631*c8564c30SJames Wright PetscInt block_size; 632*c8564c30SJames Wright 633*c8564c30SJames Wright PetscCall(MatGetBlockSize(mat_ceed, &block_size)); 634*c8564c30SJames Wright if (block_size > 1) PetscCall(MatSetBlockSize(mat_other, block_size)); 635*c8564c30SJames Wright } 636*c8564c30SJames Wright { 637*c8564c30SJames Wright PetscInt num_blocks; 638*c8564c30SJames Wright const PetscInt *block_sizes; 639*c8564c30SJames Wright 640*c8564c30SJames Wright PetscCall(MatGetVariableBlockSizes(mat_ceed, &num_blocks, &block_sizes)); 641*c8564c30SJames Wright if (num_blocks) PetscCall(MatSetVariableBlockSizes(mat_other, num_blocks, (PetscInt *)block_sizes)); 642*c8564c30SJames Wright } 643*c8564c30SJames Wright PetscCall(DMGetVecType(ctx->dm_x, &vec_type)); 644*c8564c30SJames Wright PetscCall(MatShellSetVecType(mat_other, vec_type)); 645*c8564c30SJames Wright } 646*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 647*c8564c30SJames Wright } 648*c8564c30SJames Wright 649*c8564c30SJames Wright /** 650*c8564c30SJames Wright @brief Assemble a `MATCEED` into a `MATAIJ` or similar. 651*c8564c30SJames Wright The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`. 652*c8564c30SJames Wright The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail. 653*c8564c30SJames Wright 654*c8564c30SJames Wright Collective across MPI processes. 655*c8564c30SJames Wright 656*c8564c30SJames Wright @param[in] mat_ceed `MATCEED` to assemble 657*c8564c30SJames Wright @param[in,out] mat_coo `MATAIJ` or similar to assemble into 658*c8564c30SJames Wright 659*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 660*c8564c30SJames Wright **/ 661*c8564c30SJames Wright PetscErrorCode MatCeedAssembleCOO(Mat mat_ceed, Mat mat_coo) { 662*c8564c30SJames Wright MatCeedContext ctx; 663*c8564c30SJames Wright 664*c8564c30SJames Wright PetscFunctionBeginUser; 665*c8564c30SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 666*c8564c30SJames Wright 667*c8564c30SJames Wright // Check if COO pattern set 668*c8564c30SJames Wright { 669*c8564c30SJames Wright PetscInt index = -1; 670*c8564c30SJames Wright 671*c8564c30SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) { 672*c8564c30SJames Wright if (ctx->mats_assembled_full[i] == mat_coo) index = i; 673*c8564c30SJames Wright } 674*c8564c30SJames Wright if (index == -1) { 675*c8564c30SJames Wright PetscInt *rows_petsc = NULL, *cols_petsc = NULL; 676*c8564c30SJames Wright CeedInt *rows_ceed, *cols_ceed; 677*c8564c30SJames Wright PetscCount num_entries; 678*c8564c30SJames Wright PetscLogStage stage_amg_setup; 679*c8564c30SJames Wright 680*c8564c30SJames Wright // -- Assemble sparsity pattern if mat hasn't been assembled before 681*c8564c30SJames Wright PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup)); 682*c8564c30SJames Wright if (stage_amg_setup == -1) { 683*c8564c30SJames Wright PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup)); 684*c8564c30SJames Wright } 685*c8564c30SJames Wright PetscCall(PetscLogStagePush(stage_amg_setup)); 686*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorLinearAssembleSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed)); 687*c8564c30SJames Wright PetscCall(IntArrayC2P(num_entries, &rows_ceed, &rows_petsc)); 688*c8564c30SJames Wright PetscCall(IntArrayC2P(num_entries, &cols_ceed, &cols_petsc)); 689*c8564c30SJames Wright PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc)); 690*c8564c30SJames Wright free(rows_petsc); 691*c8564c30SJames Wright free(cols_petsc); 692*c8564c30SJames Wright if (!ctx->coo_values_full) PetscCeedCall(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_full)); 693*c8564c30SJames Wright PetscCall(PetscRealloc(++ctx->num_mats_assembled_full * sizeof(Mat), &ctx->mats_assembled_full)); 694*c8564c30SJames Wright ctx->mats_assembled_full[ctx->num_mats_assembled_full - 1] = mat_coo; 695*c8564c30SJames Wright PetscCall(PetscLogStagePop()); 696*c8564c30SJames Wright } 697*c8564c30SJames Wright } 698*c8564c30SJames Wright 699*c8564c30SJames Wright // Assemble mat_ceed 700*c8564c30SJames Wright PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY)); 701*c8564c30SJames Wright { 702*c8564c30SJames Wright const CeedScalar *values; 703*c8564c30SJames Wright MatType mat_type; 704*c8564c30SJames Wright CeedMemType mem_type = CEED_MEM_HOST; 705*c8564c30SJames Wright PetscBool is_spd, is_spd_known; 706*c8564c30SJames Wright 707*c8564c30SJames Wright PetscCall(MatGetType(mat_coo, &mat_type)); 708*c8564c30SJames Wright if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE; 709*c8564c30SJames Wright else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE; 710*c8564c30SJames Wright else mem_type = CEED_MEM_HOST; 711*c8564c30SJames Wright 712*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorLinearAssemble(ctx->op_mult, ctx->coo_values_full)); 713*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_full, mem_type, &values)); 714*c8564c30SJames Wright PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES)); 715*c8564c30SJames Wright PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd)); 716*c8564c30SJames Wright if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd)); 717*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_full, &values)); 718*c8564c30SJames Wright } 719*c8564c30SJames Wright PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY)); 720*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 721*c8564c30SJames Wright } 722*c8564c30SJames Wright 723*c8564c30SJames Wright /** 724*c8564c30SJames Wright @brief Set user context for a `MATCEED`. 725*c8564c30SJames Wright 726*c8564c30SJames Wright Collective across MPI processes. 727*c8564c30SJames Wright 728*c8564c30SJames Wright @param[in,out] mat `MATCEED` 729*c8564c30SJames Wright @param[in] f The context destroy function, or NULL 730*c8564c30SJames Wright @param[in] ctx User context, or NULL to unset 731*c8564c30SJames Wright 732*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 733*c8564c30SJames Wright **/ 734*c8564c30SJames Wright PetscErrorCode MatCeedSetContext(Mat mat, PetscErrorCode (*f)(void *), void *ctx) { 735*c8564c30SJames Wright PetscContainer user_ctx = NULL; 736*c8564c30SJames Wright 737*c8564c30SJames Wright PetscFunctionBeginUser; 738*c8564c30SJames Wright if (ctx) { 739*c8564c30SJames Wright PetscCall(PetscContainerCreate(PetscObjectComm((PetscObject)mat), &user_ctx)); 740*c8564c30SJames Wright PetscCall(PetscContainerSetPointer(user_ctx, ctx)); 741*c8564c30SJames Wright PetscCall(PetscContainerSetUserDestroy(user_ctx, f)); 742*c8564c30SJames Wright } 743*c8564c30SJames Wright PetscCall(PetscObjectCompose((PetscObject)mat, "MatCeed user context", (PetscObject)user_ctx)); 744*c8564c30SJames Wright PetscCall(PetscContainerDestroy(&user_ctx)); 745*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 746*c8564c30SJames Wright } 747*c8564c30SJames Wright 748*c8564c30SJames Wright /** 749*c8564c30SJames Wright @brief Retrieve the user context for a `MATCEED`. 750*c8564c30SJames Wright 751*c8564c30SJames Wright Collective across MPI processes. 752*c8564c30SJames Wright 753*c8564c30SJames Wright @param[in,out] mat `MATCEED` 754*c8564c30SJames Wright @param[in] ctx User context 755*c8564c30SJames Wright 756*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 757*c8564c30SJames Wright **/ 758*c8564c30SJames Wright PetscErrorCode MatCeedGetContext(Mat mat, void *ctx) { 759*c8564c30SJames Wright PetscContainer user_ctx; 760*c8564c30SJames Wright 761*c8564c30SJames Wright PetscFunctionBeginUser; 762*c8564c30SJames Wright PetscCall(PetscObjectQuery((PetscObject)mat, "MatCeed user context", (PetscObject *)&user_ctx)); 763*c8564c30SJames Wright if (user_ctx) PetscCall(PetscContainerGetPointer(user_ctx, (void **)ctx)); 764*c8564c30SJames Wright else *(void **)ctx = NULL; 765*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 766*c8564c30SJames Wright } 767*c8564c30SJames Wright 768*c8564c30SJames Wright /** 769*c8564c30SJames Wright @brief Sets the inner matrix type as a string from the `MATCEED`. 770*c8564c30SJames Wright 771*c8564c30SJames Wright Collective across MPI processes. 772*c8564c30SJames Wright 773*c8564c30SJames Wright @param[in,out] mat `MATCEED` 774*c8564c30SJames Wright @param[in] type Inner `MatType` to set 775*c8564c30SJames Wright 776*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 777*c8564c30SJames Wright **/ 778*c8564c30SJames Wright PetscErrorCode MatCeedSetInnerMatType(Mat mat, MatType type) { 779*c8564c30SJames Wright MatCeedContext ctx; 780*c8564c30SJames Wright 781*c8564c30SJames Wright PetscFunctionBeginUser; 782*c8564c30SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 783*c8564c30SJames Wright // Check if same 784*c8564c30SJames Wright { 785*c8564c30SJames Wright size_t len_old, len_new; 786*c8564c30SJames Wright PetscBool is_same = PETSC_FALSE; 787*c8564c30SJames Wright 788*c8564c30SJames Wright PetscCall(PetscStrlen(ctx->internal_mat_type, &len_old)); 789*c8564c30SJames Wright PetscCall(PetscStrlen(type, &len_new)); 790*c8564c30SJames Wright if (len_old == len_new) PetscCall(PetscStrncmp(ctx->internal_mat_type, type, len_old, &is_same)); 791*c8564c30SJames Wright if (is_same) PetscFunctionReturn(PETSC_SUCCESS); 792*c8564c30SJames Wright } 793*c8564c30SJames Wright // Clean up old mats in different format 794*c8564c30SJames Wright // LCOV_EXCL_START 795*c8564c30SJames Wright if (ctx->mat_assembled_full_internal) { 796*c8564c30SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) { 797*c8564c30SJames Wright if (ctx->mats_assembled_full[i] == ctx->mat_assembled_full_internal) { 798*c8564c30SJames Wright for (PetscInt j = i + 1; j < ctx->num_mats_assembled_full; j++) { 799*c8564c30SJames Wright ctx->mats_assembled_full[j - 1] = ctx->mats_assembled_full[j]; 800*c8564c30SJames Wright } 801*c8564c30SJames Wright ctx->num_mats_assembled_full--; 802*c8564c30SJames Wright // Note: we'll realloc this array again, so no need to shrink the allocation 803*c8564c30SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_full_internal)); 804*c8564c30SJames Wright } 805*c8564c30SJames Wright } 806*c8564c30SJames Wright } 807*c8564c30SJames Wright if (ctx->mat_assembled_pbd_internal) { 808*c8564c30SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) { 809*c8564c30SJames Wright if (ctx->mats_assembled_pbd[i] == ctx->mat_assembled_pbd_internal) { 810*c8564c30SJames Wright for (PetscInt j = i + 1; j < ctx->num_mats_assembled_pbd; j++) { 811*c8564c30SJames Wright ctx->mats_assembled_pbd[j - 1] = ctx->mats_assembled_pbd[j]; 812*c8564c30SJames Wright } 813*c8564c30SJames Wright // Note: we'll realloc this array again, so no need to shrink the allocation 814*c8564c30SJames Wright ctx->num_mats_assembled_pbd--; 815*c8564c30SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal)); 816*c8564c30SJames Wright } 817*c8564c30SJames Wright } 818*c8564c30SJames Wright } 819*c8564c30SJames Wright PetscCall(PetscFree(ctx->internal_mat_type)); 820*c8564c30SJames Wright PetscCall(PetscStrallocpy(type, &ctx->internal_mat_type)); 821*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 822*c8564c30SJames Wright // LCOV_EXCL_STOP 823*c8564c30SJames Wright } 824*c8564c30SJames Wright 825*c8564c30SJames Wright /** 826*c8564c30SJames Wright @brief Gets the inner matrix type as a string from the `MATCEED`. 827*c8564c30SJames Wright 828*c8564c30SJames Wright Collective across MPI processes. 829*c8564c30SJames Wright 830*c8564c30SJames Wright @param[in,out] mat `MATCEED` 831*c8564c30SJames Wright @param[in] type Inner `MatType` 832*c8564c30SJames Wright 833*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 834*c8564c30SJames Wright **/ 835*c8564c30SJames Wright PetscErrorCode MatCeedGetInnerMatType(Mat mat, MatType *type) { 836*c8564c30SJames Wright MatCeedContext ctx; 837*c8564c30SJames Wright 838*c8564c30SJames Wright PetscFunctionBeginUser; 839*c8564c30SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 840*c8564c30SJames Wright *type = ctx->internal_mat_type; 841*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 842*c8564c30SJames Wright } 843*c8564c30SJames Wright 844*c8564c30SJames Wright /** 845*c8564c30SJames Wright @brief Set a user defined matrix operation for a `MATCEED` matrix. 846*c8564c30SJames Wright 847*c8564c30SJames Wright Within each user-defined routine, the user should call `MatCeedGetContext()` to obtain the user-defined context that was set by 848*c8564c30SJames Wright `MatCeedSetContext()`. 849*c8564c30SJames Wright 850*c8564c30SJames Wright Collective across MPI processes. 851*c8564c30SJames Wright 852*c8564c30SJames Wright @param[in,out] mat `MATCEED` 853*c8564c30SJames Wright @param[in] op Name of the `MatOperation` 854*c8564c30SJames Wright @param[in] g Function that provides the operation 855*c8564c30SJames Wright 856*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 857*c8564c30SJames Wright **/ 858*c8564c30SJames Wright PetscErrorCode MatCeedSetOperation(Mat mat, MatOperation op, void (*g)(void)) { 859*c8564c30SJames Wright PetscFunctionBeginUser; 860*c8564c30SJames Wright PetscCall(MatShellSetOperation(mat, op, g)); 861*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 862*c8564c30SJames Wright } 863*c8564c30SJames Wright 864*c8564c30SJames Wright /** 865*c8564c30SJames Wright @brief Set input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 866*c8564c30SJames Wright 867*c8564c30SJames Wright Not collective across MPI processes. 868*c8564c30SJames Wright 869*c8564c30SJames Wright @param[in,out] mat `MATCEED` 870*c8564c30SJames Wright @param[in] X_loc Input PETSc local vector, or NULL 871*c8564c30SJames Wright @param[in] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 872*c8564c30SJames Wright 873*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 874*c8564c30SJames Wright **/ 875*c8564c30SJames Wright PetscErrorCode MatCeedSetLocalVectors(Mat mat, Vec X_loc, Vec Y_loc_transpose) { 876*c8564c30SJames Wright MatCeedContext ctx; 877*c8564c30SJames Wright 878*c8564c30SJames Wright PetscFunctionBeginUser; 879*c8564c30SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 880*c8564c30SJames Wright if (X_loc) { 881*c8564c30SJames Wright PetscInt len_old, len_new; 882*c8564c30SJames Wright 883*c8564c30SJames Wright PetscCall(VecGetSize(ctx->X_loc, &len_old)); 884*c8564c30SJames Wright PetscCall(VecGetSize(X_loc, &len_new)); 885*c8564c30SJames Wright PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, "new X_loc length %" PetscInt_FMT " should match old X_loc length %" PetscInt_FMT, 886*c8564c30SJames Wright len_new, len_old); 887*c8564c30SJames Wright PetscCall(VecDestroy(&ctx->X_loc)); 888*c8564c30SJames Wright ctx->X_loc = X_loc; 889*c8564c30SJames Wright PetscCall(PetscObjectReference((PetscObject)X_loc)); 890*c8564c30SJames Wright } 891*c8564c30SJames Wright if (Y_loc_transpose) { 892*c8564c30SJames Wright PetscInt len_old, len_new; 893*c8564c30SJames Wright 894*c8564c30SJames Wright PetscCall(VecGetSize(ctx->Y_loc_transpose, &len_old)); 895*c8564c30SJames Wright PetscCall(VecGetSize(Y_loc_transpose, &len_new)); 896*c8564c30SJames Wright PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, 897*c8564c30SJames Wright "new Y_loc_transpose length %" PetscInt_FMT " should match old Y_loc_transpose length %" PetscInt_FMT, len_new, len_old); 898*c8564c30SJames Wright PetscCall(VecDestroy(&ctx->Y_loc_transpose)); 899*c8564c30SJames Wright ctx->Y_loc_transpose = Y_loc_transpose; 900*c8564c30SJames Wright PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose)); 901*c8564c30SJames Wright } 902*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 903*c8564c30SJames Wright } 904*c8564c30SJames Wright 905*c8564c30SJames Wright /** 906*c8564c30SJames Wright @brief Get input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 907*c8564c30SJames Wright 908*c8564c30SJames Wright Not collective across MPI processes. 909*c8564c30SJames Wright 910*c8564c30SJames Wright @param[in,out] mat `MATCEED` 911*c8564c30SJames Wright @param[out] X_loc Input PETSc local vector, or NULL 912*c8564c30SJames Wright @param[out] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 913*c8564c30SJames Wright 914*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 915*c8564c30SJames Wright **/ 916*c8564c30SJames Wright PetscErrorCode MatCeedGetLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) { 917*c8564c30SJames Wright MatCeedContext ctx; 918*c8564c30SJames Wright 919*c8564c30SJames Wright PetscFunctionBeginUser; 920*c8564c30SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 921*c8564c30SJames Wright if (X_loc) { 922*c8564c30SJames Wright *X_loc = ctx->X_loc; 923*c8564c30SJames Wright PetscCall(PetscObjectReference((PetscObject)*X_loc)); 924*c8564c30SJames Wright } 925*c8564c30SJames Wright if (Y_loc_transpose) { 926*c8564c30SJames Wright *Y_loc_transpose = ctx->Y_loc_transpose; 927*c8564c30SJames Wright PetscCall(PetscObjectReference((PetscObject)*Y_loc_transpose)); 928*c8564c30SJames Wright } 929*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 930*c8564c30SJames Wright } 931*c8564c30SJames Wright 932*c8564c30SJames Wright /** 933*c8564c30SJames Wright @brief Restore input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 934*c8564c30SJames Wright 935*c8564c30SJames Wright Not collective across MPI processes. 936*c8564c30SJames Wright 937*c8564c30SJames Wright @param[in,out] mat MatCeed 938*c8564c30SJames Wright @param[out] X_loc Input PETSc local vector, or NULL 939*c8564c30SJames Wright @param[out] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 940*c8564c30SJames Wright 941*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 942*c8564c30SJames Wright **/ 943*c8564c30SJames Wright PetscErrorCode MatCeedRestoreLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) { 944*c8564c30SJames Wright PetscFunctionBeginUser; 945*c8564c30SJames Wright if (X_loc) PetscCall(VecDestroy(X_loc)); 946*c8564c30SJames Wright if (Y_loc_transpose) PetscCall(VecDestroy(Y_loc_transpose)); 947*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 948*c8564c30SJames Wright } 949*c8564c30SJames Wright 950*c8564c30SJames Wright /** 951*c8564c30SJames Wright @brief Get libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 952*c8564c30SJames Wright 953*c8564c30SJames Wright Not collective across MPI processes. 954*c8564c30SJames Wright 955*c8564c30SJames Wright @param[in,out] mat MatCeed 956*c8564c30SJames Wright @param[out] op_mult libCEED `CeedOperator` for `MatMult()`, or NULL 957*c8564c30SJames Wright @param[out] op_mult_transpose libCEED `CeedOperator` for `MatMultTranspose()`, or NULL 958*c8564c30SJames Wright 959*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 960*c8564c30SJames Wright **/ 961*c8564c30SJames Wright PetscErrorCode MatCeedGetCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) { 962*c8564c30SJames Wright MatCeedContext ctx; 963*c8564c30SJames Wright 964*c8564c30SJames Wright PetscFunctionBeginUser; 965*c8564c30SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 966*c8564c30SJames Wright if (op_mult) { 967*c8564c30SJames Wright *op_mult = NULL; 968*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult, op_mult)); 969*c8564c30SJames Wright } 970*c8564c30SJames Wright if (op_mult_transpose) { 971*c8564c30SJames Wright *op_mult_transpose = NULL; 972*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult_transpose, op_mult_transpose)); 973*c8564c30SJames Wright } 974*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 975*c8564c30SJames Wright } 976*c8564c30SJames Wright 977*c8564c30SJames Wright /** 978*c8564c30SJames Wright @brief Restore libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 979*c8564c30SJames Wright 980*c8564c30SJames Wright Not collective across MPI processes. 981*c8564c30SJames Wright 982*c8564c30SJames Wright @param[in,out] mat MatCeed 983*c8564c30SJames Wright @param[out] op_mult libCEED `CeedOperator` for `MatMult()`, or NULL 984*c8564c30SJames Wright @param[out] op_mult_transpose libCEED `CeedOperator` for `MatMultTranspose()`, or NULL 985*c8564c30SJames Wright 986*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 987*c8564c30SJames Wright **/ 988*c8564c30SJames Wright PetscErrorCode MatCeedRestoreCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) { 989*c8564c30SJames Wright MatCeedContext ctx; 990*c8564c30SJames Wright 991*c8564c30SJames Wright PetscFunctionBeginUser; 992*c8564c30SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 993*c8564c30SJames Wright if (op_mult) PetscCeedCall(ctx->ceed, CeedOperatorDestroy(op_mult)); 994*c8564c30SJames Wright if (op_mult_transpose) PetscCeedCall(ctx->ceed, CeedOperatorDestroy(op_mult_transpose)); 995*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 996*c8564c30SJames Wright } 997*c8564c30SJames Wright 998*c8564c30SJames Wright /** 999*c8564c30SJames Wright @brief Set `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators. 1000*c8564c30SJames Wright 1001*c8564c30SJames Wright Not collective across MPI processes. 1002*c8564c30SJames Wright 1003*c8564c30SJames Wright @param[in,out] mat MatCeed 1004*c8564c30SJames Wright @param[out] log_event_mult `PetscLogEvent` for forward evaluation, or NULL 1005*c8564c30SJames Wright @param[out] log_event_mult_transpose `PetscLogEvent` for transpose evaluation, or NULL 1006*c8564c30SJames Wright 1007*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 1008*c8564c30SJames Wright **/ 1009*c8564c30SJames Wright PetscErrorCode MatCeedSetLogEvents(Mat mat, PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose) { 1010*c8564c30SJames Wright MatCeedContext ctx; 1011*c8564c30SJames Wright 1012*c8564c30SJames Wright PetscFunctionBeginUser; 1013*c8564c30SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 1014*c8564c30SJames Wright if (log_event_mult) ctx->log_event_mult = log_event_mult; 1015*c8564c30SJames Wright if (log_event_mult_transpose) ctx->log_event_mult_transpose = log_event_mult_transpose; 1016*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1017*c8564c30SJames Wright } 1018*c8564c30SJames Wright 1019*c8564c30SJames Wright /** 1020*c8564c30SJames Wright @brief Get `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators. 1021*c8564c30SJames Wright 1022*c8564c30SJames Wright Not collective across MPI processes. 1023*c8564c30SJames Wright 1024*c8564c30SJames Wright @param[in,out] mat MatCeed 1025*c8564c30SJames Wright @param[out] log_event_mult `PetscLogEvent` for forward evaluation, or NULL 1026*c8564c30SJames Wright @param[out] log_event_mult_transpose `PetscLogEvent` for transpose evaluation, or NULL 1027*c8564c30SJames Wright 1028*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 1029*c8564c30SJames Wright **/ 1030*c8564c30SJames Wright PetscErrorCode MatCeedGetLogEvents(Mat mat, PetscLogEvent *log_event_mult, PetscLogEvent *log_event_mult_transpose) { 1031*c8564c30SJames Wright MatCeedContext ctx; 1032*c8564c30SJames Wright 1033*c8564c30SJames Wright PetscFunctionBeginUser; 1034*c8564c30SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 1035*c8564c30SJames Wright if (log_event_mult) *log_event_mult = ctx->log_event_mult; 1036*c8564c30SJames Wright if (log_event_mult_transpose) *log_event_mult_transpose = ctx->log_event_mult_transpose; 1037*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1038*c8564c30SJames Wright } 1039*c8564c30SJames Wright 1040*c8564c30SJames Wright // ----------------------------------------------------------------------------- 1041*c8564c30SJames Wright // Operator context data 1042*c8564c30SJames Wright // ----------------------------------------------------------------------------- 1043*c8564c30SJames Wright 1044*c8564c30SJames Wright /** 1045*c8564c30SJames Wright @brief Setup context data for operator application. 1046*c8564c30SJames Wright 1047*c8564c30SJames Wright Collective across MPI processes. 1048*c8564c30SJames Wright 1049*c8564c30SJames Wright @param[in] dm_x Input `DM` 1050*c8564c30SJames Wright @param[in] dm_y Output `DM` 1051*c8564c30SJames Wright @param[in] X_loc Input PETSc local vector, or NULL 1052*c8564c30SJames Wright @param[in] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 1053*c8564c30SJames Wright @param[in] op_mult `CeedOperator` for forward evaluation 1054*c8564c30SJames Wright @param[in] op_mult_transpose `CeedOperator` for transpose evaluation 1055*c8564c30SJames Wright @param[in] log_event_mult `PetscLogEvent` for forward evaluation 1056*c8564c30SJames Wright @param[in] log_event_mult_transpose `PetscLogEvent` for transpose evaluation 1057*c8564c30SJames Wright @param[out] ctx Context data for operator evaluation 1058*c8564c30SJames Wright 1059*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 1060*c8564c30SJames Wright **/ 1061*c8564c30SJames Wright PetscErrorCode MatCeedContextCreate(DM dm_x, DM dm_y, Vec X_loc, Vec Y_loc_transpose, CeedOperator op_mult, CeedOperator op_mult_transpose, 1062*c8564c30SJames Wright PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose, MatCeedContext *ctx) { 1063*c8564c30SJames Wright CeedSize x_loc_len, y_loc_len; 1064*c8564c30SJames Wright 1065*c8564c30SJames Wright PetscFunctionBeginUser; 1066*c8564c30SJames Wright 1067*c8564c30SJames Wright // Allocate 1068*c8564c30SJames Wright PetscCall(PetscNew(ctx)); 1069*c8564c30SJames Wright (*ctx)->ref_count = 1; 1070*c8564c30SJames Wright 1071*c8564c30SJames Wright // Logging 1072*c8564c30SJames Wright (*ctx)->log_event_mult = log_event_mult; 1073*c8564c30SJames Wright (*ctx)->log_event_mult_transpose = log_event_mult_transpose; 1074*c8564c30SJames Wright 1075*c8564c30SJames Wright // PETSc objects 1076*c8564c30SJames Wright PetscCall(PetscObjectReference((PetscObject)dm_x)); 1077*c8564c30SJames Wright (*ctx)->dm_x = dm_x; 1078*c8564c30SJames Wright PetscCall(PetscObjectReference((PetscObject)dm_y)); 1079*c8564c30SJames Wright (*ctx)->dm_y = dm_y; 1080*c8564c30SJames Wright if (X_loc) PetscCall(PetscObjectReference((PetscObject)X_loc)); 1081*c8564c30SJames Wright (*ctx)->X_loc = X_loc; 1082*c8564c30SJames Wright if (Y_loc_transpose) PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose)); 1083*c8564c30SJames Wright (*ctx)->Y_loc_transpose = Y_loc_transpose; 1084*c8564c30SJames Wright 1085*c8564c30SJames Wright // Memtype 1086*c8564c30SJames Wright { 1087*c8564c30SJames Wright const PetscScalar *x; 1088*c8564c30SJames Wright Vec X; 1089*c8564c30SJames Wright 1090*c8564c30SJames Wright PetscCall(DMGetLocalVector(dm_x, &X)); 1091*c8564c30SJames Wright PetscCall(VecGetArrayReadAndMemType(X, &x, &(*ctx)->mem_type)); 1092*c8564c30SJames Wright PetscCall(VecRestoreArrayReadAndMemType(X, &x)); 1093*c8564c30SJames Wright PetscCall(DMRestoreLocalVector(dm_x, &X)); 1094*c8564c30SJames Wright } 1095*c8564c30SJames Wright 1096*c8564c30SJames Wright // libCEED objects 1097*c8564c30SJames Wright PetscCheck(CeedOperatorGetCeed(op_mult, &(*ctx)->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, 1098*c8564c30SJames Wright "retrieving Ceed context object failed"); 1099*c8564c30SJames Wright PetscCeedCall((*ctx)->ceed, CeedReference((*ctx)->ceed)); 1100*c8564c30SJames Wright PetscCeedCall((*ctx)->ceed, CeedOperatorGetActiveVectorLengths(op_mult, &x_loc_len, &y_loc_len)); 1101*c8564c30SJames Wright PetscCeedCall((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult, &(*ctx)->op_mult)); 1102*c8564c30SJames Wright if (op_mult_transpose) PetscCeedCall((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult_transpose, &(*ctx)->op_mult_transpose)); 1103*c8564c30SJames Wright PetscCeedCall((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, x_loc_len, &(*ctx)->x_loc)); 1104*c8564c30SJames Wright PetscCeedCall((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, y_loc_len, &(*ctx)->y_loc)); 1105*c8564c30SJames Wright 1106*c8564c30SJames Wright // Flop counting 1107*c8564c30SJames Wright { 1108*c8564c30SJames Wright CeedSize ceed_flops_estimate = 0; 1109*c8564c30SJames Wright 1110*c8564c30SJames Wright PetscCeedCall((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult, &ceed_flops_estimate)); 1111*c8564c30SJames Wright (*ctx)->flops_mult = ceed_flops_estimate; 1112*c8564c30SJames Wright if (op_mult_transpose) { 1113*c8564c30SJames Wright PetscCeedCall((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult_transpose, &ceed_flops_estimate)); 1114*c8564c30SJames Wright (*ctx)->flops_mult_transpose = ceed_flops_estimate; 1115*c8564c30SJames Wright } 1116*c8564c30SJames Wright } 1117*c8564c30SJames Wright 1118*c8564c30SJames Wright // Check sizes 1119*c8564c30SJames Wright if (x_loc_len > 0 || y_loc_len > 0) { 1120*c8564c30SJames Wright CeedSize ctx_x_loc_len, ctx_y_loc_len; 1121*c8564c30SJames Wright PetscInt X_loc_len, dm_x_loc_len, Y_loc_len, dm_y_loc_len; 1122*c8564c30SJames Wright Vec dm_X_loc, dm_Y_loc; 1123*c8564c30SJames Wright 1124*c8564c30SJames Wright // -- Input 1125*c8564c30SJames Wright PetscCall(DMGetLocalVector(dm_x, &dm_X_loc)); 1126*c8564c30SJames Wright PetscCall(VecGetLocalSize(dm_X_loc, &dm_x_loc_len)); 1127*c8564c30SJames Wright PetscCall(DMRestoreLocalVector(dm_x, &dm_X_loc)); 1128*c8564c30SJames Wright if (X_loc) PetscCall(VecGetLocalSize(X_loc, &X_loc_len)); 1129*c8564c30SJames Wright PetscCeedCall((*ctx)->ceed, CeedVectorGetLength((*ctx)->x_loc, &ctx_x_loc_len)); 1130*c8564c30SJames Wright if (X_loc) PetscCheck(X_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "X_loc must match dm_x dimensions"); 1131*c8564c30SJames Wright PetscCheck(x_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_x dimensions"); 1132*c8564c30SJames Wright PetscCheck(x_loc_len == ctx_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "x_loc must match op dimensions"); 1133*c8564c30SJames Wright 1134*c8564c30SJames Wright // -- Output 1135*c8564c30SJames Wright PetscCall(DMGetLocalVector(dm_y, &dm_Y_loc)); 1136*c8564c30SJames Wright PetscCall(VecGetLocalSize(dm_Y_loc, &dm_y_loc_len)); 1137*c8564c30SJames Wright PetscCall(DMRestoreLocalVector(dm_y, &dm_Y_loc)); 1138*c8564c30SJames Wright PetscCeedCall((*ctx)->ceed, CeedVectorGetLength((*ctx)->y_loc, &ctx_y_loc_len)); 1139*c8564c30SJames Wright PetscCheck(ctx_y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_y dimensions"); 1140*c8564c30SJames Wright 1141*c8564c30SJames Wright // -- Transpose 1142*c8564c30SJames Wright if (Y_loc_transpose) { 1143*c8564c30SJames Wright PetscCall(VecGetLocalSize(Y_loc_transpose, &Y_loc_len)); 1144*c8564c30SJames Wright PetscCheck(Y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "Y_loc_transpose must match dm_y dimensions"); 1145*c8564c30SJames Wright } 1146*c8564c30SJames Wright } 1147*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1148*c8564c30SJames Wright } 1149*c8564c30SJames Wright 1150*c8564c30SJames Wright /** 1151*c8564c30SJames Wright @brief Increment reference counter for `MATCEED` context. 1152*c8564c30SJames Wright 1153*c8564c30SJames Wright Not collective across MPI processes. 1154*c8564c30SJames Wright 1155*c8564c30SJames Wright @param[in,out] ctx Context data 1156*c8564c30SJames Wright 1157*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 1158*c8564c30SJames Wright **/ 1159*c8564c30SJames Wright PetscErrorCode MatCeedContextReference(MatCeedContext ctx) { 1160*c8564c30SJames Wright PetscFunctionBeginUser; 1161*c8564c30SJames Wright ctx->ref_count++; 1162*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1163*c8564c30SJames Wright } 1164*c8564c30SJames Wright 1165*c8564c30SJames Wright /** 1166*c8564c30SJames Wright @brief Copy reference for `MATCEED`. 1167*c8564c30SJames Wright Note: If `ctx_copy` is non-null, it is assumed to be a valid pointer to a `MatCeedContext`. 1168*c8564c30SJames Wright 1169*c8564c30SJames Wright Not collective across MPI processes. 1170*c8564c30SJames Wright 1171*c8564c30SJames Wright @param[in] ctx Context data 1172*c8564c30SJames Wright @param[out] ctx_copy Copy of pointer to context data 1173*c8564c30SJames Wright 1174*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 1175*c8564c30SJames Wright **/ 1176*c8564c30SJames Wright PetscErrorCode MatCeedContextReferenceCopy(MatCeedContext ctx, MatCeedContext *ctx_copy) { 1177*c8564c30SJames Wright PetscFunctionBeginUser; 1178*c8564c30SJames Wright PetscCall(MatCeedContextReference(ctx)); 1179*c8564c30SJames Wright PetscCall(MatCeedContextDestroy(*ctx_copy)); 1180*c8564c30SJames Wright *ctx_copy = ctx; 1181*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1182*c8564c30SJames Wright } 1183*c8564c30SJames Wright 1184*c8564c30SJames Wright /** 1185*c8564c30SJames Wright @brief Destroy context data for operator application. 1186*c8564c30SJames Wright 1187*c8564c30SJames Wright Collective across MPI processes. 1188*c8564c30SJames Wright 1189*c8564c30SJames Wright @param[in,out] ctx Context data for operator evaluation 1190*c8564c30SJames Wright 1191*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 1192*c8564c30SJames Wright **/ 1193*c8564c30SJames Wright PetscErrorCode MatCeedContextDestroy(MatCeedContext ctx) { 1194*c8564c30SJames Wright PetscFunctionBeginUser; 1195*c8564c30SJames Wright if (!ctx || --ctx->ref_count > 0) PetscFunctionReturn(PETSC_SUCCESS); 1196*c8564c30SJames Wright 1197*c8564c30SJames Wright // PETSc objects 1198*c8564c30SJames Wright PetscCall(DMDestroy(&ctx->dm_x)); 1199*c8564c30SJames Wright PetscCall(DMDestroy(&ctx->dm_y)); 1200*c8564c30SJames Wright PetscCall(VecDestroy(&ctx->X_loc)); 1201*c8564c30SJames Wright PetscCall(VecDestroy(&ctx->Y_loc_transpose)); 1202*c8564c30SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_full_internal)); 1203*c8564c30SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal)); 1204*c8564c30SJames Wright PetscCall(PetscFree(ctx->internal_mat_type)); 1205*c8564c30SJames Wright PetscCall(PetscFree(ctx->mats_assembled_full)); 1206*c8564c30SJames Wright PetscCall(PetscFree(ctx->mats_assembled_pbd)); 1207*c8564c30SJames Wright 1208*c8564c30SJames Wright // libCEED objects 1209*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->x_loc)); 1210*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->y_loc)); 1211*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_full)); 1212*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_pbd)); 1213*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult)); 1214*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult_transpose)); 1215*c8564c30SJames Wright PetscCheck(CeedDestroy(&ctx->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "destroying libCEED context object failed"); 1216*c8564c30SJames Wright 1217*c8564c30SJames Wright // Deallocate 1218*c8564c30SJames Wright ctx->is_destroyed = PETSC_TRUE; // Flag as destroyed in case someone has stale ref 1219*c8564c30SJames Wright PetscCall(PetscFree(ctx)); 1220*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1221*c8564c30SJames Wright } 1222*c8564c30SJames Wright 1223*c8564c30SJames Wright /** 1224*c8564c30SJames Wright @brief Compute the diagonal of an operator via libCEED. 1225*c8564c30SJames Wright 1226*c8564c30SJames Wright Collective across MPI processes. 1227*c8564c30SJames Wright 1228*c8564c30SJames Wright @param[in] A `MATCEED` 1229*c8564c30SJames Wright @param[out] D Vector holding operator diagonal 1230*c8564c30SJames Wright 1231*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 1232*c8564c30SJames Wright **/ 1233*c8564c30SJames Wright PetscErrorCode MatGetDiagonal_Ceed(Mat A, Vec D) { 1234*c8564c30SJames Wright PetscMemType mem_type; 1235*c8564c30SJames Wright Vec D_loc; 1236*c8564c30SJames Wright MatCeedContext ctx; 1237*c8564c30SJames Wright 1238*c8564c30SJames Wright PetscFunctionBeginUser; 1239*c8564c30SJames Wright PetscCall(MatShellGetContext(A, &ctx)); 1240*c8564c30SJames Wright 1241*c8564c30SJames Wright // Place PETSc vector in libCEED vector 1242*c8564c30SJames Wright PetscCall(DMGetLocalVector(ctx->dm_x, &D_loc)); 1243*c8564c30SJames Wright PetscCall(VecP2C(ctx->ceed, D_loc, &mem_type, ctx->x_loc)); 1244*c8564c30SJames Wright 1245*c8564c30SJames Wright // Compute Diagonal 1246*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorLinearAssembleDiagonal(ctx->op_mult, ctx->x_loc, CEED_REQUEST_IMMEDIATE)); 1247*c8564c30SJames Wright 1248*c8564c30SJames Wright // Restore PETSc vector 1249*c8564c30SJames Wright PetscCall(VecC2P(ctx->ceed, ctx->x_loc, mem_type, D_loc)); 1250*c8564c30SJames Wright 1251*c8564c30SJames Wright // Local-to-Global 1252*c8564c30SJames Wright PetscCall(VecZeroEntries(D)); 1253*c8564c30SJames Wright PetscCall(DMLocalToGlobal(ctx->dm_x, D_loc, ADD_VALUES, D)); 1254*c8564c30SJames Wright PetscCall(DMRestoreLocalVector(ctx->dm_x, &D_loc)); 1255*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1256*c8564c30SJames Wright } 1257*c8564c30SJames Wright 1258*c8564c30SJames Wright /** 1259*c8564c30SJames Wright @brief Compute `A X = Y` for a `MATCEED`. 1260*c8564c30SJames Wright 1261*c8564c30SJames Wright Collective across MPI processes. 1262*c8564c30SJames Wright 1263*c8564c30SJames Wright @param[in] A `MATCEED` 1264*c8564c30SJames Wright @param[in] X Input PETSc vector 1265*c8564c30SJames Wright @param[out] Y Output PETSc vector 1266*c8564c30SJames Wright 1267*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 1268*c8564c30SJames Wright **/ 1269*c8564c30SJames Wright PetscErrorCode MatMult_Ceed(Mat A, Vec X, Vec Y) { 1270*c8564c30SJames Wright MatCeedContext ctx; 1271*c8564c30SJames Wright 1272*c8564c30SJames Wright PetscFunctionBeginUser; 1273*c8564c30SJames Wright PetscCall(MatShellGetContext(A, &ctx)); 1274*c8564c30SJames Wright PetscCall(PetscLogEventBegin(ctx->log_event_mult, A, X, Y, 0)); 1275*c8564c30SJames Wright 1276*c8564c30SJames Wright { 1277*c8564c30SJames Wright PetscMemType x_mem_type, y_mem_type; 1278*c8564c30SJames Wright Vec X_loc = ctx->X_loc, Y_loc; 1279*c8564c30SJames Wright 1280*c8564c30SJames Wright // Get local vectors 1281*c8564c30SJames Wright if (!ctx->X_loc) PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc)); 1282*c8564c30SJames Wright PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc)); 1283*c8564c30SJames Wright 1284*c8564c30SJames Wright // Global-to-local 1285*c8564c30SJames Wright PetscCall(DMGlobalToLocal(ctx->dm_x, X, INSERT_VALUES, X_loc)); 1286*c8564c30SJames Wright 1287*c8564c30SJames Wright // Setup libCEED vectors 1288*c8564c30SJames Wright PetscCall(VecReadP2C(ctx->ceed, X_loc, &x_mem_type, ctx->x_loc)); 1289*c8564c30SJames Wright PetscCall(VecZeroEntries(Y_loc)); 1290*c8564c30SJames Wright PetscCall(VecP2C(ctx->ceed, Y_loc, &y_mem_type, ctx->y_loc)); 1291*c8564c30SJames Wright 1292*c8564c30SJames Wright // Apply libCEED operator 1293*c8564c30SJames Wright PetscCall(PetscLogGpuTimeBegin()); 1294*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult, ctx->x_loc, ctx->y_loc, CEED_REQUEST_IMMEDIATE)); 1295*c8564c30SJames Wright PetscCall(PetscLogGpuTimeEnd()); 1296*c8564c30SJames Wright 1297*c8564c30SJames Wright // Restore PETSc vectors 1298*c8564c30SJames Wright PetscCall(VecReadC2P(ctx->ceed, ctx->x_loc, x_mem_type, X_loc)); 1299*c8564c30SJames Wright PetscCall(VecC2P(ctx->ceed, ctx->y_loc, y_mem_type, Y_loc)); 1300*c8564c30SJames Wright 1301*c8564c30SJames Wright // Local-to-global 1302*c8564c30SJames Wright PetscCall(VecZeroEntries(Y)); 1303*c8564c30SJames Wright PetscCall(DMLocalToGlobal(ctx->dm_y, Y_loc, ADD_VALUES, Y)); 1304*c8564c30SJames Wright 1305*c8564c30SJames Wright // Restore local vectors, as needed 1306*c8564c30SJames Wright if (!ctx->X_loc) PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc)); 1307*c8564c30SJames Wright PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc)); 1308*c8564c30SJames Wright } 1309*c8564c30SJames Wright 1310*c8564c30SJames Wright // Log flops 1311*c8564c30SJames Wright if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult)); 1312*c8564c30SJames Wright else PetscCall(PetscLogFlops(ctx->flops_mult)); 1313*c8564c30SJames Wright 1314*c8564c30SJames Wright PetscCall(PetscLogEventEnd(ctx->log_event_mult, A, X, Y, 0)); 1315*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1316*c8564c30SJames Wright } 1317*c8564c30SJames Wright 1318*c8564c30SJames Wright /** 1319*c8564c30SJames Wright @brief Compute `A^T Y = X` for a `MATCEED`. 1320*c8564c30SJames Wright 1321*c8564c30SJames Wright Collective across MPI processes. 1322*c8564c30SJames Wright 1323*c8564c30SJames Wright @param[in] A `MATCEED` 1324*c8564c30SJames Wright @param[in] Y Input PETSc vector 1325*c8564c30SJames Wright @param[out] X Output PETSc vector 1326*c8564c30SJames Wright 1327*c8564c30SJames Wright @return An error code: 0 - success, otherwise - failure 1328*c8564c30SJames Wright **/ 1329*c8564c30SJames Wright PetscErrorCode MatMultTranspose_Ceed(Mat A, Vec Y, Vec X) { 1330*c8564c30SJames Wright MatCeedContext ctx; 1331*c8564c30SJames Wright 1332*c8564c30SJames Wright PetscFunctionBeginUser; 1333*c8564c30SJames Wright PetscCall(MatShellGetContext(A, &ctx)); 1334*c8564c30SJames Wright PetscCall(PetscLogEventBegin(ctx->log_event_mult_transpose, A, Y, X, 0)); 1335*c8564c30SJames Wright 1336*c8564c30SJames Wright { 1337*c8564c30SJames Wright PetscMemType x_mem_type, y_mem_type; 1338*c8564c30SJames Wright Vec X_loc, Y_loc = ctx->Y_loc_transpose; 1339*c8564c30SJames Wright 1340*c8564c30SJames Wright // Get local vectors 1341*c8564c30SJames Wright if (!ctx->Y_loc_transpose) PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc)); 1342*c8564c30SJames Wright PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc)); 1343*c8564c30SJames Wright 1344*c8564c30SJames Wright // Global-to-local 1345*c8564c30SJames Wright PetscCall(DMGlobalToLocal(ctx->dm_y, Y, INSERT_VALUES, Y_loc)); 1346*c8564c30SJames Wright 1347*c8564c30SJames Wright // Setup libCEED vectors 1348*c8564c30SJames Wright PetscCall(VecReadP2C(ctx->ceed, Y_loc, &y_mem_type, ctx->y_loc)); 1349*c8564c30SJames Wright PetscCall(VecZeroEntries(X_loc)); 1350*c8564c30SJames Wright PetscCall(VecP2C(ctx->ceed, X_loc, &x_mem_type, ctx->x_loc)); 1351*c8564c30SJames Wright 1352*c8564c30SJames Wright // Apply libCEED operator 1353*c8564c30SJames Wright PetscCall(PetscLogGpuTimeBegin()); 1354*c8564c30SJames Wright PetscCeedCall(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult_transpose, ctx->y_loc, ctx->x_loc, CEED_REQUEST_IMMEDIATE)); 1355*c8564c30SJames Wright PetscCall(PetscLogGpuTimeEnd()); 1356*c8564c30SJames Wright 1357*c8564c30SJames Wright // Restore PETSc vectors 1358*c8564c30SJames Wright PetscCall(VecReadC2P(ctx->ceed, ctx->y_loc, y_mem_type, Y_loc)); 1359*c8564c30SJames Wright PetscCall(VecC2P(ctx->ceed, ctx->x_loc, x_mem_type, X_loc)); 1360*c8564c30SJames Wright 1361*c8564c30SJames Wright // Local-to-global 1362*c8564c30SJames Wright PetscCall(VecZeroEntries(X)); 1363*c8564c30SJames Wright PetscCall(DMLocalToGlobal(ctx->dm_x, X_loc, ADD_VALUES, X)); 1364*c8564c30SJames Wright 1365*c8564c30SJames Wright // Restore local vectors, as needed 1366*c8564c30SJames Wright if (!ctx->Y_loc_transpose) PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc)); 1367*c8564c30SJames Wright PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc)); 1368*c8564c30SJames Wright } 1369*c8564c30SJames Wright 1370*c8564c30SJames Wright // Log flops 1371*c8564c30SJames Wright if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult_transpose)); 1372*c8564c30SJames Wright else PetscCall(PetscLogFlops(ctx->flops_mult_transpose)); 1373*c8564c30SJames Wright 1374*c8564c30SJames Wright PetscCall(PetscLogEventEnd(ctx->log_event_mult_transpose, A, Y, X, 0)); 1375*c8564c30SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 1376*c8564c30SJames Wright } 1377