158600ac3SJames Wright /// @file 2*40d80af1SJames Wright /// MatCEED implementation 358600ac3SJames Wright 458600ac3SJames Wright #include <ceed.h> 558600ac3SJames Wright #include <ceed/backend.h> 658600ac3SJames Wright #include <mat-ceed-impl.h> 758600ac3SJames Wright #include <mat-ceed.h> 8*40d80af1SJames Wright #include <petsc-ceed-utils.h> 9*40d80af1SJames Wright #include <petsc-ceed.h> 1058600ac3SJames Wright #include <petscdmplex.h> 1158600ac3SJames Wright #include <stdlib.h> 1258600ac3SJames Wright #include <string.h> 1358600ac3SJames Wright 1458600ac3SJames Wright PetscClassId MATCEED_CLASSID; 1558600ac3SJames Wright PetscLogEvent MATCEED_MULT, MATCEED_MULT_TRANSPOSE; 1658600ac3SJames Wright 1758600ac3SJames Wright /** 1858600ac3SJames Wright @brief Register MATCEED log events. 1958600ac3SJames Wright 2058600ac3SJames Wright Not collective across MPI processes. 2158600ac3SJames Wright 2258600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 2358600ac3SJames Wright **/ 2458600ac3SJames Wright static PetscErrorCode MatCeedRegisterLogEvents() { 25*40d80af1SJames Wright static PetscBool registered = PETSC_FALSE; 2658600ac3SJames Wright 2758600ac3SJames Wright PetscFunctionBeginUser; 2858600ac3SJames Wright if (registered) PetscFunctionReturn(PETSC_SUCCESS); 2958600ac3SJames Wright PetscCall(PetscClassIdRegister("MATCEED", &MATCEED_CLASSID)); 3058600ac3SJames Wright PetscCall(PetscLogEventRegister("MATCEED Mult", MATCEED_CLASSID, &MATCEED_MULT)); 3158600ac3SJames Wright PetscCall(PetscLogEventRegister("MATCEED Mult Transpose", MATCEED_CLASSID, &MATCEED_MULT_TRANSPOSE)); 32*40d80af1SJames Wright registered = PETSC_TRUE; 3358600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 3458600ac3SJames Wright } 3558600ac3SJames Wright 3658600ac3SJames Wright /** 3758600ac3SJames Wright @brief Assemble the point block diagonal of a `MATCEED` into a `MATAIJ` or similar. 3858600ac3SJames Wright The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`. 3958600ac3SJames Wright The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail. 4058600ac3SJames Wright 4158600ac3SJames Wright Collective across MPI processes. 4258600ac3SJames Wright 4358600ac3SJames Wright @param[in] mat_ceed `MATCEED` to assemble 4458600ac3SJames Wright @param[in,out] mat_coo `MATAIJ` or similar to assemble into 4558600ac3SJames Wright 4658600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 4758600ac3SJames Wright **/ 4858600ac3SJames Wright static PetscErrorCode MatCeedAssemblePointBlockDiagonalCOO(Mat mat_ceed, Mat mat_coo) { 4958600ac3SJames Wright MatCeedContext ctx; 5058600ac3SJames Wright 5158600ac3SJames Wright PetscFunctionBeginUser; 5258600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 5358600ac3SJames Wright 5458600ac3SJames Wright // Check if COO pattern set 5558600ac3SJames Wright { 5658600ac3SJames Wright PetscInt index = -1; 5758600ac3SJames Wright 5858600ac3SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) { 5958600ac3SJames Wright if (ctx->mats_assembled_pbd[i] == mat_coo) index = i; 6058600ac3SJames Wright } 6158600ac3SJames Wright if (index == -1) { 6258600ac3SJames Wright PetscInt *rows_petsc = NULL, *cols_petsc = NULL; 6358600ac3SJames Wright CeedInt *rows_ceed, *cols_ceed; 6458600ac3SJames Wright PetscCount num_entries; 6558600ac3SJames Wright PetscLogStage stage_amg_setup; 6658600ac3SJames Wright 6758600ac3SJames Wright // -- Assemble sparsity pattern if mat hasn't been assembled before 6858600ac3SJames Wright PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup)); 6958600ac3SJames Wright if (stage_amg_setup == -1) { 7058600ac3SJames Wright PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup)); 7158600ac3SJames Wright } 7258600ac3SJames Wright PetscCall(PetscLogStagePush(stage_amg_setup)); 7350f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonalSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed)); 74a7dac1d5SJames Wright PetscCall(IntArrayCeedToPetsc(num_entries, &rows_ceed, &rows_petsc)); 75a7dac1d5SJames Wright PetscCall(IntArrayCeedToPetsc(num_entries, &cols_ceed, &cols_petsc)); 7658600ac3SJames Wright PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc)); 7758600ac3SJames Wright free(rows_petsc); 7858600ac3SJames Wright free(cols_petsc); 7950f50432SJames Wright if (!ctx->coo_values_pbd) PetscCallCeed(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_pbd)); 8058600ac3SJames Wright PetscCall(PetscRealloc(++ctx->num_mats_assembled_pbd * sizeof(Mat), &ctx->mats_assembled_pbd)); 8158600ac3SJames Wright ctx->mats_assembled_pbd[ctx->num_mats_assembled_pbd - 1] = mat_coo; 8258600ac3SJames Wright PetscCall(PetscLogStagePop()); 8358600ac3SJames Wright } 8458600ac3SJames Wright } 8558600ac3SJames Wright 8658600ac3SJames Wright // Assemble mat_ceed 8758600ac3SJames Wright PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY)); 8858600ac3SJames Wright { 8958600ac3SJames Wright const CeedScalar *values; 9058600ac3SJames Wright MatType mat_type; 9158600ac3SJames Wright CeedMemType mem_type = CEED_MEM_HOST; 9258600ac3SJames Wright PetscBool is_spd, is_spd_known; 9358600ac3SJames Wright 9458600ac3SJames Wright PetscCall(MatGetType(mat_coo, &mat_type)); 9558600ac3SJames Wright if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE; 9658600ac3SJames Wright else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE; 9758600ac3SJames Wright else mem_type = CEED_MEM_HOST; 9858600ac3SJames Wright 9950f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemblePointBlockDiagonal(ctx->op_mult, ctx->coo_values_pbd, CEED_REQUEST_IMMEDIATE)); 10050f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_pbd, mem_type, &values)); 10158600ac3SJames Wright PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES)); 10258600ac3SJames Wright PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd)); 10358600ac3SJames Wright if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd)); 10450f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_pbd, &values)); 10558600ac3SJames Wright } 10658600ac3SJames Wright PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY)); 10758600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 10858600ac3SJames Wright } 10958600ac3SJames Wright 11058600ac3SJames Wright /** 11158600ac3SJames Wright @brief Assemble inner `Mat` for diagonal `PC` operations 11258600ac3SJames Wright 11358600ac3SJames Wright Collective across MPI processes. 11458600ac3SJames Wright 11558600ac3SJames Wright @param[in] mat_ceed `MATCEED` to invert 11658600ac3SJames Wright @param[in] use_ceed_pbd Boolean flag to use libCEED PBD assembly 11758600ac3SJames Wright @param[out] mat_inner Inner `Mat` for diagonal operations 11858600ac3SJames Wright 11958600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 12058600ac3SJames Wright **/ 12158600ac3SJames Wright static PetscErrorCode MatCeedAssembleInnerBlockDiagonalMat(Mat mat_ceed, PetscBool use_ceed_pbd, Mat *mat_inner) { 12258600ac3SJames Wright MatCeedContext ctx; 12358600ac3SJames Wright 12458600ac3SJames Wright PetscFunctionBeginUser; 12558600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 12658600ac3SJames Wright if (use_ceed_pbd) { 12758600ac3SJames Wright // Check if COO pattern set 128*40d80af1SJames Wright if (!ctx->mat_assembled_pbd_internal) PetscCall(MatCeedCreateMatCOO(mat_ceed, &ctx->mat_assembled_pbd_internal)); 12958600ac3SJames Wright 13058600ac3SJames Wright // Assemble mat_assembled_full_internal 13158600ac3SJames Wright PetscCall(MatCeedAssemblePointBlockDiagonalCOO(mat_ceed, ctx->mat_assembled_pbd_internal)); 13258600ac3SJames Wright if (mat_inner) *mat_inner = ctx->mat_assembled_pbd_internal; 13358600ac3SJames Wright } else { 13458600ac3SJames Wright // Check if COO pattern set 135*40d80af1SJames Wright if (!ctx->mat_assembled_full_internal) PetscCall(MatCeedCreateMatCOO(mat_ceed, &ctx->mat_assembled_full_internal)); 13658600ac3SJames Wright 13758600ac3SJames Wright // Assemble mat_assembled_full_internal 13858600ac3SJames Wright PetscCall(MatCeedAssembleCOO(mat_ceed, ctx->mat_assembled_full_internal)); 13958600ac3SJames Wright if (mat_inner) *mat_inner = ctx->mat_assembled_full_internal; 14058600ac3SJames Wright } 14158600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 14258600ac3SJames Wright } 14358600ac3SJames Wright 14458600ac3SJames Wright /** 14558600ac3SJames Wright @brief Get `MATCEED` diagonal block for Jacobi. 14658600ac3SJames Wright 14758600ac3SJames Wright Collective across MPI processes. 14858600ac3SJames Wright 14958600ac3SJames Wright @param[in] mat_ceed `MATCEED` to invert 15058600ac3SJames Wright @param[out] mat_block The diagonal block matrix 15158600ac3SJames Wright 15258600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 15358600ac3SJames Wright **/ 15458600ac3SJames Wright static PetscErrorCode MatGetDiagonalBlock_Ceed(Mat mat_ceed, Mat *mat_block) { 15558600ac3SJames Wright Mat mat_inner = NULL; 15658600ac3SJames Wright MatCeedContext ctx; 15758600ac3SJames Wright 15858600ac3SJames Wright PetscFunctionBeginUser; 15958600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 16058600ac3SJames Wright 16158600ac3SJames Wright // Assemble inner mat if needed 16258600ac3SJames Wright PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner)); 16358600ac3SJames Wright 16458600ac3SJames Wright // Get block diagonal 16558600ac3SJames Wright PetscCall(MatGetDiagonalBlock(mat_inner, mat_block)); 16658600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 16758600ac3SJames Wright } 16858600ac3SJames Wright 16958600ac3SJames Wright /** 17058600ac3SJames Wright @brief Invert `MATCEED` diagonal block for Jacobi. 17158600ac3SJames Wright 17258600ac3SJames Wright Collective across MPI processes. 17358600ac3SJames Wright 17458600ac3SJames Wright @param[in] mat_ceed `MATCEED` to invert 17558600ac3SJames Wright @param[out] values The block inverses in column major order 17658600ac3SJames Wright 17758600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 17858600ac3SJames Wright **/ 17958600ac3SJames Wright static PetscErrorCode MatInvertBlockDiagonal_Ceed(Mat mat_ceed, const PetscScalar **values) { 18058600ac3SJames Wright Mat mat_inner = NULL; 18158600ac3SJames Wright MatCeedContext ctx; 18258600ac3SJames Wright 18358600ac3SJames Wright PetscFunctionBeginUser; 18458600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 18558600ac3SJames Wright 18658600ac3SJames Wright // Assemble inner mat if needed 18758600ac3SJames Wright PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_pbd_valid, &mat_inner)); 18858600ac3SJames Wright 18958600ac3SJames Wright // Invert PB diagonal 19058600ac3SJames Wright PetscCall(MatInvertBlockDiagonal(mat_inner, values)); 19158600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 19258600ac3SJames Wright } 19358600ac3SJames Wright 19458600ac3SJames Wright /** 19558600ac3SJames Wright @brief Invert `MATCEED` variable diagonal block for Jacobi. 19658600ac3SJames Wright 19758600ac3SJames Wright Collective across MPI processes. 19858600ac3SJames Wright 19958600ac3SJames Wright @param[in] mat_ceed `MATCEED` to invert 20058600ac3SJames Wright @param[in] num_blocks The number of blocks on the process 20158600ac3SJames Wright @param[in] block_sizes The size of each block on the process 20258600ac3SJames Wright @param[out] values The block inverses in column major order 20358600ac3SJames Wright 20458600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 20558600ac3SJames Wright **/ 20658600ac3SJames Wright static PetscErrorCode MatInvertVariableBlockDiagonal_Ceed(Mat mat_ceed, PetscInt num_blocks, const PetscInt *block_sizes, PetscScalar *values) { 20758600ac3SJames Wright Mat mat_inner = NULL; 20858600ac3SJames Wright MatCeedContext ctx; 20958600ac3SJames Wright 21058600ac3SJames Wright PetscFunctionBeginUser; 21158600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 21258600ac3SJames Wright 21358600ac3SJames Wright // Assemble inner mat if needed 21458600ac3SJames Wright PetscCall(MatCeedAssembleInnerBlockDiagonalMat(mat_ceed, ctx->is_ceed_vpbd_valid, &mat_inner)); 21558600ac3SJames Wright 21658600ac3SJames Wright // Invert PB diagonal 21758600ac3SJames Wright PetscCall(MatInvertVariableBlockDiagonal(mat_inner, num_blocks, block_sizes, values)); 21858600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 21958600ac3SJames Wright } 22058600ac3SJames Wright 221e90c2ceeSJames Wright /** 222e90c2ceeSJames Wright @brief View `MATCEED`. 223e90c2ceeSJames Wright 224e90c2ceeSJames Wright Collective across MPI processes. 225e90c2ceeSJames Wright 226e90c2ceeSJames Wright @param[in] mat_ceed `MATCEED` to view 227e90c2ceeSJames Wright @param[in] viewer The visualization context 228e90c2ceeSJames Wright 229e90c2ceeSJames Wright @return An error code: 0 - success, otherwise - failure 230e90c2ceeSJames Wright **/ 231e90c2ceeSJames Wright static PetscErrorCode MatView_Ceed(Mat mat_ceed, PetscViewer viewer) { 232e90c2ceeSJames Wright PetscBool is_ascii; 233e90c2ceeSJames Wright PetscViewerFormat format; 234e90c2ceeSJames Wright PetscMPIInt size; 235e90c2ceeSJames Wright MatCeedContext ctx; 236e90c2ceeSJames Wright 237e90c2ceeSJames Wright PetscFunctionBeginUser; 238e90c2ceeSJames Wright PetscValidHeaderSpecific(viewer, PETSC_VIEWER_CLASSID, 2); 239e90c2ceeSJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 240e90c2ceeSJames Wright if (!viewer) PetscCall(PetscViewerASCIIGetStdout(PetscObjectComm((PetscObject)mat_ceed), &viewer)); 241e90c2ceeSJames Wright 242e90c2ceeSJames Wright PetscCall(PetscViewerGetFormat(viewer, &format)); 243e90c2ceeSJames Wright PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat_ceed), &size)); 244e90c2ceeSJames Wright if (size == 1 && format == PETSC_VIEWER_LOAD_BALANCE) PetscFunctionReturn(PETSC_SUCCESS); 245e90c2ceeSJames Wright 246e90c2ceeSJames Wright PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &is_ascii)); 247e90c2ceeSJames Wright { 248e90c2ceeSJames Wright FILE *file; 249e90c2ceeSJames Wright 250*40d80af1SJames Wright PetscCall(PetscViewerASCIIPrintf(viewer, "MatCEED:\n Default COO MatType:%s\n", ctx->coo_mat_type)); 251e90c2ceeSJames Wright PetscCall(PetscViewerASCIIGetPointer(viewer, &file)); 252e90c2ceeSJames Wright PetscCall(PetscViewerASCIIPrintf(viewer, " libCEED Operator:\n")); 253e90c2ceeSJames Wright PetscCallCeed(ctx->ceed, CeedOperatorView(ctx->op_mult, file)); 254e90c2ceeSJames Wright if (ctx->op_mult_transpose) { 255e90c2ceeSJames Wright PetscCall(PetscViewerASCIIPrintf(viewer, " libCEED Transpose Operator:\n")); 256e90c2ceeSJames Wright PetscCallCeed(ctx->ceed, CeedOperatorView(ctx->op_mult_transpose, file)); 257e90c2ceeSJames Wright } 258e90c2ceeSJames Wright } 259e90c2ceeSJames Wright PetscFunctionReturn(PETSC_SUCCESS); 260e90c2ceeSJames Wright } 261e90c2ceeSJames Wright 26258600ac3SJames Wright // ----------------------------------------------------------------------------- 26358600ac3SJames Wright // MatCeed 26458600ac3SJames Wright // ----------------------------------------------------------------------------- 26558600ac3SJames Wright 26658600ac3SJames Wright /** 26758600ac3SJames Wright @brief Create PETSc `Mat` from libCEED operators. 26858600ac3SJames Wright 26958600ac3SJames Wright Collective across MPI processes. 27058600ac3SJames Wright 27158600ac3SJames Wright @param[in] dm_x Input `DM` 27258600ac3SJames Wright @param[in] dm_y Output `DM` 27358600ac3SJames Wright @param[in] op_mult `CeedOperator` for forward evaluation 27458600ac3SJames Wright @param[in] op_mult_transpose `CeedOperator` for transpose evaluation 27558600ac3SJames Wright @param[out] mat New MatCeed 27658600ac3SJames Wright 27758600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 27858600ac3SJames Wright **/ 27958600ac3SJames Wright PetscErrorCode MatCeedCreate(DM dm_x, DM dm_y, CeedOperator op_mult, CeedOperator op_mult_transpose, Mat *mat) { 28058600ac3SJames Wright PetscInt X_l_size, X_g_size, Y_l_size, Y_g_size; 28158600ac3SJames Wright VecType vec_type; 28258600ac3SJames Wright MatCeedContext ctx; 28358600ac3SJames Wright 28458600ac3SJames Wright PetscFunctionBeginUser; 28558600ac3SJames Wright PetscCall(MatCeedRegisterLogEvents()); 28658600ac3SJames Wright 28758600ac3SJames Wright // Collect context data 28858600ac3SJames Wright PetscCall(DMGetVecType(dm_x, &vec_type)); 28958600ac3SJames Wright { 29058600ac3SJames Wright Vec X; 29158600ac3SJames Wright 29258600ac3SJames Wright PetscCall(DMGetGlobalVector(dm_x, &X)); 29358600ac3SJames Wright PetscCall(VecGetSize(X, &X_g_size)); 29458600ac3SJames Wright PetscCall(VecGetLocalSize(X, &X_l_size)); 29558600ac3SJames Wright PetscCall(DMRestoreGlobalVector(dm_x, &X)); 29658600ac3SJames Wright } 29758600ac3SJames Wright if (dm_y) { 29858600ac3SJames Wright Vec Y; 29958600ac3SJames Wright 30058600ac3SJames Wright PetscCall(DMGetGlobalVector(dm_y, &Y)); 30158600ac3SJames Wright PetscCall(VecGetSize(Y, &Y_g_size)); 30258600ac3SJames Wright PetscCall(VecGetLocalSize(Y, &Y_l_size)); 30358600ac3SJames Wright PetscCall(DMRestoreGlobalVector(dm_y, &Y)); 30458600ac3SJames Wright } else { 30558600ac3SJames Wright dm_y = dm_x; 30658600ac3SJames Wright Y_g_size = X_g_size; 30758600ac3SJames Wright Y_l_size = X_l_size; 30858600ac3SJames Wright } 309*40d80af1SJames Wright 31058600ac3SJames Wright // Create context 31158600ac3SJames Wright { 31258600ac3SJames Wright Vec X_loc, Y_loc_transpose = NULL; 31358600ac3SJames Wright 31458600ac3SJames Wright PetscCall(DMCreateLocalVector(dm_x, &X_loc)); 31558600ac3SJames Wright PetscCall(VecZeroEntries(X_loc)); 31658600ac3SJames Wright if (op_mult_transpose) { 31758600ac3SJames Wright PetscCall(DMCreateLocalVector(dm_y, &Y_loc_transpose)); 31858600ac3SJames Wright PetscCall(VecZeroEntries(Y_loc_transpose)); 31958600ac3SJames Wright } 32058600ac3SJames Wright PetscCall(MatCeedContextCreate(dm_x, dm_y, X_loc, Y_loc_transpose, op_mult, op_mult_transpose, MATCEED_MULT, MATCEED_MULT_TRANSPOSE, &ctx)); 32158600ac3SJames Wright PetscCall(VecDestroy(&X_loc)); 32258600ac3SJames Wright PetscCall(VecDestroy(&Y_loc_transpose)); 32358600ac3SJames Wright } 32458600ac3SJames Wright 32558600ac3SJames Wright // Create mat 32658600ac3SJames Wright PetscCall(MatCreateShell(PetscObjectComm((PetscObject)dm_x), Y_l_size, X_l_size, Y_g_size, X_g_size, ctx, mat)); 32758600ac3SJames Wright PetscCall(PetscObjectChangeTypeName((PetscObject)*mat, MATCEED)); 32858600ac3SJames Wright // -- Set block and variable block sizes 32958600ac3SJames Wright if (dm_x == dm_y) { 33058600ac3SJames Wright MatType dm_mat_type, dm_mat_type_copy; 33158600ac3SJames Wright Mat temp_mat; 33258600ac3SJames Wright 33358600ac3SJames Wright PetscCall(DMGetMatType(dm_x, &dm_mat_type)); 33458600ac3SJames Wright PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy)); 33558600ac3SJames Wright PetscCall(DMSetMatType(dm_x, MATAIJ)); 33658600ac3SJames Wright PetscCall(DMCreateMatrix(dm_x, &temp_mat)); 33758600ac3SJames Wright PetscCall(DMSetMatType(dm_x, dm_mat_type_copy)); 33858600ac3SJames Wright PetscCall(PetscFree(dm_mat_type_copy)); 33958600ac3SJames Wright 34058600ac3SJames Wright { 34158600ac3SJames Wright PetscInt block_size, num_blocks, max_vblock_size = PETSC_INT_MAX; 34258600ac3SJames Wright const PetscInt *vblock_sizes; 34358600ac3SJames Wright 34458600ac3SJames Wright // -- Get block sizes 34558600ac3SJames Wright PetscCall(MatGetBlockSize(temp_mat, &block_size)); 34658600ac3SJames Wright PetscCall(MatGetVariableBlockSizes(temp_mat, &num_blocks, &vblock_sizes)); 34758600ac3SJames Wright { 34858600ac3SJames Wright PetscInt local_min_max[2] = {0}, global_min_max[2] = {0, PETSC_INT_MAX}; 34958600ac3SJames Wright 35058600ac3SJames Wright for (PetscInt i = 0; i < num_blocks; i++) local_min_max[1] = PetscMax(local_min_max[1], vblock_sizes[i]); 35158600ac3SJames Wright PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_min_max, global_min_max)); 35258600ac3SJames Wright max_vblock_size = global_min_max[1]; 35358600ac3SJames Wright } 35458600ac3SJames Wright 35558600ac3SJames Wright // -- Copy block sizes 35658600ac3SJames Wright if (block_size > 1) PetscCall(MatSetBlockSize(*mat, block_size)); 35758600ac3SJames Wright if (num_blocks) PetscCall(MatSetVariableBlockSizes(*mat, num_blocks, (PetscInt *)vblock_sizes)); 35858600ac3SJames Wright 35958600ac3SJames Wright // -- Check libCEED compatibility 36058600ac3SJames Wright { 36158600ac3SJames Wright bool is_composite; 36258600ac3SJames Wright 36358600ac3SJames Wright ctx->is_ceed_pbd_valid = PETSC_TRUE; 36458600ac3SJames Wright ctx->is_ceed_vpbd_valid = PETSC_TRUE; 36550f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorIsComposite(op_mult, &is_composite)); 36658600ac3SJames Wright if (is_composite) { 36758600ac3SJames Wright CeedInt num_sub_operators; 36858600ac3SJames Wright CeedOperator *sub_operators; 36958600ac3SJames Wright 37050f50432SJames Wright PetscCallCeed(ctx->ceed, CeedCompositeOperatorGetNumSub(op_mult, &num_sub_operators)); 37150f50432SJames Wright PetscCallCeed(ctx->ceed, CeedCompositeOperatorGetSubList(op_mult, &sub_operators)); 37258600ac3SJames Wright for (CeedInt i = 0; i < num_sub_operators; i++) { 37358600ac3SJames Wright CeedInt num_bases, num_comp; 37458600ac3SJames Wright CeedBasis *active_bases; 37558600ac3SJames Wright CeedOperatorAssemblyData assembly_data; 37658600ac3SJames Wright 37750f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorGetOperatorAssemblyData(sub_operators[i], &assembly_data)); 37850f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL)); 37950f50432SJames Wright PetscCallCeed(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp)); 38058600ac3SJames Wright if (num_bases > 1) { 38158600ac3SJames Wright ctx->is_ceed_pbd_valid = PETSC_FALSE; 38258600ac3SJames Wright ctx->is_ceed_vpbd_valid = PETSC_FALSE; 38358600ac3SJames Wright } 38458600ac3SJames Wright if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE; 38558600ac3SJames Wright if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE; 38658600ac3SJames Wright } 38758600ac3SJames Wright } else { 38858600ac3SJames Wright // LCOV_EXCL_START 38958600ac3SJames Wright CeedInt num_bases, num_comp; 39058600ac3SJames Wright CeedBasis *active_bases; 39158600ac3SJames Wright CeedOperatorAssemblyData assembly_data; 39258600ac3SJames Wright 39350f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorGetOperatorAssemblyData(op_mult, &assembly_data)); 39450f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorAssemblyDataGetBases(assembly_data, &num_bases, &active_bases, NULL, NULL, NULL, NULL)); 39550f50432SJames Wright PetscCallCeed(ctx->ceed, CeedBasisGetNumComponents(active_bases[0], &num_comp)); 39658600ac3SJames Wright if (num_bases > 1) { 39758600ac3SJames Wright ctx->is_ceed_pbd_valid = PETSC_FALSE; 39858600ac3SJames Wright ctx->is_ceed_vpbd_valid = PETSC_FALSE; 39958600ac3SJames Wright } 40058600ac3SJames Wright if (num_comp != block_size) ctx->is_ceed_pbd_valid = PETSC_FALSE; 40158600ac3SJames Wright if (num_comp < max_vblock_size) ctx->is_ceed_vpbd_valid = PETSC_FALSE; 40258600ac3SJames Wright // LCOV_EXCL_STOP 40358600ac3SJames Wright } 40458600ac3SJames Wright { 40558600ac3SJames Wright PetscInt local_is_valid[2], global_is_valid[2]; 40658600ac3SJames Wright 40758600ac3SJames Wright local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_pbd_valid; 40858600ac3SJames Wright PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid)); 40958600ac3SJames Wright ctx->is_ceed_pbd_valid = global_is_valid[0]; 41058600ac3SJames Wright local_is_valid[0] = local_is_valid[1] = ctx->is_ceed_vpbd_valid; 41158600ac3SJames Wright PetscCall(PetscGlobalMinMaxInt(PetscObjectComm((PetscObject)dm_x), local_is_valid, global_is_valid)); 41258600ac3SJames Wright ctx->is_ceed_vpbd_valid = global_is_valid[0]; 41358600ac3SJames Wright } 41458600ac3SJames Wright } 41558600ac3SJames Wright } 41658600ac3SJames Wright PetscCall(MatDestroy(&temp_mat)); 41758600ac3SJames Wright } 41858600ac3SJames Wright // -- Set internal mat type 41958600ac3SJames Wright { 42058600ac3SJames Wright VecType vec_type; 421*40d80af1SJames Wright MatType coo_mat_type; 42258600ac3SJames Wright 42358600ac3SJames Wright PetscCall(VecGetType(ctx->X_loc, &vec_type)); 424*40d80af1SJames Wright if (strstr(vec_type, VECCUDA)) coo_mat_type = MATAIJCUSPARSE; 425*40d80af1SJames Wright else if (strstr(vec_type, VECKOKKOS)) coo_mat_type = MATAIJKOKKOS; 426*40d80af1SJames Wright else coo_mat_type = MATAIJ; 427*40d80af1SJames Wright PetscCall(PetscStrallocpy(coo_mat_type, &ctx->coo_mat_type)); 42858600ac3SJames Wright } 42958600ac3SJames Wright // -- Set mat operations 43058600ac3SJames Wright PetscCall(MatShellSetContextDestroy(*mat, (PetscErrorCode(*)(void *))MatCeedContextDestroy)); 431e90c2ceeSJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_VIEW, (void (*)(void))MatView_Ceed)); 43258600ac3SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_MULT, (void (*)(void))MatMult_Ceed)); 43358600ac3SJames Wright if (op_mult_transpose) PetscCall(MatShellSetOperation(*mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed)); 43458600ac3SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed)); 43558600ac3SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed)); 43658600ac3SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed)); 43758600ac3SJames Wright PetscCall(MatShellSetOperation(*mat, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed)); 43858600ac3SJames Wright PetscCall(MatShellSetVecType(*mat, vec_type)); 43958600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 44058600ac3SJames Wright } 44158600ac3SJames Wright 44258600ac3SJames Wright /** 44358600ac3SJames Wright @brief Copy `MATCEED` into a compatible `Mat` with type `MatShell` or `MATCEED`. 44458600ac3SJames Wright 44558600ac3SJames Wright Collective across MPI processes. 44658600ac3SJames Wright 44758600ac3SJames Wright @param[in] mat_ceed `MATCEED` to copy from 44858600ac3SJames Wright @param[out] mat_other `MatShell` or `MATCEED` to copy into 44958600ac3SJames Wright 45058600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 45158600ac3SJames Wright **/ 45258600ac3SJames Wright PetscErrorCode MatCeedCopy(Mat mat_ceed, Mat mat_other) { 45358600ac3SJames Wright PetscFunctionBeginUser; 45458600ac3SJames Wright PetscCall(MatCeedRegisterLogEvents()); 45558600ac3SJames Wright 45658600ac3SJames Wright // Check type compatibility 45758600ac3SJames Wright { 458*40d80af1SJames Wright PetscBool is_matceed = PETSC_FALSE, is_matshell = PETSC_FALSE; 45958600ac3SJames Wright MatType mat_type_ceed, mat_type_other; 46058600ac3SJames Wright 46158600ac3SJames Wright PetscCall(MatGetType(mat_ceed, &mat_type_ceed)); 462*40d80af1SJames Wright PetscCall(PetscStrcmp(mat_type_ceed, MATCEED, &is_matceed)); 463*40d80af1SJames Wright PetscCheck(is_matceed, PETSC_COMM_SELF, PETSC_ERR_LIB, "mat_ceed must have type " MATCEED); 464*40d80af1SJames Wright PetscCall(MatGetType(mat_other, &mat_type_other)); 465*40d80af1SJames Wright PetscCall(PetscStrcmp(mat_type_other, MATCEED, &is_matceed)); 466*40d80af1SJames Wright PetscCall(PetscStrcmp(mat_type_other, MATSHELL, &is_matceed)); 467*40d80af1SJames Wright PetscCheck(is_matceed || is_matshell, PETSC_COMM_SELF, PETSC_ERR_LIB, "mat_other must have type " MATCEED " or " MATSHELL); 46858600ac3SJames Wright } 46958600ac3SJames Wright 47058600ac3SJames Wright // Check dimension compatibility 47158600ac3SJames Wright { 47258600ac3SJames Wright PetscInt X_l_ceed_size, X_g_ceed_size, Y_l_ceed_size, Y_g_ceed_size, X_l_other_size, X_g_other_size, Y_l_other_size, Y_g_other_size; 47358600ac3SJames Wright 47458600ac3SJames Wright PetscCall(MatGetSize(mat_ceed, &Y_g_ceed_size, &X_g_ceed_size)); 47558600ac3SJames Wright PetscCall(MatGetLocalSize(mat_ceed, &Y_l_ceed_size, &X_l_ceed_size)); 47658600ac3SJames Wright PetscCall(MatGetSize(mat_ceed, &Y_g_other_size, &X_g_other_size)); 47758600ac3SJames Wright PetscCall(MatGetLocalSize(mat_ceed, &Y_l_other_size, &X_l_other_size)); 47858600ac3SJames Wright PetscCheck((Y_g_ceed_size == Y_g_other_size) && (X_g_ceed_size == X_g_other_size) && (Y_l_ceed_size == Y_l_other_size) && 47958600ac3SJames Wright (X_l_ceed_size == X_l_other_size), 48058600ac3SJames Wright PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, 48158600ac3SJames Wright "mat_ceed and mat_other must have compatible sizes; found mat_ceed (Global: %" PetscInt_FMT ", %" PetscInt_FMT 48258600ac3SJames Wright "; Local: %" PetscInt_FMT ", %" PetscInt_FMT ") mat_other (Global: %" PetscInt_FMT ", %" PetscInt_FMT "; Local: %" PetscInt_FMT 48358600ac3SJames Wright ", %" PetscInt_FMT ")", 48458600ac3SJames Wright Y_g_ceed_size, X_g_ceed_size, Y_l_ceed_size, X_l_ceed_size, Y_g_other_size, X_g_other_size, Y_l_other_size, X_l_other_size); 48558600ac3SJames Wright } 48658600ac3SJames Wright 48758600ac3SJames Wright // Convert 48858600ac3SJames Wright { 48958600ac3SJames Wright VecType vec_type; 49058600ac3SJames Wright MatCeedContext ctx; 49158600ac3SJames Wright 49258600ac3SJames Wright PetscCall(PetscObjectChangeTypeName((PetscObject)mat_other, MATCEED)); 49358600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 49458600ac3SJames Wright PetscCall(MatCeedContextReference(ctx)); 49558600ac3SJames Wright PetscCall(MatShellSetContext(mat_other, ctx)); 49658600ac3SJames Wright PetscCall(MatShellSetContextDestroy(mat_other, (PetscErrorCode(*)(void *))MatCeedContextDestroy)); 497e90c2ceeSJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_VIEW, (void (*)(void))MatView_Ceed)); 49858600ac3SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_MULT, (void (*)(void))MatMult_Ceed)); 49958600ac3SJames Wright if (ctx->op_mult_transpose) PetscCall(MatShellSetOperation(mat_other, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Ceed)); 50058600ac3SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Ceed)); 50158600ac3SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_GET_DIAGONAL_BLOCK, (void (*)(void))MatGetDiagonalBlock_Ceed)); 50258600ac3SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_BLOCK_DIAGONAL, (void (*)(void))MatInvertBlockDiagonal_Ceed)); 50358600ac3SJames Wright PetscCall(MatShellSetOperation(mat_other, MATOP_INVERT_VBLOCK_DIAGONAL, (void (*)(void))MatInvertVariableBlockDiagonal_Ceed)); 50458600ac3SJames Wright { 50558600ac3SJames Wright PetscInt block_size; 50658600ac3SJames Wright 50758600ac3SJames Wright PetscCall(MatGetBlockSize(mat_ceed, &block_size)); 50858600ac3SJames Wright if (block_size > 1) PetscCall(MatSetBlockSize(mat_other, block_size)); 50958600ac3SJames Wright } 51058600ac3SJames Wright { 51158600ac3SJames Wright PetscInt num_blocks; 51258600ac3SJames Wright const PetscInt *block_sizes; 51358600ac3SJames Wright 51458600ac3SJames Wright PetscCall(MatGetVariableBlockSizes(mat_ceed, &num_blocks, &block_sizes)); 51558600ac3SJames Wright if (num_blocks) PetscCall(MatSetVariableBlockSizes(mat_other, num_blocks, (PetscInt *)block_sizes)); 51658600ac3SJames Wright } 51758600ac3SJames Wright PetscCall(DMGetVecType(ctx->dm_x, &vec_type)); 51858600ac3SJames Wright PetscCall(MatShellSetVecType(mat_other, vec_type)); 51958600ac3SJames Wright } 52058600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 52158600ac3SJames Wright } 52258600ac3SJames Wright 52358600ac3SJames Wright /** 524*40d80af1SJames Wright @brief Setup a `Mat` with the same COO pattern as a `MatCEED`. 525*40d80af1SJames Wright 526*40d80af1SJames Wright Collective across MPI processes. 527*40d80af1SJames Wright 528*40d80af1SJames Wright @param[in] mat_ceed `MATCEED` 529*40d80af1SJames Wright @param[out] mat_coo Sparse `Mat` with same COO pattern 530*40d80af1SJames Wright 531*40d80af1SJames Wright @return An error code: 0 - success, otherwise - failure 532*40d80af1SJames Wright **/ 533*40d80af1SJames Wright PetscErrorCode MatCeedCreateMatCOO(Mat mat_ceed, Mat *mat_coo) { 534*40d80af1SJames Wright MatCeedContext ctx; 535*40d80af1SJames Wright 536*40d80af1SJames Wright PetscFunctionBeginUser; 537*40d80af1SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 538*40d80af1SJames Wright 539*40d80af1SJames Wright PetscCheck(ctx->dm_x == ctx->dm_y, PetscObjectComm((PetscObject)mat_ceed), PETSC_ERR_SUP, "COO assembly only supported for MATCEED on a single DM"); 540*40d80af1SJames Wright 541*40d80af1SJames Wright // Check cl mat type 542*40d80af1SJames Wright { 543*40d80af1SJames Wright PetscBool is_coo_mat_type_cl = PETSC_FALSE; 544*40d80af1SJames Wright char coo_mat_type_cl[64]; 545*40d80af1SJames Wright 546*40d80af1SJames Wright // Check for specific CL coo mat type for this Mat 547*40d80af1SJames Wright { 548*40d80af1SJames Wright const char *mat_ceed_prefix = NULL; 549*40d80af1SJames Wright 550*40d80af1SJames Wright PetscCall(MatGetOptionsPrefix(mat_ceed, &mat_ceed_prefix)); 551*40d80af1SJames Wright PetscOptionsBegin(PetscObjectComm((PetscObject)mat_ceed), mat_ceed_prefix, "", NULL); 552*40d80af1SJames Wright PetscCall(PetscOptionsFList("-ceed_coo_mat_type", "Default MATCEED COO assembly MatType", NULL, MatList, coo_mat_type_cl, coo_mat_type_cl, 553*40d80af1SJames Wright sizeof(coo_mat_type_cl), &is_coo_mat_type_cl)); 554*40d80af1SJames Wright PetscOptionsEnd(); 555*40d80af1SJames Wright if (is_coo_mat_type_cl) { 556*40d80af1SJames Wright PetscCall(PetscFree(ctx->coo_mat_type)); 557*40d80af1SJames Wright PetscCall(PetscStrallocpy(coo_mat_type_cl, &ctx->coo_mat_type)); 558*40d80af1SJames Wright } 559*40d80af1SJames Wright } 560*40d80af1SJames Wright } 561*40d80af1SJames Wright 562*40d80af1SJames Wright // Create sparse matrix 563*40d80af1SJames Wright { 564*40d80af1SJames Wright MatType dm_mat_type, dm_mat_type_copy; 565*40d80af1SJames Wright 566*40d80af1SJames Wright PetscCall(DMGetMatType(ctx->dm_x, &dm_mat_type)); 567*40d80af1SJames Wright PetscCall(PetscStrallocpy(dm_mat_type, (char **)&dm_mat_type_copy)); 568*40d80af1SJames Wright PetscCall(DMSetMatType(ctx->dm_x, ctx->coo_mat_type)); 569*40d80af1SJames Wright PetscCall(DMCreateMatrix(ctx->dm_x, mat_coo)); 570*40d80af1SJames Wright PetscCall(DMSetMatType(ctx->dm_x, dm_mat_type_copy)); 571*40d80af1SJames Wright PetscCall(PetscFree(dm_mat_type_copy)); 572*40d80af1SJames Wright } 573*40d80af1SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 574*40d80af1SJames Wright } 575*40d80af1SJames Wright 576*40d80af1SJames Wright /** 577*40d80af1SJames Wright @brief Setup the COO preallocation `MATCEED` into a `MATAIJ` or similar. 57858600ac3SJames Wright The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail. 57958600ac3SJames Wright 58058600ac3SJames Wright Collective across MPI processes. 58158600ac3SJames Wright 58258600ac3SJames Wright @param[in] mat_ceed `MATCEED` to assemble 58358600ac3SJames Wright @param[in,out] mat_coo `MATAIJ` or similar to assemble into 58458600ac3SJames Wright 58558600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 58658600ac3SJames Wright **/ 587*40d80af1SJames Wright PetscErrorCode MatCeedSetPreallocationCOO(Mat mat_ceed, Mat mat_coo) { 58858600ac3SJames Wright MatCeedContext ctx; 58958600ac3SJames Wright 59058600ac3SJames Wright PetscFunctionBeginUser; 59158600ac3SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 59258600ac3SJames Wright 59358600ac3SJames Wright { 59458600ac3SJames Wright PetscInt *rows_petsc = NULL, *cols_petsc = NULL; 59558600ac3SJames Wright CeedInt *rows_ceed, *cols_ceed; 59658600ac3SJames Wright PetscCount num_entries; 59758600ac3SJames Wright PetscLogStage stage_amg_setup; 59858600ac3SJames Wright 59958600ac3SJames Wright // -- Assemble sparsity pattern if mat hasn't been assembled before 60058600ac3SJames Wright PetscCall(PetscLogStageGetId("MATCEED Assembly Setup", &stage_amg_setup)); 60158600ac3SJames Wright if (stage_amg_setup == -1) { 60258600ac3SJames Wright PetscCall(PetscLogStageRegister("MATCEED Assembly Setup", &stage_amg_setup)); 60358600ac3SJames Wright } 60458600ac3SJames Wright PetscCall(PetscLogStagePush(stage_amg_setup)); 60550f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorLinearAssembleSymbolic(ctx->op_mult, &num_entries, &rows_ceed, &cols_ceed)); 606a7dac1d5SJames Wright PetscCall(IntArrayCeedToPetsc(num_entries, &rows_ceed, &rows_petsc)); 607a7dac1d5SJames Wright PetscCall(IntArrayCeedToPetsc(num_entries, &cols_ceed, &cols_petsc)); 60858600ac3SJames Wright PetscCall(MatSetPreallocationCOOLocal(mat_coo, num_entries, rows_petsc, cols_petsc)); 60958600ac3SJames Wright free(rows_petsc); 61058600ac3SJames Wright free(cols_petsc); 61150f50432SJames Wright if (!ctx->coo_values_full) PetscCallCeed(ctx->ceed, CeedVectorCreate(ctx->ceed, num_entries, &ctx->coo_values_full)); 61258600ac3SJames Wright PetscCall(PetscRealloc(++ctx->num_mats_assembled_full * sizeof(Mat), &ctx->mats_assembled_full)); 61358600ac3SJames Wright ctx->mats_assembled_full[ctx->num_mats_assembled_full - 1] = mat_coo; 61458600ac3SJames Wright PetscCall(PetscLogStagePop()); 61558600ac3SJames Wright } 616*40d80af1SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 617*40d80af1SJames Wright } 618*40d80af1SJames Wright 619*40d80af1SJames Wright /** 620*40d80af1SJames Wright @brief Assemble a `MATCEED` into a `MATAIJ` or similar. 621*40d80af1SJames Wright The `mat_coo` preallocation is set to match the sparsity pattern of `mat_ceed`. 622*40d80af1SJames Wright The caller is responsible for assuring the global and local sizes are compatible, otherwise this function will fail. 623*40d80af1SJames Wright 624*40d80af1SJames Wright Collective across MPI processes. 625*40d80af1SJames Wright 626*40d80af1SJames Wright @param[in] mat_ceed `MATCEED` to assemble 627*40d80af1SJames Wright @param[in,out] mat_coo `MATAIJ` or similar to assemble into 628*40d80af1SJames Wright 629*40d80af1SJames Wright @return An error code: 0 - success, otherwise - failure 630*40d80af1SJames Wright **/ 631*40d80af1SJames Wright PetscErrorCode MatCeedAssembleCOO(Mat mat_ceed, Mat mat_coo) { 632*40d80af1SJames Wright MatCeedContext ctx; 633*40d80af1SJames Wright 634*40d80af1SJames Wright PetscFunctionBeginUser; 635*40d80af1SJames Wright PetscCall(MatShellGetContext(mat_ceed, &ctx)); 636*40d80af1SJames Wright 637*40d80af1SJames Wright // Set COO pattern if needed 638*40d80af1SJames Wright { 639*40d80af1SJames Wright CeedInt index = -1; 640*40d80af1SJames Wright 641*40d80af1SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) { 642*40d80af1SJames Wright if (ctx->mats_assembled_full[i] == mat_coo) index = i; 643*40d80af1SJames Wright } 644*40d80af1SJames Wright if (index == -1) PetscCall(MatCeedSetPreallocationCOO(mat_ceed, mat_coo)); 64558600ac3SJames Wright } 64658600ac3SJames Wright 64758600ac3SJames Wright // Assemble mat_ceed 64858600ac3SJames Wright PetscCall(MatAssemblyBegin(mat_coo, MAT_FINAL_ASSEMBLY)); 64958600ac3SJames Wright { 65058600ac3SJames Wright const CeedScalar *values; 65158600ac3SJames Wright MatType mat_type; 65258600ac3SJames Wright CeedMemType mem_type = CEED_MEM_HOST; 65358600ac3SJames Wright PetscBool is_spd, is_spd_known; 65458600ac3SJames Wright 65558600ac3SJames Wright PetscCall(MatGetType(mat_coo, &mat_type)); 65658600ac3SJames Wright if (strstr(mat_type, "cusparse")) mem_type = CEED_MEM_DEVICE; 65758600ac3SJames Wright else if (strstr(mat_type, "kokkos")) mem_type = CEED_MEM_DEVICE; 65858600ac3SJames Wright else mem_type = CEED_MEM_HOST; 65958600ac3SJames Wright 66050f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorLinearAssemble(ctx->op_mult, ctx->coo_values_full)); 66150f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorGetArrayRead(ctx->coo_values_full, mem_type, &values)); 66258600ac3SJames Wright PetscCall(MatSetValuesCOO(mat_coo, values, INSERT_VALUES)); 66358600ac3SJames Wright PetscCall(MatIsSPDKnown(mat_ceed, &is_spd_known, &is_spd)); 66458600ac3SJames Wright if (is_spd_known) PetscCall(MatSetOption(mat_coo, MAT_SPD, is_spd)); 66550f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorRestoreArrayRead(ctx->coo_values_full, &values)); 66658600ac3SJames Wright } 66758600ac3SJames Wright PetscCall(MatAssemblyEnd(mat_coo, MAT_FINAL_ASSEMBLY)); 66858600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 66958600ac3SJames Wright } 67058600ac3SJames Wright 67158600ac3SJames Wright /** 672*40d80af1SJames Wright @brief Set the current value of a context field for a `MatCEED`. 673*40d80af1SJames Wright 674*40d80af1SJames Wright Not collective across MPI processes. 675*40d80af1SJames Wright 676*40d80af1SJames Wright @param[in,out] mat `MatCEED` 677*40d80af1SJames Wright @param[in] name Name of the context field 678*40d80af1SJames Wright @param[in] value New context field value 679*40d80af1SJames Wright 680*40d80af1SJames Wright @return An error code: 0 - success, otherwise - failure 681*40d80af1SJames Wright **/ 682*40d80af1SJames Wright PetscErrorCode MatCeedSetContextDouble(Mat mat, const char *name, double value) { 683*40d80af1SJames Wright PetscBool was_updated = PETSC_FALSE; 684*40d80af1SJames Wright MatCeedContext ctx; 685*40d80af1SJames Wright 686*40d80af1SJames Wright PetscFunctionBeginUser; 687*40d80af1SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 688*40d80af1SJames Wright { 689*40d80af1SJames Wright CeedContextFieldLabel label = NULL; 690*40d80af1SJames Wright 691*40d80af1SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorGetContextFieldLabel(ctx->op_mult, name, &label)); 692*40d80af1SJames Wright if (label) { 693*40d80af1SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorSetContextDouble(ctx->op_mult, label, &value)); 694*40d80af1SJames Wright was_updated = PETSC_TRUE; 695*40d80af1SJames Wright } 696*40d80af1SJames Wright if (ctx->op_mult_transpose) { 697*40d80af1SJames Wright label = NULL; 698*40d80af1SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorGetContextFieldLabel(ctx->op_mult_transpose, name, &label)); 699*40d80af1SJames Wright if (label) { 700*40d80af1SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorSetContextDouble(ctx->op_mult_transpose, label, &value)); 701*40d80af1SJames Wright was_updated = PETSC_TRUE; 702*40d80af1SJames Wright } 703*40d80af1SJames Wright } 704*40d80af1SJames Wright } 705*40d80af1SJames Wright if (was_updated) { 706*40d80af1SJames Wright PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY)); 707*40d80af1SJames Wright PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY)); 708*40d80af1SJames Wright } 709*40d80af1SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 710*40d80af1SJames Wright } 711*40d80af1SJames Wright 712*40d80af1SJames Wright /** 713*40d80af1SJames Wright @brief Get the current value of a context field for a `MatCEED`. 714*40d80af1SJames Wright 715*40d80af1SJames Wright Not collective across MPI processes. 716*40d80af1SJames Wright 717*40d80af1SJames Wright @param[in] mat `MatCEED` 718*40d80af1SJames Wright @param[in] name Name of the context field 719*40d80af1SJames Wright @param[out] value Current context field value 720*40d80af1SJames Wright 721*40d80af1SJames Wright @return An error code: 0 - success, otherwise - failure 722*40d80af1SJames Wright **/ 723*40d80af1SJames Wright PetscErrorCode MatCeedGetContextDouble(Mat mat, const char *name, double *value) { 724*40d80af1SJames Wright MatCeedContext ctx; 725*40d80af1SJames Wright 726*40d80af1SJames Wright PetscFunctionBeginUser; 727*40d80af1SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 728*40d80af1SJames Wright { 729*40d80af1SJames Wright CeedContextFieldLabel label = NULL; 730*40d80af1SJames Wright CeedOperator op = ctx->op_mult; 731*40d80af1SJames Wright 732*40d80af1SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorGetContextFieldLabel(op, name, &label)); 733*40d80af1SJames Wright if (!label && ctx->op_mult_transpose) { 734*40d80af1SJames Wright op = ctx->op_mult_transpose; 735*40d80af1SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorGetContextFieldLabel(op, name, &label)); 736*40d80af1SJames Wright } 737*40d80af1SJames Wright if (label) { 738*40d80af1SJames Wright PetscSizeT num_values; 739*40d80af1SJames Wright const double *values_ceed; 740*40d80af1SJames Wright 741*40d80af1SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorGetContextDoubleRead(op, label, &num_values, &values_ceed)); 742*40d80af1SJames Wright *value = values_ceed[0]; 743*40d80af1SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorRestoreContextDoubleRead(op, label, &values_ceed)); 744*40d80af1SJames Wright } 745*40d80af1SJames Wright } 746*40d80af1SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 747*40d80af1SJames Wright } 748*40d80af1SJames Wright 749*40d80af1SJames Wright /** 75058600ac3SJames Wright @brief Set user context for a `MATCEED`. 75158600ac3SJames Wright 75258600ac3SJames Wright Collective across MPI processes. 75358600ac3SJames Wright 75458600ac3SJames Wright @param[in,out] mat `MATCEED` 75558600ac3SJames Wright @param[in] f The context destroy function, or NULL 75658600ac3SJames Wright @param[in] ctx User context, or NULL to unset 75758600ac3SJames Wright 75858600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 75958600ac3SJames Wright **/ 76058600ac3SJames Wright PetscErrorCode MatCeedSetContext(Mat mat, PetscErrorCode (*f)(void *), void *ctx) { 76158600ac3SJames Wright PetscContainer user_ctx = NULL; 76258600ac3SJames Wright 76358600ac3SJames Wright PetscFunctionBeginUser; 76458600ac3SJames Wright if (ctx) { 76558600ac3SJames Wright PetscCall(PetscContainerCreate(PetscObjectComm((PetscObject)mat), &user_ctx)); 76658600ac3SJames Wright PetscCall(PetscContainerSetPointer(user_ctx, ctx)); 76758600ac3SJames Wright PetscCall(PetscContainerSetUserDestroy(user_ctx, f)); 76858600ac3SJames Wright } 76958600ac3SJames Wright PetscCall(PetscObjectCompose((PetscObject)mat, "MatCeed user context", (PetscObject)user_ctx)); 77058600ac3SJames Wright PetscCall(PetscContainerDestroy(&user_ctx)); 77158600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 77258600ac3SJames Wright } 77358600ac3SJames Wright 77458600ac3SJames Wright /** 77558600ac3SJames Wright @brief Retrieve the user context for a `MATCEED`. 77658600ac3SJames Wright 77758600ac3SJames Wright Collective across MPI processes. 77858600ac3SJames Wright 77958600ac3SJames Wright @param[in,out] mat `MATCEED` 78058600ac3SJames Wright @param[in] ctx User context 78158600ac3SJames Wright 78258600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 78358600ac3SJames Wright **/ 78458600ac3SJames Wright PetscErrorCode MatCeedGetContext(Mat mat, void *ctx) { 78558600ac3SJames Wright PetscContainer user_ctx; 78658600ac3SJames Wright 78758600ac3SJames Wright PetscFunctionBeginUser; 78858600ac3SJames Wright PetscCall(PetscObjectQuery((PetscObject)mat, "MatCeed user context", (PetscObject *)&user_ctx)); 78958600ac3SJames Wright if (user_ctx) PetscCall(PetscContainerGetPointer(user_ctx, (void **)ctx)); 79058600ac3SJames Wright else *(void **)ctx = NULL; 79158600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 79258600ac3SJames Wright } 79358600ac3SJames Wright /** 794*40d80af1SJames Wright @brief Set a user defined matrix operation for a `MATCEED` matrix. 795*40d80af1SJames Wright 796*40d80af1SJames Wright Within each user-defined routine, the user should call `MatCeedGetContext()` to obtain the user-defined context that was set by 797*40d80af1SJames Wright `MatCeedSetContext()`. 79858600ac3SJames Wright 79958600ac3SJames Wright Collective across MPI processes. 80058600ac3SJames Wright 80158600ac3SJames Wright @param[in,out] mat `MATCEED` 802*40d80af1SJames Wright @param[in] op Name of the `MatOperation` 803*40d80af1SJames Wright @param[in] g Function that provides the operation 80458600ac3SJames Wright 80558600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 80658600ac3SJames Wright **/ 807*40d80af1SJames Wright PetscErrorCode MatCeedSetOperation(Mat mat, MatOperation op, void (*g)(void)) { 808*40d80af1SJames Wright PetscFunctionBeginUser; 809*40d80af1SJames Wright PetscCall(MatShellSetOperation(mat, op, g)); 810*40d80af1SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 811*40d80af1SJames Wright } 812*40d80af1SJames Wright 813*40d80af1SJames Wright /** 814*40d80af1SJames Wright @brief Sets the default COO matrix type as a string from the `MATCEED`. 815*40d80af1SJames Wright 816*40d80af1SJames Wright Collective across MPI processes. 817*40d80af1SJames Wright 818*40d80af1SJames Wright @param[in,out] mat `MATCEED` 819*40d80af1SJames Wright @param[in] type COO `MatType` to set 820*40d80af1SJames Wright 821*40d80af1SJames Wright @return An error code: 0 - success, otherwise - failure 822*40d80af1SJames Wright **/ 823*40d80af1SJames Wright PetscErrorCode MatCeedSetCOOMatType(Mat mat, MatType type) { 82458600ac3SJames Wright MatCeedContext ctx; 82558600ac3SJames Wright 82658600ac3SJames Wright PetscFunctionBeginUser; 82758600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 82858600ac3SJames Wright // Check if same 82958600ac3SJames Wright { 83058600ac3SJames Wright size_t len_old, len_new; 83158600ac3SJames Wright PetscBool is_same = PETSC_FALSE; 83258600ac3SJames Wright 833*40d80af1SJames Wright PetscCall(PetscStrlen(ctx->coo_mat_type, &len_old)); 83458600ac3SJames Wright PetscCall(PetscStrlen(type, &len_new)); 835*40d80af1SJames Wright if (len_old == len_new) PetscCall(PetscStrncmp(ctx->coo_mat_type, type, len_old, &is_same)); 83658600ac3SJames Wright if (is_same) PetscFunctionReturn(PETSC_SUCCESS); 83758600ac3SJames Wright } 83858600ac3SJames Wright // Clean up old mats in different format 83958600ac3SJames Wright // LCOV_EXCL_START 84058600ac3SJames Wright if (ctx->mat_assembled_full_internal) { 84158600ac3SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_full; i++) { 84258600ac3SJames Wright if (ctx->mats_assembled_full[i] == ctx->mat_assembled_full_internal) { 84358600ac3SJames Wright for (PetscInt j = i + 1; j < ctx->num_mats_assembled_full; j++) { 84458600ac3SJames Wright ctx->mats_assembled_full[j - 1] = ctx->mats_assembled_full[j]; 84558600ac3SJames Wright } 84658600ac3SJames Wright ctx->num_mats_assembled_full--; 84758600ac3SJames Wright // Note: we'll realloc this array again, so no need to shrink the allocation 84858600ac3SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_full_internal)); 84958600ac3SJames Wright } 85058600ac3SJames Wright } 85158600ac3SJames Wright } 85258600ac3SJames Wright if (ctx->mat_assembled_pbd_internal) { 85358600ac3SJames Wright for (PetscInt i = 0; i < ctx->num_mats_assembled_pbd; i++) { 85458600ac3SJames Wright if (ctx->mats_assembled_pbd[i] == ctx->mat_assembled_pbd_internal) { 85558600ac3SJames Wright for (PetscInt j = i + 1; j < ctx->num_mats_assembled_pbd; j++) { 85658600ac3SJames Wright ctx->mats_assembled_pbd[j - 1] = ctx->mats_assembled_pbd[j]; 85758600ac3SJames Wright } 85858600ac3SJames Wright // Note: we'll realloc this array again, so no need to shrink the allocation 85958600ac3SJames Wright ctx->num_mats_assembled_pbd--; 86058600ac3SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal)); 86158600ac3SJames Wright } 86258600ac3SJames Wright } 86358600ac3SJames Wright } 864*40d80af1SJames Wright PetscCall(PetscFree(ctx->coo_mat_type)); 865*40d80af1SJames Wright PetscCall(PetscStrallocpy(type, &ctx->coo_mat_type)); 86658600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 86758600ac3SJames Wright // LCOV_EXCL_STOP 86858600ac3SJames Wright } 86958600ac3SJames Wright 87058600ac3SJames Wright /** 871*40d80af1SJames Wright @brief Gets the default COO matrix type as a string from the `MATCEED`. 87258600ac3SJames Wright 87358600ac3SJames Wright Collective across MPI processes. 87458600ac3SJames Wright 87558600ac3SJames Wright @param[in,out] mat `MATCEED` 876*40d80af1SJames Wright @param[in] type COO `MatType` 87758600ac3SJames Wright 87858600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 87958600ac3SJames Wright **/ 880*40d80af1SJames Wright PetscErrorCode MatCeedGetCOOMatType(Mat mat, MatType *type) { 88158600ac3SJames Wright MatCeedContext ctx; 88258600ac3SJames Wright 88358600ac3SJames Wright PetscFunctionBeginUser; 88458600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 885*40d80af1SJames Wright *type = ctx->coo_mat_type; 88658600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 88758600ac3SJames Wright } 88858600ac3SJames Wright 88958600ac3SJames Wright /** 89058600ac3SJames Wright @brief Set input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 89158600ac3SJames Wright 89258600ac3SJames Wright Not collective across MPI processes. 89358600ac3SJames Wright 89458600ac3SJames Wright @param[in,out] mat `MATCEED` 89558600ac3SJames Wright @param[in] X_loc Input PETSc local vector, or NULL 89658600ac3SJames Wright @param[in] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 89758600ac3SJames Wright 89858600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 89958600ac3SJames Wright **/ 90058600ac3SJames Wright PetscErrorCode MatCeedSetLocalVectors(Mat mat, Vec X_loc, Vec Y_loc_transpose) { 90158600ac3SJames Wright MatCeedContext ctx; 90258600ac3SJames Wright 90358600ac3SJames Wright PetscFunctionBeginUser; 90458600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 90558600ac3SJames Wright if (X_loc) { 90658600ac3SJames Wright PetscInt len_old, len_new; 90758600ac3SJames Wright 90858600ac3SJames Wright PetscCall(VecGetSize(ctx->X_loc, &len_old)); 90958600ac3SJames Wright PetscCall(VecGetSize(X_loc, &len_new)); 91058600ac3SJames Wright PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, "new X_loc length %" PetscInt_FMT " should match old X_loc length %" PetscInt_FMT, 91158600ac3SJames Wright len_new, len_old); 912*40d80af1SJames Wright PetscCall(VecReferenceCopy(X_loc, &ctx->X_loc)); 91358600ac3SJames Wright } 91458600ac3SJames Wright if (Y_loc_transpose) { 91558600ac3SJames Wright PetscInt len_old, len_new; 91658600ac3SJames Wright 91758600ac3SJames Wright PetscCall(VecGetSize(ctx->Y_loc_transpose, &len_old)); 91858600ac3SJames Wright PetscCall(VecGetSize(Y_loc_transpose, &len_new)); 91958600ac3SJames Wright PetscCheck(len_old == len_new, PETSC_COMM_SELF, PETSC_ERR_LIB, 92058600ac3SJames Wright "new Y_loc_transpose length %" PetscInt_FMT " should match old Y_loc_transpose length %" PetscInt_FMT, len_new, len_old); 921*40d80af1SJames Wright PetscCall(VecReferenceCopy(Y_loc_transpose, &ctx->Y_loc_transpose)); 92258600ac3SJames Wright } 92358600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 92458600ac3SJames Wright } 92558600ac3SJames Wright 92658600ac3SJames Wright /** 92758600ac3SJames Wright @brief Get input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 92858600ac3SJames Wright 92958600ac3SJames Wright Not collective across MPI processes. 93058600ac3SJames Wright 93158600ac3SJames Wright @param[in,out] mat `MATCEED` 93258600ac3SJames Wright @param[out] X_loc Input PETSc local vector, or NULL 93358600ac3SJames Wright @param[out] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 93458600ac3SJames Wright 93558600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 93658600ac3SJames Wright **/ 93758600ac3SJames Wright PetscErrorCode MatCeedGetLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) { 93858600ac3SJames Wright MatCeedContext ctx; 93958600ac3SJames Wright 94058600ac3SJames Wright PetscFunctionBeginUser; 94158600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 94258600ac3SJames Wright if (X_loc) { 943*40d80af1SJames Wright *X_loc = NULL; 944*40d80af1SJames Wright PetscCall(VecReferenceCopy(ctx->X_loc, X_loc)); 94558600ac3SJames Wright } 94658600ac3SJames Wright if (Y_loc_transpose) { 947*40d80af1SJames Wright *Y_loc_transpose = NULL; 948*40d80af1SJames Wright PetscCall(VecReferenceCopy(ctx->Y_loc_transpose, Y_loc_transpose)); 94958600ac3SJames Wright } 95058600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 95158600ac3SJames Wright } 95258600ac3SJames Wright 95358600ac3SJames Wright /** 95458600ac3SJames Wright @brief Restore input local vectors for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 95558600ac3SJames Wright 95658600ac3SJames Wright Not collective across MPI processes. 95758600ac3SJames Wright 95858600ac3SJames Wright @param[in,out] mat MatCeed 95958600ac3SJames Wright @param[out] X_loc Input PETSc local vector, or NULL 96058600ac3SJames Wright @param[out] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 96158600ac3SJames Wright 96258600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 96358600ac3SJames Wright **/ 96458600ac3SJames Wright PetscErrorCode MatCeedRestoreLocalVectors(Mat mat, Vec *X_loc, Vec *Y_loc_transpose) { 96558600ac3SJames Wright PetscFunctionBeginUser; 96658600ac3SJames Wright if (X_loc) PetscCall(VecDestroy(X_loc)); 96758600ac3SJames Wright if (Y_loc_transpose) PetscCall(VecDestroy(Y_loc_transpose)); 96858600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 96958600ac3SJames Wright } 97058600ac3SJames Wright 97158600ac3SJames Wright /** 97258600ac3SJames Wright @brief Get libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 97358600ac3SJames Wright 97458600ac3SJames Wright Not collective across MPI processes. 97558600ac3SJames Wright 97658600ac3SJames Wright @param[in,out] mat MatCeed 97758600ac3SJames Wright @param[out] op_mult libCEED `CeedOperator` for `MatMult()`, or NULL 97858600ac3SJames Wright @param[out] op_mult_transpose libCEED `CeedOperator` for `MatMultTranspose()`, or NULL 97958600ac3SJames Wright 98058600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 98158600ac3SJames Wright **/ 98258600ac3SJames Wright PetscErrorCode MatCeedGetCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) { 98358600ac3SJames Wright MatCeedContext ctx; 98458600ac3SJames Wright 98558600ac3SJames Wright PetscFunctionBeginUser; 98658600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 98758600ac3SJames Wright if (op_mult) { 98858600ac3SJames Wright *op_mult = NULL; 98950f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult, op_mult)); 99058600ac3SJames Wright } 99158600ac3SJames Wright if (op_mult_transpose) { 99258600ac3SJames Wright *op_mult_transpose = NULL; 99350f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorReferenceCopy(ctx->op_mult_transpose, op_mult_transpose)); 99458600ac3SJames Wright } 99558600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 99658600ac3SJames Wright } 99758600ac3SJames Wright 99858600ac3SJames Wright /** 99958600ac3SJames Wright @brief Restore libCEED `CeedOperator` for `MATCEED` `MatMult()` and `MatMultTranspose()` operations. 100058600ac3SJames Wright 100158600ac3SJames Wright Not collective across MPI processes. 100258600ac3SJames Wright 100358600ac3SJames Wright @param[in,out] mat MatCeed 100458600ac3SJames Wright @param[out] op_mult libCEED `CeedOperator` for `MatMult()`, or NULL 100558600ac3SJames Wright @param[out] op_mult_transpose libCEED `CeedOperator` for `MatMultTranspose()`, or NULL 100658600ac3SJames Wright 100758600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 100858600ac3SJames Wright **/ 100958600ac3SJames Wright PetscErrorCode MatCeedRestoreCeedOperators(Mat mat, CeedOperator *op_mult, CeedOperator *op_mult_transpose) { 101058600ac3SJames Wright MatCeedContext ctx; 101158600ac3SJames Wright 101258600ac3SJames Wright PetscFunctionBeginUser; 101358600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 101450f50432SJames Wright if (op_mult) PetscCallCeed(ctx->ceed, CeedOperatorDestroy(op_mult)); 101550f50432SJames Wright if (op_mult_transpose) PetscCallCeed(ctx->ceed, CeedOperatorDestroy(op_mult_transpose)); 101658600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 101758600ac3SJames Wright } 101858600ac3SJames Wright 101958600ac3SJames Wright /** 102058600ac3SJames Wright @brief Set `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators. 102158600ac3SJames Wright 102258600ac3SJames Wright Not collective across MPI processes. 102358600ac3SJames Wright 102458600ac3SJames Wright @param[in,out] mat MatCeed 102558600ac3SJames Wright @param[out] log_event_mult `PetscLogEvent` for forward evaluation, or NULL 102658600ac3SJames Wright @param[out] log_event_mult_transpose `PetscLogEvent` for transpose evaluation, or NULL 102758600ac3SJames Wright 102858600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 102958600ac3SJames Wright **/ 103058600ac3SJames Wright PetscErrorCode MatCeedSetLogEvents(Mat mat, PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose) { 103158600ac3SJames Wright MatCeedContext ctx; 103258600ac3SJames Wright 103358600ac3SJames Wright PetscFunctionBeginUser; 103458600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 103558600ac3SJames Wright if (log_event_mult) ctx->log_event_mult = log_event_mult; 103658600ac3SJames Wright if (log_event_mult_transpose) ctx->log_event_mult_transpose = log_event_mult_transpose; 103758600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 103858600ac3SJames Wright } 103958600ac3SJames Wright 104058600ac3SJames Wright /** 104158600ac3SJames Wright @brief Get `PetscLogEvent` for `MATCEED` `MatMult()` and `MatMultTranspose()` operators. 104258600ac3SJames Wright 104358600ac3SJames Wright Not collective across MPI processes. 104458600ac3SJames Wright 104558600ac3SJames Wright @param[in,out] mat MatCeed 104658600ac3SJames Wright @param[out] log_event_mult `PetscLogEvent` for forward evaluation, or NULL 104758600ac3SJames Wright @param[out] log_event_mult_transpose `PetscLogEvent` for transpose evaluation, or NULL 104858600ac3SJames Wright 104958600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 105058600ac3SJames Wright **/ 105158600ac3SJames Wright PetscErrorCode MatCeedGetLogEvents(Mat mat, PetscLogEvent *log_event_mult, PetscLogEvent *log_event_mult_transpose) { 105258600ac3SJames Wright MatCeedContext ctx; 105358600ac3SJames Wright 105458600ac3SJames Wright PetscFunctionBeginUser; 105558600ac3SJames Wright PetscCall(MatShellGetContext(mat, &ctx)); 105658600ac3SJames Wright if (log_event_mult) *log_event_mult = ctx->log_event_mult; 105758600ac3SJames Wright if (log_event_mult_transpose) *log_event_mult_transpose = ctx->log_event_mult_transpose; 105858600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 105958600ac3SJames Wright } 106058600ac3SJames Wright 106158600ac3SJames Wright // ----------------------------------------------------------------------------- 106258600ac3SJames Wright // Operator context data 106358600ac3SJames Wright // ----------------------------------------------------------------------------- 106458600ac3SJames Wright 106558600ac3SJames Wright /** 106658600ac3SJames Wright @brief Setup context data for operator application. 106758600ac3SJames Wright 106858600ac3SJames Wright Collective across MPI processes. 106958600ac3SJames Wright 107058600ac3SJames Wright @param[in] dm_x Input `DM` 107158600ac3SJames Wright @param[in] dm_y Output `DM` 107258600ac3SJames Wright @param[in] X_loc Input PETSc local vector, or NULL 107358600ac3SJames Wright @param[in] Y_loc_transpose Input PETSc local vector for transpose operation, or NULL 107458600ac3SJames Wright @param[in] op_mult `CeedOperator` for forward evaluation 107558600ac3SJames Wright @param[in] op_mult_transpose `CeedOperator` for transpose evaluation 107658600ac3SJames Wright @param[in] log_event_mult `PetscLogEvent` for forward evaluation 107758600ac3SJames Wright @param[in] log_event_mult_transpose `PetscLogEvent` for transpose evaluation 107858600ac3SJames Wright @param[out] ctx Context data for operator evaluation 107958600ac3SJames Wright 108058600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 108158600ac3SJames Wright **/ 108258600ac3SJames Wright PetscErrorCode MatCeedContextCreate(DM dm_x, DM dm_y, Vec X_loc, Vec Y_loc_transpose, CeedOperator op_mult, CeedOperator op_mult_transpose, 108358600ac3SJames Wright PetscLogEvent log_event_mult, PetscLogEvent log_event_mult_transpose, MatCeedContext *ctx) { 108458600ac3SJames Wright CeedSize x_loc_len, y_loc_len; 108558600ac3SJames Wright 108658600ac3SJames Wright PetscFunctionBeginUser; 108758600ac3SJames Wright 108858600ac3SJames Wright // Allocate 108958600ac3SJames Wright PetscCall(PetscNew(ctx)); 109058600ac3SJames Wright (*ctx)->ref_count = 1; 109158600ac3SJames Wright 109258600ac3SJames Wright // Logging 109358600ac3SJames Wright (*ctx)->log_event_mult = log_event_mult; 109458600ac3SJames Wright (*ctx)->log_event_mult_transpose = log_event_mult_transpose; 109558600ac3SJames Wright 109658600ac3SJames Wright // PETSc objects 1097*40d80af1SJames Wright PetscCall(DMReferenceCopy(dm_x, &(*ctx)->dm_x)); 1098*40d80af1SJames Wright PetscCall(DMReferenceCopy(dm_y, &(*ctx)->dm_y)); 1099*40d80af1SJames Wright if (X_loc) PetscCall(VecReferenceCopy(X_loc, &(*ctx)->X_loc)); 1100*40d80af1SJames Wright if (Y_loc_transpose) PetscCall(VecReferenceCopy(Y_loc_transpose, &(*ctx)->Y_loc_transpose)); 110158600ac3SJames Wright 110258600ac3SJames Wright // Memtype 110358600ac3SJames Wright { 110458600ac3SJames Wright const PetscScalar *x; 110558600ac3SJames Wright Vec X; 110658600ac3SJames Wright 110758600ac3SJames Wright PetscCall(DMGetLocalVector(dm_x, &X)); 110858600ac3SJames Wright PetscCall(VecGetArrayReadAndMemType(X, &x, &(*ctx)->mem_type)); 110958600ac3SJames Wright PetscCall(VecRestoreArrayReadAndMemType(X, &x)); 111058600ac3SJames Wright PetscCall(DMRestoreLocalVector(dm_x, &X)); 111158600ac3SJames Wright } 111258600ac3SJames Wright 111358600ac3SJames Wright // libCEED objects 111458600ac3SJames Wright PetscCheck(CeedOperatorGetCeed(op_mult, &(*ctx)->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, 111558600ac3SJames Wright "retrieving Ceed context object failed"); 111650f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedReference((*ctx)->ceed)); 111750f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedOperatorGetActiveVectorLengths(op_mult, &x_loc_len, &y_loc_len)); 111850f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult, &(*ctx)->op_mult)); 111950f50432SJames Wright if (op_mult_transpose) PetscCallCeed((*ctx)->ceed, CeedOperatorReferenceCopy(op_mult_transpose, &(*ctx)->op_mult_transpose)); 112050f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, x_loc_len, &(*ctx)->x_loc)); 112150f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedVectorCreate((*ctx)->ceed, y_loc_len, &(*ctx)->y_loc)); 112258600ac3SJames Wright 112358600ac3SJames Wright // Flop counting 112458600ac3SJames Wright { 112558600ac3SJames Wright CeedSize ceed_flops_estimate = 0; 112658600ac3SJames Wright 112750f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult, &ceed_flops_estimate)); 112858600ac3SJames Wright (*ctx)->flops_mult = ceed_flops_estimate; 112958600ac3SJames Wright if (op_mult_transpose) { 113050f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedOperatorGetFlopsEstimate(op_mult_transpose, &ceed_flops_estimate)); 113158600ac3SJames Wright (*ctx)->flops_mult_transpose = ceed_flops_estimate; 113258600ac3SJames Wright } 113358600ac3SJames Wright } 113458600ac3SJames Wright 113558600ac3SJames Wright // Check sizes 113658600ac3SJames Wright if (x_loc_len > 0 || y_loc_len > 0) { 113758600ac3SJames Wright CeedSize ctx_x_loc_len, ctx_y_loc_len; 113858600ac3SJames Wright PetscInt X_loc_len, dm_x_loc_len, Y_loc_len, dm_y_loc_len; 113958600ac3SJames Wright Vec dm_X_loc, dm_Y_loc; 114058600ac3SJames Wright 114158600ac3SJames Wright // -- Input 114258600ac3SJames Wright PetscCall(DMGetLocalVector(dm_x, &dm_X_loc)); 114358600ac3SJames Wright PetscCall(VecGetLocalSize(dm_X_loc, &dm_x_loc_len)); 114458600ac3SJames Wright PetscCall(DMRestoreLocalVector(dm_x, &dm_X_loc)); 114550f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedVectorGetLength((*ctx)->x_loc, &ctx_x_loc_len)); 11464c17272bSJames Wright if (X_loc) { 11474c17272bSJames Wright PetscCall(VecGetLocalSize(X_loc, &X_loc_len)); 11484c17272bSJames Wright PetscCheck(X_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, 11494c17272bSJames Wright "X_loc (%" PetscInt_FMT ") must match dm_x (%" PetscInt_FMT ") dimensions", X_loc_len, dm_x_loc_len); 11504c17272bSJames Wright } 11514c17272bSJames Wright PetscCheck(x_loc_len == dm_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op (%" CeedSize_FMT ") must match dm_x (%" PetscInt_FMT ") dimensions", 11524c17272bSJames Wright x_loc_len, dm_x_loc_len); 11534c17272bSJames Wright PetscCheck(x_loc_len == ctx_x_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "x_loc (%" CeedSize_FMT ") must match op dimensions (%" CeedSize_FMT ")", 11544c17272bSJames Wright x_loc_len, ctx_x_loc_len); 115558600ac3SJames Wright 115658600ac3SJames Wright // -- Output 115758600ac3SJames Wright PetscCall(DMGetLocalVector(dm_y, &dm_Y_loc)); 115858600ac3SJames Wright PetscCall(VecGetLocalSize(dm_Y_loc, &dm_y_loc_len)); 115958600ac3SJames Wright PetscCall(DMRestoreLocalVector(dm_y, &dm_Y_loc)); 116050f50432SJames Wright PetscCallCeed((*ctx)->ceed, CeedVectorGetLength((*ctx)->y_loc, &ctx_y_loc_len)); 11614c17272bSJames Wright PetscCheck(ctx_y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, "op (%" CeedSize_FMT ") must match dm_y (%" PetscInt_FMT ") dimensions", 11624c17272bSJames Wright ctx_y_loc_len, dm_y_loc_len); 116358600ac3SJames Wright 116458600ac3SJames Wright // -- Transpose 116558600ac3SJames Wright if (Y_loc_transpose) { 116658600ac3SJames Wright PetscCall(VecGetLocalSize(Y_loc_transpose, &Y_loc_len)); 11674c17272bSJames Wright PetscCheck(Y_loc_len == dm_y_loc_len, PETSC_COMM_SELF, PETSC_ERR_LIB, 11684c17272bSJames Wright "Y_loc_transpose (%" PetscInt_FMT ") must match dm_y (%" PetscInt_FMT ") dimensions", Y_loc_len, dm_y_loc_len); 116958600ac3SJames Wright } 117058600ac3SJames Wright } 117158600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 117258600ac3SJames Wright } 117358600ac3SJames Wright 117458600ac3SJames Wright /** 117558600ac3SJames Wright @brief Increment reference counter for `MATCEED` context. 117658600ac3SJames Wright 117758600ac3SJames Wright Not collective across MPI processes. 117858600ac3SJames Wright 117958600ac3SJames Wright @param[in,out] ctx Context data 118058600ac3SJames Wright 118158600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 118258600ac3SJames Wright **/ 118358600ac3SJames Wright PetscErrorCode MatCeedContextReference(MatCeedContext ctx) { 118458600ac3SJames Wright PetscFunctionBeginUser; 118558600ac3SJames Wright ctx->ref_count++; 118658600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 118758600ac3SJames Wright } 118858600ac3SJames Wright 118958600ac3SJames Wright /** 119058600ac3SJames Wright @brief Copy reference for `MATCEED`. 119158600ac3SJames Wright Note: If `ctx_copy` is non-null, it is assumed to be a valid pointer to a `MatCeedContext`. 119258600ac3SJames Wright 119358600ac3SJames Wright Not collective across MPI processes. 119458600ac3SJames Wright 119558600ac3SJames Wright @param[in] ctx Context data 119658600ac3SJames Wright @param[out] ctx_copy Copy of pointer to context data 119758600ac3SJames Wright 119858600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 119958600ac3SJames Wright **/ 120058600ac3SJames Wright PetscErrorCode MatCeedContextReferenceCopy(MatCeedContext ctx, MatCeedContext *ctx_copy) { 120158600ac3SJames Wright PetscFunctionBeginUser; 120258600ac3SJames Wright PetscCall(MatCeedContextReference(ctx)); 120358600ac3SJames Wright PetscCall(MatCeedContextDestroy(*ctx_copy)); 120458600ac3SJames Wright *ctx_copy = ctx; 120558600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 120658600ac3SJames Wright } 120758600ac3SJames Wright 120858600ac3SJames Wright /** 120958600ac3SJames Wright @brief Destroy context data for operator application. 121058600ac3SJames Wright 121158600ac3SJames Wright Collective across MPI processes. 121258600ac3SJames Wright 121358600ac3SJames Wright @param[in,out] ctx Context data for operator evaluation 121458600ac3SJames Wright 121558600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 121658600ac3SJames Wright **/ 121758600ac3SJames Wright PetscErrorCode MatCeedContextDestroy(MatCeedContext ctx) { 121858600ac3SJames Wright PetscFunctionBeginUser; 121958600ac3SJames Wright if (!ctx || --ctx->ref_count > 0) PetscFunctionReturn(PETSC_SUCCESS); 122058600ac3SJames Wright 122158600ac3SJames Wright // PETSc objects 122258600ac3SJames Wright PetscCall(DMDestroy(&ctx->dm_x)); 122358600ac3SJames Wright PetscCall(DMDestroy(&ctx->dm_y)); 122458600ac3SJames Wright PetscCall(VecDestroy(&ctx->X_loc)); 122558600ac3SJames Wright PetscCall(VecDestroy(&ctx->Y_loc_transpose)); 122658600ac3SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_full_internal)); 122758600ac3SJames Wright PetscCall(MatDestroy(&ctx->mat_assembled_pbd_internal)); 1228*40d80af1SJames Wright PetscCall(PetscFree(ctx->coo_mat_type)); 122958600ac3SJames Wright PetscCall(PetscFree(ctx->mats_assembled_full)); 123058600ac3SJames Wright PetscCall(PetscFree(ctx->mats_assembled_pbd)); 123158600ac3SJames Wright 123258600ac3SJames Wright // libCEED objects 123350f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->x_loc)); 123450f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->y_loc)); 123550f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_full)); 123650f50432SJames Wright PetscCallCeed(ctx->ceed, CeedVectorDestroy(&ctx->coo_values_pbd)); 123750f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult)); 123850f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorDestroy(&ctx->op_mult_transpose)); 123958600ac3SJames Wright PetscCheck(CeedDestroy(&ctx->ceed) == CEED_ERROR_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "destroying libCEED context object failed"); 124058600ac3SJames Wright 124158600ac3SJames Wright // Deallocate 124258600ac3SJames Wright ctx->is_destroyed = PETSC_TRUE; // Flag as destroyed in case someone has stale ref 124358600ac3SJames Wright PetscCall(PetscFree(ctx)); 124458600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 124558600ac3SJames Wright } 124658600ac3SJames Wright 124758600ac3SJames Wright /** 124858600ac3SJames Wright @brief Compute the diagonal of an operator via libCEED. 124958600ac3SJames Wright 125058600ac3SJames Wright Collective across MPI processes. 125158600ac3SJames Wright 125258600ac3SJames Wright @param[in] A `MATCEED` 125358600ac3SJames Wright @param[out] D Vector holding operator diagonal 125458600ac3SJames Wright 125558600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 125658600ac3SJames Wright **/ 125758600ac3SJames Wright PetscErrorCode MatGetDiagonal_Ceed(Mat A, Vec D) { 125858600ac3SJames Wright PetscMemType mem_type; 125958600ac3SJames Wright Vec D_loc; 126058600ac3SJames Wright MatCeedContext ctx; 126158600ac3SJames Wright 126258600ac3SJames Wright PetscFunctionBeginUser; 126358600ac3SJames Wright PetscCall(MatShellGetContext(A, &ctx)); 126458600ac3SJames Wright 126558600ac3SJames Wright // Place PETSc vector in libCEED vector 126658600ac3SJames Wright PetscCall(DMGetLocalVector(ctx->dm_x, &D_loc)); 1267a7dac1d5SJames Wright PetscCall(VecPetscToCeed(D_loc, &mem_type, ctx->x_loc)); 126858600ac3SJames Wright 126958600ac3SJames Wright // Compute Diagonal 127050f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorLinearAssembleDiagonal(ctx->op_mult, ctx->x_loc, CEED_REQUEST_IMMEDIATE)); 127158600ac3SJames Wright 127258600ac3SJames Wright // Restore PETSc vector 1273a7dac1d5SJames Wright PetscCall(VecCeedToPetsc(ctx->x_loc, mem_type, D_loc)); 127458600ac3SJames Wright 127558600ac3SJames Wright // Local-to-Global 127658600ac3SJames Wright PetscCall(VecZeroEntries(D)); 127758600ac3SJames Wright PetscCall(DMLocalToGlobal(ctx->dm_x, D_loc, ADD_VALUES, D)); 127858600ac3SJames Wright PetscCall(DMRestoreLocalVector(ctx->dm_x, &D_loc)); 127958600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 128058600ac3SJames Wright } 128158600ac3SJames Wright 128258600ac3SJames Wright /** 128358600ac3SJames Wright @brief Compute `A X = Y` for a `MATCEED`. 128458600ac3SJames Wright 128558600ac3SJames Wright Collective across MPI processes. 128658600ac3SJames Wright 128758600ac3SJames Wright @param[in] A `MATCEED` 128858600ac3SJames Wright @param[in] X Input PETSc vector 128958600ac3SJames Wright @param[out] Y Output PETSc vector 129058600ac3SJames Wright 129158600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 129258600ac3SJames Wright **/ 129358600ac3SJames Wright PetscErrorCode MatMult_Ceed(Mat A, Vec X, Vec Y) { 129458600ac3SJames Wright MatCeedContext ctx; 129558600ac3SJames Wright 129658600ac3SJames Wright PetscFunctionBeginUser; 129758600ac3SJames Wright PetscCall(MatShellGetContext(A, &ctx)); 129858600ac3SJames Wright PetscCall(PetscLogEventBegin(ctx->log_event_mult, A, X, Y, 0)); 129958600ac3SJames Wright 130058600ac3SJames Wright { 130158600ac3SJames Wright PetscMemType x_mem_type, y_mem_type; 130258600ac3SJames Wright Vec X_loc = ctx->X_loc, Y_loc; 130358600ac3SJames Wright 130458600ac3SJames Wright // Get local vectors 130558600ac3SJames Wright if (!ctx->X_loc) PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc)); 130658600ac3SJames Wright PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc)); 130758600ac3SJames Wright 130858600ac3SJames Wright // Global-to-local 130958600ac3SJames Wright PetscCall(DMGlobalToLocal(ctx->dm_x, X, INSERT_VALUES, X_loc)); 131058600ac3SJames Wright 131158600ac3SJames Wright // Setup libCEED vectors 1312a7dac1d5SJames Wright PetscCall(VecReadPetscToCeed(X_loc, &x_mem_type, ctx->x_loc)); 131358600ac3SJames Wright PetscCall(VecZeroEntries(Y_loc)); 1314a7dac1d5SJames Wright PetscCall(VecPetscToCeed(Y_loc, &y_mem_type, ctx->y_loc)); 131558600ac3SJames Wright 131658600ac3SJames Wright // Apply libCEED operator 131758600ac3SJames Wright PetscCall(PetscLogGpuTimeBegin()); 131850f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult, ctx->x_loc, ctx->y_loc, CEED_REQUEST_IMMEDIATE)); 131958600ac3SJames Wright PetscCall(PetscLogGpuTimeEnd()); 132058600ac3SJames Wright 132158600ac3SJames Wright // Restore PETSc vectors 1322a7dac1d5SJames Wright PetscCall(VecReadCeedToPetsc(ctx->x_loc, x_mem_type, X_loc)); 1323a7dac1d5SJames Wright PetscCall(VecCeedToPetsc(ctx->y_loc, y_mem_type, Y_loc)); 132458600ac3SJames Wright 132558600ac3SJames Wright // Local-to-global 132658600ac3SJames Wright PetscCall(VecZeroEntries(Y)); 132758600ac3SJames Wright PetscCall(DMLocalToGlobal(ctx->dm_y, Y_loc, ADD_VALUES, Y)); 132858600ac3SJames Wright 132958600ac3SJames Wright // Restore local vectors, as needed 133058600ac3SJames Wright if (!ctx->X_loc) PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc)); 133158600ac3SJames Wright PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc)); 133258600ac3SJames Wright } 133358600ac3SJames Wright 133458600ac3SJames Wright // Log flops 133558600ac3SJames Wright if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult)); 133658600ac3SJames Wright else PetscCall(PetscLogFlops(ctx->flops_mult)); 133758600ac3SJames Wright 133858600ac3SJames Wright PetscCall(PetscLogEventEnd(ctx->log_event_mult, A, X, Y, 0)); 133958600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 134058600ac3SJames Wright } 134158600ac3SJames Wright 134258600ac3SJames Wright /** 134358600ac3SJames Wright @brief Compute `A^T Y = X` for a `MATCEED`. 134458600ac3SJames Wright 134558600ac3SJames Wright Collective across MPI processes. 134658600ac3SJames Wright 134758600ac3SJames Wright @param[in] A `MATCEED` 134858600ac3SJames Wright @param[in] Y Input PETSc vector 134958600ac3SJames Wright @param[out] X Output PETSc vector 135058600ac3SJames Wright 135158600ac3SJames Wright @return An error code: 0 - success, otherwise - failure 135258600ac3SJames Wright **/ 135358600ac3SJames Wright PetscErrorCode MatMultTranspose_Ceed(Mat A, Vec Y, Vec X) { 135458600ac3SJames Wright MatCeedContext ctx; 135558600ac3SJames Wright 135658600ac3SJames Wright PetscFunctionBeginUser; 135758600ac3SJames Wright PetscCall(MatShellGetContext(A, &ctx)); 135858600ac3SJames Wright PetscCall(PetscLogEventBegin(ctx->log_event_mult_transpose, A, Y, X, 0)); 135958600ac3SJames Wright 136058600ac3SJames Wright { 136158600ac3SJames Wright PetscMemType x_mem_type, y_mem_type; 136258600ac3SJames Wright Vec X_loc, Y_loc = ctx->Y_loc_transpose; 136358600ac3SJames Wright 136458600ac3SJames Wright // Get local vectors 136558600ac3SJames Wright if (!ctx->Y_loc_transpose) PetscCall(DMGetLocalVector(ctx->dm_y, &Y_loc)); 136658600ac3SJames Wright PetscCall(DMGetLocalVector(ctx->dm_x, &X_loc)); 136758600ac3SJames Wright 136858600ac3SJames Wright // Global-to-local 136958600ac3SJames Wright PetscCall(DMGlobalToLocal(ctx->dm_y, Y, INSERT_VALUES, Y_loc)); 137058600ac3SJames Wright 137158600ac3SJames Wright // Setup libCEED vectors 1372a7dac1d5SJames Wright PetscCall(VecReadPetscToCeed(Y_loc, &y_mem_type, ctx->y_loc)); 137358600ac3SJames Wright PetscCall(VecZeroEntries(X_loc)); 1374a7dac1d5SJames Wright PetscCall(VecPetscToCeed(X_loc, &x_mem_type, ctx->x_loc)); 137558600ac3SJames Wright 137658600ac3SJames Wright // Apply libCEED operator 137758600ac3SJames Wright PetscCall(PetscLogGpuTimeBegin()); 137850f50432SJames Wright PetscCallCeed(ctx->ceed, CeedOperatorApplyAdd(ctx->op_mult_transpose, ctx->y_loc, ctx->x_loc, CEED_REQUEST_IMMEDIATE)); 137958600ac3SJames Wright PetscCall(PetscLogGpuTimeEnd()); 138058600ac3SJames Wright 138158600ac3SJames Wright // Restore PETSc vectors 1382a7dac1d5SJames Wright PetscCall(VecReadCeedToPetsc(ctx->y_loc, y_mem_type, Y_loc)); 1383a7dac1d5SJames Wright PetscCall(VecCeedToPetsc(ctx->x_loc, x_mem_type, X_loc)); 138458600ac3SJames Wright 138558600ac3SJames Wright // Local-to-global 138658600ac3SJames Wright PetscCall(VecZeroEntries(X)); 138758600ac3SJames Wright PetscCall(DMLocalToGlobal(ctx->dm_x, X_loc, ADD_VALUES, X)); 138858600ac3SJames Wright 138958600ac3SJames Wright // Restore local vectors, as needed 139058600ac3SJames Wright if (!ctx->Y_loc_transpose) PetscCall(DMRestoreLocalVector(ctx->dm_y, &Y_loc)); 139158600ac3SJames Wright PetscCall(DMRestoreLocalVector(ctx->dm_x, &X_loc)); 139258600ac3SJames Wright } 139358600ac3SJames Wright 139458600ac3SJames Wright // Log flops 139558600ac3SJames Wright if (PetscMemTypeDevice(ctx->mem_type)) PetscCall(PetscLogGpuFlops(ctx->flops_mult_transpose)); 139658600ac3SJames Wright else PetscCall(PetscLogFlops(ctx->flops_mult_transpose)); 139758600ac3SJames Wright 139858600ac3SJames Wright PetscCall(PetscLogEventEnd(ctx->log_event_mult_transpose, A, Y, X, 0)); 139958600ac3SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 140058600ac3SJames Wright } 1401