1 #include <../src/tao/bound/impls/bqnk/bqnk.h> /*I "petsctao.h" I*/ 2 #include <petscksp.h> 3 4 static PetscErrorCode TaoBQNKComputeHessian(Tao tao) 5 { 6 TAO_BNK *bnk = (TAO_BNK *)tao->data; 7 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 8 PetscReal gnorm2, delta; 9 10 PetscFunctionBegin; 11 /* Alias the LMVM matrix into the TAO hessian */ 12 if (tao->hessian) PetscCall(MatDestroy(&tao->hessian)); 13 if (tao->hessian_pre) PetscCall(MatDestroy(&tao->hessian_pre)); 14 PetscCall(PetscObjectReference((PetscObject)bqnk->B)); 15 tao->hessian = bqnk->B; 16 PetscCall(PetscObjectReference((PetscObject)bqnk->B)); 17 tao->hessian_pre = bqnk->B; 18 /* Update the Hessian with the latest solution */ 19 if (bqnk->is_spd) { 20 gnorm2 = bnk->gnorm * bnk->gnorm; 21 if (gnorm2 == 0.0) gnorm2 = PETSC_MACHINE_EPSILON; 22 if (bnk->f == 0.0) { 23 delta = 2.0 / gnorm2; 24 } else { 25 delta = 2.0 * PetscAbsScalar(bnk->f) / gnorm2; 26 } 27 PetscCall(MatLMVMSymBroydenSetDelta(bqnk->B, delta)); 28 } 29 PetscCall(MatLMVMUpdate(tao->hessian, tao->solution, bnk->unprojected_gradient)); 30 PetscCall(MatLMVMResetShift(tao->hessian)); 31 /* Prepare the reduced sub-matrices for the inactive set */ 32 PetscCall(MatDestroy(&bnk->H_inactive)); 33 if (bnk->active_idx) { 34 PetscCall(MatCreateSubMatrixVirtual(tao->hessian, bnk->inactive_idx, bnk->inactive_idx, &bnk->H_inactive)); 35 PetscCall(PCLMVMSetIS(bqnk->pc, bnk->inactive_idx)); 36 } else { 37 PetscCall(PetscObjectReference((PetscObject)tao->hessian)); 38 bnk->H_inactive = tao->hessian; 39 PetscCall(PCLMVMClearIS(bqnk->pc)); 40 } 41 PetscCall(MatDestroy(&bnk->Hpre_inactive)); 42 PetscCall(PetscObjectReference((PetscObject)bnk->H_inactive)); 43 bnk->Hpre_inactive = bnk->H_inactive; 44 PetscFunctionReturn(PETSC_SUCCESS); 45 } 46 47 static PetscErrorCode TaoBQNKComputeStep(Tao tao, PetscBool shift, KSPConvergedReason *ksp_reason, PetscInt *step_type) 48 { 49 TAO_BNK *bnk = (TAO_BNK *)tao->data; 50 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 51 52 PetscFunctionBegin; 53 PetscCall(TaoBNKComputeStep(tao, shift, ksp_reason, step_type)); 54 if (*ksp_reason < 0) { 55 /* Krylov solver failed to converge so reset the LMVM matrix */ 56 PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE)); 57 PetscCall(MatLMVMUpdate(bqnk->B, tao->solution, bnk->unprojected_gradient)); 58 } 59 PetscFunctionReturn(PETSC_SUCCESS); 60 } 61 62 PetscErrorCode TaoSolve_BQNK(Tao tao) 63 { 64 TAO_BNK *bnk = (TAO_BNK *)tao->data; 65 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 66 Mat_LMVM *lmvm = (Mat_LMVM *)bqnk->B->data; 67 Mat_LMVM *J0; 68 Mat_SymBrdn *diag_ctx; 69 PetscBool flg = PETSC_FALSE; 70 71 PetscFunctionBegin; 72 if (!tao->recycle) { 73 PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE)); 74 lmvm->nresets = 0; 75 if (lmvm->J0) { 76 PetscCall(PetscObjectBaseTypeCompare((PetscObject)lmvm->J0, MATLMVM, &flg)); 77 if (flg) { 78 J0 = (Mat_LMVM *)lmvm->J0->data; 79 J0->nresets = 0; 80 } 81 } 82 flg = PETSC_FALSE; 83 PetscCall(PetscObjectTypeCompareAny((PetscObject)bqnk->B, &flg, MATLMVMSYMBROYDEN, MATLMVMSYMBADBROYDEN, MATLMVMBFGS, MATLMVMDFP, "")); 84 if (flg) { 85 diag_ctx = (Mat_SymBrdn *)lmvm->ctx; 86 J0 = (Mat_LMVM *)diag_ctx->D->data; 87 J0->nresets = 0; 88 } 89 } 90 PetscCall((*bqnk->solve)(tao)); 91 PetscFunctionReturn(PETSC_SUCCESS); 92 } 93 94 PetscErrorCode TaoSetUp_BQNK(Tao tao) 95 { 96 TAO_BNK *bnk = (TAO_BNK *)tao->data; 97 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 98 PetscInt n, N; 99 PetscBool is_lmvm, is_set, is_sym; 100 101 PetscFunctionBegin; 102 PetscCall(TaoSetUp_BNK(tao)); 103 PetscCall(VecGetLocalSize(tao->solution, &n)); 104 PetscCall(VecGetSize(tao->solution, &N)); 105 PetscCall(MatSetSizes(bqnk->B, n, n, N, N)); 106 PetscCall(MatLMVMAllocate(bqnk->B, tao->solution, bnk->unprojected_gradient)); 107 PetscCall(PetscObjectBaseTypeCompare((PetscObject)bqnk->B, MATLMVM, &is_lmvm)); 108 PetscCheck(is_lmvm, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Matrix must be an LMVM-type"); 109 PetscCall(MatIsSymmetricKnown(bqnk->B, &is_set, &is_sym)); 110 PetscCheck(is_set && is_sym, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM matrix must be symmetric"); 111 PetscCall(KSPGetPC(tao->ksp, &bqnk->pc)); 112 PetscCall(PCSetType(bqnk->pc, PCLMVM)); 113 PetscCall(PCLMVMSetMatLMVM(bqnk->pc, bqnk->B)); 114 PetscFunctionReturn(PETSC_SUCCESS); 115 } 116 117 static PetscErrorCode TaoSetFromOptions_BQNK(Tao tao, PetscOptionItems *PetscOptionsObject) 118 { 119 TAO_BNK *bnk = (TAO_BNK *)tao->data; 120 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 121 PetscBool is_set; 122 123 PetscFunctionBegin; 124 PetscCall(TaoSetFromOptions_BNK(tao, PetscOptionsObject)); 125 if (bnk->init_type == BNK_INIT_INTERPOLATION) bnk->init_type = BNK_INIT_DIRECTION; 126 PetscCall(MatSetOptionsPrefix(bqnk->B, ((PetscObject)tao)->prefix)); 127 PetscCall(MatAppendOptionsPrefix(bqnk->B, "tao_bqnk_")); 128 PetscCall(MatSetFromOptions(bqnk->B)); 129 PetscCall(MatIsSPDKnown(bqnk->B, &is_set, &bqnk->is_spd)); 130 if (!is_set) bqnk->is_spd = PETSC_FALSE; 131 PetscFunctionReturn(PETSC_SUCCESS); 132 } 133 134 static PetscErrorCode TaoView_BQNK(Tao tao, PetscViewer viewer) 135 { 136 TAO_BNK *bnk = (TAO_BNK *)tao->data; 137 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 138 PetscBool isascii; 139 140 PetscFunctionBegin; 141 PetscCall(TaoView_BNK(tao, viewer)); 142 PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii)); 143 if (isascii) { 144 PetscCall(PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO)); 145 PetscCall(MatView(bqnk->B, viewer)); 146 PetscCall(PetscViewerPopFormat(viewer)); 147 } 148 PetscFunctionReturn(PETSC_SUCCESS); 149 } 150 151 static PetscErrorCode TaoDestroy_BQNK(Tao tao) 152 { 153 TAO_BNK *bnk = (TAO_BNK *)tao->data; 154 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 155 156 PetscFunctionBegin; 157 PetscCall(MatDestroy(&bnk->Hpre_inactive)); 158 PetscCall(MatDestroy(&bnk->H_inactive)); 159 PetscCall(MatDestroy(&bqnk->B)); 160 PetscCall(PetscFree(bnk->ctx)); 161 PetscCall(TaoDestroy_BNK(tao)); 162 PetscFunctionReturn(PETSC_SUCCESS); 163 } 164 165 PETSC_INTERN PetscErrorCode TaoCreate_BQNK(Tao tao) 166 { 167 TAO_BNK *bnk; 168 TAO_BQNK *bqnk; 169 170 PetscFunctionBegin; 171 PetscCall(TaoCreate_BNK(tao)); 172 tao->ops->solve = TaoSolve_BQNK; 173 tao->ops->setfromoptions = TaoSetFromOptions_BQNK; 174 tao->ops->destroy = TaoDestroy_BQNK; 175 tao->ops->view = TaoView_BQNK; 176 tao->ops->setup = TaoSetUp_BQNK; 177 178 bnk = (TAO_BNK *)tao->data; 179 bnk->computehessian = TaoBQNKComputeHessian; 180 bnk->computestep = TaoBQNKComputeStep; 181 bnk->init_type = BNK_INIT_DIRECTION; 182 183 PetscCall(PetscNew(&bqnk)); 184 bnk->ctx = (void *)bqnk; 185 bqnk->is_spd = PETSC_TRUE; 186 187 PetscCall(MatCreate(PetscObjectComm((PetscObject)tao), &bqnk->B)); 188 PetscCall(PetscObjectIncrementTabLevel((PetscObject)bqnk->B, (PetscObject)tao, 1)); 189 PetscCall(MatSetType(bqnk->B, MATLMVMSR1)); 190 PetscFunctionReturn(PETSC_SUCCESS); 191 } 192 193 /*@ 194 TaoGetLMVMMatrix - Returns a pointer to the internal LMVM matrix. Valid 195 only for quasi-Newton family of methods. 196 197 Input Parameter: 198 . tao - `Tao` solver context 199 200 Output Parameter: 201 . B - LMVM matrix 202 203 Level: advanced 204 205 .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoSetLMVMMatrix()` 206 @*/ 207 PetscErrorCode TaoGetLMVMMatrix(Tao tao, Mat *B) 208 { 209 TAO_BNK *bnk = (TAO_BNK *)tao->data; 210 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 211 PetscBool flg = PETSC_FALSE; 212 213 PetscFunctionBegin; 214 PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "")); 215 PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms"); 216 *B = bqnk->B; 217 PetscFunctionReturn(PETSC_SUCCESS); 218 } 219 220 /*@ 221 TaoSetLMVMMatrix - Sets an external LMVM matrix into the Tao solver. Valid 222 only for quasi-Newton family of methods. 223 224 QN family of methods create their own LMVM matrices and users who wish to 225 manipulate this matrix should use TaoGetLMVMMatrix() instead. 226 227 Input Parameters: 228 + tao - Tao solver context 229 - B - LMVM matrix 230 231 Level: advanced 232 233 .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoGetLMVMMatrix()` 234 @*/ 235 PetscErrorCode TaoSetLMVMMatrix(Tao tao, Mat B) 236 { 237 TAO_BNK *bnk = (TAO_BNK *)tao->data; 238 TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx; 239 PetscBool flg = PETSC_FALSE; 240 241 PetscFunctionBegin; 242 PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "")); 243 PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms"); 244 PetscCall(PetscObjectBaseTypeCompare((PetscObject)B, MATLMVM, &flg)); 245 PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Given matrix is not an LMVM matrix"); 246 if (bqnk->B) PetscCall(MatDestroy(&bqnk->B)); 247 PetscCall(PetscObjectReference((PetscObject)B)); 248 bqnk->B = B; 249 PetscFunctionReturn(PETSC_SUCCESS); 250 } 251