158600ac3SJames Wright /// @file 258600ac3SJames Wright /// MatCeed and it's related operators 358600ac3SJames Wright 4*a7dac1d5SJames Wright #include <ceed-utils.h> 558600ac3SJames Wright #include <ceed.h> 658600ac3SJames Wright #include <ceed/backend.h> 758600ac3SJames Wright #include <mat-ceed-impl.h> 858600ac3SJames Wright #include <mat-ceed.h> 958600ac3SJames Wright #include <petscdmplex.h> 1058600ac3SJames Wright #include <stdlib.h> 1158600ac3SJames Wright #include <string.h> 1258600ac3SJames Wright 1358600ac3SJames Wright PetscClassId MATCEED_CLASSID; 1458600ac3SJames Wright PetscLogEvent MATCEED_MULT, MATCEED_MULT_TRANSPOSE; 1558600ac3SJames Wright 1658600ac3SJames Wright /** 1758600ac3SJames Wright @brief Register MATCEED log events. 1858600ac3SJames Wright 1958600ac3SJames Wright Not collective across MPI processes. 2058600ac3SJames Wright 2158600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 2258600ac3SJames Wright **/ 2358600ac3SJames Wright static PetscErrorCode MatCeedRegisterLogEvents() { 2458600ac3SJames Wright static bool registered = false; 2558600ac3SJames Wright 2658600ac3SJames Wright PetscFunctionBeginUser; 2758600ac3SJames Wright if (registered) PetscFunctionReturn(PETSC_SUCCESS); 2858600ac3SJames Wright PetscCall(PetscClassIdRegister("MATCEED", &MATCEED_CLASSID)); 2958600ac3SJames Wright PetscCall(PetscLogEventRegister("MATCEED Mult", MATCEED_CLASSID, &MATCEED_MULT)); 3058600ac3SJames Wright PetscCall(PetscLogEventRegister("MATCEED Mult Transpose", MATCEED_CLASSID, &MATCEED_MULT_TRANSPOSE)); 3158600ac3SJames Wright registered = true; 3258600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 3358600ac3SJames Wright } 3458600ac3SJames Wright 3558600ac3SJames Wright /** 3658600ac3SJames Wright @brief Setup inner `Mat` for `PC` operations not directly supported by libCEED. 3758600ac3SJames Wright 3858600ac3SJames Wright Collective across MPI processes. 3958600ac3SJames Wright 4058600ac3SJames Wright @param[in] mat_ceed `MATCEED` to setup 4158600ac3SJames Wright @param[out] mat_inner Inner `Mat` 4258600ac3SJames Wright 4358600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 4458600ac3SJames Wright **/ 4558600ac3SJames Wright static PetscErrorCode MatCeedSetupInnerMat(Mat mat_ceed, Mat *mat_inner) { 4658600ac3SJames Wright MatCeedContext ctx; 4758600ac3SJames Wright 4858600ac3SJames Wright PetscFunctionBeginUser; 4958600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 5058600ac3SJames Wright 5158600ac3SJames Wright PetscCheck(ctx->dm_x == ctx->dm_y, PetscObjectComm((PetscObject)mat_ceed), PETSC_ERR_SUP, "PC only supported for MATCEED on a single DM"); 5258600ac3SJames Wright 5358600ac3SJames Wright // Check cl mat type 5458600ac3SJames Wright { 5558600ac3SJames Wright PetscBool is_internal_mat_type_cl = PETSC_FALSE; 5658600ac3SJames Wright char internal_mat_type_cl[64]; 5758600ac3SJames Wright 5858600ac3SJames Wright // Check for specific CL inner mat type for this Mat 5958600ac3SJames Wright { 6058600ac3SJames Wright const char *mat_ceed_prefix = NULL; 6158600ac3SJames Wright 6258600ac3SJames Wright PetscCall(MatGetOptionsPrefix(mat_ceed, &mat_ceed_prefix)); 6358600ac3SJames Wright PetscOptionsBegin(PetscObjectComm((PetscObject)mat_ceed), mat_ceed_prefix, "", NULL); 6458600ac3SJames Wright PetscCall(PetscOptionsFList("-ceed_inner_mat_type", "MATCEED inner assembled MatType for PC support", NULL, MatList, internal_mat_type_cl, 6558600ac3SJames Wright internal_mat_type_cl, sizeof(internal_mat_type_cl), &is_internal_mat_type_cl)); 6658600ac3SJames Wright PetscOptionsEnd(); 6758600ac3SJames Wright if (is_internal_mat_type_cl) { 6858600ac3SJames Wright PetscCall(PetscFree(ctx->internal_mat_type)); 6958600ac3SJames Wright PetscCall(PetscStrallocpy(internal_mat_type_cl, &ctx->internal_mat_type)); 7058600ac3SJames Wright } 7158600ac3SJames Wright } 7258600ac3SJames Wright } 7358600ac3SJames Wright 7458600ac3SJames Wright // Create sparse matrix 7558600ac3SJames Wright { 7658600ac3SJames Wright MatType dm_mat_type, dm_mat_type_copy; 7758600ac3SJames Wright 7858600ac3SJames Wright PetscCall(DMGetMatType(ctx->dm_x, &dm_mat_type)); 7958600ac3SJames Wright PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy)); 8058600ac3SJames Wright PetscCall(DMSetMatType(ctx->dm_x, ctx->internal_mat_type)); 8158600ac3SJames Wright PetscCall(DMCreateMatrix(ctx->dm_x, mat_inner)); 8258600ac3SJames Wright PetscCall(DMSetMatType(ctx->dm_x, dm_mat_type_copy)); 8358600ac3SJames Wright PetscCall(PetscFree(dm_mat_type_copy)); 8458600ac3SJames Wright } 8558600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 8658600ac3SJames Wright } 8758600ac3SJames Wright 8858600ac3SJames Wright /** 8958600ac3SJames Wright @brief Assemble the point block diagonal of a `MATCEED` into a `MATAIJ` or similar. 9058600ac3SJames Wright The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`. 9158600ac3SJames Wright The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail. 9258600ac3SJames Wright 9358600ac3SJames Wright Collective across MPI processes. 9458600ac3SJames Wright 9558600ac3SJames Wright @param[in] mat_ceed `MATCEED` to assemble 9658600ac3SJames Wright @param[in,out] mat_coo `MATAIJ` or similar to assemble into 9758600ac3SJames Wright 9858600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 9958600ac3SJames Wright **/ 10058600ac3SJames Wright static PetscErrorCode MatCeedAssemblePointBlockDiagonalCOO(Mat mat_ceed, Mat mat_coo) { 10158600ac3SJames Wright MatCeedContext ctx; 10258600ac3SJames Wright 10358600ac3SJames Wright PetscFunctionBeginUser; 10458600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 10558600ac3SJames Wright 10658600ac3SJames Wright // Check if COO pattern set 10758600ac3SJames Wright { 10858600ac3SJames Wright PetscInt index = -1; 10958600ac3SJames Wright 11058600ac3SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) { 11158600ac3SJames Wright if (ctx->mats_assembled_pbd[i] == mat_coo) index = i; 11258600ac3SJames Wright } 11358600ac3SJames Wright if (index == -1) { 11458600ac3SJames Wright PetscInt *rows_petsc = NULL, *cols_petsc = NULL; 11558600ac3SJames Wright CeedInt *rows_ceed, *cols_ceed; 11658600ac3SJames Wright PetscCount num_entries; 11758600ac3SJames Wright PetscLogStage stage_amg_setup; 11858600ac3SJames Wright 11958600ac3SJames Wright // -- Assemble sparsity pattern if mat hasn't been assembled before 12058600ac3SJames Wright PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup)); 12158600ac3SJames Wright if (stage_amg_setup == -1) { 12258600ac3SJames Wright PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup)); 12358600ac3SJames Wright } 12458600ac3SJames Wright PetscCall(PetscLogStagePush(stage_amg_setup)); 12550f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonalSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed)); 126*a7dac1d5SJames Wright PetscCall(IntArrayCeedToPetsc(num_entries, &rows_ceed, &rows_petsc)); 127*a7dac1d5SJames Wright PetscCall(IntArrayCeedToPetsc(num_entries, &cols_ceed, &cols_petsc)); 12858600ac3SJames Wright PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc)); 12958600ac3SJames Wright free(rows_petsc); 13058600ac3SJames Wright free(cols_petsc); 13150f50432SJames Wright if (!ctx->coo_values_pbd) PetscCallCeed(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_pbd)); 13258600ac3SJames Wright PetscCall(PetscRealloc(++ctx->num_mats_assembled_pbd * sizeof(Mat), &ctx->mats_assembled_pbd)); 13358600ac3SJames Wright ctx->mats_assembled_pbd[ctx->num_mats_assembled_pbd - 1] = mat_coo; 13458600ac3SJames Wright PetscCall(PetscLogStagePop()); 13558600ac3SJames Wright } 13658600ac3SJames Wright } 13758600ac3SJames Wright 13858600ac3SJames Wright // Assemble mat_ceed 13958600ac3SJames Wright PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY)); 14058600ac3SJames Wright { 14158600ac3SJames Wright const CeedScalar *values; 14258600ac3SJames Wright MatType mat_type; 14358600ac3SJames Wright CeedMemType mem_type = CEED_MEM_HOST; 14458600ac3SJames Wright PetscBool is_spd, is_spd_known; 14558600ac3SJames Wright 14658600ac3SJames Wright PetscCall(MatGetType(mat_coo, &mat_type)); 14758600ac3SJames Wright if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE; 14858600ac3SJames Wright else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE; 14958600ac3SJames Wright else mem_type = CEED_MEM_HOST; 15058600ac3SJames Wright 15150f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonal(ctx->op_mult, ctx->coo_values_pbd, CEED_REQUEST_IMMEDIATE)); 15250f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_pbd, mem_type, &values)); 15358600ac3SJames Wright PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES)); 15458600ac3SJames Wright PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd)); 15558600ac3SJames Wright if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd)); 15650f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_pbd, &values)); 15758600ac3SJames Wright } 15858600ac3SJames Wright PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY)); 15958600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 16058600ac3SJames Wright } 16158600ac3SJames Wright 16258600ac3SJames Wright /** 16358600ac3SJames Wright @brief Assemble inner `Mat` for diagonal `PC` operations 16458600ac3SJames Wright 16558600ac3SJames Wright Collective across MPI processes. 16658600ac3SJames Wright 16758600ac3SJames Wright @param[in] mat_ceed `MATCEED` to invert 16858600ac3SJames Wright @param[in] use_ceed_pbd Boolean flag to use libCEED PBD assembly 16958600ac3SJames Wright @param[out] mat_inner Inner `Mat` for diagonal operations 17058600ac3SJames Wright 17158600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 17258600ac3SJames Wright **/ 17358600ac3SJames Wright static PetscErrorCode MatCeedAssembleInnerBlockDiagonalMat(Mat mat_ceed, PetscBool use_ceed_pbd, Mat *mat_inner) { 17458600ac3SJames Wright MatCeedContext ctx; 17558600ac3SJames Wright 17658600ac3SJames Wright PetscFunctionBeginUser; 17758600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 17858600ac3SJames Wright if (use_ceed_pbd) { 17958600ac3SJames Wright // Check if COO pattern set 18058600ac3SJames Wright if (!ctx->mat_assembled_pbd_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_pbd_internal)); 18158600ac3SJames Wright 18258600ac3SJames Wright // Assemble mat_assembled_full_internal 18358600ac3SJames Wright PetscCall(MatCeedAssemblePointBlockDiagonalCOO(mat_ceed, ctx->mat_assembled_pbd_internal)); 18458600ac3SJames Wright if (mat_inner) *mat_inner = ctx->mat_assembled_pbd_internal; 18558600ac3SJames Wright } else { 18658600ac3SJames Wright // Check if COO pattern set 18758600ac3SJames Wright if (!ctx->mat_assembled_full_internal) PetscCall(MatCeedSetupInnerMat(mat_ceed, &ctx->mat_assembled_full_internal)); 18858600ac3SJames Wright 18958600ac3SJames Wright // Assemble mat_assembled_full_internal 19058600ac3SJames Wright PetscCall(MatCeedAssembleCOO(mat_ceed, ctx->mat_assembled_full_internal)); 19158600ac3SJames Wright if (mat_inner) *mat_inner = ctx->mat_assembled_full_internal; 19258600ac3SJames Wright } 19358600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 19458600ac3SJames Wright } 19558600ac3SJames Wright 19658600ac3SJames Wright /** 19758600ac3SJames Wright @brief Get `MATCEED` diagonal block for Jacobi. 19858600ac3SJames Wright 19958600ac3SJames Wright Collective across MPI processes. 20058600ac3SJames Wright 20158600ac3SJames Wright @param[in] mat_ceed `MATCEED` to invert 20258600ac3SJames Wright @param[out] mat_block The diagonal block matrix 20358600ac3SJames Wright 20458600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 20558600ac3SJames Wright **/ 20658600ac3SJames Wright static PetscErrorCode MatGetDiagonalBlock_Ceed(Mat mat_ceed, Mat *mat_block) { 20758600ac3SJames Wright Mat mat_inner = NULL; 20858600ac3SJames Wright MatCeedContext ctx; 20958600ac3SJames Wright 21058600ac3SJames Wright PetscFunctionBeginUser; 21158600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 21258600ac3SJames Wright 21358600ac3SJames Wright // Assemble inner mat if needed 21458600ac3SJames Wright PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner)); 21558600ac3SJames Wright 21658600ac3SJames Wright // Get block diagonal 21758600ac3SJames Wright PetscCall(MatGetDiagonalBlock(mat_inner, mat_block)); 21858600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 21958600ac3SJames Wright } 22058600ac3SJames Wright 22158600ac3SJames Wright /** 22258600ac3SJames Wright @brief Invert `MATCEED` diagonal block for Jacobi. 22358600ac3SJames Wright 22458600ac3SJames Wright Collective across MPI processes. 22558600ac3SJames Wright 22658600ac3SJames Wright @param[in] mat_ceed `MATCEED` to invert 22758600ac3SJames Wright @param[out] values The block inverses in column major order 22858600ac3SJames Wright 22958600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 23058600ac3SJames Wright **/ 23158600ac3SJames Wright static PetscErrorCode MatInvertBlockDiagonal_Ceed(Mat mat_ceed, const PetscScalar **values) { 23258600ac3SJames Wright Mat mat_inner = NULL; 23358600ac3SJames Wright MatCeedContext ctx; 23458600ac3SJames Wright 23558600ac3SJames Wright PetscFunctionBeginUser; 23658600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 23758600ac3SJames Wright 23858600ac3SJames Wright // Assemble inner mat if needed 23958600ac3SJames Wright PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner)); 24058600ac3SJames Wright 24158600ac3SJames Wright // Invert PB diagonal 24258600ac3SJames Wright PetscCall(MatInvertBlockDiagonal(mat_inner, values)); 24358600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 24458600ac3SJames Wright } 24558600ac3SJames Wright 24658600ac3SJames Wright /** 24758600ac3SJames Wright @brief Invert `MATCEED` variable diagonal block for Jacobi. 24858600ac3SJames Wright 24958600ac3SJames Wright Collective across MPI processes. 25058600ac3SJames Wright 25158600ac3SJames Wright @param[in] mat_ceed `MATCEED` to invert 25258600ac3SJames Wright @param[in] num_blocks The number of blocks on the process 25358600ac3SJames Wright @param[in] block_sizes The size of each block on the process 25458600ac3SJames Wright @param[out] values The block inverses in column major order 25558600ac3SJames Wright 25658600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 25758600ac3SJames Wright **/ 25858600ac3SJames Wright static PetscErrorCode MatInvertVariableBlockDiagonal_Ceed(Mat mat_ceed, PetscInt num_blocks, const PetscInt *block_sizes, PetscScalar *values) { 25958600ac3SJames Wright Mat mat_inner = NULL; 26058600ac3SJames Wright MatCeedContext ctx; 26158600ac3SJames Wright 26258600ac3SJames Wright PetscFunctionBeginUser; 26358600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 26458600ac3SJames Wright 26558600ac3SJames Wright // Assemble inner mat if needed 26658600ac3SJames Wright PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_vpbd_valid, &mat_inner)); 26758600ac3SJames Wright 26858600ac3SJames Wright // Invert PB diagonal 26958600ac3SJames Wright PetscCall(MatInvertVariableBlockDiagonal(mat_inner, num_blocks, block_sizes, values)); 27058600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 27158600ac3SJames Wright } 27258600ac3SJames Wright 27358600ac3SJames Wright // ----------------------------------------------------------------------------- 27458600ac3SJames Wright // MatCeed 27558600ac3SJames Wright // ----------------------------------------------------------------------------- 27658600ac3SJames Wright 27758600ac3SJames Wright /** 27858600ac3SJames Wright @brief Create PETSc `Mat` from libCEED operators. 27958600ac3SJames Wright 28058600ac3SJames Wright Collective across MPI processes. 28158600ac3SJames Wright 28258600ac3SJames Wright @param[in] dm_x Input `DM` 28358600ac3SJames Wright @param[in] dm_y Output `DM` 28458600ac3SJames Wright @param[in] op_mult `CeedOperator` for forward evaluation 28558600ac3SJames Wright @param[in] op_mult_transpose `CeedOperator` for transpose evaluation 28658600ac3SJames Wright @param[out] mat New MatCeed 28758600ac3SJames Wright 28858600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 28958600ac3SJames Wright **/ 29058600ac3SJames Wright PetscErrorCode MatCeedCreate(DM dm_x, DM dm_y, CeedOperator op_mult, CeedOperator op_mult_transpose, Mat *mat) { 29158600ac3SJames Wright PetscInt X_l_size, X_g_size, Y_l_size, Y_g_size; 29258600ac3SJames Wright VecType vec_type; 29358600ac3SJames Wright MatCeedContext ctx; 29458600ac3SJames Wright 29558600ac3SJames Wright PetscFunctionBeginUser; 29658600ac3SJames Wright PetscCall(MatCeedRegisterLogEvents()); 29758600ac3SJames Wright 29858600ac3SJames Wright // Collect context data 29958600ac3SJames Wright PetscCall(DMGetVecType(dm_x, &vec_type)); 30058600ac3SJames Wright { 30158600ac3SJames Wright Vec X; 30258600ac3SJames Wright 30358600ac3SJames Wright PetscCall(DMGetGlobalVector(dm_x, &X)); 30458600ac3SJames Wright PetscCall(VecGetSize(X, &X_g_size)); 30558600ac3SJames Wright PetscCall(VecGetLocalSize(X, &X_l_size)); 30658600ac3SJames Wright PetscCall(DMRestoreGlobalVector(dm_x, &X)); 30758600ac3SJames Wright } 30858600ac3SJames Wright if (dm_y) { 30958600ac3SJames Wright Vec Y; 31058600ac3SJames Wright 31158600ac3SJames Wright PetscCall(DMGetGlobalVector(dm_y, &Y)); 31258600ac3SJames Wright PetscCall(VecGetSize(Y, &Y_g_size)); 31358600ac3SJames Wright PetscCall(VecGetLocalSize(Y, &Y_l_size)); 31458600ac3SJames Wright PetscCall(DMRestoreGlobalVector(dm_y, &Y)); 31558600ac3SJames Wright } else { 31658600ac3SJames Wright dm_y = dm_x; 31758600ac3SJames Wright Y_g_size = X_g_size; 31858600ac3SJames Wright Y_l_size = X_l_size; 31958600ac3SJames Wright } 32058600ac3SJames Wright // Create context 32158600ac3SJames Wright { 32258600ac3SJames Wright Vec X_loc, Y_loc_transpose = NULL; 32358600ac3SJames Wright 32458600ac3SJames Wright PetscCall(DMCreateLocalVector(dm_x, &X_loc)); 32558600ac3SJames Wright PetscCall(VecZeroEntries(X_loc)); 32658600ac3SJames Wright if (op_mult_transpose) { 32758600ac3SJames Wright PetscCall(DMCreateLocalVector(dm_y, &Y_loc_transpose)); 32858600ac3SJames Wright PetscCall(VecZeroEntries(Y_loc_transpose)); 32958600ac3SJames Wright } 33058600ac3SJames Wright PetscCall(MatCeedContextCreate(dm_x, dm_y, X_loc, Y_loc_transpose, op_mult, op_mult_transpose, MATCEED_MULT, MATCEED_MULT_TRANSPOSE, &ctx)); 33158600ac3SJames Wright PetscCall(VecDestroy(&X_loc)); 33258600ac3SJames Wright PetscCall(VecDestroy(&Y_loc_transpose)); 33358600ac3SJames Wright } 33458600ac3SJames Wright 33558600ac3SJames Wright // Create mat 33658600ac3SJames Wright PetscCall(MatCreateShell(PetscObjectComm((PetscObject)dm_x), Y_l_size, X_l_size, Y_g_size, X_g_size, ctx, mat)); 33758600ac3SJames Wright PetscCall(PetscObjectChangeTypeName((PetscObject)*mat, MATCEED)); 33858600ac3SJames Wright // -- Set block and variable block sizes 33958600ac3SJames Wright if (dm_x == dm_y) { 34058600ac3SJames Wright MatType dm_mat_type, dm_mat_type_copy; 34158600ac3SJames Wright Mat temp_mat; 34258600ac3SJames Wright 34358600ac3SJames Wright PetscCall(DMGetMatType(dm_x, &dm_mat_type)); 34458600ac3SJames Wright PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy)); 34558600ac3SJames Wright PetscCall(DMSetMatType(dm_x, MATAIJ)); 34658600ac3SJames Wright PetscCall(DMCreateMatrix(dm_x, &temp_mat)); 34758600ac3SJames Wright PetscCall(DMSetMatType(dm_x, dm_mat_type_copy)); 34858600ac3SJames Wright PetscCall(PetscFree(dm_mat_type_copy)); 34958600ac3SJames Wright 35058600ac3SJames Wright { 35158600ac3SJames Wright PetscInt block_size, num_blocks, max_vblock_size = PETSC_INT_MAX; 35258600ac3SJames Wright const PetscInt *vblock_sizes; 35358600ac3SJames Wright 35458600ac3SJames Wright // -- Get block sizes 35558600ac3SJames Wright PetscCall(MatGetBlockSize(temp_mat, &block_size)); 35658600ac3SJames Wright PetscCall(MatGetVariableBlockSizes(temp_mat, &num_blocks, &vblock_sizes)); 35758600ac3SJames Wright { 35858600ac3SJames Wright PetscInt local_min_max[2] = {0}, global_min_max[2] = {0, PETSC_INT_MAX}; 35958600ac3SJames Wright 36058600ac3SJames Wright for (PetscInt i = 0; i < num_blocks; i++) local_min_max[1] = PetscMax(local_min_max[1], vblock_sizes[i]); 36158600ac3SJames Wright PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_min_max, global_min_max)); 36258600ac3SJames Wright max_vblock_size = global_min_max[1]; 36358600ac3SJames Wright } 36458600ac3SJames Wright 36558600ac3SJames Wright // -- Copy block sizes 36658600ac3SJames Wright if (block_size > 1) PetscCall(MatSetBlockSize(*mat, block_size)); 36758600ac3SJames Wright if (num_blocks) PetscCall(MatSetVariableBlockSizes(*mat, num_blocks, (PetscInt *)vblock_sizes)); 36858600ac3SJames Wright 36958600ac3SJames Wright // -- Check libCEED compatibility 37058600ac3SJames Wright { 37158600ac3SJames Wright bool is_composite; 37258600ac3SJames Wright 37358600ac3SJames Wright ctx->is_ceed_pbd_valid = PETSC_TRUE; 37458600ac3SJames Wright ctx->is_ceed_vpbd_valid = PETSC_TRUE; 37550f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorIsComposite(op_mult, &is_composite)); 37658600ac3SJames Wright if (is_composite) { 37758600ac3SJames Wright CeedInt num_sub_operators; 37858600ac3SJames Wright CeedOperator *sub_operators; 37958600ac3SJames Wright 38050f50432SJames Wright PetscCallCeed(ctx->ceed, CeedCompositeOperatorGetNumSub(op_mult, &num_sub_operators)); 38150f50432SJames Wright PetscCallCeed(ctx->ceed, CeedCompositeOperatorGetSubList(op_mult, &sub_operators)); 38258600ac3SJames Wright for (CeedInt i = 0; i < num_sub_operators; i++) { 38358600ac3SJames Wright CeedInt num_bases, num_comp; 38458600ac3SJames Wright CeedBasis *active_bases; 38558600ac3SJames Wright CeedOperatorAssemblyData assembly_data; 38658600ac3SJames Wright 38750f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorGetOperatorAssemblyData(sub_operators[i], &assembly_data)); 38850f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL)); 38950f50432SJames Wright PetscCallCeed(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp)); 39058600ac3SJames Wright if (num_bases > 1) { 39158600ac3SJames Wright ctx->is_ceed_pbd_valid = PETSC_FALSE; 39258600ac3SJames Wright ctx->is_ceed_vpbd_valid = PETSC_FALSE; 39358600ac3SJames Wright } 39458600ac3SJames Wright if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE; 39558600ac3SJames Wright if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE; 39658600ac3SJames Wright } 39758600ac3SJames Wright } else { 39858600ac3SJames Wright // LCOV_EXCL_START 39958600ac3SJames Wright CeedInt num_bases, num_comp; 40058600ac3SJames Wright CeedBasis *active_bases; 40158600ac3SJames Wright CeedOperatorAssemblyData assembly_data; 40258600ac3SJames Wright 40350f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorGetOperatorAssemblyData(op_mult, &assembly_data)); 40450f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL)); 40550f50432SJames Wright PetscCallCeed(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp)); 40658600ac3SJames Wright if (num_bases > 1) { 40758600ac3SJames Wright ctx->is_ceed_pbd_valid = PETSC_FALSE; 40858600ac3SJames Wright ctx->is_ceed_vpbd_valid = PETSC_FALSE; 40958600ac3SJames Wright } 41058600ac3SJames Wright if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE; 41158600ac3SJames Wright if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE; 41258600ac3SJames Wright // LCOV_EXCL_STOP 41358600ac3SJames Wright } 41458600ac3SJames Wright { 41558600ac3SJames Wright PetscInt local_is_valid[2], global_is_valid[2]; 41658600ac3SJames Wright 41758600ac3SJames Wright local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_pbd_valid; 41858600ac3SJames Wright PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid)); 41958600ac3SJames Wright ctx->is_ceed_pbd_valid = global_is_valid[0]; 42058600ac3SJames Wright local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_vpbd_valid; 42158600ac3SJames Wright PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid)); 42258600ac3SJames Wright ctx->is_ceed_vpbd_valid = global_is_valid[0]; 42358600ac3SJames Wright } 42458600ac3SJames Wright } 42558600ac3SJames Wright } 42658600ac3SJames Wright PetscCall(MatDestroy(&temp_mat)); 42758600ac3SJames Wright } 42858600ac3SJames Wright // -- Set internal mat type 42958600ac3SJames Wright { 43058600ac3SJames Wright VecType vec_type; 43158600ac3SJames Wright MatType internal_mat_type = MATAIJ; 43258600ac3SJames Wright 43358600ac3SJames Wright PetscCall(VecGetType(ctx->X_loc, &vec_type)); 43458600ac3SJames Wright if (strstr(vec_type, VECCUDA)) internal_mat_type = MATAIJCUSPARSE; 43558600ac3SJames Wright else if (strstr(vec_type, VECKOKKOS)) internal_mat_type = MATAIJKOKKOS; 43658600ac3SJames Wright else internal_mat_type = MATAIJ; 43758600ac3SJames Wright PetscCall(PetscStrallocpy(internal_mat_type, &ctx->internal_mat_type)); 43858600ac3SJames Wright } 43958600ac3SJames Wright // -- Set mat operations 44058600ac3SJames Wright PetscCall(MatShellSetContextDestroy(*mat, (PetscErrorCode(*)(void *))MatCeedContextDestroy)); 44158600ac3SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_MULT, (void (*)(void))MatMult_Ceed)); 44258600ac3SJames Wright if (op_mult_transpose) PetscCall(MatShellSetOperation(*mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed)); 44358600ac3SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed)); 44458600ac3SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed)); 44558600ac3SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed)); 44658600ac3SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed)); 44758600ac3SJames Wright PetscCall(MatShellSetVecType(*mat, vec_type)); 44858600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 44958600ac3SJames Wright } 45058600ac3SJames Wright 45158600ac3SJames Wright /** 45258600ac3SJames Wright @brief Copy `MATCEED` into a compatible `Mat` with type `MatShell` or `MATCEED`. 45358600ac3SJames Wright 45458600ac3SJames Wright Collective across MPI processes. 45558600ac3SJames Wright 45658600ac3SJames Wright @param[in] mat_ceed `MATCEED` to copy from 45758600ac3SJames Wright @param[out] mat_other `MatShell` or `MATCEED` to copy into 45858600ac3SJames Wright 45958600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 46058600ac3SJames Wright **/ 46158600ac3SJames Wright PetscErrorCode MatCeedCopy(Mat mat_ceed, Mat mat_other) { 46258600ac3SJames Wright PetscFunctionBeginUser; 46358600ac3SJames Wright PetscCall(MatCeedRegisterLogEvents()); 46458600ac3SJames Wright 46558600ac3SJames Wright // Check type compatibility 46658600ac3SJames Wright { 46758600ac3SJames Wright MatType mat_type_ceed, mat_type_other; 46858600ac3SJames Wright 46958600ac3SJames Wright PetscCall(MatGetType(mat_ceed, &mat_type_ceed)); 47058600ac3SJames Wright PetscCheck(!strcmp(mat_type_ceed, MATCEED), PETSC_COMM_SELF, PETSC_ERR_LIB, "mat_ceed must have type " MATCEED); 47158600ac3SJames Wright PetscCall(MatGetType(mat_ceed, &mat_type_other)); 47258600ac3SJames Wright PetscCheck(!strcmp(mat_type_other, MATCEED) || !strcmp(mat_type_other, MATSHELL), PETSC_COMM_SELF, PETSC_ERR_LIB, 47358600ac3SJames Wright "mat_other must have type " MATCEED " or " MATSHELL); 47458600ac3SJames Wright } 47558600ac3SJames Wright 47658600ac3SJames Wright // Check dimension compatibility 47758600ac3SJames Wright { 47858600ac3SJames Wright PetscInt X_l_ceed_size, X_g_ceed_size, Y_l_ceed_size, Y_g_ceed_size, X_l_other_size, X_g_other_size, Y_l_other_size, Y_g_other_size; 47958600ac3SJames Wright 48058600ac3SJames Wright PetscCall(MatGetSize(mat_ceed, &Y_g_ceed_size, &X_g_ceed_size)); 48158600ac3SJames Wright PetscCall(MatGetLocalSize(mat_ceed, &Y_l_ceed_size, &X_l_ceed_size)); 48258600ac3SJames Wright PetscCall(MatGetSize(mat_ceed, &Y_g_other_size, &X_g_other_size)); 48358600ac3SJames Wright PetscCall(MatGetLocalSize(mat_ceed, &Y_l_other_size, &X_l_other_size)); 48458600ac3SJames Wright PetscCheck((Y_g_ceed_size == Y_g_other_size) && (X_g_ceed_size == X_g_other_size) && (Y_l_ceed_size == Y_l_other_size) && 48558600ac3SJames Wright (X_l_ceed_size == X_l_other_size), 48658600ac3SJames Wright PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, 48758600ac3SJames Wright "mat_ceed and mat_other must have compatible sizes; found mat_ceed (Global: %" PetscInt_FMT ", %" PetscInt_FMT 48858600ac3SJames Wright "; Local: %" PetscInt_FMT ", %" PetscInt_FMT ") mat_other (Global: %" PetscInt_FMT ", %" PetscInt_FMT "; Local: %" PetscInt_FMT 48958600ac3SJames Wright ", %" PetscInt_FMT ")", 49058600ac3SJames Wright Y_g_ceed_size, X_g_ceed_size, Y_l_ceed_size, X_l_ceed_size, Y_g_other_size, X_g_other_size, Y_l_other_size, X_l_other_size); 49158600ac3SJames Wright } 49258600ac3SJames Wright 49358600ac3SJames Wright // Convert 49458600ac3SJames Wright { 49558600ac3SJames Wright VecType vec_type; 49658600ac3SJames Wright MatCeedContext ctx; 49758600ac3SJames Wright 49858600ac3SJames Wright PetscCall(PetscObjectChangeTypeName((PetscObject)mat_other, MATCEED)); 49958600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 50058600ac3SJames Wright PetscCall(MatCeedContextReference(ctx)); 50158600ac3SJames Wright PetscCall(MatShellSetContext(mat_other, ctx)); 50258600ac3SJames Wright PetscCall(MatShellSetContextDestroy(mat_other, (PetscErrorCode(*)(void *))MatCeedContextDestroy)); 50358600ac3SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_MULT, (void (*)(void))MatMult_Ceed)); 50458600ac3SJames Wright if (ctx->op_mult_transpose) PetscCall(MatShellSetOperation(mat_other, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed)); 50558600ac3SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed)); 50658600ac3SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed)); 50758600ac3SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed)); 50858600ac3SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed)); 50958600ac3SJames Wright { 51058600ac3SJames Wright PetscInt block_size; 51158600ac3SJames Wright 51258600ac3SJames Wright PetscCall(MatGetBlockSize(mat_ceed, &block_size)); 51358600ac3SJames Wright if (block_size > 1) PetscCall(MatSetBlockSize(mat_other, block_size)); 51458600ac3SJames Wright } 51558600ac3SJames Wright { 51658600ac3SJames Wright PetscInt num_blocks; 51758600ac3SJames Wright const PetscInt *block_sizes; 51858600ac3SJames Wright 51958600ac3SJames Wright PetscCall(MatGetVariableBlockSizes(mat_ceed, &num_blocks, &block_sizes)); 52058600ac3SJames Wright if (num_blocks) PetscCall(MatSetVariableBlockSizes(mat_other, num_blocks, (PetscInt *)block_sizes)); 52158600ac3SJames Wright } 52258600ac3SJames Wright PetscCall(DMGetVecType(ctx->dm_x, &vec_type)); 52358600ac3SJames Wright PetscCall(MatShellSetVecType(mat_other, vec_type)); 52458600ac3SJames Wright } 52558600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 52658600ac3SJames Wright } 52758600ac3SJames Wright 52858600ac3SJames Wright /** 52958600ac3SJames Wright @brief Assemble a `MATCEED` into a `MATAIJ` or similar. 53058600ac3SJames Wright The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`. 53158600ac3SJames Wright The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail. 53258600ac3SJames Wright 53358600ac3SJames Wright Collective across MPI processes. 53458600ac3SJames Wright 53558600ac3SJames Wright @param[in] mat_ceed `MATCEED` to assemble 53658600ac3SJames Wright @param[in,out] mat_coo `MATAIJ` or similar to assemble into 53758600ac3SJames Wright 53858600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 53958600ac3SJames Wright **/ 54058600ac3SJames Wright PetscErrorCode MatCeedAssembleCOO(Mat mat_ceed, Mat mat_coo) { 54158600ac3SJames Wright MatCeedContext ctx; 54258600ac3SJames Wright 54358600ac3SJames Wright PetscFunctionBeginUser; 54458600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 54558600ac3SJames Wright 54658600ac3SJames Wright // Check if COO pattern set 54758600ac3SJames Wright { 54858600ac3SJames Wright PetscInt index = -1; 54958600ac3SJames Wright 55058600ac3SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) { 55158600ac3SJames Wright if (ctx->mats_assembled_full[i] == mat_coo) index = i; 55258600ac3SJames Wright } 55358600ac3SJames Wright if (index == -1) { 55458600ac3SJames Wright PetscInt *rows_petsc = NULL, *cols_petsc = NULL; 55558600ac3SJames Wright CeedInt *rows_ceed, *cols_ceed; 55658600ac3SJames Wright PetscCount num_entries; 55758600ac3SJames Wright PetscLogStage stage_amg_setup; 55858600ac3SJames Wright 55958600ac3SJames Wright // -- Assemble sparsity pattern if mat hasn't been assembled before 56058600ac3SJames Wright PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup)); 56158600ac3SJames Wright if (stage_amg_setup == -1) { 56258600ac3SJames Wright PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup)); 56358600ac3SJames Wright } 56458600ac3SJames Wright PetscCall(PetscLogStagePush(stage_amg_setup)); 56550f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorLinearAssembleSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed)); 566*a7dac1d5SJames Wright PetscCall(IntArrayCeedToPetsc(num_entries, &rows_ceed, &rows_petsc)); 567*a7dac1d5SJames Wright PetscCall(IntArrayCeedToPetsc(num_entries, &cols_ceed, &cols_petsc)); 56858600ac3SJames Wright PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc)); 56958600ac3SJames Wright free(rows_petsc); 57058600ac3SJames Wright free(cols_petsc); 57150f50432SJames Wright if (!ctx->coo_values_full) PetscCallCeed(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_full)); 57258600ac3SJames Wright PetscCall(PetscRealloc(++ctx->num_mats_assembled_full * sizeof(Mat), &ctx->mats_assembled_full)); 57358600ac3SJames Wright ctx->mats_assembled_full[ctx->num_mats_assembled_full - 1] = mat_coo; 57458600ac3SJames Wright PetscCall(PetscLogStagePop()); 57558600ac3SJames Wright } 57658600ac3SJames Wright } 57758600ac3SJames Wright 57858600ac3SJames Wright // Assemble mat_ceed 57958600ac3SJames Wright PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY)); 58058600ac3SJames Wright { 58158600ac3SJames Wright const CeedScalar *values; 58258600ac3SJames Wright MatType mat_type; 58358600ac3SJames Wright CeedMemType mem_type = CEED_MEM_HOST; 58458600ac3SJames Wright PetscBool is_spd, is_spd_known; 58558600ac3SJames Wright 58658600ac3SJames Wright PetscCall(MatGetType(mat_coo, &mat_type)); 58758600ac3SJames Wright if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE; 58858600ac3SJames Wright else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE; 58958600ac3SJames Wright else mem_type = CEED_MEM_HOST; 59058600ac3SJames Wright 59150f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemble(ctx->op_mult, ctx->coo_values_full)); 59250f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_full, mem_type, &values)); 59358600ac3SJames Wright PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES)); 59458600ac3SJames Wright PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd)); 59558600ac3SJames Wright if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd)); 59650f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_full, &values)); 59758600ac3SJames Wright } 59858600ac3SJames Wright PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY)); 59958600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 60058600ac3SJames Wright } 60158600ac3SJames Wright 60258600ac3SJames Wright /** 60358600ac3SJames Wright @brief Set user context for a `MATCEED`. 60458600ac3SJames Wright 60558600ac3SJames Wright Collective across MPI processes. 60658600ac3SJames Wright 60758600ac3SJames Wright @param[in,out] mat `MATCEED` 60858600ac3SJames Wright @param[in] f The context destroy function, or NULL 60958600ac3SJames Wright @param[in] ctx User context, or NULL to unset 61058600ac3SJames Wright 61158600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 61258600ac3SJames Wright **/ 61358600ac3SJames Wright PetscErrorCode MatCeedSetContext(Mat mat, PetscErrorCode (*f)(void *), void *ctx) { 61458600ac3SJames Wright PetscContainer user_ctx = NULL; 61558600ac3SJames Wright 61658600ac3SJames Wright PetscFunctionBeginUser; 61758600ac3SJames Wright if (ctx) { 61858600ac3SJames Wright PetscCall(PetscContainerCreate(PetscObjectComm((PetscObject)mat), &user_ctx)); 61958600ac3SJames Wright PetscCall(PetscContainerSetPointer(user_ctx, ctx)); 62058600ac3SJames Wright PetscCall(PetscContainerSetUserDestroy(user_ctx, f)); 62158600ac3SJames Wright } 62258600ac3SJames Wright PetscCall(PetscObjectCompose((PetscObject)mat, "MatCeed user context", (PetscObject)user_ctx)); 62358600ac3SJames Wright PetscCall(PetscContainerDestroy(&user_ctx)); 62458600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 62558600ac3SJames Wright } 62658600ac3SJames Wright 62758600ac3SJames Wright /** 62858600ac3SJames Wright @brief Retrieve the user context for a `MATCEED`. 62958600ac3SJames Wright 63058600ac3SJames Wright Collective across MPI processes. 63158600ac3SJames Wright 63258600ac3SJames Wright @param[in,out] mat `MATCEED` 63358600ac3SJames Wright @param[in] ctx User context 63458600ac3SJames Wright 63558600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 63658600ac3SJames Wright **/ 63758600ac3SJames Wright PetscErrorCode MatCeedGetContext(Mat mat, void *ctx) { 63858600ac3SJames Wright PetscContainer user_ctx; 63958600ac3SJames Wright 64058600ac3SJames Wright PetscFunctionBeginUser; 64158600ac3SJames Wright PetscCall(PetscObjectQuery((PetscObject)mat, "MatCeed user context", (PetscObject *)&user_ctx)); 64258600ac3SJames Wright if (user_ctx) PetscCall(PetscContainerGetPointer(user_ctx, (void **)ctx)); 64358600ac3SJames Wright else *(void **)ctx = NULL; 64458600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 64558600ac3SJames Wright } 64658600ac3SJames Wright 64758600ac3SJames Wright /** 64858600ac3SJames Wright @brief Sets the inner matrix type as a string from the `MATCEED`. 64958600ac3SJames Wright 65058600ac3SJames Wright Collective across MPI processes. 65158600ac3SJames Wright 65258600ac3SJames Wright @param[in,out] mat `MATCEED` 65358600ac3SJames Wright @param[in] type Inner `MatType` to set 65458600ac3SJames Wright 65558600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 65658600ac3SJames Wright **/ 65758600ac3SJames Wright PetscErrorCode MatCeedSetInnerMatType(Mat mat, MatType type) { 65858600ac3SJames Wright MatCeedContext ctx; 65958600ac3SJames Wright 66058600ac3SJames Wright PetscFunctionBeginUser; 66158600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 66258600ac3SJames Wright // Check if same 66358600ac3SJames Wright { 66458600ac3SJames Wright size_t len_old, len_new; 66558600ac3SJames Wright PetscBool is_same = PETSC_FALSE; 66658600ac3SJames Wright 66758600ac3SJames Wright PetscCall(PetscStrlen(ctx->internal_mat_type, &len_old)); 66858600ac3SJames Wright PetscCall(PetscStrlen(type, &len_new)); 66958600ac3SJames Wright if (len_old == len_new) PetscCall(PetscStrncmp(ctx->internal_mat_type, type, len_old, &is_same)); 67058600ac3SJames Wright if (is_same) PetscFunctionReturn(PETSC_SUCCESS); 67158600ac3SJames Wright } 67258600ac3SJames Wright // Clean up old mats in different format 67358600ac3SJames Wright // LCOV_EXCL_START 67458600ac3SJames Wright if (ctx->mat_assembled_full_internal) { 67558600ac3SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) { 67658600ac3SJames Wright if (ctx->mats_assembled_full[i] == ctx->mat_assembled_full_internal) { 67758600ac3SJames Wright for (PetscInt j = i + 1; j < ctx->num_mats_assembled_full; j++) { 67858600ac3SJames Wright ctx->mats_assembled_full[j - 1] = ctx->mats_assembled_full[j]; 67958600ac3SJames Wright } 68058600ac3SJames Wright ctx->num_mats_assembled_full--; 68158600ac3SJames Wright // Note: we'll realloc this array again, so no need to shrink the allocation 68258600ac3SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_full_internal)); 68358600ac3SJames Wright } 68458600ac3SJames Wright } 68558600ac3SJames Wright } 68658600ac3SJames Wright if (ctx->mat_assembled_pbd_internal) { 68758600ac3SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) { 68858600ac3SJames Wright if (ctx->mats_assembled_pbd[i] == ctx->mat_assembled_pbd_internal) { 68958600ac3SJames Wright for (PetscInt j = i + 1; j < ctx->num_mats_assembled_pbd; j++) { 69058600ac3SJames Wright ctx->mats_assembled_pbd[j - 1] = ctx->mats_assembled_pbd[j]; 69158600ac3SJames Wright } 69258600ac3SJames Wright // Note: we'll realloc this array again, so no need to shrink the allocation 69358600ac3SJames Wright ctx->num_mats_assembled_pbd--; 69458600ac3SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal)); 69558600ac3SJames Wright } 69658600ac3SJames Wright } 69758600ac3SJames Wright } 69858600ac3SJames Wright PetscCall(PetscFree(ctx->internal_mat_type)); 69958600ac3SJames Wright PetscCall(PetscStrallocpy(type, &ctx->internal_mat_type)); 70058600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 70158600ac3SJames Wright // LCOV_EXCL_STOP 70258600ac3SJames Wright } 70358600ac3SJames Wright 70458600ac3SJames Wright /** 70558600ac3SJames Wright @brief Gets the inner matrix type as a string from the `MATCEED`. 70658600ac3SJames Wright 70758600ac3SJames Wright Collective across MPI processes. 70858600ac3SJames Wright 70958600ac3SJames Wright @param[in,out] mat `MATCEED` 71058600ac3SJames Wright @param[in] type Inner `MatType` 71158600ac3SJames Wright 71258600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 71358600ac3SJames Wright **/ 71458600ac3SJames Wright PetscErrorCode MatCeedGetInnerMatType(Mat mat, MatType *type) { 71558600ac3SJames Wright MatCeedContext ctx; 71658600ac3SJames Wright 71758600ac3SJames Wright PetscFunctionBeginUser; 71858600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 71958600ac3SJames Wright *type = ctx->internal_mat_type; 72058600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 72158600ac3SJames Wright } 72258600ac3SJames Wright 72358600ac3SJames Wright /** 72458600ac3SJames Wright @brief Set a user defined matrix operation for a `MATCEED` matrix. 72558600ac3SJames Wright 72658600ac3SJames Wright Within each user-defined routine, the user should call `MatCeedGetContext()` to obtain the user-defined context that was set by 72758600ac3SJames Wright `MatCeedSetContext()`. 72858600ac3SJames Wright 72958600ac3SJames Wright Collective across MPI processes. 73058600ac3SJames Wright 73158600ac3SJames Wright @param[in,out] mat `MATCEED` 73258600ac3SJames Wright @param[in] op Name of the `MatOperation` 73358600ac3SJames Wright @param[in] g Function that provides the operation 73458600ac3SJames Wright 73558600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 73658600ac3SJames Wright **/ 73758600ac3SJames Wright PetscErrorCode MatCeedSetOperation(Mat mat, MatOperation op, void (*g)(void)) { 73858600ac3SJames Wright PetscFunctionBeginUser; 73958600ac3SJames Wright PetscCall(MatShellSetOperation(mat, op, g)); 74058600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 74158600ac3SJames Wright } 74258600ac3SJames Wright 74358600ac3SJames Wright /** 74458600ac3SJames Wright @brief Set input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 74558600ac3SJames Wright 74658600ac3SJames Wright Not collective across MPI processes. 74758600ac3SJames Wright 74858600ac3SJames Wright @param[in,out] mat `MATCEED` 74958600ac3SJames Wright @param[in] X_loc Input PETSc local vector, or NULL 75058600ac3SJames Wright @param[in] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 75158600ac3SJames Wright 75258600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 75358600ac3SJames Wright **/ 75458600ac3SJames Wright PetscErrorCode MatCeedSetLocalVectors(Mat mat, Vec X_loc, Vec Y_loc_transpose) { 75558600ac3SJames Wright MatCeedContext ctx; 75658600ac3SJames Wright 75758600ac3SJames Wright PetscFunctionBeginUser; 75858600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 75958600ac3SJames Wright if (X_loc) { 76058600ac3SJames Wright PetscInt len_old, len_new; 76158600ac3SJames Wright 76258600ac3SJames Wright PetscCall(VecGetSize(ctx->X_loc, &len_old)); 76358600ac3SJames Wright PetscCall(VecGetSize(X_loc, &len_new)); 76458600ac3SJames Wright PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, "new X_loc length %" PetscInt_FMT " should match old X_loc length %" PetscInt_FMT, 76558600ac3SJames Wright len_new, len_old); 76658600ac3SJames Wright PetscCall(VecDestroy(&ctx->X_loc)); 76758600ac3SJames Wright ctx->X_loc = X_loc; 76858600ac3SJames Wright PetscCall(PetscObjectReference((PetscObject)X_loc)); 76958600ac3SJames Wright } 77058600ac3SJames Wright if (Y_loc_transpose) { 77158600ac3SJames Wright PetscInt len_old, len_new; 77258600ac3SJames Wright 77358600ac3SJames Wright PetscCall(VecGetSize(ctx->Y_loc_transpose, &len_old)); 77458600ac3SJames Wright PetscCall(VecGetSize(Y_loc_transpose, &len_new)); 77558600ac3SJames Wright PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, 77658600ac3SJames Wright "new Y_loc_transpose length %" PetscInt_FMT " should match old Y_loc_transpose length %" PetscInt_FMT, len_new, len_old); 77758600ac3SJames Wright PetscCall(VecDestroy(&ctx->Y_loc_transpose)); 77858600ac3SJames Wright ctx->Y_loc_transpose = Y_loc_transpose; 77958600ac3SJames Wright PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose)); 78058600ac3SJames Wright } 78158600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 78258600ac3SJames Wright } 78358600ac3SJames Wright 78458600ac3SJames Wright /** 78558600ac3SJames Wright @brief Get input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 78658600ac3SJames Wright 78758600ac3SJames Wright Not collective across MPI processes. 78858600ac3SJames Wright 78958600ac3SJames Wright @param[in,out] mat `MATCEED` 79058600ac3SJames Wright @param[out] X_loc Input PETSc local vector, or NULL 79158600ac3SJames Wright @param[out] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 79258600ac3SJames Wright 79358600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 79458600ac3SJames Wright **/ 79558600ac3SJames Wright PetscErrorCode MatCeedGetLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) { 79658600ac3SJames Wright MatCeedContext ctx; 79758600ac3SJames Wright 79858600ac3SJames Wright PetscFunctionBeginUser; 79958600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 80058600ac3SJames Wright if (X_loc) { 80158600ac3SJames Wright *X_loc = ctx->X_loc; 80258600ac3SJames Wright PetscCall(PetscObjectReference((PetscObject)*X_loc)); 80358600ac3SJames Wright } 80458600ac3SJames Wright if (Y_loc_transpose) { 80558600ac3SJames Wright *Y_loc_transpose = ctx->Y_loc_transpose; 80658600ac3SJames Wright PetscCall(PetscObjectReference((PetscObject)*Y_loc_transpose)); 80758600ac3SJames Wright } 80858600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 80958600ac3SJames Wright } 81058600ac3SJames Wright 81158600ac3SJames Wright /** 81258600ac3SJames Wright @brief Restore input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 81358600ac3SJames Wright 81458600ac3SJames Wright Not collective across MPI processes. 81558600ac3SJames Wright 81658600ac3SJames Wright @param[in,out] mat MatCeed 81758600ac3SJames Wright @param[out] X_loc Input PETSc local vector, or NULL 81858600ac3SJames Wright @param[out] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 81958600ac3SJames Wright 82058600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 82158600ac3SJames Wright **/ 82258600ac3SJames Wright PetscErrorCode MatCeedRestoreLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) { 82358600ac3SJames Wright PetscFunctionBeginUser; 82458600ac3SJames Wright if (X_loc) PetscCall(VecDestroy(X_loc)); 82558600ac3SJames Wright if (Y_loc_transpose) PetscCall(VecDestroy(Y_loc_transpose)); 82658600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 82758600ac3SJames Wright } 82858600ac3SJames Wright 82958600ac3SJames Wright /** 83058600ac3SJames Wright @brief Get libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 83158600ac3SJames Wright 83258600ac3SJames Wright Not collective across MPI processes. 83358600ac3SJames Wright 83458600ac3SJames Wright @param[in,out] mat MatCeed 83558600ac3SJames Wright @param[out] op_mult libCEED `CeedOperator` for `MatMult()`, or NULL 83658600ac3SJames Wright @param[out] op_mult_transpose libCEED `CeedOperator` for `MatMultTranspose()`, or NULL 83758600ac3SJames Wright 83858600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 83958600ac3SJames Wright **/ 84058600ac3SJames Wright PetscErrorCode MatCeedGetCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) { 84158600ac3SJames Wright MatCeedContext ctx; 84258600ac3SJames Wright 84358600ac3SJames Wright PetscFunctionBeginUser; 84458600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 84558600ac3SJames Wright if (op_mult) { 84658600ac3SJames Wright *op_mult = NULL; 84750f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult, op_mult)); 84858600ac3SJames Wright } 84958600ac3SJames Wright if (op_mult_transpose) { 85058600ac3SJames Wright *op_mult_transpose = NULL; 85150f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult_transpose, op_mult_transpose)); 85258600ac3SJames Wright } 85358600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 85458600ac3SJames Wright } 85558600ac3SJames Wright 85658600ac3SJames Wright /** 85758600ac3SJames Wright @brief Restore libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 85858600ac3SJames Wright 85958600ac3SJames Wright Not collective across MPI processes. 86058600ac3SJames Wright 86158600ac3SJames Wright @param[in,out] mat MatCeed 86258600ac3SJames Wright @param[out] op_mult libCEED `CeedOperator` for `MatMult()`, or NULL 86358600ac3SJames Wright @param[out] op_mult_transpose libCEED `CeedOperator` for `MatMultTranspose()`, or NULL 86458600ac3SJames Wright 86558600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 86658600ac3SJames Wright **/ 86758600ac3SJames Wright PetscErrorCode MatCeedRestoreCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) { 86858600ac3SJames Wright MatCeedContext ctx; 86958600ac3SJames Wright 87058600ac3SJames Wright PetscFunctionBeginUser; 87158600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 87250f50432SJames Wright if (op_mult) PetscCallCeed(ctx->ceed, CeedOperatorDestroy(op_mult)); 87350f50432SJames Wright if (op_mult_transpose) PetscCallCeed(ctx->ceed, CeedOperatorDestroy(op_mult_transpose)); 87458600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 87558600ac3SJames Wright } 87658600ac3SJames Wright 87758600ac3SJames Wright /** 87858600ac3SJames Wright @brief Set `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators. 87958600ac3SJames Wright 88058600ac3SJames Wright Not collective across MPI processes. 88158600ac3SJames Wright 88258600ac3SJames Wright @param[in,out] mat MatCeed 88358600ac3SJames Wright @param[out] log_event_mult `PetscLogEvent` for forward evaluation, or NULL 88458600ac3SJames Wright @param[out] log_event_mult_transpose `PetscLogEvent` for transpose evaluation, or NULL 88558600ac3SJames Wright 88658600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 88758600ac3SJames Wright **/ 88858600ac3SJames Wright PetscErrorCode MatCeedSetLogEvents(Mat mat, PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose) { 88958600ac3SJames Wright MatCeedContext ctx; 89058600ac3SJames Wright 89158600ac3SJames Wright PetscFunctionBeginUser; 89258600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 89358600ac3SJames Wright if (log_event_mult) ctx->log_event_mult = log_event_mult; 89458600ac3SJames Wright if (log_event_mult_transpose) ctx->log_event_mult_transpose = log_event_mult_transpose; 89558600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 89658600ac3SJames Wright } 89758600ac3SJames Wright 89858600ac3SJames Wright /** 89958600ac3SJames Wright @brief Get `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators. 90058600ac3SJames Wright 90158600ac3SJames Wright Not collective across MPI processes. 90258600ac3SJames Wright 90358600ac3SJames Wright @param[in,out] mat MatCeed 90458600ac3SJames Wright @param[out] log_event_mult `PetscLogEvent` for forward evaluation, or NULL 90558600ac3SJames Wright @param[out] log_event_mult_transpose `PetscLogEvent` for transpose evaluation, or NULL 90658600ac3SJames Wright 90758600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 90858600ac3SJames Wright **/ 90958600ac3SJames Wright PetscErrorCode MatCeedGetLogEvents(Mat mat, PetscLogEvent *log_event_mult, PetscLogEvent *log_event_mult_transpose) { 91058600ac3SJames Wright MatCeedContext ctx; 91158600ac3SJames Wright 91258600ac3SJames Wright PetscFunctionBeginUser; 91358600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 91458600ac3SJames Wright if (log_event_mult) *log_event_mult = ctx->log_event_mult; 91558600ac3SJames Wright if (log_event_mult_transpose) *log_event_mult_transpose = ctx->log_event_mult_transpose; 91658600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 91758600ac3SJames Wright } 91858600ac3SJames Wright 91958600ac3SJames Wright // ----------------------------------------------------------------------------- 92058600ac3SJames Wright // Operator context data 92158600ac3SJames Wright // ----------------------------------------------------------------------------- 92258600ac3SJames Wright 92358600ac3SJames Wright /** 92458600ac3SJames Wright @brief Setup context data for operator application. 92558600ac3SJames Wright 92658600ac3SJames Wright Collective across MPI processes. 92758600ac3SJames Wright 92858600ac3SJames Wright @param[in] dm_x Input `DM` 92958600ac3SJames Wright @param[in] dm_y Output `DM` 93058600ac3SJames Wright @param[in] X_loc Input PETSc local vector, or NULL 93158600ac3SJames Wright @param[in] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 93258600ac3SJames Wright @param[in] op_mult `CeedOperator` for forward evaluation 93358600ac3SJames Wright @param[in] op_mult_transpose `CeedOperator` for transpose evaluation 93458600ac3SJames Wright @param[in] log_event_mult `PetscLogEvent` for forward evaluation 93558600ac3SJames Wright @param[in] log_event_mult_transpose `PetscLogEvent` for transpose evaluation 93658600ac3SJames Wright @param[out] ctx Context data for operator evaluation 93758600ac3SJames Wright 93858600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 93958600ac3SJames Wright **/ 94058600ac3SJames Wright PetscErrorCode MatCeedContextCreate(DM dm_x, DM dm_y, Vec X_loc, Vec Y_loc_transpose, CeedOperator op_mult, CeedOperator op_mult_transpose, 94158600ac3SJames Wright PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose, MatCeedContext *ctx) { 94258600ac3SJames Wright CeedSize x_loc_len, y_loc_len; 94358600ac3SJames Wright 94458600ac3SJames Wright PetscFunctionBeginUser; 94558600ac3SJames Wright 94658600ac3SJames Wright // Allocate 94758600ac3SJames Wright PetscCall(PetscNew(ctx)); 94858600ac3SJames Wright (*ctx)->ref_count = 1; 94958600ac3SJames Wright 95058600ac3SJames Wright // Logging 95158600ac3SJames Wright (*ctx)->log_event_mult = log_event_mult; 95258600ac3SJames Wright (*ctx)->log_event_mult_transpose = log_event_mult_transpose; 95358600ac3SJames Wright 95458600ac3SJames Wright // PETSc objects 95558600ac3SJames Wright PetscCall(PetscObjectReference((PetscObject)dm_x)); 95658600ac3SJames Wright (*ctx)->dm_x = dm_x; 95758600ac3SJames Wright PetscCall(PetscObjectReference((PetscObject)dm_y)); 95858600ac3SJames Wright (*ctx)->dm_y = dm_y; 95958600ac3SJames Wright if (X_loc) PetscCall(PetscObjectReference((PetscObject)X_loc)); 96058600ac3SJames Wright (*ctx)->X_loc = X_loc; 96158600ac3SJames Wright if (Y_loc_transpose) PetscCall(PetscObjectReference((PetscObject)Y_loc_transpose)); 96258600ac3SJames Wright (*ctx)->Y_loc_transpose = Y_loc_transpose; 96358600ac3SJames Wright 96458600ac3SJames Wright // Memtype 96558600ac3SJames Wright { 96658600ac3SJames Wright const PetscScalar *x; 96758600ac3SJames Wright Vec X; 96858600ac3SJames Wright 96958600ac3SJames Wright PetscCall(DMGetLocalVector(dm_x, &X)); 97058600ac3SJames Wright PetscCall(VecGetArrayReadAndMemType(X, &x, &(*ctx)->mem_type)); 97158600ac3SJames Wright PetscCall(VecRestoreArrayReadAndMemType(X, &x)); 97258600ac3SJames Wright PetscCall(DMRestoreLocalVector(dm_x, &X)); 97358600ac3SJames Wright } 97458600ac3SJames Wright 97558600ac3SJames Wright // libCEED objects 97658600ac3SJames Wright PetscCheck(CeedOperatorGetCeed(op_mult, &(*ctx)->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, 97758600ac3SJames Wright "retrieving Ceed context object failed"); 97850f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedReference((*ctx)->ceed)); 97950f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedOperatorGetActiveVectorLengths(op_mult, &x_loc_len, &y_loc_len)); 98050f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult, &(*ctx)->op_mult)); 98150f50432SJames Wright if (op_mult_transpose) PetscCallCeed((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult_transpose, &(*ctx)->op_mult_transpose)); 98250f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, x_loc_len, &(*ctx)->x_loc)); 98350f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, y_loc_len, &(*ctx)->y_loc)); 98458600ac3SJames Wright 98558600ac3SJames Wright // Flop counting 98658600ac3SJames Wright { 98758600ac3SJames Wright CeedSize ceed_flops_estimate = 0; 98858600ac3SJames Wright 98950f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult, &ceed_flops_estimate)); 99058600ac3SJames Wright (*ctx)->flops_mult = ceed_flops_estimate; 99158600ac3SJames Wright if (op_mult_transpose) { 99250f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult_transpose, &ceed_flops_estimate)); 99358600ac3SJames Wright (*ctx)->flops_mult_transpose = ceed_flops_estimate; 99458600ac3SJames Wright } 99558600ac3SJames Wright } 99658600ac3SJames Wright 99758600ac3SJames Wright // Check sizes 99858600ac3SJames Wright if (x_loc_len > 0 || y_loc_len > 0) { 99958600ac3SJames Wright CeedSize ctx_x_loc_len, ctx_y_loc_len; 100058600ac3SJames Wright PetscInt X_loc_len, dm_x_loc_len, Y_loc_len, dm_y_loc_len; 100158600ac3SJames Wright Vec dm_X_loc, dm_Y_loc; 100258600ac3SJames Wright 100358600ac3SJames Wright // -- Input 100458600ac3SJames Wright PetscCall(DMGetLocalVector(dm_x, &dm_X_loc)); 100558600ac3SJames Wright PetscCall(VecGetLocalSize(dm_X_loc, &dm_x_loc_len)); 100658600ac3SJames Wright PetscCall(DMRestoreLocalVector(dm_x, &dm_X_loc)); 100758600ac3SJames Wright if (X_loc) PetscCall(VecGetLocalSize(X_loc, &X_loc_len)); 100850f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedVectorGetLength((*ctx)->x_loc, &ctx_x_loc_len)); 100958600ac3SJames Wright if (X_loc) PetscCheck(X_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "X_loc must match dm_x dimensions"); 101058600ac3SJames Wright PetscCheck(x_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_x dimensions"); 101158600ac3SJames Wright PetscCheck(x_loc_len == ctx_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "x_loc must match op dimensions"); 101258600ac3SJames Wright 101358600ac3SJames Wright // -- Output 101458600ac3SJames Wright PetscCall(DMGetLocalVector(dm_y, &dm_Y_loc)); 101558600ac3SJames Wright PetscCall(VecGetLocalSize(dm_Y_loc, &dm_y_loc_len)); 101658600ac3SJames Wright PetscCall(DMRestoreLocalVector(dm_y, &dm_Y_loc)); 101750f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedVectorGetLength((*ctx)->y_loc, &ctx_y_loc_len)); 101858600ac3SJames Wright PetscCheck(ctx_y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op must match dm_y dimensions"); 101958600ac3SJames Wright 102058600ac3SJames Wright // -- Transpose 102158600ac3SJames Wright if (Y_loc_transpose) { 102258600ac3SJames Wright PetscCall(VecGetLocalSize(Y_loc_transpose, &Y_loc_len)); 102358600ac3SJames Wright PetscCheck(Y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "Y_loc_transpose must match dm_y dimensions"); 102458600ac3SJames Wright } 102558600ac3SJames Wright } 102658600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 102758600ac3SJames Wright } 102858600ac3SJames Wright 102958600ac3SJames Wright /** 103058600ac3SJames Wright @brief Increment reference counter for `MATCEED` context. 103158600ac3SJames Wright 103258600ac3SJames Wright Not collective across MPI processes. 103358600ac3SJames Wright 103458600ac3SJames Wright @param[in,out] ctx Context data 103558600ac3SJames Wright 103658600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 103758600ac3SJames Wright **/ 103858600ac3SJames Wright PetscErrorCode MatCeedContextReference(MatCeedContext ctx) { 103958600ac3SJames Wright PetscFunctionBeginUser; 104058600ac3SJames Wright ctx->ref_count++; 104158600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 104258600ac3SJames Wright } 104358600ac3SJames Wright 104458600ac3SJames Wright /** 104558600ac3SJames Wright @brief Copy reference for `MATCEED`. 104658600ac3SJames Wright Note: If `ctx_copy` is non-null, it is assumed to be a valid pointer to a `MatCeedContext`. 104758600ac3SJames Wright 104858600ac3SJames Wright Not collective across MPI processes. 104958600ac3SJames Wright 105058600ac3SJames Wright @param[in] ctx Context data 105158600ac3SJames Wright @param[out] ctx_copy Copy of pointer to context data 105258600ac3SJames Wright 105358600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 105458600ac3SJames Wright **/ 105558600ac3SJames Wright PetscErrorCode MatCeedContextReferenceCopy(MatCeedContext ctx, MatCeedContext *ctx_copy) { 105658600ac3SJames Wright PetscFunctionBeginUser; 105758600ac3SJames Wright PetscCall(MatCeedContextReference(ctx)); 105858600ac3SJames Wright PetscCall(MatCeedContextDestroy(*ctx_copy)); 105958600ac3SJames Wright *ctx_copy = ctx; 106058600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 106158600ac3SJames Wright } 106258600ac3SJames Wright 106358600ac3SJames Wright /** 106458600ac3SJames Wright @brief Destroy context data for operator application. 106558600ac3SJames Wright 106658600ac3SJames Wright Collective across MPI processes. 106758600ac3SJames Wright 106858600ac3SJames Wright @param[in,out] ctx Context data for operator evaluation 106958600ac3SJames Wright 107058600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 107158600ac3SJames Wright **/ 107258600ac3SJames Wright PetscErrorCode MatCeedContextDestroy(MatCeedContext ctx) { 107358600ac3SJames Wright PetscFunctionBeginUser; 107458600ac3SJames Wright if (!ctx || --ctx->ref_count > 0) PetscFunctionReturn(PETSC_SUCCESS); 107558600ac3SJames Wright 107658600ac3SJames Wright // PETSc objects 107758600ac3SJames Wright PetscCall(DMDestroy(&ctx->dm_x)); 107858600ac3SJames Wright PetscCall(DMDestroy(&ctx->dm_y)); 107958600ac3SJames Wright PetscCall(VecDestroy(&ctx->X_loc)); 108058600ac3SJames Wright PetscCall(VecDestroy(&ctx->Y_loc_transpose)); 108158600ac3SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_full_internal)); 108258600ac3SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal)); 108358600ac3SJames Wright PetscCall(PetscFree(ctx->internal_mat_type)); 108458600ac3SJames Wright PetscCall(PetscFree(ctx->mats_assembled_full)); 108558600ac3SJames Wright PetscCall(PetscFree(ctx->mats_assembled_pbd)); 108658600ac3SJames Wright 108758600ac3SJames Wright // libCEED objects 108850f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->x_loc)); 108950f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->y_loc)); 109050f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_full)); 109150f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_pbd)); 109250f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult)); 109350f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult_transpose)); 109458600ac3SJames Wright PetscCheck(CeedDestroy(&ctx->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "destroying libCEED context object failed"); 109558600ac3SJames Wright 109658600ac3SJames Wright // Deallocate 109758600ac3SJames Wright ctx->is_destroyed = PETSC_TRUE; // Flag as destroyed in case someone has stale ref 109858600ac3SJames Wright PetscCall(PetscFree(ctx)); 109958600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 110058600ac3SJames Wright } 110158600ac3SJames Wright 110258600ac3SJames Wright /** 110358600ac3SJames Wright @brief Compute the diagonal of an operator via libCEED. 110458600ac3SJames Wright 110558600ac3SJames Wright Collective across MPI processes. 110658600ac3SJames Wright 110758600ac3SJames Wright @param[in] A `MATCEED` 110858600ac3SJames Wright @param[out] D Vector holding operator diagonal 110958600ac3SJames Wright 111058600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 111158600ac3SJames Wright **/ 111258600ac3SJames Wright PetscErrorCode MatGetDiagonal_Ceed(Mat A, Vec D) { 111358600ac3SJames Wright PetscMemType mem_type; 111458600ac3SJames Wright Vec D_loc; 111558600ac3SJames Wright MatCeedContext ctx; 111658600ac3SJames Wright 111758600ac3SJames Wright PetscFunctionBeginUser; 111858600ac3SJames Wright PetscCall(MatShellGetContext(A, &ctx)); 111958600ac3SJames Wright 112058600ac3SJames Wright // Place PETSc vector in libCEED vector 112158600ac3SJames Wright PetscCall(DMGetLocalVector(ctx->dm_x, &D_loc)); 1122*a7dac1d5SJames Wright PetscCall(VecPetscToCeed(D_loc, &mem_type, ctx->x_loc)); 112358600ac3SJames Wright 112458600ac3SJames Wright // Compute Diagonal 112550f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorLinearAssembleDiagonal(ctx->op_mult, ctx->x_loc, CEED_REQUEST_IMMEDIATE)); 112658600ac3SJames Wright 112758600ac3SJames Wright // Restore PETSc vector 1128*a7dac1d5SJames Wright PetscCall(VecCeedToPetsc(ctx->x_loc, mem_type, D_loc)); 112958600ac3SJames Wright 113058600ac3SJames Wright // Local-to-Global 113158600ac3SJames Wright PetscCall(VecZeroEntries(D)); 113258600ac3SJames Wright PetscCall(DMLocalToGlobal(ctx->dm_x, D_loc, ADD_VALUES, D)); 113358600ac3SJames Wright PetscCall(DMRestoreLocalVector(ctx->dm_x, &D_loc)); 113458600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 113558600ac3SJames Wright } 113658600ac3SJames Wright 113758600ac3SJames Wright /** 113858600ac3SJames Wright @brief Compute `A X = Y` for a `MATCEED`. 113958600ac3SJames Wright 114058600ac3SJames Wright Collective across MPI processes. 114158600ac3SJames Wright 114258600ac3SJames Wright @param[in] A `MATCEED` 114358600ac3SJames Wright @param[in] X Input PETSc vector 114458600ac3SJames Wright @param[out] Y Output PETSc vector 114558600ac3SJames Wright 114658600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 114758600ac3SJames Wright **/ 114858600ac3SJames Wright PetscErrorCode MatMult_Ceed(Mat A, Vec X, Vec Y) { 114958600ac3SJames Wright MatCeedContext ctx; 115058600ac3SJames Wright 115158600ac3SJames Wright PetscFunctionBeginUser; 115258600ac3SJames Wright PetscCall(MatShellGetContext(A, &ctx)); 115358600ac3SJames Wright PetscCall(PetscLogEventBegin(ctx->log_event_mult, A, X, Y, 0)); 115458600ac3SJames Wright 115558600ac3SJames Wright { 115658600ac3SJames Wright PetscMemType x_mem_type, y_mem_type; 115758600ac3SJames Wright Vec X_loc = ctx->X_loc, Y_loc; 115858600ac3SJames Wright 115958600ac3SJames Wright // Get local vectors 116058600ac3SJames Wright if (!ctx->X_loc) PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc)); 116158600ac3SJames Wright PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc)); 116258600ac3SJames Wright 116358600ac3SJames Wright // Global-to-local 116458600ac3SJames Wright PetscCall(DMGlobalToLocal(ctx->dm_x, X, INSERT_VALUES, X_loc)); 116558600ac3SJames Wright 116658600ac3SJames Wright // Setup libCEED vectors 1167*a7dac1d5SJames Wright PetscCall(VecReadPetscToCeed(X_loc, &x_mem_type, ctx->x_loc)); 116858600ac3SJames Wright PetscCall(VecZeroEntries(Y_loc)); 1169*a7dac1d5SJames Wright PetscCall(VecPetscToCeed(Y_loc, &y_mem_type, ctx->y_loc)); 117058600ac3SJames Wright 117158600ac3SJames Wright // Apply libCEED operator 117258600ac3SJames Wright PetscCall(PetscLogGpuTimeBegin()); 117350f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult, ctx->x_loc, ctx->y_loc, CEED_REQUEST_IMMEDIATE)); 117458600ac3SJames Wright PetscCall(PetscLogGpuTimeEnd()); 117558600ac3SJames Wright 117658600ac3SJames Wright // Restore PETSc vectors 1177*a7dac1d5SJames Wright PetscCall(VecReadCeedToPetsc(ctx->x_loc, x_mem_type, X_loc)); 1178*a7dac1d5SJames Wright PetscCall(VecCeedToPetsc(ctx->y_loc, y_mem_type, Y_loc)); 117958600ac3SJames Wright 118058600ac3SJames Wright // Local-to-global 118158600ac3SJames Wright PetscCall(VecZeroEntries(Y)); 118258600ac3SJames Wright PetscCall(DMLocalToGlobal(ctx->dm_y, Y_loc, ADD_VALUES, Y)); 118358600ac3SJames Wright 118458600ac3SJames Wright // Restore local vectors, as needed 118558600ac3SJames Wright if (!ctx->X_loc) PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc)); 118658600ac3SJames Wright PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc)); 118758600ac3SJames Wright } 118858600ac3SJames Wright 118958600ac3SJames Wright // Log flops 119058600ac3SJames Wright if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult)); 119158600ac3SJames Wright else PetscCall(PetscLogFlops(ctx->flops_mult)); 119258600ac3SJames Wright 119358600ac3SJames Wright PetscCall(PetscLogEventEnd(ctx->log_event_mult, A, X, Y, 0)); 119458600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 119558600ac3SJames Wright } 119658600ac3SJames Wright 119758600ac3SJames Wright /** 119858600ac3SJames Wright @brief Compute `A^T Y = X` for a `MATCEED`. 119958600ac3SJames Wright 120058600ac3SJames Wright Collective across MPI processes. 120158600ac3SJames Wright 120258600ac3SJames Wright @param[in] A `MATCEED` 120358600ac3SJames Wright @param[in] Y Input PETSc vector 120458600ac3SJames Wright @param[out] X Output PETSc vector 120558600ac3SJames Wright 120658600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 120758600ac3SJames Wright **/ 120858600ac3SJames Wright PetscErrorCode MatMultTranspose_Ceed(Mat A, Vec Y, Vec X) { 120958600ac3SJames Wright MatCeedContext ctx; 121058600ac3SJames Wright 121158600ac3SJames Wright PetscFunctionBeginUser; 121258600ac3SJames Wright PetscCall(MatShellGetContext(A, &ctx)); 121358600ac3SJames Wright PetscCall(PetscLogEventBegin(ctx->log_event_mult_transpose, A, Y, X, 0)); 121458600ac3SJames Wright 121558600ac3SJames Wright { 121658600ac3SJames Wright PetscMemType x_mem_type, y_mem_type; 121758600ac3SJames Wright Vec X_loc, Y_loc = ctx->Y_loc_transpose; 121858600ac3SJames Wright 121958600ac3SJames Wright // Get local vectors 122058600ac3SJames Wright if (!ctx->Y_loc_transpose) PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc)); 122158600ac3SJames Wright PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc)); 122258600ac3SJames Wright 122358600ac3SJames Wright // Global-to-local 122458600ac3SJames Wright PetscCall(DMGlobalToLocal(ctx->dm_y, Y, INSERT_VALUES, Y_loc)); 122558600ac3SJames Wright 122658600ac3SJames Wright // Setup libCEED vectors 1227*a7dac1d5SJames Wright PetscCall(VecReadPetscToCeed(Y_loc, &y_mem_type, ctx->y_loc)); 122858600ac3SJames Wright PetscCall(VecZeroEntries(X_loc)); 1229*a7dac1d5SJames Wright PetscCall(VecPetscToCeed(X_loc, &x_mem_type, ctx->x_loc)); 123058600ac3SJames Wright 123158600ac3SJames Wright // Apply libCEED operator 123258600ac3SJames Wright PetscCall(PetscLogGpuTimeBegin()); 123350f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult_transpose, ctx->y_loc, ctx->x_loc, CEED_REQUEST_IMMEDIATE)); 123458600ac3SJames Wright PetscCall(PetscLogGpuTimeEnd()); 123558600ac3SJames Wright 123658600ac3SJames Wright // Restore PETSc vectors 1237*a7dac1d5SJames Wright PetscCall(VecReadCeedToPetsc(ctx->y_loc, y_mem_type, Y_loc)); 1238*a7dac1d5SJames Wright PetscCall(VecCeedToPetsc(ctx->x_loc, x_mem_type, X_loc)); 123958600ac3SJames Wright 124058600ac3SJames Wright // Local-to-global 124158600ac3SJames Wright PetscCall(VecZeroEntries(X)); 124258600ac3SJames Wright PetscCall(DMLocalToGlobal(ctx->dm_x, X_loc, ADD_VALUES, X)); 124358600ac3SJames Wright 124458600ac3SJames Wright // Restore local vectors, as needed 124558600ac3SJames Wright if (!ctx->Y_loc_transpose) PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc)); 124658600ac3SJames Wright PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc)); 124758600ac3SJames Wright } 124858600ac3SJames Wright 124958600ac3SJames Wright // Log flops 125058600ac3SJames Wright if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult_transpose)); 125158600ac3SJames Wright else PetscCall(PetscLogFlops(ctx->flops_mult_transpose)); 125258600ac3SJames Wright 125358600ac3SJames Wright PetscCall(PetscLogEventEnd(ctx->log_event_mult_transpose, A, Y, X, 0)); 125458600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 125558600ac3SJames Wright } 1256