xref: /petsc/src/ksp/ksp/utils/lmvm/brdn/badbrdn.c (revision 3307d110e72ee4e6d2468971620073eb5ff93529)
1 #include <../src/ksp/ksp/utils/lmvm/brdn/brdn.h> /*I "petscksp.h" I*/
2 
3 /*------------------------------------------------------------*/
4 
5 /*
6   The solution method is the matrix-free implementation of the inverse Hessian in
7   Equation 6 on page 312 of Griewank "Broyden Updating, The Good and The Bad!"
8   (http://www.emis.ams.org/journals/DMJDMV/vol-ismp/45_griewank-andreas-broyden.pdf).
9 
10   Q[i] = (B_i)^{-1}*S[i] terms are computed ahead of time whenever
11   the matrix is updated with a new (S[i], Y[i]) pair. This allows
12   repeated calls of MatSolve without incurring redundant computation.
13 
14   dX <- J0^{-1} * F
15 
16   for i=0,1,2,...,k
17     # Q[i] = (B_i)^{-1} * Y[i]
18     tau = (Y[i]^T F) / (Y[i]^T Y[i])
19     dX <- dX + (tau * (S[i] - Q[i]))
20   end
21  */
22 
23 static PetscErrorCode MatSolve_LMVMBadBrdn(Mat B, Vec F, Vec dX)
24 {
25   Mat_LMVM          *lmvm = (Mat_LMVM*)B->data;
26   Mat_Brdn          *lbb = (Mat_Brdn*)lmvm->ctx;
27   PetscInt          i, j;
28   PetscScalar       yjtyi, ytf;
29 
30   PetscFunctionBegin;
31   VecCheckSameSize(F, 2, dX, 3);
32   VecCheckMatCompatible(B, dX, 3, F, 2);
33 
34   if (lbb->needQ) {
35     /* Pre-compute (Q[i] = (B_i)^{-1} * Y[i]) */
36     for (i = 0; i <= lmvm->k; ++i) {
37       PetscCall(MatLMVMApplyJ0Inv(B, lmvm->Y[i], lbb->Q[i]));
38       for (j = 0; j <= i-1; ++j) {
39         PetscCall(VecDot(lmvm->Y[j], lmvm->Y[i], &yjtyi));
40         PetscCall(VecAXPBYPCZ(lbb->Q[i], PetscRealPart(yjtyi)/lbb->yty[j], -PetscRealPart(yjtyi)/lbb->yty[j], 1.0, lmvm->S[j], lbb->Q[j]));
41       }
42     }
43     lbb->needQ = PETSC_FALSE;
44   }
45 
46   PetscCall(MatLMVMApplyJ0Inv(B, F, dX));
47   for (i = 0; i <= lmvm->k; ++i) {
48     PetscCall(VecDot(lmvm->Y[i], F, &ytf));
49     PetscCall(VecAXPBYPCZ(dX, PetscRealPart(ytf)/lbb->yty[i], -PetscRealPart(ytf)/lbb->yty[i], 1.0, lmvm->S[i], lbb->Q[i]));
50   }
51   PetscFunctionReturn(0);
52 }
53 
54 /*------------------------------------------------------------*/
55 
56 /*
57   The forward product is the matrix-free implementation of the direct update in
58   Equation 6 on page 302 of Griewank "Broyden Updating, The Good and The Bad!"
59   (http://www.emis.ams.org/journals/DMJDMV/vol-ismp/45_griewank-andreas-broyden.pdf).
60 
61   P[i] = (B_i)*S[i] terms are computed ahead of time whenever
62   the matrix is updated with a new (S[i], Y[i]) pair. This allows
63   repeated calls of MatMult inside KSP solvers without unnecessarily
64   recomputing P[i] terms in expensive nested-loops.
65 
66   Z <- J0 * X
67 
68   for i=0,1,2,...,k
69     # P[i] = B_i * S[i]
70     tau = (Y[i]^T X) / (Y[i]^T S[i])
71     dX <- dX + (tau * (Y[i] - P[i]))
72   end
73  */
74 
75 static PetscErrorCode MatMult_LMVMBadBrdn(Mat B, Vec X, Vec Z)
76 {
77   Mat_LMVM          *lmvm = (Mat_LMVM*)B->data;
78   Mat_Brdn          *lbb = (Mat_Brdn*)lmvm->ctx;
79   PetscInt          i, j;
80   PetscScalar       yjtsi, ytx;
81 
82   PetscFunctionBegin;
83   VecCheckSameSize(X, 2, Z, 3);
84   VecCheckMatCompatible(B, X, 2, Z, 3);
85 
86   if (lbb->needP) {
87     /* Pre-compute (P[i] = (B_i) * S[i]) */
88     for (i = 0; i <= lmvm->k; ++i) {
89       PetscCall(MatLMVMApplyJ0Fwd(B, lmvm->S[i], lbb->P[i]));
90       for (j = 0; j <= i-1; ++j) {
91         PetscCall(VecDot(lmvm->Y[j], lmvm->S[i], &yjtsi));
92         PetscCall(VecAXPBYPCZ(lbb->P[i], PetscRealPart(yjtsi)/lbb->yts[j], -PetscRealPart(yjtsi)/lbb->yts[j], 1.0, lmvm->Y[j], lbb->P[j]));
93       }
94     }
95     lbb->needP = PETSC_FALSE;
96   }
97 
98   PetscCall(MatLMVMApplyJ0Fwd(B, X, Z));
99   for (i = 0; i <= lmvm->k; ++i) {
100     PetscCall(VecDot(lmvm->Y[i], X, &ytx));
101     PetscCall(VecAXPBYPCZ(Z, PetscRealPart(ytx)/lbb->yts[i], -PetscRealPart(ytx)/lbb->yts[i], 1.0, lmvm->Y[i], lbb->P[i]));
102   }
103   PetscFunctionReturn(0);
104 }
105 
106 /*------------------------------------------------------------*/
107 
108 static PetscErrorCode MatUpdate_LMVMBadBrdn(Mat B, Vec X, Vec F)
109 {
110   Mat_LMVM          *lmvm = (Mat_LMVM*)B->data;
111   Mat_Brdn          *lbb = (Mat_Brdn*)lmvm->ctx;
112   PetscInt          old_k, i;
113   PetscScalar       yty, yts;
114 
115   PetscFunctionBegin;
116   if (!lmvm->m) PetscFunctionReturn(0);
117   if (lmvm->prev_set) {
118     /* Compute the new (S = X - Xprev) and (Y = F - Fprev) vectors */
119     PetscCall(VecAYPX(lmvm->Xprev, -1.0, X));
120     PetscCall(VecAYPX(lmvm->Fprev, -1.0, F));
121     /* Accept the update */
122     lbb->needP = lbb->needQ = PETSC_TRUE;
123     old_k = lmvm->k;
124     PetscCall(MatUpdateKernel_LMVM(B, lmvm->Xprev, lmvm->Fprev));
125     /* If we hit the memory limit, shift the yty and yts arrays */
126     if (old_k == lmvm->k) {
127       for (i = 0; i <= lmvm->k-1; ++i) {
128         lbb->yty[i] = lbb->yty[i+1];
129         lbb->yts[i] = lbb->yts[i+1];
130       }
131     }
132     /* Accumulate the latest yTy and yTs dot products */
133     PetscCall(VecDotBegin(lmvm->Y[lmvm->k], lmvm->Y[lmvm->k], &yty));
134     PetscCall(VecDotBegin(lmvm->Y[lmvm->k], lmvm->S[lmvm->k], &yts));
135     PetscCall(VecDotEnd(lmvm->Y[lmvm->k], lmvm->Y[lmvm->k], &yty));
136     PetscCall(VecDotEnd(lmvm->Y[lmvm->k], lmvm->S[lmvm->k], &yts));
137     lbb->yty[lmvm->k] = PetscRealPart(yty);
138     lbb->yts[lmvm->k] = PetscRealPart(yts);
139   }
140   /* Save the solution and function to be used in the next update */
141   PetscCall(VecCopy(X, lmvm->Xprev));
142   PetscCall(VecCopy(F, lmvm->Fprev));
143   lmvm->prev_set = PETSC_TRUE;
144   PetscFunctionReturn(0);
145 }
146 
147 /*------------------------------------------------------------*/
148 
149 static PetscErrorCode MatCopy_LMVMBadBrdn(Mat B, Mat M, MatStructure str)
150 {
151   Mat_LMVM          *bdata = (Mat_LMVM*)B->data;
152   Mat_Brdn          *bctx = (Mat_Brdn*)bdata->ctx;
153   Mat_LMVM          *mdata = (Mat_LMVM*)M->data;
154   Mat_Brdn          *mctx = (Mat_Brdn*)mdata->ctx;
155   PetscInt          i;
156 
157   PetscFunctionBegin;
158   mctx->needP = bctx->needP;
159   mctx->needQ = bctx->needQ;
160   for (i=0; i<=bdata->k; ++i) {
161     mctx->yty[i] = bctx->yty[i];
162     mctx->yts[i] = bctx->yts[i];
163     PetscCall(VecCopy(bctx->P[i], mctx->P[i]));
164     PetscCall(VecCopy(bctx->Q[i], mctx->Q[i]));
165   }
166   PetscFunctionReturn(0);
167 }
168 
169 /*------------------------------------------------------------*/
170 
171 static PetscErrorCode MatReset_LMVMBadBrdn(Mat B, PetscBool destructive)
172 {
173   Mat_LMVM          *lmvm = (Mat_LMVM*)B->data;
174   Mat_Brdn          *lbb = (Mat_Brdn*)lmvm->ctx;
175 
176   PetscFunctionBegin;
177   lbb->needP = lbb->needQ = PETSC_TRUE;
178   if (destructive && lbb->allocated) {
179     PetscCall(PetscFree2(lbb->yty, lbb->yts));
180     PetscCall(VecDestroyVecs(lmvm->m, &lbb->P));
181     PetscCall(VecDestroyVecs(lmvm->m, &lbb->Q));
182     lbb->allocated = PETSC_FALSE;
183   }
184   PetscCall(MatReset_LMVM(B, destructive));
185   PetscFunctionReturn(0);
186 }
187 
188 /*------------------------------------------------------------*/
189 
190 static PetscErrorCode MatAllocate_LMVMBadBrdn(Mat B, Vec X, Vec F)
191 {
192   Mat_LMVM          *lmvm = (Mat_LMVM*)B->data;
193   Mat_Brdn          *lbb = (Mat_Brdn*)lmvm->ctx;
194 
195   PetscFunctionBegin;
196   PetscCall(MatAllocate_LMVM(B, X, F));
197   if (!lbb->allocated) {
198     PetscCall(PetscMalloc2(lmvm->m, &lbb->yty, lmvm->m, &lbb->yts));
199     if (lmvm->m > 0) {
200       PetscCall(VecDuplicateVecs(X, lmvm->m, &lbb->P));
201       PetscCall(VecDuplicateVecs(X, lmvm->m, &lbb->Q));
202     }
203     lbb->allocated = PETSC_TRUE;
204   }
205   PetscFunctionReturn(0);
206 }
207 
208 /*------------------------------------------------------------*/
209 
210 static PetscErrorCode MatDestroy_LMVMBadBrdn(Mat B)
211 {
212   Mat_LMVM          *lmvm = (Mat_LMVM*)B->data;
213   Mat_Brdn          *lbb = (Mat_Brdn*)lmvm->ctx;
214 
215   PetscFunctionBegin;
216   if (lbb->allocated) {
217     PetscCall(PetscFree2(lbb->yty, lbb->yts));
218     PetscCall(VecDestroyVecs(lmvm->m, &lbb->P));
219     PetscCall(VecDestroyVecs(lmvm->m, &lbb->Q));
220     lbb->allocated = PETSC_FALSE;
221   }
222   PetscCall(PetscFree(lmvm->ctx));
223   PetscCall(MatDestroy_LMVM(B));
224   PetscFunctionReturn(0);
225 }
226 
227 /*------------------------------------------------------------*/
228 
229 static PetscErrorCode MatSetUp_LMVMBadBrdn(Mat B)
230 {
231   Mat_LMVM          *lmvm = (Mat_LMVM*)B->data;
232   Mat_Brdn          *lbb = (Mat_Brdn*)lmvm->ctx;
233 
234   PetscFunctionBegin;
235   PetscCall(MatSetUp_LMVM(B));
236   if (!lbb->allocated) {
237     PetscCall(PetscMalloc2(lmvm->m, &lbb->yty, lmvm->m, &lbb->yts));
238     if (lmvm->m > 0) {
239       PetscCall(VecDuplicateVecs(lmvm->Xprev, lmvm->m, &lbb->P));
240       PetscCall(VecDuplicateVecs(lmvm->Xprev, lmvm->m, &lbb->Q));
241     }
242     lbb->allocated = PETSC_TRUE;
243   }
244   PetscFunctionReturn(0);
245 }
246 
247 /*------------------------------------------------------------*/
248 
249 PetscErrorCode MatCreate_LMVMBadBrdn(Mat B)
250 {
251   Mat_LMVM          *lmvm;
252   Mat_Brdn          *lbb;
253 
254   PetscFunctionBegin;
255   PetscCall(MatCreate_LMVM(B));
256   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVMBADBROYDEN));
257   B->ops->setup = MatSetUp_LMVMBadBrdn;
258   B->ops->destroy = MatDestroy_LMVMBadBrdn;
259   B->ops->solve = MatSolve_LMVMBadBrdn;
260 
261   lmvm = (Mat_LMVM*)B->data;
262   lmvm->square = PETSC_TRUE;
263   lmvm->ops->allocate = MatAllocate_LMVMBadBrdn;
264   lmvm->ops->reset = MatReset_LMVMBadBrdn;
265   lmvm->ops->mult = MatMult_LMVMBadBrdn;
266   lmvm->ops->update = MatUpdate_LMVMBadBrdn;
267   lmvm->ops->copy = MatCopy_LMVMBadBrdn;
268 
269   PetscCall(PetscNewLog(B, &lbb));
270   lmvm->ctx = (void*)lbb;
271   lbb->allocated = PETSC_FALSE;
272   lbb->needP = lbb->needQ = PETSC_TRUE;
273   PetscFunctionReturn(0);
274 }
275 
276 /*------------------------------------------------------------*/
277 
278 /*@
279    MatCreateLMVMBadBroyden - Creates a limited-memory modified (aka "bad") Broyden-type
280    approximation matrix used for a Jacobian. L-BadBrdn is not guaranteed to be
281    symmetric or positive-definite.
282 
283    The provided local and global sizes must match the solution and function vectors
284    used with MatLMVMUpdate() and MatSolve(). The resulting L-BadBrdn matrix will have
285    storage vectors allocated with VecCreateSeq() in serial and VecCreateMPI() in
286    parallel. To use the L-BadBrdn matrix with other vector types, the matrix must be
287    created using MatCreate() and MatSetType(), followed by MatLMVMAllocate().
288    This ensures that the internal storage and work vectors are duplicated from the
289    correct type of vector.
290 
291    Collective
292 
293    Input Parameters:
294 +  comm - MPI communicator, set to PETSC_COMM_SELF
295 .  n - number of local rows for storage vectors
296 -  N - global size of the storage vectors
297 
298    Output Parameter:
299 .  B - the matrix
300 
301    It is recommended that one use the MatCreate(), MatSetType() and/or MatSetFromOptions()
302    paradigm instead of this routine directly.
303 
304    Options Database Keys:
305 +   -mat_lmvm_num_vecs - maximum number of correction vectors (i.e.: updates) stored
306 .   -mat_lmvm_scale_type - (developer) type of scaling applied to J0 (none, scalar, diagonal)
307 .   -mat_lmvm_theta - (developer) convex ratio between BFGS and DFP components of the diagonal J0 scaling
308 .   -mat_lmvm_rho - (developer) update limiter for the J0 scaling
309 .   -mat_lmvm_alpha - (developer) coefficient factor for the quadratic subproblem in J0 scaling
310 .   -mat_lmvm_beta - (developer) exponential factor for the diagonal J0 scaling
311 -   -mat_lmvm_sigma_hist - (developer) number of past updates to use in J0 scaling
312 
313    Level: intermediate
314 
315 .seealso: MatCreate(), MATLMVM, MATLMVMBADBRDN, MatCreateLMVMDFP(), MatCreateLMVMSR1(),
316           MatCreateLMVMBFGS(), MatCreateLMVMBrdn(), MatCreateLMVMSymBrdn()
317 @*/
318 PetscErrorCode MatCreateLMVMBadBroyden(MPI_Comm comm, PetscInt n, PetscInt N, Mat *B)
319 {
320   PetscFunctionBegin;
321   PetscCall(MatCreate(comm, B));
322   PetscCall(MatSetSizes(*B, n, n, N, N));
323   PetscCall(MatSetType(*B, MATLMVMBADBROYDEN));
324   PetscCall(MatSetUp(*B));
325   PetscFunctionReturn(0);
326 }
327