1 #include <../src/tao/bound/impls/bqnk/bqnk.h> /*I "petsctao.h" I*/ /*I "petscmat.h" I*/ 2 #include <petscksp.h> 3 4 static PetscErrorCode TaoBQNKComputeHessian(Tao tao) 5 { 6 TAO_BNK *bnk = (TAO_BNK *)tao->data; 7 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 8 PetscReal gnorm2, delta; 9 10 PetscFunctionBegin; 11 /* Alias the LMVM matrix into the TAO hessian */ 12 if (tao->hessian) { 13 PetscCall(MatDestroy(&tao->hessian)); 14 } 15 if (tao->hessian_pre) { 16 PetscCall(MatDestroy(&tao->hessian_pre)); 17 } 18 PetscCall(PetscObjectReference((PetscObject)bqnk->B)); 19 tao->hessian = bqnk->B; 20 PetscCall(PetscObjectReference((PetscObject)bqnk->B)); 21 tao->hessian_pre = bqnk->B; 22 /* Update the Hessian with the latest solution */ 23 if (bqnk->is_spd) { 24 gnorm2 = bnk->gnorm*bnk->gnorm; 25 if (gnorm2 == 0.0) gnorm2 = PETSC_MACHINE_EPSILON; 26 if (bnk->f == 0.0) { 27 delta = 2.0 / gnorm2; 28 } else { 29 delta = 2.0 * PetscAbsScalar(bnk->f) / gnorm2; 30 } 31 PetscCall(MatLMVMSymBroydenSetDelta(bqnk->B, delta)); 32 } 33 PetscCall(MatLMVMUpdate(tao->hessian, tao->solution, bnk->unprojected_gradient)); 34 PetscCall(MatLMVMResetShift(tao->hessian)); 35 /* Prepare the reduced sub-matrices for the inactive set */ 36 PetscCall(MatDestroy(&bnk->H_inactive)); 37 if (bnk->active_idx) { 38 PetscCall(MatCreateSubMatrixVirtual(tao->hessian, bnk->inactive_idx, bnk->inactive_idx, &bnk->H_inactive)); 39 PetscCall(PCLMVMSetIS(bqnk->pc, bnk->inactive_idx)); 40 } else { 41 PetscCall(PetscObjectReference((PetscObject)tao->hessian)); 42 bnk->H_inactive = tao->hessian; 43 PetscCall(PCLMVMClearIS(bqnk->pc)); 44 } 45 PetscCall(MatDestroy(&bnk->Hpre_inactive)); 46 PetscCall(PetscObjectReference((PetscObject)bnk->H_inactive)); 47 bnk->Hpre_inactive = bnk->H_inactive; 48 PetscFunctionReturn(0); 49 } 50 51 static PetscErrorCode TaoBQNKComputeStep(Tao tao, PetscBool shift, KSPConvergedReason *ksp_reason, PetscInt *step_type) 52 { 53 TAO_BNK *bnk = (TAO_BNK *)tao->data; 54 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 55 56 PetscFunctionBegin; 57 PetscCall(TaoBNKComputeStep(tao, shift, ksp_reason, step_type)); 58 if (*ksp_reason < 0) { 59 /* Krylov solver failed to converge so reset the LMVM matrix */ 60 PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE)); 61 PetscCall(MatLMVMUpdate(bqnk->B, tao->solution, bnk->unprojected_gradient)); 62 } 63 PetscFunctionReturn(0); 64 } 65 66 PetscErrorCode TaoSolve_BQNK(Tao tao) 67 { 68 TAO_BNK *bnk = (TAO_BNK *)tao->data; 69 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 70 Mat_LMVM *lmvm = (Mat_LMVM*)bqnk->B->data; 71 Mat_LMVM *J0; 72 Mat_SymBrdn *diag_ctx; 73 PetscBool flg = PETSC_FALSE; 74 75 PetscFunctionBegin; 76 if (!tao->recycle) { 77 PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE)); 78 lmvm->nresets = 0; 79 if (lmvm->J0) { 80 PetscCall(PetscObjectBaseTypeCompare((PetscObject)lmvm->J0, MATLMVM, &flg)); 81 if (flg) { 82 J0 = (Mat_LMVM*)lmvm->J0->data; 83 J0->nresets = 0; 84 } 85 } 86 flg = PETSC_FALSE; 87 PetscCall(PetscObjectTypeCompareAny((PetscObject)bqnk->B, &flg, MATLMVMSYMBROYDEN, MATLMVMSYMBADBROYDEN, MATLMVMBFGS, MATLMVMDFP, "")); 88 if (flg) { 89 diag_ctx = (Mat_SymBrdn*)lmvm->ctx; 90 J0 = (Mat_LMVM*)diag_ctx->D->data; 91 J0->nresets = 0; 92 } 93 } 94 PetscCall((*bqnk->solve)(tao)); 95 PetscFunctionReturn(0); 96 } 97 98 PetscErrorCode TaoSetUp_BQNK(Tao tao) 99 { 100 TAO_BNK *bnk = (TAO_BNK *)tao->data; 101 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 102 PetscInt n, N; 103 PetscBool is_lmvm, is_sym, is_spd; 104 105 PetscFunctionBegin; 106 PetscCall(TaoSetUp_BNK(tao)); 107 PetscCall(VecGetLocalSize(tao->solution,&n)); 108 PetscCall(VecGetSize(tao->solution,&N)); 109 PetscCall(MatSetSizes(bqnk->B, n, n, N, N)); 110 PetscCall(MatLMVMAllocate(bqnk->B,tao->solution,bnk->unprojected_gradient)); 111 PetscCall(PetscObjectBaseTypeCompare((PetscObject)bqnk->B, MATLMVM, &is_lmvm)); 112 PetscCheck(is_lmvm,PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Matrix must be an LMVM-type"); 113 PetscCall(MatGetOption(bqnk->B, MAT_SYMMETRIC, &is_sym)); 114 PetscCheck(is_sym,PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM matrix must be symmetric"); 115 PetscCall(MatGetOption(bqnk->B, MAT_SPD, &is_spd)); 116 PetscCall(KSPGetPC(tao->ksp, &bqnk->pc)); 117 PetscCall(PCSetType(bqnk->pc, PCLMVM)); 118 PetscCall(PCLMVMSetMatLMVM(bqnk->pc, bqnk->B)); 119 PetscFunctionReturn(0); 120 } 121 122 static PetscErrorCode TaoSetFromOptions_BQNK(PetscOptionItems *PetscOptionsObject,Tao tao) 123 { 124 TAO_BNK *bnk = (TAO_BNK *)tao->data; 125 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 126 127 PetscFunctionBegin; 128 PetscCall(TaoSetFromOptions_BNK(PetscOptionsObject,tao)); 129 if (bnk->init_type == BNK_INIT_INTERPOLATION) bnk->init_type = BNK_INIT_DIRECTION; 130 PetscCall(MatSetOptionsPrefix(bqnk->B, ((PetscObject)tao)->prefix)); 131 PetscCall(MatAppendOptionsPrefix(bqnk->B, "tao_bqnk_")); 132 PetscCall(MatSetFromOptions(bqnk->B)); 133 PetscCall(MatGetOption(bqnk->B, MAT_SPD, &bqnk->is_spd)); 134 PetscFunctionReturn(0); 135 } 136 137 static PetscErrorCode TaoView_BQNK(Tao tao, PetscViewer viewer) 138 { 139 TAO_BNK *bnk = (TAO_BNK*)tao->data; 140 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 141 PetscBool isascii; 142 143 PetscFunctionBegin; 144 PetscCall(TaoView_BNK(tao, viewer)); 145 PetscCall(PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii)); 146 if (isascii) { 147 PetscCall(PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO)); 148 PetscCall(MatView(bqnk->B, viewer)); 149 PetscCall(PetscViewerPopFormat(viewer)); 150 } 151 PetscFunctionReturn(0); 152 } 153 154 static PetscErrorCode TaoDestroy_BQNK(Tao tao) 155 { 156 TAO_BNK *bnk = (TAO_BNK*)tao->data; 157 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 158 159 PetscFunctionBegin; 160 PetscCall(MatDestroy(&bnk->Hpre_inactive)); 161 PetscCall(MatDestroy(&bnk->H_inactive)); 162 PetscCall(MatDestroy(&bqnk->B)); 163 PetscCall(PetscFree(bnk->ctx)); 164 PetscCall(TaoDestroy_BNK(tao)); 165 PetscFunctionReturn(0); 166 } 167 168 PETSC_INTERN PetscErrorCode TaoCreate_BQNK(Tao tao) 169 { 170 TAO_BNK *bnk; 171 TAO_BQNK *bqnk; 172 173 PetscFunctionBegin; 174 PetscCall(TaoCreate_BNK(tao)); 175 tao->ops->solve = TaoSolve_BQNK; 176 tao->ops->setfromoptions = TaoSetFromOptions_BQNK; 177 tao->ops->destroy = TaoDestroy_BQNK; 178 tao->ops->view = TaoView_BQNK; 179 tao->ops->setup = TaoSetUp_BQNK; 180 181 bnk = (TAO_BNK *)tao->data; 182 bnk->computehessian = TaoBQNKComputeHessian; 183 bnk->computestep = TaoBQNKComputeStep; 184 bnk->init_type = BNK_INIT_DIRECTION; 185 186 PetscCall(PetscNewLog(tao,&bqnk)); 187 bnk->ctx = (void*)bqnk; 188 bqnk->is_spd = PETSC_TRUE; 189 190 PetscCall(MatCreate(PetscObjectComm((PetscObject)tao), &bqnk->B)); 191 PetscCall(PetscObjectIncrementTabLevel((PetscObject)bqnk->B, (PetscObject)tao, 1)); 192 PetscCall(MatSetType(bqnk->B, MATLMVMSR1)); 193 PetscFunctionReturn(0); 194 } 195 196 /*@ 197 TaoGetLMVMMatrix - Returns a pointer to the internal LMVM matrix. Valid 198 only for quasi-Newton family of methods. 199 200 Input Parameters: 201 . tao - Tao solver context 202 203 Output Parameters: 204 . B - LMVM matrix 205 206 Level: advanced 207 208 .seealso: TAOBQNLS, TAOBQNKLS, TAOBQNKTL, TAOBQNKTR, MATLMVM, TaoSetLMVMMatrix() 209 @*/ 210 PetscErrorCode TaoGetLMVMMatrix(Tao tao, Mat *B) 211 { 212 TAO_BNK *bnk = (TAO_BNK*)tao->data; 213 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 214 PetscBool flg = PETSC_FALSE; 215 216 PetscFunctionBegin; 217 PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "")); 218 PetscCheck(flg,PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms"); 219 *B = bqnk->B; 220 PetscFunctionReturn(0); 221 } 222 223 /*@ 224 TaoSetLMVMMatrix - Sets an external LMVM matrix into the Tao solver. Valid 225 only for quasi-Newton family of methods. 226 227 QN family of methods create their own LMVM matrices and users who wish to 228 manipulate this matrix should use TaoGetLMVMMatrix() instead. 229 230 Input Parameters: 231 + tao - Tao solver context 232 - B - LMVM matrix 233 234 Level: advanced 235 236 .seealso: TAOBQNLS, TAOBQNKLS, TAOBQNKTL, TAOBQNKTR, MATLMVM, TaoGetLMVMMatrix() 237 @*/ 238 PetscErrorCode TaoSetLMVMMatrix(Tao tao, Mat B) 239 { 240 TAO_BNK *bnk = (TAO_BNK*)tao->data; 241 TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx; 242 PetscBool flg = PETSC_FALSE; 243 244 PetscFunctionBegin; 245 PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "")); 246 PetscCheck(flg,PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms"); 247 PetscCall(PetscObjectBaseTypeCompare((PetscObject)B, MATLMVM, &flg)); 248 PetscCheck(flg,PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Given matrix is not an LMVM matrix"); 249 if (bqnk->B) { 250 PetscCall(MatDestroy(&bqnk->B)); 251 } 252 PetscCall(PetscObjectReference((PetscObject)B)); 253 bqnk->B = B; 254 PetscFunctionReturn(0); 255 } 256