xref: /petsc/src/tao/bound/impls/bqnk/bqnk.c (revision 9371c9d470a9602b6d10a8bf50c9b2280a79e45a)
1414d97d3SAlp Dener #include <../src/tao/bound/impls/bqnk/bqnk.h> /*I "petsctao.h" I*/ /*I "petscmat.h" I*/
2e0ed867bSAlp Dener #include <petscksp.h>
3e0ed867bSAlp Dener 
4*9371c9d4SSatish Balay static PetscErrorCode TaoBQNKComputeHessian(Tao tao) {
5e0ed867bSAlp Dener   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
6e0ed867bSAlp Dener   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
765f5217aSAlp Dener   PetscReal gnorm2, delta;
8e0ed867bSAlp Dener 
9e0ed867bSAlp Dener   PetscFunctionBegin;
10e0ed867bSAlp Dener   /* Alias the LMVM matrix into the TAO hessian */
11*9371c9d4SSatish Balay   if (tao->hessian) { PetscCall(MatDestroy(&tao->hessian)); }
12*9371c9d4SSatish Balay   if (tao->hessian_pre) { PetscCall(MatDestroy(&tao->hessian_pre)); }
139566063dSJacob Faibussowitsch   PetscCall(PetscObjectReference((PetscObject)bqnk->B));
14e0ed867bSAlp Dener   tao->hessian = bqnk->B;
159566063dSJacob Faibussowitsch   PetscCall(PetscObjectReference((PetscObject)bqnk->B));
16e0ed867bSAlp Dener   tao->hessian_pre = bqnk->B;
17e0ed867bSAlp Dener   /* Update the Hessian with the latest solution */
18f5766c09SAlp Dener   if (bqnk->is_spd) {
1965f5217aSAlp Dener     gnorm2 = bnk->gnorm * bnk->gnorm;
208cabe928SAlp Dener     if (gnorm2 == 0.0) gnorm2 = PETSC_MACHINE_EPSILON;
218cabe928SAlp Dener     if (bnk->f == 0.0) {
228cabe928SAlp Dener       delta = 2.0 / gnorm2;
238cabe928SAlp Dener     } else {
248cabe928SAlp Dener       delta = 2.0 * PetscAbsScalar(bnk->f) / gnorm2;
258cabe928SAlp Dener     }
269566063dSJacob Faibussowitsch     PetscCall(MatLMVMSymBroydenSetDelta(bqnk->B, delta));
27f5766c09SAlp Dener   }
289566063dSJacob Faibussowitsch   PetscCall(MatLMVMUpdate(tao->hessian, tao->solution, bnk->unprojected_gradient));
299566063dSJacob Faibussowitsch   PetscCall(MatLMVMResetShift(tao->hessian));
30e0ed867bSAlp Dener   /* Prepare the reduced sub-matrices for the inactive set */
319566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&bnk->H_inactive));
32e0ed867bSAlp Dener   if (bnk->active_idx) {
339566063dSJacob Faibussowitsch     PetscCall(MatCreateSubMatrixVirtual(tao->hessian, bnk->inactive_idx, bnk->inactive_idx, &bnk->H_inactive));
349566063dSJacob Faibussowitsch     PetscCall(PCLMVMSetIS(bqnk->pc, bnk->inactive_idx));
35e0ed867bSAlp Dener   } else {
369566063dSJacob Faibussowitsch     PetscCall(PetscObjectReference((PetscObject)tao->hessian));
37e0ed867bSAlp Dener     bnk->H_inactive = tao->hessian;
389566063dSJacob Faibussowitsch     PetscCall(PCLMVMClearIS(bqnk->pc));
39e0ed867bSAlp Dener   }
409566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&bnk->Hpre_inactive));
419566063dSJacob Faibussowitsch   PetscCall(PetscObjectReference((PetscObject)bnk->H_inactive));
42e0ed867bSAlp Dener   bnk->Hpre_inactive = bnk->H_inactive;
43e0ed867bSAlp Dener   PetscFunctionReturn(0);
44e0ed867bSAlp Dener }
45e0ed867bSAlp Dener 
46*9371c9d4SSatish Balay static PetscErrorCode TaoBQNKComputeStep(Tao tao, PetscBool shift, KSPConvergedReason *ksp_reason, PetscInt *step_type) {
47e0ed867bSAlp Dener   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
48e0ed867bSAlp Dener   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
49e0ed867bSAlp Dener 
50e0ed867bSAlp Dener   PetscFunctionBegin;
519566063dSJacob Faibussowitsch   PetscCall(TaoBNKComputeStep(tao, shift, ksp_reason, step_type));
52e0ed867bSAlp Dener   if (*ksp_reason < 0) {
53e0ed867bSAlp Dener     /* Krylov solver failed to converge so reset the LMVM matrix */
549566063dSJacob Faibussowitsch     PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE));
559566063dSJacob Faibussowitsch     PetscCall(MatLMVMUpdate(bqnk->B, tao->solution, bnk->unprojected_gradient));
56e0ed867bSAlp Dener   }
57e0ed867bSAlp Dener   PetscFunctionReturn(0);
58e0ed867bSAlp Dener }
59e0ed867bSAlp Dener 
60*9371c9d4SSatish Balay PetscErrorCode TaoSolve_BQNK(Tao tao) {
61414d97d3SAlp Dener   TAO_BNK     *bnk  = (TAO_BNK *)tao->data;
62414d97d3SAlp Dener   TAO_BQNK    *bqnk = (TAO_BQNK *)bnk->ctx;
63414d97d3SAlp Dener   Mat_LMVM    *lmvm = (Mat_LMVM *)bqnk->B->data;
64414d97d3SAlp Dener   Mat_LMVM    *J0;
65414d97d3SAlp Dener   Mat_SymBrdn *diag_ctx;
66414d97d3SAlp Dener   PetscBool    flg = PETSC_FALSE;
67414d97d3SAlp Dener 
68414d97d3SAlp Dener   PetscFunctionBegin;
69414d97d3SAlp Dener   if (!tao->recycle) {
709566063dSJacob Faibussowitsch     PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE));
71414d97d3SAlp Dener     lmvm->nresets = 0;
72414d97d3SAlp Dener     if (lmvm->J0) {
739566063dSJacob Faibussowitsch       PetscCall(PetscObjectBaseTypeCompare((PetscObject)lmvm->J0, MATLMVM, &flg));
74414d97d3SAlp Dener       if (flg) {
75414d97d3SAlp Dener         J0          = (Mat_LMVM *)lmvm->J0->data;
76414d97d3SAlp Dener         J0->nresets = 0;
77414d97d3SAlp Dener       }
78414d97d3SAlp Dener     }
79414d97d3SAlp Dener     flg = PETSC_FALSE;
809566063dSJacob Faibussowitsch     PetscCall(PetscObjectTypeCompareAny((PetscObject)bqnk->B, &flg, MATLMVMSYMBROYDEN, MATLMVMSYMBADBROYDEN, MATLMVMBFGS, MATLMVMDFP, ""));
81414d97d3SAlp Dener     if (flg) {
82414d97d3SAlp Dener       diag_ctx    = (Mat_SymBrdn *)lmvm->ctx;
83414d97d3SAlp Dener       J0          = (Mat_LMVM *)diag_ctx->D->data;
84414d97d3SAlp Dener       J0->nresets = 0;
85414d97d3SAlp Dener     }
86414d97d3SAlp Dener   }
879566063dSJacob Faibussowitsch   PetscCall((*bqnk->solve)(tao));
88414d97d3SAlp Dener   PetscFunctionReturn(0);
89414d97d3SAlp Dener }
90414d97d3SAlp Dener 
91*9371c9d4SSatish Balay PetscErrorCode TaoSetUp_BQNK(Tao tao) {
924f4fdda4SAlp Dener   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
934f4fdda4SAlp Dener   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
944f4fdda4SAlp Dener   PetscInt  n, N;
95b94d7dedSBarry Smith   PetscBool is_lmvm, is_set, is_sym;
964f4fdda4SAlp Dener 
974f4fdda4SAlp Dener   PetscFunctionBegin;
989566063dSJacob Faibussowitsch   PetscCall(TaoSetUp_BNK(tao));
999566063dSJacob Faibussowitsch   PetscCall(VecGetLocalSize(tao->solution, &n));
1009566063dSJacob Faibussowitsch   PetscCall(VecGetSize(tao->solution, &N));
1019566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(bqnk->B, n, n, N, N));
1029566063dSJacob Faibussowitsch   PetscCall(MatLMVMAllocate(bqnk->B, tao->solution, bnk->unprojected_gradient));
1039566063dSJacob Faibussowitsch   PetscCall(PetscObjectBaseTypeCompare((PetscObject)bqnk->B, MATLMVM, &is_lmvm));
1043c859ba3SBarry Smith   PetscCheck(is_lmvm, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Matrix must be an LMVM-type");
105b94d7dedSBarry Smith   PetscCall(MatIsSymmetricKnown(bqnk->B, &is_set, &is_sym));
106b94d7dedSBarry Smith   PetscCheck(is_set && is_sym, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM matrix must be symmetric");
1079566063dSJacob Faibussowitsch   PetscCall(KSPGetPC(tao->ksp, &bqnk->pc));
1089566063dSJacob Faibussowitsch   PetscCall(PCSetType(bqnk->pc, PCLMVM));
1099566063dSJacob Faibussowitsch   PetscCall(PCLMVMSetMatLMVM(bqnk->pc, bqnk->B));
1104f4fdda4SAlp Dener   PetscFunctionReturn(0);
1114f4fdda4SAlp Dener }
1124f4fdda4SAlp Dener 
113*9371c9d4SSatish Balay static PetscErrorCode TaoSetFromOptions_BQNK(Tao tao, PetscOptionItems *PetscOptionsObject) {
114e0ed867bSAlp Dener   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
115e0ed867bSAlp Dener   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
116b94d7dedSBarry Smith   PetscBool is_set;
117e0ed867bSAlp Dener 
118e0ed867bSAlp Dener   PetscFunctionBegin;
119dbbe0bcdSBarry Smith   PetscCall(TaoSetFromOptions_BNK(tao, PetscOptionsObject));
120e0ed867bSAlp Dener   if (bnk->init_type == BNK_INIT_INTERPOLATION) bnk->init_type = BNK_INIT_DIRECTION;
1219566063dSJacob Faibussowitsch   PetscCall(MatSetOptionsPrefix(bqnk->B, ((PetscObject)tao)->prefix));
1229566063dSJacob Faibussowitsch   PetscCall(MatAppendOptionsPrefix(bqnk->B, "tao_bqnk_"));
1239566063dSJacob Faibussowitsch   PetscCall(MatSetFromOptions(bqnk->B));
124b94d7dedSBarry Smith   PetscCall(MatIsSPDKnown(bqnk->B, &is_set, &bqnk->is_spd));
125b94d7dedSBarry Smith   if (!is_set) bqnk->is_spd = PETSC_FALSE;
126e0ed867bSAlp Dener   PetscFunctionReturn(0);
127e0ed867bSAlp Dener }
128e0ed867bSAlp Dener 
129*9371c9d4SSatish Balay static PetscErrorCode TaoView_BQNK(Tao tao, PetscViewer viewer) {
130e0ed867bSAlp Dener   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
131e0ed867bSAlp Dener   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
132e0ed867bSAlp Dener   PetscBool isascii;
133e0ed867bSAlp Dener 
134e0ed867bSAlp Dener   PetscFunctionBegin;
1359566063dSJacob Faibussowitsch   PetscCall(TaoView_BNK(tao, viewer));
1369566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
137e0ed867bSAlp Dener   if (isascii) {
1389566063dSJacob Faibussowitsch     PetscCall(PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO));
1399566063dSJacob Faibussowitsch     PetscCall(MatView(bqnk->B, viewer));
1409566063dSJacob Faibussowitsch     PetscCall(PetscViewerPopFormat(viewer));
141e0ed867bSAlp Dener   }
142e0ed867bSAlp Dener   PetscFunctionReturn(0);
143e0ed867bSAlp Dener }
144e0ed867bSAlp Dener 
145*9371c9d4SSatish Balay static PetscErrorCode TaoDestroy_BQNK(Tao tao) {
146e0ed867bSAlp Dener   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
147e0ed867bSAlp Dener   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
148e0ed867bSAlp Dener 
149e0ed867bSAlp Dener   PetscFunctionBegin;
1509566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&bnk->Hpre_inactive));
1519566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&bnk->H_inactive));
1529566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&bqnk->B));
1539566063dSJacob Faibussowitsch   PetscCall(PetscFree(bnk->ctx));
1549566063dSJacob Faibussowitsch   PetscCall(TaoDestroy_BNK(tao));
155e0ed867bSAlp Dener   PetscFunctionReturn(0);
156e0ed867bSAlp Dener }
157e0ed867bSAlp Dener 
158*9371c9d4SSatish Balay PETSC_INTERN PetscErrorCode TaoCreate_BQNK(Tao tao) {
159e0ed867bSAlp Dener   TAO_BNK  *bnk;
160e0ed867bSAlp Dener   TAO_BQNK *bqnk;
161e0ed867bSAlp Dener 
162e0ed867bSAlp Dener   PetscFunctionBegin;
1639566063dSJacob Faibussowitsch   PetscCall(TaoCreate_BNK(tao));
164414d97d3SAlp Dener   tao->ops->solve          = TaoSolve_BQNK;
165e0ed867bSAlp Dener   tao->ops->setfromoptions = TaoSetFromOptions_BQNK;
166e0ed867bSAlp Dener   tao->ops->destroy        = TaoDestroy_BQNK;
167e0ed867bSAlp Dener   tao->ops->view           = TaoView_BQNK;
1684f4fdda4SAlp Dener   tao->ops->setup          = TaoSetUp_BQNK;
169e0ed867bSAlp Dener 
170e0ed867bSAlp Dener   bnk                 = (TAO_BNK *)tao->data;
171e0ed867bSAlp Dener   bnk->computehessian = TaoBQNKComputeHessian;
172e0ed867bSAlp Dener   bnk->computestep    = TaoBQNKComputeStep;
173e0ed867bSAlp Dener   bnk->init_type      = BNK_INIT_DIRECTION;
174e0ed867bSAlp Dener 
1759566063dSJacob Faibussowitsch   PetscCall(PetscNewLog(tao, &bqnk));
176e0ed867bSAlp Dener   bnk->ctx     = (void *)bqnk;
177f5766c09SAlp Dener   bqnk->is_spd = PETSC_TRUE;
178e0ed867bSAlp Dener 
1799566063dSJacob Faibussowitsch   PetscCall(MatCreate(PetscObjectComm((PetscObject)tao), &bqnk->B));
1809566063dSJacob Faibussowitsch   PetscCall(PetscObjectIncrementTabLevel((PetscObject)bqnk->B, (PetscObject)tao, 1));
1819566063dSJacob Faibussowitsch   PetscCall(MatSetType(bqnk->B, MATLMVMSR1));
182e0ed867bSAlp Dener   PetscFunctionReturn(0);
183e0ed867bSAlp Dener }
184f5766c09SAlp Dener 
185414d97d3SAlp Dener /*@
186414d97d3SAlp Dener    TaoGetLMVMMatrix - Returns a pointer to the internal LMVM matrix. Valid
187414d97d3SAlp Dener    only for quasi-Newton family of methods.
188414d97d3SAlp Dener 
189414d97d3SAlp Dener    Input Parameters:
190414d97d3SAlp Dener .  tao - Tao solver context
191414d97d3SAlp Dener 
192414d97d3SAlp Dener    Output Parameters:
193414d97d3SAlp Dener .  B - LMVM matrix
194414d97d3SAlp Dener 
195414d97d3SAlp Dener    Level: advanced
196414d97d3SAlp Dener 
197db781477SPatrick Sanan .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoSetLMVMMatrix()`
198414d97d3SAlp Dener @*/
199*9371c9d4SSatish Balay PetscErrorCode TaoGetLMVMMatrix(Tao tao, Mat *B) {
200f5766c09SAlp Dener   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
201f5766c09SAlp Dener   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
202414d97d3SAlp Dener   PetscBool flg  = PETSC_FALSE;
203f5766c09SAlp Dener 
204f5766c09SAlp Dener   PetscFunctionBegin;
2059566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, ""));
2063c859ba3SBarry Smith   PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms");
207f5766c09SAlp Dener   *B = bqnk->B;
208f5766c09SAlp Dener   PetscFunctionReturn(0);
209f5766c09SAlp Dener }
210414d97d3SAlp Dener 
211414d97d3SAlp Dener /*@
212414d97d3SAlp Dener    TaoSetLMVMMatrix - Sets an external LMVM matrix into the Tao solver. Valid
213414d97d3SAlp Dener    only for quasi-Newton family of methods.
214414d97d3SAlp Dener 
215414d97d3SAlp Dener    QN family of methods create their own LMVM matrices and users who wish to
216414d97d3SAlp Dener    manipulate this matrix should use TaoGetLMVMMatrix() instead.
217414d97d3SAlp Dener 
218414d97d3SAlp Dener    Input Parameters:
219414d97d3SAlp Dener +  tao - Tao solver context
220414d97d3SAlp Dener -  B - LMVM matrix
221414d97d3SAlp Dener 
222414d97d3SAlp Dener    Level: advanced
223414d97d3SAlp Dener 
224db781477SPatrick Sanan .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoGetLMVMMatrix()`
225414d97d3SAlp Dener @*/
226*9371c9d4SSatish Balay PetscErrorCode TaoSetLMVMMatrix(Tao tao, Mat B) {
227414d97d3SAlp Dener   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
228414d97d3SAlp Dener   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
229414d97d3SAlp Dener   PetscBool flg  = PETSC_FALSE;
230414d97d3SAlp Dener 
231414d97d3SAlp Dener   PetscFunctionBegin;
2329566063dSJacob Faibussowitsch   PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, ""));
2333c859ba3SBarry Smith   PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms");
2349566063dSJacob Faibussowitsch   PetscCall(PetscObjectBaseTypeCompare((PetscObject)B, MATLMVM, &flg));
2353c859ba3SBarry Smith   PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Given matrix is not an LMVM matrix");
236*9371c9d4SSatish Balay   if (bqnk->B) { PetscCall(MatDestroy(&bqnk->B)); }
2379566063dSJacob Faibussowitsch   PetscCall(PetscObjectReference((PetscObject)B));
238414d97d3SAlp Dener   bqnk->B = B;
239414d97d3SAlp Dener   PetscFunctionReturn(0);
240414d97d3SAlp Dener }
241