1 #include <../src/tao/bound/impls/bqnk/bqnk.h> /*I "petsctao.h" I*/ /*I "petscmat.h" I*/ 2 #include <petscksp.h> 3 4 static PetscErrorCode TaoBQNKComputeHessian(Tao tao) { 5 TAO_BNK *bnk = (TAO_BNK *)tao->data; 6 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 7 PetscReal gnorm2, delta; 8 9 PetscFunctionBegin; 10 /* Alias the LMVM matrix into the TAO hessian */ 11 if (tao->hessian) { PetscCall(MatDestroy(&tao->hessian)); } 12 if (tao->hessian_pre) { PetscCall(MatDestroy(&tao->hessian_pre)); } 13 PetscCall(PetscObjectReference((PetscObject)bqnk->B)); 14 tao->hessian = bqnk->B; 15 PetscCall(PetscObjectReference((PetscObject)bqnk->B)); 16 tao->hessian_pre = bqnk->B; 17 /* Update the Hessian with the latest solution */ 18 if (bqnk->is_spd) { 19 gnorm2 = bnk->gnorm * bnk->gnorm; 20 if (gnorm2 == 0.0) gnorm2 = PETSC_MACHINE_EPSILON; 21 if (bnk->f == 0.0) { 22 delta = 2.0 / gnorm2; 23 } else { 24 delta = 2.0 * PetscAbsScalar(bnk->f) / gnorm2; 25 } 26 PetscCall(MatLMVMSymBroydenSetDelta(bqnk->B, delta)); 27 } 28 PetscCall(MatLMVMUpdate(tao->hessian, tao->solution, bnk->unprojected_gradient)); 29 PetscCall(MatLMVMResetShift(tao->hessian)); 30 /* Prepare the reduced sub-matrices for the inactive set */ 31 PetscCall(MatDestroy(&bnk->H_inactive)); 32 if (bnk->active_idx) { 33 PetscCall(MatCreateSubMatrixVirtual(tao->hessian, bnk->inactive_idx, bnk->inactive_idx, &bnk->H_inactive)); 34 PetscCall(PCLMVMSetIS(bqnk->pc, bnk->inactive_idx)); 35 } else { 36 PetscCall(PetscObjectReference((PetscObject)tao->hessian)); 37 bnk->H_inactive = tao->hessian; 38 PetscCall(PCLMVMClearIS(bqnk->pc)); 39 } 40 PetscCall(MatDestroy(&bnk->Hpre_inactive)); 41 PetscCall(PetscObjectReference((PetscObject)bnk->H_inactive)); 42 bnk->Hpre_inactive = bnk->H_inactive; 43 PetscFunctionReturn(0); 44 } 45 46 static PetscErrorCode TaoBQNKComputeStep(Tao tao, PetscBool shift, KSPConvergedReason *ksp_reason, PetscInt *step_type) { 47 TAO_BNK *bnk = (TAO_BNK *)tao->data; 48 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 49 50 PetscFunctionBegin; 51 PetscCall(TaoBNKComputeStep(tao, shift, ksp_reason, step_type)); 52 if (*ksp_reason < 0) { 53 /* Krylov solver failed to converge so reset the LMVM matrix */ 54 PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE)); 55 PetscCall(MatLMVMUpdate(bqnk->B, tao->solution, bnk->unprojected_gradient)); 56 } 57 PetscFunctionReturn(0); 58 } 59 60 PetscErrorCode TaoSolve_BQNK(Tao tao) { 61 TAO_BNK *bnk = (TAO_BNK *)tao->data; 62 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 63 Mat_LMVM *lmvm = (Mat_LMVM *)bqnk->B->data; 64 Mat_LMVM *J0; 65 Mat_SymBrdn *diag_ctx; 66 PetscBool flg = PETSC_FALSE; 67 68 PetscFunctionBegin; 69 if (!tao->recycle) { 70 PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE)); 71 lmvm->nresets = 0; 72 if (lmvm->J0) { 73 PetscCall(PetscObjectBaseTypeCompare((PetscObject)lmvm->J0, MATLMVM, &flg)); 74 if (flg) { 75 J0 = (Mat_LMVM *)lmvm->J0->data; 76 J0->nresets = 0; 77 } 78 } 79 flg = PETSC_FALSE; 80 PetscCall(PetscObjectTypeCompareAny((PetscObject)bqnk->B, &flg, MATLMVMSYMBROYDEN, MATLMVMSYMBADBROYDEN, MATLMVMBFGS, MATLMVMDFP, "")); 81 if (flg) { 82 diag_ctx = (Mat_SymBrdn *)lmvm->ctx; 83 J0 = (Mat_LMVM *)diag_ctx->D->data; 84 J0->nresets = 0; 85 } 86 } 87 PetscCall((*bqnk->solve)(tao)); 88 PetscFunctionReturn(0); 89 } 90 91 PetscErrorCode TaoSetUp_BQNK(Tao tao) { 92 TAO_BNK *bnk = (TAO_BNK *)tao->data; 93 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 94 PetscInt n, N; 95 PetscBool is_lmvm, is_set, is_sym; 96 97 PetscFunctionBegin; 98 PetscCall(TaoSetUp_BNK(tao)); 99 PetscCall(VecGetLocalSize(tao->solution, &n)); 100 PetscCall(VecGetSize(tao->solution, &N)); 101 PetscCall(MatSetSizes(bqnk->B, n, n, N, N)); 102 PetscCall(MatLMVMAllocate(bqnk->B, tao->solution, bnk->unprojected_gradient)); 103 PetscCall(PetscObjectBaseTypeCompare((PetscObject)bqnk->B, MATLMVM, &is_lmvm)); 104 PetscCheck(is_lmvm, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Matrix must be an LMVM-type"); 105 PetscCall(MatIsSymmetricKnown(bqnk->B, &is_set, &is_sym)); 106 PetscCheck(is_set && is_sym, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM matrix must be symmetric"); 107 PetscCall(KSPGetPC(tao->ksp, &bqnk->pc)); 108 PetscCall(PCSetType(bqnk->pc, PCLMVM)); 109 PetscCall(PCLMVMSetMatLMVM(bqnk->pc, bqnk->B)); 110 PetscFunctionReturn(0); 111 } 112 113 static PetscErrorCode TaoSetFromOptions_BQNK(Tao tao, PetscOptionItems *PetscOptionsObject) { 114 TAO_BNK *bnk = (TAO_BNK *)tao->data; 115 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 116 PetscBool is_set; 117 118 PetscFunctionBegin; 119 PetscCall(TaoSetFromOptions_BNK(tao, PetscOptionsObject)); 120 if (bnk->init_type == BNK_INIT_INTERPOLATION) bnk->init_type = BNK_INIT_DIRECTION; 121 PetscCall(MatSetOptionsPrefix(bqnk->B, ((PetscObject)tao)->prefix)); 122 PetscCall(MatAppendOptionsPrefix(bqnk->B, "tao_bqnk_")); 123 PetscCall(MatSetFromOptions(bqnk->B)); 124 PetscCall(MatIsSPDKnown(bqnk->B, &is_set, &bqnk->is_spd)); 125 if (!is_set) bqnk->is_spd = PETSC_FALSE; 126 PetscFunctionReturn(0); 127 } 128 129 static PetscErrorCode TaoView_BQNK(Tao tao, PetscViewer viewer) { 130 TAO_BNK *bnk = (TAO_BNK *)tao->data; 131 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 132 PetscBool isascii; 133 134 PetscFunctionBegin; 135 PetscCall(TaoView_BNK(tao, viewer)); 136 PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii)); 137 if (isascii) { 138 PetscCall(PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO)); 139 PetscCall(MatView(bqnk->B, viewer)); 140 PetscCall(PetscViewerPopFormat(viewer)); 141 } 142 PetscFunctionReturn(0); 143 } 144 145 static PetscErrorCode TaoDestroy_BQNK(Tao tao) { 146 TAO_BNK *bnk = (TAO_BNK *)tao->data; 147 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 148 149 PetscFunctionBegin; 150 PetscCall(MatDestroy(&bnk->Hpre_inactive)); 151 PetscCall(MatDestroy(&bnk->H_inactive)); 152 PetscCall(MatDestroy(&bqnk->B)); 153 PetscCall(PetscFree(bnk->ctx)); 154 PetscCall(TaoDestroy_BNK(tao)); 155 PetscFunctionReturn(0); 156 } 157 158 PETSC_INTERN PetscErrorCode TaoCreate_BQNK(Tao tao) { 159 TAO_BNK *bnk; 160 TAO_BQNK *bqnk; 161 162 PetscFunctionBegin; 163 PetscCall(TaoCreate_BNK(tao)); 164 tao->ops->solve = TaoSolve_BQNK; 165 tao->ops->setfromoptions = TaoSetFromOptions_BQNK; 166 tao->ops->destroy = TaoDestroy_BQNK; 167 tao->ops->view = TaoView_BQNK; 168 tao->ops->setup = TaoSetUp_BQNK; 169 170 bnk = (TAO_BNK *)tao->data; 171 bnk->computehessian = TaoBQNKComputeHessian; 172 bnk->computestep = TaoBQNKComputeStep; 173 bnk->init_type = BNK_INIT_DIRECTION; 174 175 PetscCall(PetscNewLog(tao, &bqnk)); 176 bnk->ctx = (void *)bqnk; 177 bqnk->is_spd = PETSC_TRUE; 178 179 PetscCall(MatCreate(PetscObjectComm((PetscObject)tao), &bqnk->B)); 180 PetscCall(PetscObjectIncrementTabLevel((PetscObject)bqnk->B, (PetscObject)tao, 1)); 181 PetscCall(MatSetType(bqnk->B, MATLMVMSR1)); 182 PetscFunctionReturn(0); 183 } 184 185 /*@ 186 TaoGetLMVMMatrix - Returns a pointer to the internal LMVM matrix. Valid 187 only for quasi-Newton family of methods. 188 189 Input Parameters: 190 . tao - Tao solver context 191 192 Output Parameters: 193 . B - LMVM matrix 194 195 Level: advanced 196 197 .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoSetLMVMMatrix()` 198 @*/ 199 PetscErrorCode TaoGetLMVMMatrix(Tao tao, Mat *B) { 200 TAO_BNK *bnk = (TAO_BNK *)tao->data; 201 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 202 PetscBool flg = PETSC_FALSE; 203 204 PetscFunctionBegin; 205 PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "")); 206 PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms"); 207 *B = bqnk->B; 208 PetscFunctionReturn(0); 209 } 210 211 /*@ 212 TaoSetLMVMMatrix - Sets an external LMVM matrix into the Tao solver. Valid 213 only for quasi-Newton family of methods. 214 215 QN family of methods create their own LMVM matrices and users who wish to 216 manipulate this matrix should use TaoGetLMVMMatrix() instead. 217 218 Input Parameters: 219 + tao - Tao solver context 220 - B - LMVM matrix 221 222 Level: advanced 223 224 .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoGetLMVMMatrix()` 225 @*/ 226 PetscErrorCode TaoSetLMVMMatrix(Tao tao, Mat B) { 227 TAO_BNK *bnk = (TAO_BNK *)tao->data; 228 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 229 PetscBool flg = PETSC_FALSE; 230 231 PetscFunctionBegin; 232 PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "")); 233 PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms"); 234 PetscCall(PetscObjectBaseTypeCompare((PetscObject)B, MATLMVM, &flg)); 235 PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Given matrix is not an LMVM matrix"); 236 if (bqnk->B) { PetscCall(MatDestroy(&bqnk->B)); } 237 PetscCall(PetscObjectReference((PetscObject)B)); 238 bqnk->B = B; 239 PetscFunctionReturn(0); 240 } 241