1 #include <../src/tao/bound/impls/bqnk/bqnk.h> /*I "petsctao.h" I*/ /*I "petscmat.h" I*/ 2 #include <petscksp.h> 3 4 static PetscErrorCode TaoBQNKComputeHessian(Tao tao) 5 { 6 TAO_BNK *bnk = (TAO_BNK *)tao->data; 7 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 8 PetscReal gnorm2, delta; 9 10 PetscFunctionBegin; 11 /* Alias the LMVM matrix into the TAO hessian */ 12 if (tao->hessian) { 13 PetscCall(MatDestroy(&tao->hessian)); 14 } 15 if (tao->hessian_pre) { 16 PetscCall(MatDestroy(&tao->hessian_pre)); 17 } 18 PetscCall(PetscObjectReference((PetscObject)bqnk->B)); 19 tao->hessian = bqnk->B; 20 PetscCall(PetscObjectReference((PetscObject)bqnk->B)); 21 tao->hessian_pre = bqnk->B; 22 /* Update the Hessian with the latest solution */ 23 if (bqnk->is_spd) { 24 gnorm2 = bnk->gnorm*bnk->gnorm; 25 if (gnorm2 == 0.0) gnorm2 = PETSC_MACHINE_EPSILON; 26 if (bnk->f == 0.0) { 27 delta = 2.0 / gnorm2; 28 } else { 29 delta = 2.0 * PetscAbsScalar(bnk->f) / gnorm2; 30 } 31 PetscCall(MatLMVMSymBroydenSetDelta(bqnk->B, delta)); 32 } 33 PetscCall(MatLMVMUpdate(tao->hessian, tao->solution, bnk->unprojected_gradient)); 34 PetscCall(MatLMVMResetShift(tao->hessian)); 35 /* Prepare the reduced sub-matrices for the inactive set */ 36 PetscCall(MatDestroy(&bnk->H_inactive)); 37 if (bnk->active_idx) { 38 PetscCall(MatCreateSubMatrixVirtual(tao->hessian, bnk->inactive_idx, bnk->inactive_idx, &bnk->H_inactive)); 39 PetscCall(PCLMVMSetIS(bqnk->pc, bnk->inactive_idx)); 40 } else { 41 PetscCall(PetscObjectReference((PetscObject)tao->hessian)); 42 bnk->H_inactive = tao->hessian; 43 PetscCall(PCLMVMClearIS(bqnk->pc)); 44 } 45 PetscCall(MatDestroy(&bnk->Hpre_inactive)); 46 PetscCall(PetscObjectReference((PetscObject)bnk->H_inactive)); 47 bnk->Hpre_inactive = bnk->H_inactive; 48 PetscFunctionReturn(0); 49 } 50 51 static PetscErrorCode TaoBQNKComputeStep(Tao tao, PetscBool shift, KSPConvergedReason *ksp_reason, PetscInt *step_type) 52 { 53 TAO_BNK *bnk = (TAO_BNK *)tao->data; 54 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 55 56 PetscFunctionBegin; 57 PetscCall(TaoBNKComputeStep(tao, shift, ksp_reason, step_type)); 58 if (*ksp_reason < 0) { 59 /* Krylov solver failed to converge so reset the LMVM matrix */ 60 PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE)); 61 PetscCall(MatLMVMUpdate(bqnk->B, tao->solution, bnk->unprojected_gradient)); 62 } 63 PetscFunctionReturn(0); 64 } 65 66 PetscErrorCode TaoSolve_BQNK(Tao tao) 67 { 68 TAO_BNK *bnk = (TAO_BNK *)tao->data; 69 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 70 Mat_LMVM *lmvm = (Mat_LMVM*)bqnk->B->data; 71 Mat_LMVM *J0; 72 Mat_SymBrdn *diag_ctx; 73 PetscBool flg = PETSC_FALSE; 74 75 PetscFunctionBegin; 76 if (!tao->recycle) { 77 PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE)); 78 lmvm->nresets = 0; 79 if (lmvm->J0) { 80 PetscCall(PetscObjectBaseTypeCompare((PetscObject)lmvm->J0, MATLMVM, &flg)); 81 if (flg) { 82 J0 = (Mat_LMVM*)lmvm->J0->data; 83 J0->nresets = 0; 84 } 85 } 86 flg = PETSC_FALSE; 87 PetscCall(PetscObjectTypeCompareAny((PetscObject)bqnk->B, &flg, MATLMVMSYMBROYDEN, MATLMVMSYMBADBROYDEN, MATLMVMBFGS, MATLMVMDFP, "")); 88 if (flg) { 89 diag_ctx = (Mat_SymBrdn*)lmvm->ctx; 90 J0 = (Mat_LMVM*)diag_ctx->D->data; 91 J0->nresets = 0; 92 } 93 } 94 PetscCall((*bqnk->solve)(tao)); 95 PetscFunctionReturn(0); 96 } 97 98 PetscErrorCode TaoSetUp_BQNK(Tao tao) 99 { 100 TAO_BNK *bnk = (TAO_BNK *)tao->data; 101 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 102 PetscInt n, N; 103 PetscBool is_lmvm, is_set,is_sym; 104 105 PetscFunctionBegin; 106 PetscCall(TaoSetUp_BNK(tao)); 107 PetscCall(VecGetLocalSize(tao->solution,&n)); 108 PetscCall(VecGetSize(tao->solution,&N)); 109 PetscCall(MatSetSizes(bqnk->B, n, n, N, N)); 110 PetscCall(MatLMVMAllocate(bqnk->B,tao->solution,bnk->unprojected_gradient)); 111 PetscCall(PetscObjectBaseTypeCompare((PetscObject)bqnk->B, MATLMVM, &is_lmvm)); 112 PetscCheck(is_lmvm,PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Matrix must be an LMVM-type"); 113 PetscCall(MatIsSymmetricKnown(bqnk->B, &is_set,&is_sym)); 114 PetscCheck(is_set && is_sym,PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM matrix must be symmetric"); 115 PetscCall(KSPGetPC(tao->ksp, &bqnk->pc)); 116 PetscCall(PCSetType(bqnk->pc, PCLMVM)); 117 PetscCall(PCLMVMSetMatLMVM(bqnk->pc, bqnk->B)); 118 PetscFunctionReturn(0); 119 } 120 121 static PetscErrorCode TaoSetFromOptions_BQNK(PetscOptionItems *PetscOptionsObject,Tao tao) 122 { 123 TAO_BNK *bnk = (TAO_BNK *)tao->data; 124 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 125 PetscBool is_set; 126 127 PetscFunctionBegin; 128 PetscCall(TaoSetFromOptions_BNK(PetscOptionsObject,tao)); 129 if (bnk->init_type == BNK_INIT_INTERPOLATION) bnk->init_type = BNK_INIT_DIRECTION; 130 PetscCall(MatSetOptionsPrefix(bqnk->B, ((PetscObject)tao)->prefix)); 131 PetscCall(MatAppendOptionsPrefix(bqnk->B, "tao_bqnk_")); 132 PetscCall(MatSetFromOptions(bqnk->B)); 133 PetscCall(MatIsSPDKnown(bqnk->B, &is_set, &bqnk->is_spd)); 134 if (!is_set) bqnk->is_spd = PETSC_FALSE; 135 PetscFunctionReturn(0); 136 } 137 138 static PetscErrorCode TaoView_BQNK(Tao tao, PetscViewer viewer) 139 { 140 TAO_BNK *bnk = (TAO_BNK*)tao->data; 141 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 142 PetscBool isascii; 143 144 PetscFunctionBegin; 145 PetscCall(TaoView_BNK(tao, viewer)); 146 PetscCall(PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii)); 147 if (isascii) { 148 PetscCall(PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO)); 149 PetscCall(MatView(bqnk->B, viewer)); 150 PetscCall(PetscViewerPopFormat(viewer)); 151 } 152 PetscFunctionReturn(0); 153 } 154 155 static PetscErrorCode TaoDestroy_BQNK(Tao tao) 156 { 157 TAO_BNK *bnk = (TAO_BNK*)tao->data; 158 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 159 160 PetscFunctionBegin; 161 PetscCall(MatDestroy(&bnk->Hpre_inactive)); 162 PetscCall(MatDestroy(&bnk->H_inactive)); 163 PetscCall(MatDestroy(&bqnk->B)); 164 PetscCall(PetscFree(bnk->ctx)); 165 PetscCall(TaoDestroy_BNK(tao)); 166 PetscFunctionReturn(0); 167 } 168 169 PETSC_INTERN PetscErrorCode TaoCreate_BQNK(Tao tao) 170 { 171 TAO_BNK *bnk; 172 TAO_BQNK *bqnk; 173 174 PetscFunctionBegin; 175 PetscCall(TaoCreate_BNK(tao)); 176 tao->ops->solve = TaoSolve_BQNK; 177 tao->ops->setfromoptions = TaoSetFromOptions_BQNK; 178 tao->ops->destroy = TaoDestroy_BQNK; 179 tao->ops->view = TaoView_BQNK; 180 tao->ops->setup = TaoSetUp_BQNK; 181 182 bnk = (TAO_BNK *)tao->data; 183 bnk->computehessian = TaoBQNKComputeHessian; 184 bnk->computestep = TaoBQNKComputeStep; 185 bnk->init_type = BNK_INIT_DIRECTION; 186 187 PetscCall(PetscNewLog(tao,&bqnk)); 188 bnk->ctx = (void*)bqnk; 189 bqnk->is_spd = PETSC_TRUE; 190 191 PetscCall(MatCreate(PetscObjectComm((PetscObject)tao), &bqnk->B)); 192 PetscCall(PetscObjectIncrementTabLevel((PetscObject)bqnk->B, (PetscObject)tao, 1)); 193 PetscCall(MatSetType(bqnk->B, MATLMVMSR1)); 194 PetscFunctionReturn(0); 195 } 196 197 /*@ 198 TaoGetLMVMMatrix - Returns a pointer to the internal LMVM matrix. Valid 199 only for quasi-Newton family of methods. 200 201 Input Parameters: 202 . tao - Tao solver context 203 204 Output Parameters: 205 . B - LMVM matrix 206 207 Level: advanced 208 209 .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoSetLMVMMatrix()` 210 @*/ 211 PetscErrorCode TaoGetLMVMMatrix(Tao tao, Mat *B) 212 { 213 TAO_BNK *bnk = (TAO_BNK*)tao->data; 214 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 215 PetscBool flg = PETSC_FALSE; 216 217 PetscFunctionBegin; 218 PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "")); 219 PetscCheck(flg,PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms"); 220 *B = bqnk->B; 221 PetscFunctionReturn(0); 222 } 223 224 /*@ 225 TaoSetLMVMMatrix - Sets an external LMVM matrix into the Tao solver. Valid 226 only for quasi-Newton family of methods. 227 228 QN family of methods create their own LMVM matrices and users who wish to 229 manipulate this matrix should use TaoGetLMVMMatrix() instead. 230 231 Input Parameters: 232 + tao - Tao solver context 233 - B - LMVM matrix 234 235 Level: advanced 236 237 .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoGetLMVMMatrix()` 238 @*/ 239 PetscErrorCode TaoSetLMVMMatrix(Tao tao, Mat B) 240 { 241 TAO_BNK *bnk = (TAO_BNK*)tao->data; 242 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 243 PetscBool flg = PETSC_FALSE; 244 245 PetscFunctionBegin; 246 PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "")); 247 PetscCheck(flg,PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms"); 248 PetscCall(PetscObjectBaseTypeCompare((PetscObject)B, MATLMVM, &flg)); 249 PetscCheck(flg,PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Given matrix is not an LMVM matrix"); 250 if (bqnk->B) { 251 PetscCall(MatDestroy(&bqnk->B)); 252 } 253 PetscCall(PetscObjectReference((PetscObject)B)); 254 bqnk->B = B; 255 PetscFunctionReturn(0); 256 } 257