xref: /petsc/src/ksp/ksp/utils/lmvm/brdn/badbrdn.c (revision ff6450cf88c38d9373e3b4bf5fb1660da3ed7cfa)
1 #include <../src/ksp/ksp/utils/lmvm/brdn/brdn.h> /*I "petscksp.h" I*/
2 
3 /*------------------------------------------------------------*/
4 
5 /*
6   The solution method is the matrix-free implementation of the inverse Hessian in
7   Equation 6 on page 312 of Griewank "Broyden Updating, The Good and The Bad!"
8   (http://www.emis.ams.org/journals/DMJDMV/vol-ismp/45_griewank-andreas-broyden.pdf).
9 
10   Q[i] = (B_i)^{-1}*S[i] terms are computed ahead of time whenever
11   the matrix is updated with a new (S[i], Y[i]) pair. This allows
12   repeated calls of MatSolve without incurring redundant computation.
13 
14   dX <- J0^{-1} * F
15 
16   for i=0,1,2,...,k
17     # Q[i] = (B_i)^{-1} * Y[i]
18     tau = (Y[i]^T F) / (Y[i]^T Y[i])
19     dX <- dX + (tau * (S[i] - Q[i]))
20   end
21  */
22 
23 static PetscErrorCode MatSolve_LMVMBadBrdn(Mat B, Vec F, Vec dX) {
24   Mat_LMVM   *lmvm = (Mat_LMVM *)B->data;
25   Mat_Brdn   *lbb  = (Mat_Brdn *)lmvm->ctx;
26   PetscInt    i, j;
27   PetscScalar yjtyi, ytf;
28 
29   PetscFunctionBegin;
30   VecCheckSameSize(F, 2, dX, 3);
31   VecCheckMatCompatible(B, dX, 3, F, 2);
32 
33   if (lbb->needQ) {
34     /* Pre-compute (Q[i] = (B_i)^{-1} * Y[i]) */
35     for (i = 0; i <= lmvm->k; ++i) {
36       PetscCall(MatLMVMApplyJ0Inv(B, lmvm->Y[i], lbb->Q[i]));
37       for (j = 0; j <= i - 1; ++j) {
38         PetscCall(VecDot(lmvm->Y[j], lmvm->Y[i], &yjtyi));
39         PetscCall(VecAXPBYPCZ(lbb->Q[i], PetscRealPart(yjtyi) / lbb->yty[j], -PetscRealPart(yjtyi) / lbb->yty[j], 1.0, lmvm->S[j], lbb->Q[j]));
40       }
41     }
42     lbb->needQ = PETSC_FALSE;
43   }
44 
45   PetscCall(MatLMVMApplyJ0Inv(B, F, dX));
46   for (i = 0; i <= lmvm->k; ++i) {
47     PetscCall(VecDot(lmvm->Y[i], F, &ytf));
48     PetscCall(VecAXPBYPCZ(dX, PetscRealPart(ytf) / lbb->yty[i], -PetscRealPart(ytf) / lbb->yty[i], 1.0, lmvm->S[i], lbb->Q[i]));
49   }
50   PetscFunctionReturn(0);
51 }
52 
53 /*------------------------------------------------------------*/
54 
55 /*
56   The forward product is the matrix-free implementation of the direct update in
57   Equation 6 on page 302 of Griewank "Broyden Updating, The Good and The Bad!"
58   (http://www.emis.ams.org/journals/DMJDMV/vol-ismp/45_griewank-andreas-broyden.pdf).
59 
60   P[i] = (B_i)*S[i] terms are computed ahead of time whenever
61   the matrix is updated with a new (S[i], Y[i]) pair. This allows
62   repeated calls of MatMult inside KSP solvers without unnecessarily
63   recomputing P[i] terms in expensive nested-loops.
64 
65   Z <- J0 * X
66 
67   for i=0,1,2,...,k
68     # P[i] = B_i * S[i]
69     tau = (Y[i]^T X) / (Y[i]^T S[i])
70     dX <- dX + (tau * (Y[i] - P[i]))
71   end
72  */
73 
74 static PetscErrorCode MatMult_LMVMBadBrdn(Mat B, Vec X, Vec Z) {
75   Mat_LMVM   *lmvm = (Mat_LMVM *)B->data;
76   Mat_Brdn   *lbb  = (Mat_Brdn *)lmvm->ctx;
77   PetscInt    i, j;
78   PetscScalar yjtsi, ytx;
79 
80   PetscFunctionBegin;
81   VecCheckSameSize(X, 2, Z, 3);
82   VecCheckMatCompatible(B, X, 2, Z, 3);
83 
84   if (lbb->needP) {
85     /* Pre-compute (P[i] = (B_i) * S[i]) */
86     for (i = 0; i <= lmvm->k; ++i) {
87       PetscCall(MatLMVMApplyJ0Fwd(B, lmvm->S[i], lbb->P[i]));
88       for (j = 0; j <= i - 1; ++j) {
89         PetscCall(VecDot(lmvm->Y[j], lmvm->S[i], &yjtsi));
90         PetscCall(VecAXPBYPCZ(lbb->P[i], PetscRealPart(yjtsi) / lbb->yts[j], -PetscRealPart(yjtsi) / lbb->yts[j], 1.0, lmvm->Y[j], lbb->P[j]));
91       }
92     }
93     lbb->needP = PETSC_FALSE;
94   }
95 
96   PetscCall(MatLMVMApplyJ0Fwd(B, X, Z));
97   for (i = 0; i <= lmvm->k; ++i) {
98     PetscCall(VecDot(lmvm->Y[i], X, &ytx));
99     PetscCall(VecAXPBYPCZ(Z, PetscRealPart(ytx) / lbb->yts[i], -PetscRealPart(ytx) / lbb->yts[i], 1.0, lmvm->Y[i], lbb->P[i]));
100   }
101   PetscFunctionReturn(0);
102 }
103 
104 /*------------------------------------------------------------*/
105 
106 static PetscErrorCode MatUpdate_LMVMBadBrdn(Mat B, Vec X, Vec F) {
107   Mat_LMVM   *lmvm = (Mat_LMVM *)B->data;
108   Mat_Brdn   *lbb  = (Mat_Brdn *)lmvm->ctx;
109   PetscInt    old_k, i;
110   PetscScalar yty, yts;
111 
112   PetscFunctionBegin;
113   if (!lmvm->m) PetscFunctionReturn(0);
114   if (lmvm->prev_set) {
115     /* Compute the new (S = X - Xprev) and (Y = F - Fprev) vectors */
116     PetscCall(VecAYPX(lmvm->Xprev, -1.0, X));
117     PetscCall(VecAYPX(lmvm->Fprev, -1.0, F));
118     /* Accept the update */
119     lbb->needP = lbb->needQ = PETSC_TRUE;
120     old_k                   = lmvm->k;
121     PetscCall(MatUpdateKernel_LMVM(B, lmvm->Xprev, lmvm->Fprev));
122     /* If we hit the memory limit, shift the yty and yts arrays */
123     if (old_k == lmvm->k) {
124       for (i = 0; i <= lmvm->k - 1; ++i) {
125         lbb->yty[i] = lbb->yty[i + 1];
126         lbb->yts[i] = lbb->yts[i + 1];
127       }
128     }
129     /* Accumulate the latest yTy and yTs dot products */
130     PetscCall(VecDotBegin(lmvm->Y[lmvm->k], lmvm->Y[lmvm->k], &yty));
131     PetscCall(VecDotBegin(lmvm->Y[lmvm->k], lmvm->S[lmvm->k], &yts));
132     PetscCall(VecDotEnd(lmvm->Y[lmvm->k], lmvm->Y[lmvm->k], &yty));
133     PetscCall(VecDotEnd(lmvm->Y[lmvm->k], lmvm->S[lmvm->k], &yts));
134     lbb->yty[lmvm->k] = PetscRealPart(yty);
135     lbb->yts[lmvm->k] = PetscRealPart(yts);
136   }
137   /* Save the solution and function to be used in the next update */
138   PetscCall(VecCopy(X, lmvm->Xprev));
139   PetscCall(VecCopy(F, lmvm->Fprev));
140   lmvm->prev_set = PETSC_TRUE;
141   PetscFunctionReturn(0);
142 }
143 
144 /*------------------------------------------------------------*/
145 
146 static PetscErrorCode MatCopy_LMVMBadBrdn(Mat B, Mat M, MatStructure str) {
147   Mat_LMVM *bdata = (Mat_LMVM *)B->data;
148   Mat_Brdn *bctx  = (Mat_Brdn *)bdata->ctx;
149   Mat_LMVM *mdata = (Mat_LMVM *)M->data;
150   Mat_Brdn *mctx  = (Mat_Brdn *)mdata->ctx;
151   PetscInt  i;
152 
153   PetscFunctionBegin;
154   mctx->needP = bctx->needP;
155   mctx->needQ = bctx->needQ;
156   for (i = 0; i <= bdata->k; ++i) {
157     mctx->yty[i] = bctx->yty[i];
158     mctx->yts[i] = bctx->yts[i];
159     PetscCall(VecCopy(bctx->P[i], mctx->P[i]));
160     PetscCall(VecCopy(bctx->Q[i], mctx->Q[i]));
161   }
162   PetscFunctionReturn(0);
163 }
164 
165 /*------------------------------------------------------------*/
166 
167 static PetscErrorCode MatReset_LMVMBadBrdn(Mat B, PetscBool destructive) {
168   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
169   Mat_Brdn *lbb  = (Mat_Brdn *)lmvm->ctx;
170 
171   PetscFunctionBegin;
172   lbb->needP = lbb->needQ = PETSC_TRUE;
173   if (destructive && lbb->allocated) {
174     PetscCall(PetscFree2(lbb->yty, lbb->yts));
175     PetscCall(VecDestroyVecs(lmvm->m, &lbb->P));
176     PetscCall(VecDestroyVecs(lmvm->m, &lbb->Q));
177     lbb->allocated = PETSC_FALSE;
178   }
179   PetscCall(MatReset_LMVM(B, destructive));
180   PetscFunctionReturn(0);
181 }
182 
183 /*------------------------------------------------------------*/
184 
185 static PetscErrorCode MatAllocate_LMVMBadBrdn(Mat B, Vec X, Vec F) {
186   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
187   Mat_Brdn *lbb  = (Mat_Brdn *)lmvm->ctx;
188 
189   PetscFunctionBegin;
190   PetscCall(MatAllocate_LMVM(B, X, F));
191   if (!lbb->allocated) {
192     PetscCall(PetscMalloc2(lmvm->m, &lbb->yty, lmvm->m, &lbb->yts));
193     if (lmvm->m > 0) {
194       PetscCall(VecDuplicateVecs(X, lmvm->m, &lbb->P));
195       PetscCall(VecDuplicateVecs(X, lmvm->m, &lbb->Q));
196     }
197     lbb->allocated = PETSC_TRUE;
198   }
199   PetscFunctionReturn(0);
200 }
201 
202 /*------------------------------------------------------------*/
203 
204 static PetscErrorCode MatDestroy_LMVMBadBrdn(Mat B) {
205   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
206   Mat_Brdn *lbb  = (Mat_Brdn *)lmvm->ctx;
207 
208   PetscFunctionBegin;
209   if (lbb->allocated) {
210     PetscCall(PetscFree2(lbb->yty, lbb->yts));
211     PetscCall(VecDestroyVecs(lmvm->m, &lbb->P));
212     PetscCall(VecDestroyVecs(lmvm->m, &lbb->Q));
213     lbb->allocated = PETSC_FALSE;
214   }
215   PetscCall(PetscFree(lmvm->ctx));
216   PetscCall(MatDestroy_LMVM(B));
217   PetscFunctionReturn(0);
218 }
219 
220 /*------------------------------------------------------------*/
221 
222 static PetscErrorCode MatSetUp_LMVMBadBrdn(Mat B) {
223   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
224   Mat_Brdn *lbb  = (Mat_Brdn *)lmvm->ctx;
225 
226   PetscFunctionBegin;
227   PetscCall(MatSetUp_LMVM(B));
228   if (!lbb->allocated) {
229     PetscCall(PetscMalloc2(lmvm->m, &lbb->yty, lmvm->m, &lbb->yts));
230     if (lmvm->m > 0) {
231       PetscCall(VecDuplicateVecs(lmvm->Xprev, lmvm->m, &lbb->P));
232       PetscCall(VecDuplicateVecs(lmvm->Xprev, lmvm->m, &lbb->Q));
233     }
234     lbb->allocated = PETSC_TRUE;
235   }
236   PetscFunctionReturn(0);
237 }
238 
239 /*------------------------------------------------------------*/
240 
241 PetscErrorCode MatCreate_LMVMBadBrdn(Mat B) {
242   Mat_LMVM *lmvm;
243   Mat_Brdn *lbb;
244 
245   PetscFunctionBegin;
246   PetscCall(MatCreate_LMVM(B));
247   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVMBADBROYDEN));
248   B->ops->setup   = MatSetUp_LMVMBadBrdn;
249   B->ops->destroy = MatDestroy_LMVMBadBrdn;
250   B->ops->solve   = MatSolve_LMVMBadBrdn;
251 
252   lmvm                = (Mat_LMVM *)B->data;
253   lmvm->square        = PETSC_TRUE;
254   lmvm->ops->allocate = MatAllocate_LMVMBadBrdn;
255   lmvm->ops->reset    = MatReset_LMVMBadBrdn;
256   lmvm->ops->mult     = MatMult_LMVMBadBrdn;
257   lmvm->ops->update   = MatUpdate_LMVMBadBrdn;
258   lmvm->ops->copy     = MatCopy_LMVMBadBrdn;
259 
260   PetscCall(PetscNew(&lbb));
261   lmvm->ctx      = (void *)lbb;
262   lbb->allocated = PETSC_FALSE;
263   lbb->needP = lbb->needQ = PETSC_TRUE;
264   PetscFunctionReturn(0);
265 }
266 
267 /*------------------------------------------------------------*/
268 
269 /*@
270    MatCreateLMVMBadBroyden - Creates a limited-memory modified (aka "bad") Broyden-type
271    approximation matrix used for a Jacobian. L-BadBrdn is not guaranteed to be
272    symmetric or positive-definite.
273 
274    The provided local and global sizes must match the solution and function vectors
275    used with MatLMVMUpdate() and MatSolve(). The resulting L-BadBrdn matrix will have
276    storage vectors allocated with VecCreateSeq() in serial and VecCreateMPI() in
277    parallel. To use the L-BadBrdn matrix with other vector types, the matrix must be
278    created using MatCreate() and MatSetType(), followed by MatLMVMAllocate().
279    This ensures that the internal storage and work vectors are duplicated from the
280    correct type of vector.
281 
282    Collective
283 
284    Input Parameters:
285 +  comm - MPI communicator, set to PETSC_COMM_SELF
286 .  n - number of local rows for storage vectors
287 -  N - global size of the storage vectors
288 
289    Output Parameter:
290 .  B - the matrix
291 
292    It is recommended that one use the MatCreate(), MatSetType() and/or MatSetFromOptions()
293    paradigm instead of this routine directly.
294 
295    Options Database Keys:
296 +   -mat_lmvm_scale_type - (developer) type of scaling applied to J0 (none, scalar, diagonal)
297 .   -mat_lmvm_theta - (developer) convex ratio between BFGS and DFP components of the diagonal J0 scaling
298 .   -mat_lmvm_rho - (developer) update limiter for the J0 scaling
299 .   -mat_lmvm_alpha - (developer) coefficient factor for the quadratic subproblem in J0 scaling
300 .   -mat_lmvm_beta - (developer) exponential factor for the diagonal J0 scaling
301 -   -mat_lmvm_sigma_hist - (developer) number of past updates to use in J0 scaling
302 
303    Level: intermediate
304 
305 .seealso: `MatCreate()`, `MATLMVM`, `MATLMVMBADBRDN`, `MatCreateLMVMDFP()`, `MatCreateLMVMSR1()`,
306           `MatCreateLMVMBFGS()`, `MatCreateLMVMBrdn()`, `MatCreateLMVMSymBrdn()`
307 @*/
308 PetscErrorCode MatCreateLMVMBadBroyden(MPI_Comm comm, PetscInt n, PetscInt N, Mat *B) {
309   PetscFunctionBegin;
310   PetscCall(MatCreate(comm, B));
311   PetscCall(MatSetSizes(*B, n, n, N, N));
312   PetscCall(MatSetType(*B, MATLMVMBADBROYDEN));
313   PetscCall(MatSetUp(*B));
314   PetscFunctionReturn(0);
315 }
316