xref: /petsc/src/tao/bound/impls/bqnk/bqnk.c (revision 2cdf5ea42bccd4e651ec69c5d7cf37657be83b41)
1 #include <../src/tao/bound/impls/bqnk/bqnk.h> /*I "petsctao.h" I*/
2 #include <petscksp.h>
3 
4 static PetscErrorCode TaoBQNKComputeHessian(Tao tao)
5 {
6   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
7   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
8   PetscReal gnorm2, delta;
9 
10   PetscFunctionBegin;
11   /* Alias the LMVM matrix into the TAO hessian */
12   if (tao->hessian) PetscCall(MatDestroy(&tao->hessian));
13   if (tao->hessian_pre) PetscCall(MatDestroy(&tao->hessian_pre));
14   PetscCall(PetscObjectReference((PetscObject)bqnk->B));
15   tao->hessian = bqnk->B;
16   PetscCall(PetscObjectReference((PetscObject)bqnk->B));
17   tao->hessian_pre = bqnk->B;
18   /* Update the Hessian with the latest solution */
19   if (bqnk->is_spd) {
20     gnorm2 = bnk->gnorm * bnk->gnorm;
21     if (gnorm2 == 0.0) gnorm2 = PETSC_MACHINE_EPSILON;
22     if (bnk->f == 0.0) {
23       delta = 2.0 / gnorm2;
24     } else {
25       delta = 2.0 * PetscAbsScalar(bnk->f) / gnorm2;
26     }
27     PetscCall(MatLMVMSymBroydenSetDelta(bqnk->B, delta));
28   }
29   PetscCall(MatLMVMUpdate(tao->hessian, tao->solution, bnk->unprojected_gradient));
30   PetscCall(MatLMVMResetShift(tao->hessian));
31   /* Prepare the reduced sub-matrices for the inactive set */
32   PetscCall(MatDestroy(&bnk->H_inactive));
33   if (bnk->active_idx) {
34     PetscCall(MatCreateSubMatrixVirtual(tao->hessian, bnk->inactive_idx, bnk->inactive_idx, &bnk->H_inactive));
35     PetscCall(PCLMVMSetIS(bqnk->pc, bnk->inactive_idx));
36   } else {
37     PetscCall(PetscObjectReference((PetscObject)tao->hessian));
38     bnk->H_inactive = tao->hessian;
39     PetscCall(PCLMVMClearIS(bqnk->pc));
40   }
41   PetscCall(MatDestroy(&bnk->Hpre_inactive));
42   PetscCall(PetscObjectReference((PetscObject)bnk->H_inactive));
43   bnk->Hpre_inactive = bnk->H_inactive;
44   PetscFunctionReturn(PETSC_SUCCESS);
45 }
46 
47 static PetscErrorCode TaoBQNKComputeStep(Tao tao, PetscBool shift, KSPConvergedReason *ksp_reason, PetscInt *step_type)
48 {
49   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
50   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
51 
52   PetscFunctionBegin;
53   PetscCall(TaoBNKComputeStep(tao, shift, ksp_reason, step_type));
54   if (*ksp_reason < 0) {
55     /* Krylov solver failed to converge so reset the LMVM matrix */
56     PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE));
57     PetscCall(MatLMVMUpdate(bqnk->B, tao->solution, bnk->unprojected_gradient));
58   }
59   PetscFunctionReturn(PETSC_SUCCESS);
60 }
61 
62 PetscErrorCode TaoSolve_BQNK(Tao tao)
63 {
64   TAO_BNK     *bnk  = (TAO_BNK *)tao->data;
65   TAO_BQNK    *bqnk = (TAO_BQNK *)bnk->ctx;
66   Mat_LMVM    *lmvm = (Mat_LMVM *)bqnk->B->data;
67   Mat_LMVM    *J0;
68   Mat_SymBrdn *diag_ctx;
69   PetscBool    flg = PETSC_FALSE;
70 
71   PetscFunctionBegin;
72   if (!tao->recycle) {
73     PetscCall(MatLMVMReset(bqnk->B, PETSC_FALSE));
74     lmvm->nresets = 0;
75     if (lmvm->J0) {
76       PetscCall(PetscObjectBaseTypeCompare((PetscObject)lmvm->J0, MATLMVM, &flg));
77       if (flg) {
78         J0          = (Mat_LMVM *)lmvm->J0->data;
79         J0->nresets = 0;
80       }
81     }
82     flg = PETSC_FALSE;
83     PetscCall(PetscObjectTypeCompareAny((PetscObject)bqnk->B, &flg, MATLMVMSYMBROYDEN, MATLMVMSYMBADBROYDEN, MATLMVMBFGS, MATLMVMDFP, ""));
84     if (flg) {
85       diag_ctx    = (Mat_SymBrdn *)lmvm->ctx;
86       J0          = (Mat_LMVM *)diag_ctx->D->data;
87       J0->nresets = 0;
88     }
89   }
90   PetscCall((*bqnk->solve)(tao));
91   PetscFunctionReturn(PETSC_SUCCESS);
92 }
93 
94 PetscErrorCode TaoSetUp_BQNK(Tao tao)
95 {
96   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
97   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
98   PetscInt  n, N;
99   PetscBool is_lmvm, is_set, is_sym;
100 
101   PetscFunctionBegin;
102   PetscCall(TaoSetUp_BNK(tao));
103   PetscCall(VecGetLocalSize(tao->solution, &n));
104   PetscCall(VecGetSize(tao->solution, &N));
105   PetscCall(MatSetSizes(bqnk->B, n, n, N, N));
106   PetscCall(MatLMVMAllocate(bqnk->B, tao->solution, bnk->unprojected_gradient));
107   PetscCall(PetscObjectBaseTypeCompare((PetscObject)bqnk->B, MATLMVM, &is_lmvm));
108   PetscCheck(is_lmvm, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Matrix must be an LMVM-type");
109   PetscCall(MatIsSymmetricKnown(bqnk->B, &is_set, &is_sym));
110   PetscCheck(is_set && is_sym, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM matrix must be symmetric");
111   PetscCall(KSPGetPC(tao->ksp, &bqnk->pc));
112   PetscCall(PCSetType(bqnk->pc, PCLMVM));
113   PetscCall(PCLMVMSetMatLMVM(bqnk->pc, bqnk->B));
114   PetscFunctionReturn(PETSC_SUCCESS);
115 }
116 
117 static PetscErrorCode TaoSetFromOptions_BQNK(Tao tao, PetscOptionItems PetscOptionsObject)
118 {
119   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
120   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
121   PetscBool is_set;
122 
123   PetscFunctionBegin;
124   PetscCall(TaoSetFromOptions_BNK(tao, PetscOptionsObject));
125   if (bnk->init_type == BNK_INIT_INTERPOLATION) bnk->init_type = BNK_INIT_DIRECTION;
126   PetscCall(MatSetOptionsPrefix(bqnk->B, ((PetscObject)tao)->prefix));
127   PetscCall(MatAppendOptionsPrefix(bqnk->B, "tao_bqnk_"));
128   PetscCall(MatSetFromOptions(bqnk->B));
129   PetscCall(MatIsSPDKnown(bqnk->B, &is_set, &bqnk->is_spd));
130   if (!is_set) bqnk->is_spd = PETSC_FALSE;
131   PetscFunctionReturn(PETSC_SUCCESS);
132 }
133 
134 static PetscErrorCode TaoView_BQNK(Tao tao, PetscViewer viewer)
135 {
136   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
137   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
138   PetscBool isascii;
139 
140   PetscFunctionBegin;
141   PetscCall(TaoView_BNK(tao, viewer));
142   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
143   if (isascii) {
144     PetscCall(PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO));
145     PetscCall(MatView(bqnk->B, viewer));
146     PetscCall(PetscViewerPopFormat(viewer));
147   }
148   PetscFunctionReturn(PETSC_SUCCESS);
149 }
150 
151 static PetscErrorCode TaoDestroy_BQNK(Tao tao)
152 {
153   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
154   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
155 
156   PetscFunctionBegin;
157   PetscCall(MatDestroy(&bnk->Hpre_inactive));
158   PetscCall(MatDestroy(&bnk->H_inactive));
159   PetscCall(MatDestroy(&bqnk->B));
160   PetscCall(PetscFree(bnk->ctx));
161   PetscCall(TaoDestroy_BNK(tao));
162   PetscFunctionReturn(PETSC_SUCCESS);
163 }
164 
165 PETSC_INTERN PetscErrorCode TaoCreate_BQNK(Tao tao)
166 {
167   TAO_BNK  *bnk;
168   TAO_BQNK *bqnk;
169 
170   PetscFunctionBegin;
171   PetscCall(TaoCreate_BNK(tao));
172   tao->ops->solve          = TaoSolve_BQNK;
173   tao->ops->setfromoptions = TaoSetFromOptions_BQNK;
174   tao->ops->destroy        = TaoDestroy_BQNK;
175   tao->ops->view           = TaoView_BQNK;
176   tao->ops->setup          = TaoSetUp_BQNK;
177 
178   bnk                 = (TAO_BNK *)tao->data;
179   bnk->computehessian = TaoBQNKComputeHessian;
180   bnk->computestep    = TaoBQNKComputeStep;
181   bnk->init_type      = BNK_INIT_DIRECTION;
182 
183   PetscCall(PetscNew(&bqnk));
184   bnk->ctx     = (void *)bqnk;
185   bqnk->is_spd = PETSC_TRUE;
186 
187   PetscCall(MatCreate(PetscObjectComm((PetscObject)tao), &bqnk->B));
188   PetscCall(PetscObjectIncrementTabLevel((PetscObject)bqnk->B, (PetscObject)tao, 1));
189   PetscCall(MatSetType(bqnk->B, MATLMVMSR1));
190   PetscFunctionReturn(PETSC_SUCCESS);
191 }
192 
193 /*@
194   TaoGetLMVMMatrix - Returns a pointer to the internal LMVM matrix. Valid
195   only for quasi-Newton family of methods.
196 
197   Input Parameter:
198 . tao - `Tao` solver context
199 
200   Output Parameter:
201 . B - LMVM matrix
202 
203   Level: advanced
204 
205 .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoSetLMVMMatrix()`
206 @*/
207 PetscErrorCode TaoGetLMVMMatrix(Tao tao, Mat *B)
208 {
209   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
210   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
211   PetscBool flg  = PETSC_FALSE;
212 
213   PetscFunctionBegin;
214   PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, ""));
215   PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms");
216   *B = bqnk->B;
217   PetscFunctionReturn(PETSC_SUCCESS);
218 }
219 
220 /*@
221   TaoSetLMVMMatrix - Sets an external LMVM matrix into the Tao solver. Valid
222   only for quasi-Newton family of methods.
223 
224   QN family of methods create their own LMVM matrices and users who wish to
225   manipulate this matrix should use TaoGetLMVMMatrix() instead.
226 
227   Input Parameters:
228 + tao - Tao solver context
229 - B   - LMVM matrix
230 
231   Level: advanced
232 
233 .seealso: `TAOBQNLS`, `TAOBQNKLS`, `TAOBQNKTL`, `TAOBQNKTR`, `MATLMVM`, `TaoGetLMVMMatrix()`
234 @*/
235 PetscErrorCode TaoSetLMVMMatrix(Tao tao, Mat B)
236 {
237   TAO_BNK  *bnk  = (TAO_BNK *)tao->data;
238   TAO_BQNK *bqnk = (TAO_BQNK *)bnk->ctx;
239   PetscBool flg  = PETSC_FALSE;
240 
241   PetscFunctionBegin;
242   PetscCall(PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, ""));
243   PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "LMVM Matrix only exists for quasi-Newton algorithms");
244   PetscCall(PetscObjectBaseTypeCompare((PetscObject)B, MATLMVM, &flg));
245   PetscCheck(flg, PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_INCOMP, "Given matrix is not an LMVM matrix");
246   if (bqnk->B) PetscCall(MatDestroy(&bqnk->B));
247   PetscCall(PetscObjectReference((PetscObject)B));
248   bqnk->B = B;
249   PetscFunctionReturn(PETSC_SUCCESS);
250 }
251