xref: /petsc/src/tao/tutorials/ex4.c (revision 48a46eb9bd028bec07ec0f396b1a3abb43f14558)
1c4762a1bSJed Brown static char help[] = "Simple example to test separable objective optimizers.\n";
2c4762a1bSJed Brown 
3c4762a1bSJed Brown #include <petsc.h>
4c4762a1bSJed Brown #include <petsctao.h>
5c4762a1bSJed Brown #include <petscvec.h>
6c4762a1bSJed Brown #include <petscmath.h>
7c4762a1bSJed Brown 
8c4762a1bSJed Brown #define NWORKLEFT  4
9c4762a1bSJed Brown #define NWORKRIGHT 12
10c4762a1bSJed Brown 
119371c9d4SSatish Balay typedef struct _UserCtx {
12c4762a1bSJed Brown   PetscInt    m;       /* The row dimension of F */
13c4762a1bSJed Brown   PetscInt    n;       /* The column dimension of F */
14c4762a1bSJed Brown   PetscInt    matops;  /* Matrix format. 0 for stencil, 1 for random */
15c4762a1bSJed Brown   PetscInt    iter;    /* Numer of iterations for ADMM */
16c4762a1bSJed Brown   PetscReal   hStart;  /* Starting point for Taylor test */
17c4762a1bSJed Brown   PetscReal   hFactor; /* Taylor test step factor */
18c4762a1bSJed Brown   PetscReal   hMin;    /* Taylor test end goal */
19c4762a1bSJed Brown   PetscReal   alpha;   /* regularization constant applied to || x ||_p */
20c4762a1bSJed Brown   PetscReal   eps;     /* small constant for approximating gradient of || x ||_1 */
21c4762a1bSJed Brown   PetscReal   mu;      /* the augmented Lagrangian term in ADMM */
22c4762a1bSJed Brown   PetscReal   abstol;
23c4762a1bSJed Brown   PetscReal   reltol;
24c4762a1bSJed Brown   Mat         F;                     /* matrix in least squares component $(1/2) * || F x - d ||_2^2$ */
25c4762a1bSJed Brown   Mat         W;                     /* Workspace matrix. ATA */
26c4762a1bSJed Brown   Mat         Hm;                    /* Hessian Misfit*/
27c4762a1bSJed Brown   Mat         Hr;                    /* Hessian Reg*/
28c4762a1bSJed Brown   Vec         d;                     /* RHS in least squares component $(1/2) * || F x - d ||_2^2$ */
29c4762a1bSJed Brown   Vec         workLeft[NWORKLEFT];   /* Workspace for temporary vec */
30c4762a1bSJed Brown   Vec         workRight[NWORKRIGHT]; /* Workspace for temporary vec */
31c4762a1bSJed Brown   NormType    p;
32c4762a1bSJed Brown   PetscRandom rctx;
33c4762a1bSJed Brown   PetscBool   taylor;   /* Flag to determine whether to run Taylor test or not */
34c4762a1bSJed Brown   PetscBool   use_admm; /* Flag to determine whether to run Taylor test or not */
35c4762a1bSJed Brown } * UserCtx;
36c4762a1bSJed Brown 
379371c9d4SSatish Balay static PetscErrorCode CreateRHS(UserCtx ctx) {
38c4762a1bSJed Brown   PetscFunctionBegin;
39c4762a1bSJed Brown   /* build the rhs d in ctx */
409566063dSJacob Faibussowitsch   PetscCall(VecCreate(PETSC_COMM_WORLD, &(ctx->d)));
419566063dSJacob Faibussowitsch   PetscCall(VecSetSizes(ctx->d, PETSC_DECIDE, ctx->m));
429566063dSJacob Faibussowitsch   PetscCall(VecSetFromOptions(ctx->d));
439566063dSJacob Faibussowitsch   PetscCall(VecSetRandom(ctx->d, ctx->rctx));
44c4762a1bSJed Brown   PetscFunctionReturn(0);
45c4762a1bSJed Brown }
46c4762a1bSJed Brown 
479371c9d4SSatish Balay static PetscErrorCode CreateMatrix(UserCtx ctx) {
48c4762a1bSJed Brown   PetscInt Istart, Iend, i, j, Ii, gridN, I_n, I_s, I_e, I_w;
49c4762a1bSJed Brown #if defined(PETSC_USE_LOG)
50c4762a1bSJed Brown   PetscLogStage stage;
51c4762a1bSJed Brown #endif
52c4762a1bSJed Brown 
53c4762a1bSJed Brown   PetscFunctionBegin;
54c4762a1bSJed Brown   /* build the matrix F in ctx */
559566063dSJacob Faibussowitsch   PetscCall(MatCreate(PETSC_COMM_WORLD, &(ctx->F)));
569566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(ctx->F, PETSC_DECIDE, PETSC_DECIDE, ctx->m, ctx->n));
579566063dSJacob Faibussowitsch   PetscCall(MatSetType(ctx->F, MATAIJ));                          /* TODO: Decide specific SetType other than dummy*/
589566063dSJacob Faibussowitsch   PetscCall(MatMPIAIJSetPreallocation(ctx->F, 5, NULL, 5, NULL)); /*TODO: some number other than 5?*/
599566063dSJacob Faibussowitsch   PetscCall(MatSeqAIJSetPreallocation(ctx->F, 5, NULL));
609566063dSJacob Faibussowitsch   PetscCall(MatSetUp(ctx->F));
619566063dSJacob Faibussowitsch   PetscCall(MatGetOwnershipRange(ctx->F, &Istart, &Iend));
629566063dSJacob Faibussowitsch   PetscCall(PetscLogStageRegister("Assembly", &stage));
639566063dSJacob Faibussowitsch   PetscCall(PetscLogStagePush(stage));
64c4762a1bSJed Brown 
653c859ba3SBarry Smith   /* Set matrix elements in  2-D five point stencil format. */
66c4762a1bSJed Brown   if (!(ctx->matops)) {
673c859ba3SBarry Smith     PetscCheck(ctx->m == ctx->n, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Stencil matrix must be square");
68c4762a1bSJed Brown     gridN = (PetscInt)PetscSqrtReal((PetscReal)ctx->m);
693c859ba3SBarry Smith     PetscCheck(gridN * gridN == ctx->m, PETSC_COMM_WORLD, PETSC_ERR_ARG_SIZ, "Number of rows must be square");
70c4762a1bSJed Brown     for (Ii = Istart; Ii < Iend; Ii++) {
719371c9d4SSatish Balay       i   = Ii / gridN;
729371c9d4SSatish Balay       j   = Ii % gridN;
73c4762a1bSJed Brown       I_n = i * gridN + j + 1;
74c4762a1bSJed Brown       if (j + 1 >= gridN) I_n = -1;
75c4762a1bSJed Brown       I_s = i * gridN + j - 1;
76c4762a1bSJed Brown       if (j - 1 < 0) I_s = -1;
77c4762a1bSJed Brown       I_e = (i + 1) * gridN + j;
78c4762a1bSJed Brown       if (i + 1 >= gridN) I_e = -1;
79c4762a1bSJed Brown       I_w = (i - 1) * gridN + j;
80c4762a1bSJed Brown       if (i - 1 < 0) I_w = -1;
819566063dSJacob Faibussowitsch       PetscCall(MatSetValue(ctx->F, Ii, Ii, 4., INSERT_VALUES));
829566063dSJacob Faibussowitsch       PetscCall(MatSetValue(ctx->F, Ii, I_n, -1., INSERT_VALUES));
839566063dSJacob Faibussowitsch       PetscCall(MatSetValue(ctx->F, Ii, I_s, -1., INSERT_VALUES));
849566063dSJacob Faibussowitsch       PetscCall(MatSetValue(ctx->F, Ii, I_e, -1., INSERT_VALUES));
859566063dSJacob Faibussowitsch       PetscCall(MatSetValue(ctx->F, Ii, I_w, -1., INSERT_VALUES));
86c4762a1bSJed Brown     }
879566063dSJacob Faibussowitsch   } else PetscCall(MatSetRandom(ctx->F, ctx->rctx));
889566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(ctx->F, MAT_FINAL_ASSEMBLY));
899566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(ctx->F, MAT_FINAL_ASSEMBLY));
909566063dSJacob Faibussowitsch   PetscCall(PetscLogStagePop());
91c4762a1bSJed Brown   /* Stencil matrix is symmetric. Setting symmetric flag for ICC/Cholesky preconditioner */
92*48a46eb9SPierre Jolivet   if (!(ctx->matops)) PetscCall(MatSetOption(ctx->F, MAT_SYMMETRIC, PETSC_TRUE));
939566063dSJacob Faibussowitsch   PetscCall(MatTransposeMatMult(ctx->F, ctx->F, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(ctx->W)));
94c4762a1bSJed Brown   /* Setup Hessian Workspace in same shape as W */
959566063dSJacob Faibussowitsch   PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &(ctx->Hm)));
969566063dSJacob Faibussowitsch   PetscCall(MatDuplicate(ctx->W, MAT_DO_NOT_COPY_VALUES, &(ctx->Hr)));
97c4762a1bSJed Brown   PetscFunctionReturn(0);
98c4762a1bSJed Brown }
99c4762a1bSJed Brown 
1009371c9d4SSatish Balay static PetscErrorCode SetupWorkspace(UserCtx ctx) {
101c4762a1bSJed Brown   PetscInt i;
102c4762a1bSJed Brown 
103c4762a1bSJed Brown   PetscFunctionBegin;
1049566063dSJacob Faibussowitsch   PetscCall(MatCreateVecs(ctx->F, &ctx->workLeft[0], &ctx->workRight[0]));
105*48a46eb9SPierre Jolivet   for (i = 1; i < NWORKLEFT; i++) PetscCall(VecDuplicate(ctx->workLeft[0], &(ctx->workLeft[i])));
106*48a46eb9SPierre Jolivet   for (i = 1; i < NWORKRIGHT; i++) PetscCall(VecDuplicate(ctx->workRight[0], &(ctx->workRight[i])));
107c4762a1bSJed Brown   PetscFunctionReturn(0);
108c4762a1bSJed Brown }
109c4762a1bSJed Brown 
1109371c9d4SSatish Balay static PetscErrorCode ConfigureContext(UserCtx ctx) {
111c4762a1bSJed Brown   PetscFunctionBegin;
112c4762a1bSJed Brown   ctx->m        = 16;
113c4762a1bSJed Brown   ctx->n        = 16;
114c4762a1bSJed Brown   ctx->eps      = 1.e-3;
115c4762a1bSJed Brown   ctx->abstol   = 1.e-4;
116c4762a1bSJed Brown   ctx->reltol   = 1.e-2;
117c4762a1bSJed Brown   ctx->hStart   = 1.;
118c4762a1bSJed Brown   ctx->hMin     = 1.e-3;
119c4762a1bSJed Brown   ctx->hFactor  = 0.5;
120c4762a1bSJed Brown   ctx->alpha    = 1.;
121c4762a1bSJed Brown   ctx->mu       = 1.0;
122c4762a1bSJed Brown   ctx->matops   = 0;
123c4762a1bSJed Brown   ctx->iter     = 10;
124c4762a1bSJed Brown   ctx->p        = NORM_2;
125c4762a1bSJed Brown   ctx->taylor   = PETSC_TRUE;
126c4762a1bSJed Brown   ctx->use_admm = PETSC_FALSE;
127d0609cedSBarry Smith   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "ex4.c");
1289566063dSJacob Faibussowitsch   PetscCall(PetscOptionsInt("-m", "The row dimension of matrix F", "ex4.c", ctx->m, &(ctx->m), NULL));
1299566063dSJacob Faibussowitsch   PetscCall(PetscOptionsInt("-n", "The column dimension of matrix F", "ex4.c", ctx->n, &(ctx->n), NULL));
1309566063dSJacob Faibussowitsch   PetscCall(PetscOptionsInt("-matrix_format", "Decide format of F matrix. 0 for stencil, 1 for random", "ex4.c", ctx->matops, &(ctx->matops), NULL));
1319566063dSJacob Faibussowitsch   PetscCall(PetscOptionsInt("-iter", "Iteration number ADMM", "ex4.c", ctx->iter, &(ctx->iter), NULL));
1329566063dSJacob Faibussowitsch   PetscCall(PetscOptionsReal("-alpha", "The regularization multiplier. 1 default", "ex4.c", ctx->alpha, &(ctx->alpha), NULL));
1339566063dSJacob Faibussowitsch   PetscCall(PetscOptionsReal("-epsilon", "The small constant added to |x_i| in the denominator to approximate the gradient of ||x||_1", "ex4.c", ctx->eps, &(ctx->eps), NULL));
1349566063dSJacob Faibussowitsch   PetscCall(PetscOptionsReal("-mu", "The augmented lagrangian multiplier in ADMM", "ex4.c", ctx->mu, &(ctx->mu), NULL));
1359566063dSJacob Faibussowitsch   PetscCall(PetscOptionsReal("-hStart", "Taylor test starting point. 1 default.", "ex4.c", ctx->hStart, &(ctx->hStart), NULL));
1369566063dSJacob Faibussowitsch   PetscCall(PetscOptionsReal("-hFactor", "Taylor test multiplier factor. 0.5 default", "ex4.c", ctx->hFactor, &(ctx->hFactor), NULL));
1379566063dSJacob Faibussowitsch   PetscCall(PetscOptionsReal("-hMin", "Taylor test ending condition. 1.e-3 default", "ex4.c", ctx->hMin, &(ctx->hMin), NULL));
1389566063dSJacob Faibussowitsch   PetscCall(PetscOptionsReal("-abstol", "Absolute stopping criterion for ADMM", "ex4.c", ctx->abstol, &(ctx->abstol), NULL));
1399566063dSJacob Faibussowitsch   PetscCall(PetscOptionsReal("-reltol", "Relative stopping criterion for ADMM", "ex4.c", ctx->reltol, &(ctx->reltol), NULL));
1409566063dSJacob Faibussowitsch   PetscCall(PetscOptionsBool("-taylor", "Flag for Taylor test. Default is true.", "ex4.c", ctx->taylor, &(ctx->taylor), NULL));
1419566063dSJacob Faibussowitsch   PetscCall(PetscOptionsBool("-use_admm", "Use the ADMM solver in this example.", "ex4.c", ctx->use_admm, &(ctx->use_admm), NULL));
1429566063dSJacob Faibussowitsch   PetscCall(PetscOptionsEnum("-p", "Norm type.", "ex4.c", NormTypes, (PetscEnum)ctx->p, (PetscEnum *)&(ctx->p), NULL));
143d0609cedSBarry Smith   PetscOptionsEnd();
144c4762a1bSJed Brown   /* Creating random ctx */
1459566063dSJacob Faibussowitsch   PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &(ctx->rctx)));
1469566063dSJacob Faibussowitsch   PetscCall(PetscRandomSetFromOptions(ctx->rctx));
1479566063dSJacob Faibussowitsch   PetscCall(CreateMatrix(ctx));
1489566063dSJacob Faibussowitsch   PetscCall(CreateRHS(ctx));
1499566063dSJacob Faibussowitsch   PetscCall(SetupWorkspace(ctx));
150c4762a1bSJed Brown   PetscFunctionReturn(0);
151c4762a1bSJed Brown }
152c4762a1bSJed Brown 
1539371c9d4SSatish Balay static PetscErrorCode DestroyContext(UserCtx *ctx) {
154c4762a1bSJed Brown   PetscInt i;
155c4762a1bSJed Brown 
156c4762a1bSJed Brown   PetscFunctionBegin;
1579566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&((*ctx)->F)));
1589566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&((*ctx)->W)));
1599566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&((*ctx)->Hm)));
1609566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&((*ctx)->Hr)));
1619566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&((*ctx)->d)));
162*48a46eb9SPierre Jolivet   for (i = 0; i < NWORKLEFT; i++) PetscCall(VecDestroy(&((*ctx)->workLeft[i])));
163*48a46eb9SPierre Jolivet   for (i = 0; i < NWORKRIGHT; i++) PetscCall(VecDestroy(&((*ctx)->workRight[i])));
1649566063dSJacob Faibussowitsch   PetscCall(PetscRandomDestroy(&((*ctx)->rctx)));
1659566063dSJacob Faibussowitsch   PetscCall(PetscFree(*ctx));
166c4762a1bSJed Brown   PetscFunctionReturn(0);
167c4762a1bSJed Brown }
168c4762a1bSJed Brown 
169c4762a1bSJed Brown /* compute (1/2) * ||F x - d||^2 */
1709371c9d4SSatish Balay static PetscErrorCode ObjectiveMisfit(Tao tao, Vec x, PetscReal *J, void *_ctx) {
171c4762a1bSJed Brown   UserCtx ctx = (UserCtx)_ctx;
172c4762a1bSJed Brown   Vec     y;
173c4762a1bSJed Brown 
174c4762a1bSJed Brown   PetscFunctionBegin;
175c4762a1bSJed Brown   y = ctx->workLeft[0];
1769566063dSJacob Faibussowitsch   PetscCall(MatMult(ctx->F, x, y));
1779566063dSJacob Faibussowitsch   PetscCall(VecAXPY(y, -1., ctx->d));
1789566063dSJacob Faibussowitsch   PetscCall(VecDot(y, y, J));
179c4762a1bSJed Brown   *J *= 0.5;
180c4762a1bSJed Brown   PetscFunctionReturn(0);
181c4762a1bSJed Brown }
182c4762a1bSJed Brown 
183c4762a1bSJed Brown /* compute V = FTFx - FTd */
1849371c9d4SSatish Balay static PetscErrorCode GradientMisfit(Tao tao, Vec x, Vec V, void *_ctx) {
185c4762a1bSJed Brown   UserCtx ctx = (UserCtx)_ctx;
186c4762a1bSJed Brown   Vec     FTFx, FTd;
187c4762a1bSJed Brown 
188c4762a1bSJed Brown   PetscFunctionBegin;
189c4762a1bSJed Brown   /* work1 is A^T Ax, work2 is Ab, W is A^T A*/
190c4762a1bSJed Brown   FTFx = ctx->workRight[0];
191c4762a1bSJed Brown   FTd  = ctx->workRight[1];
1929566063dSJacob Faibussowitsch   PetscCall(MatMult(ctx->W, x, FTFx));
1939566063dSJacob Faibussowitsch   PetscCall(MatMultTranspose(ctx->F, ctx->d, FTd));
1949566063dSJacob Faibussowitsch   PetscCall(VecWAXPY(V, -1., FTd, FTFx));
195c4762a1bSJed Brown   PetscFunctionReturn(0);
196c4762a1bSJed Brown }
197c4762a1bSJed Brown 
198c4762a1bSJed Brown /* returns FTF */
1999371c9d4SSatish Balay static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) {
200c4762a1bSJed Brown   UserCtx ctx = (UserCtx)_ctx;
201c4762a1bSJed Brown 
202c4762a1bSJed Brown   PetscFunctionBegin;
2039566063dSJacob Faibussowitsch   if (H != ctx->W) PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
2049566063dSJacob Faibussowitsch   if (Hpre != ctx->W) PetscCall(MatCopy(ctx->W, Hpre, DIFFERENT_NONZERO_PATTERN));
205c4762a1bSJed Brown   PetscFunctionReturn(0);
206c4762a1bSJed Brown }
207c4762a1bSJed Brown 
208c4762a1bSJed Brown /* computes augment Lagrangian objective (with scaled dual):
209c4762a1bSJed Brown  * 0.5 * ||F x - d||^2  + 0.5 * mu ||x - z + u||^2 */
2109371c9d4SSatish Balay static PetscErrorCode ObjectiveMisfitADMM(Tao tao, Vec x, PetscReal *J, void *_ctx) {
211c4762a1bSJed Brown   UserCtx   ctx = (UserCtx)_ctx;
212c4762a1bSJed Brown   PetscReal mu, workNorm, misfit;
213c4762a1bSJed Brown   Vec       z, u, temp;
214c4762a1bSJed Brown 
215c4762a1bSJed Brown   PetscFunctionBegin;
216c4762a1bSJed Brown   mu   = ctx->mu;
217c4762a1bSJed Brown   z    = ctx->workRight[5];
218c4762a1bSJed Brown   u    = ctx->workRight[6];
219c4762a1bSJed Brown   temp = ctx->workRight[10];
220c4762a1bSJed Brown   /* misfit = f(x) */
2219566063dSJacob Faibussowitsch   PetscCall(ObjectiveMisfit(tao, x, &misfit, _ctx));
2229566063dSJacob Faibussowitsch   PetscCall(VecCopy(x, temp));
223c4762a1bSJed Brown   /* temp = x - z + u */
2249566063dSJacob Faibussowitsch   PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u));
225c4762a1bSJed Brown   /* workNorm = ||x - z + u||^2 */
2269566063dSJacob Faibussowitsch   PetscCall(VecDot(temp, temp, &workNorm));
227c4762a1bSJed Brown   /* augment Lagrangian objective (with scaled dual): f(x) + 0.5 * mu ||x -z + u||^2 */
228c4762a1bSJed Brown   *J = misfit + 0.5 * mu * workNorm;
229c4762a1bSJed Brown   PetscFunctionReturn(0);
230c4762a1bSJed Brown }
231c4762a1bSJed Brown 
232c4762a1bSJed Brown /* computes FTFx - FTd  mu*(x - z + u) */
2339371c9d4SSatish Balay static PetscErrorCode GradientMisfitADMM(Tao tao, Vec x, Vec V, void *_ctx) {
234c4762a1bSJed Brown   UserCtx   ctx = (UserCtx)_ctx;
235c4762a1bSJed Brown   PetscReal mu;
236c4762a1bSJed Brown   Vec       z, u, temp;
237c4762a1bSJed Brown 
238c4762a1bSJed Brown   PetscFunctionBegin;
239c4762a1bSJed Brown   mu   = ctx->mu;
240c4762a1bSJed Brown   z    = ctx->workRight[5];
241c4762a1bSJed Brown   u    = ctx->workRight[6];
242c4762a1bSJed Brown   temp = ctx->workRight[10];
2439566063dSJacob Faibussowitsch   PetscCall(GradientMisfit(tao, x, V, _ctx));
2449566063dSJacob Faibussowitsch   PetscCall(VecCopy(x, temp));
245c4762a1bSJed Brown   /* temp = x - z + u */
2469566063dSJacob Faibussowitsch   PetscCall(VecAXPBYPCZ(temp, -1., 1., 1., z, u));
247c4762a1bSJed Brown   /* V =  FTFx - FTd  mu*(x - z + u) */
2489566063dSJacob Faibussowitsch   PetscCall(VecAXPY(V, mu, temp));
249c4762a1bSJed Brown   PetscFunctionReturn(0);
250c4762a1bSJed Brown }
251c4762a1bSJed Brown 
252c4762a1bSJed Brown /* returns FTF + diag(mu) */
2539371c9d4SSatish Balay static PetscErrorCode HessianMisfitADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) {
254c4762a1bSJed Brown   UserCtx ctx = (UserCtx)_ctx;
255c4762a1bSJed Brown 
256c4762a1bSJed Brown   PetscFunctionBegin;
2579566063dSJacob Faibussowitsch   PetscCall(MatCopy(ctx->W, H, DIFFERENT_NONZERO_PATTERN));
2589566063dSJacob Faibussowitsch   PetscCall(MatShift(H, ctx->mu));
259*48a46eb9SPierre Jolivet   if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
260c4762a1bSJed Brown   PetscFunctionReturn(0);
261c4762a1bSJed Brown }
262c4762a1bSJed Brown 
263c4762a1bSJed Brown /* computes || x ||_p (mult by 0.5 in case of NORM_2) */
2649371c9d4SSatish Balay static PetscErrorCode ObjectiveRegularization(Tao tao, Vec x, PetscReal *J, void *_ctx) {
265c4762a1bSJed Brown   UserCtx   ctx = (UserCtx)_ctx;
266c4762a1bSJed Brown   PetscReal norm;
267c4762a1bSJed Brown 
268c4762a1bSJed Brown   PetscFunctionBegin;
269c4762a1bSJed Brown   *J = 0;
2709566063dSJacob Faibussowitsch   PetscCall(VecNorm(x, ctx->p, &norm));
271c4762a1bSJed Brown   if (ctx->p == NORM_2) norm = 0.5 * norm * norm;
272c4762a1bSJed Brown   *J = ctx->alpha * norm;
273c4762a1bSJed Brown   PetscFunctionReturn(0);
274c4762a1bSJed Brown }
275c4762a1bSJed Brown 
276c4762a1bSJed Brown /* NORM_2 Case: return x
277c4762a1bSJed Brown  * NORM_1 Case: x/(|x| + eps)
278c4762a1bSJed Brown  * Else: TODO */
2799371c9d4SSatish Balay static PetscErrorCode GradientRegularization(Tao tao, Vec x, Vec V, void *_ctx) {
280c4762a1bSJed Brown   UserCtx   ctx = (UserCtx)_ctx;
281c4762a1bSJed Brown   PetscReal eps = ctx->eps;
282c4762a1bSJed Brown 
283c4762a1bSJed Brown   PetscFunctionBegin;
284c4762a1bSJed Brown   if (ctx->p == NORM_2) {
2859566063dSJacob Faibussowitsch     PetscCall(VecCopy(x, V));
286c4762a1bSJed Brown   } else if (ctx->p == NORM_1) {
2879566063dSJacob Faibussowitsch     PetscCall(VecCopy(x, ctx->workRight[1]));
2889566063dSJacob Faibussowitsch     PetscCall(VecAbs(ctx->workRight[1]));
2899566063dSJacob Faibussowitsch     PetscCall(VecShift(ctx->workRight[1], eps));
2909566063dSJacob Faibussowitsch     PetscCall(VecPointwiseDivide(V, x, ctx->workRight[1]));
291c4762a1bSJed Brown   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
292c4762a1bSJed Brown   PetscFunctionReturn(0);
293c4762a1bSJed Brown }
294c4762a1bSJed Brown 
295c4762a1bSJed Brown /* NORM_2 Case: returns diag(mu)
296c4762a1bSJed Brown  * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)))  */
2979371c9d4SSatish Balay static PetscErrorCode HessianRegularization(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) {
298c4762a1bSJed Brown   UserCtx   ctx = (UserCtx)_ctx;
299c4762a1bSJed Brown   PetscReal eps = ctx->eps;
300c4762a1bSJed Brown   Vec       copy1, copy2, copy3;
301c4762a1bSJed Brown 
302c4762a1bSJed Brown   PetscFunctionBegin;
303c4762a1bSJed Brown   if (ctx->p == NORM_2) {
304c4762a1bSJed Brown     /* Identity matrix scaled by mu */
3059566063dSJacob Faibussowitsch     PetscCall(MatZeroEntries(H));
3069566063dSJacob Faibussowitsch     PetscCall(MatShift(H, ctx->mu));
307c4762a1bSJed Brown     if (Hpre != H) {
3089566063dSJacob Faibussowitsch       PetscCall(MatZeroEntries(Hpre));
3099566063dSJacob Faibussowitsch       PetscCall(MatShift(Hpre, ctx->mu));
310c4762a1bSJed Brown     }
311c4762a1bSJed Brown   } else if (ctx->p == NORM_1) {
312c4762a1bSJed Brown     /* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps)) */
313c4762a1bSJed Brown     copy1 = ctx->workRight[1];
314c4762a1bSJed Brown     copy2 = ctx->workRight[2];
315c4762a1bSJed Brown     copy3 = ctx->workRight[3];
316c4762a1bSJed Brown     /* copy1 : 1/sqrt(x_i^2 + eps) */
3179566063dSJacob Faibussowitsch     PetscCall(VecCopy(x, copy1));
3189566063dSJacob Faibussowitsch     PetscCall(VecPow(copy1, 2));
3199566063dSJacob Faibussowitsch     PetscCall(VecShift(copy1, eps));
3209566063dSJacob Faibussowitsch     PetscCall(VecSqrtAbs(copy1));
3219566063dSJacob Faibussowitsch     PetscCall(VecReciprocal(copy1));
322c4762a1bSJed Brown     /* copy2:  x_i^2.*/
3239566063dSJacob Faibussowitsch     PetscCall(VecCopy(x, copy2));
3249566063dSJacob Faibussowitsch     PetscCall(VecPow(copy2, 2));
325c4762a1bSJed Brown     /* copy3: abs(x_i^2 + eps) */
3269566063dSJacob Faibussowitsch     PetscCall(VecCopy(x, copy3));
3279566063dSJacob Faibussowitsch     PetscCall(VecPow(copy3, 2));
3289566063dSJacob Faibussowitsch     PetscCall(VecShift(copy3, eps));
3299566063dSJacob Faibussowitsch     PetscCall(VecAbs(copy3));
330c4762a1bSJed Brown     /* copy2: 1 - x_i^2/abs(x_i^2 + eps) */
3319566063dSJacob Faibussowitsch     PetscCall(VecPointwiseDivide(copy2, copy2, copy3));
3329566063dSJacob Faibussowitsch     PetscCall(VecScale(copy2, -1.));
3339566063dSJacob Faibussowitsch     PetscCall(VecShift(copy2, 1.));
3349566063dSJacob Faibussowitsch     PetscCall(VecAXPY(copy1, 1., copy2));
3359566063dSJacob Faibussowitsch     PetscCall(VecScale(copy1, ctx->mu));
3369566063dSJacob Faibussowitsch     PetscCall(MatZeroEntries(H));
3379566063dSJacob Faibussowitsch     PetscCall(MatDiagonalSet(H, copy1, INSERT_VALUES));
338c4762a1bSJed Brown     if (Hpre != H) {
3399566063dSJacob Faibussowitsch       PetscCall(MatZeroEntries(Hpre));
3409566063dSJacob Faibussowitsch       PetscCall(MatDiagonalSet(Hpre, copy1, INSERT_VALUES));
341c4762a1bSJed Brown     }
342c4762a1bSJed Brown   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
343c4762a1bSJed Brown   PetscFunctionReturn(0);
344c4762a1bSJed Brown }
345c4762a1bSJed Brown 
346c4762a1bSJed Brown /* NORM_2 Case: 0.5 || x ||_2 + 0.5 * mu * ||x + u - z||^2
347c4762a1bSJed Brown  * Else : || x ||_2 + 0.5 * mu * ||x + u - z||^2 */
3489371c9d4SSatish Balay static PetscErrorCode ObjectiveRegularizationADMM(Tao tao, Vec z, PetscReal *J, void *_ctx) {
349c4762a1bSJed Brown   UserCtx   ctx = (UserCtx)_ctx;
350c4762a1bSJed Brown   PetscReal mu, workNorm, reg;
351c4762a1bSJed Brown   Vec       x, u, temp;
352c4762a1bSJed Brown 
353c4762a1bSJed Brown   PetscFunctionBegin;
354c4762a1bSJed Brown   mu   = ctx->mu;
355c4762a1bSJed Brown   x    = ctx->workRight[4];
356c4762a1bSJed Brown   u    = ctx->workRight[6];
357c4762a1bSJed Brown   temp = ctx->workRight[10];
3589566063dSJacob Faibussowitsch   PetscCall(ObjectiveRegularization(tao, z, &reg, _ctx));
3599566063dSJacob Faibussowitsch   PetscCall(VecCopy(z, temp));
360c4762a1bSJed Brown   /* temp = x + u -z */
3619566063dSJacob Faibussowitsch   PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u));
362c4762a1bSJed Brown   /* workNorm = ||x + u - z ||^2 */
3639566063dSJacob Faibussowitsch   PetscCall(VecDot(temp, temp, &workNorm));
364c4762a1bSJed Brown   *J = reg + 0.5 * mu * workNorm;
365c4762a1bSJed Brown   PetscFunctionReturn(0);
366c4762a1bSJed Brown }
367c4762a1bSJed Brown 
368c4762a1bSJed Brown /* NORM_2 Case: x - mu*(x + u - z)
369c4762a1bSJed Brown  * NORM_1 Case: x/(|x| + eps) - mu*(x + u - z)
370c4762a1bSJed Brown  * Else: TODO */
3719371c9d4SSatish Balay static PetscErrorCode GradientRegularizationADMM(Tao tao, Vec z, Vec V, void *_ctx) {
372c4762a1bSJed Brown   UserCtx   ctx = (UserCtx)_ctx;
373c4762a1bSJed Brown   PetscReal mu;
374c4762a1bSJed Brown   Vec       x, u, temp;
375c4762a1bSJed Brown 
376c4762a1bSJed Brown   PetscFunctionBegin;
377c4762a1bSJed Brown   mu   = ctx->mu;
378c4762a1bSJed Brown   x    = ctx->workRight[4];
379c4762a1bSJed Brown   u    = ctx->workRight[6];
380c4762a1bSJed Brown   temp = ctx->workRight[10];
3819566063dSJacob Faibussowitsch   PetscCall(GradientRegularization(tao, z, V, _ctx));
3829566063dSJacob Faibussowitsch   PetscCall(VecCopy(z, temp));
383c4762a1bSJed Brown   /* temp = x + u -z */
3849566063dSJacob Faibussowitsch   PetscCall(VecAXPBYPCZ(temp, 1., 1., -1., x, u));
3859566063dSJacob Faibussowitsch   PetscCall(VecAXPY(V, -mu, temp));
386c4762a1bSJed Brown   PetscFunctionReturn(0);
387c4762a1bSJed Brown }
388c4762a1bSJed Brown 
389c4762a1bSJed Brown /* NORM_2 Case: returns diag(mu)
390c4762a1bSJed Brown  * NORM_1 Case: FTF + diag(mu) */
3919371c9d4SSatish Balay static PetscErrorCode HessianRegularizationADMM(Tao tao, Vec x, Mat H, Mat Hpre, void *_ctx) {
392c4762a1bSJed Brown   UserCtx ctx = (UserCtx)_ctx;
393c4762a1bSJed Brown 
394c4762a1bSJed Brown   PetscFunctionBegin;
395c4762a1bSJed Brown   if (ctx->p == NORM_2) {
396c4762a1bSJed Brown     /* Identity matrix scaled by mu */
3979566063dSJacob Faibussowitsch     PetscCall(MatZeroEntries(H));
3989566063dSJacob Faibussowitsch     PetscCall(MatShift(H, ctx->mu));
399c4762a1bSJed Brown     if (Hpre != H) {
4009566063dSJacob Faibussowitsch       PetscCall(MatZeroEntries(Hpre));
4019566063dSJacob Faibussowitsch       PetscCall(MatShift(Hpre, ctx->mu));
402c4762a1bSJed Brown     }
403c4762a1bSJed Brown   } else if (ctx->p == NORM_1) {
4049566063dSJacob Faibussowitsch     PetscCall(HessianMisfit(tao, x, H, Hpre, (void *)ctx));
4059566063dSJacob Faibussowitsch     PetscCall(MatShift(H, ctx->mu));
4069566063dSJacob Faibussowitsch     if (Hpre != H) PetscCall(MatShift(Hpre, ctx->mu));
407c4762a1bSJed Brown   } else SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ARG_OUTOFRANGE, "Example only works for NORM_1 and NORM_2");
408c4762a1bSJed Brown   PetscFunctionReturn(0);
409c4762a1bSJed Brown }
410c4762a1bSJed Brown 
411c4762a1bSJed Brown /* NORM_2 Case : (1/2) * ||F x - d||^2 + 0.5 * || x ||_p
412c4762a1bSJed Brown *  NORM_1 Case : (1/2) * ||F x - d||^2 + || x ||_p */
4139371c9d4SSatish Balay static PetscErrorCode ObjectiveComplete(Tao tao, Vec x, PetscReal *J, void *ctx) {
414c4762a1bSJed Brown   PetscReal Jm, Jr;
415c4762a1bSJed Brown 
416c4762a1bSJed Brown   PetscFunctionBegin;
4179566063dSJacob Faibussowitsch   PetscCall(ObjectiveMisfit(tao, x, &Jm, ctx));
4189566063dSJacob Faibussowitsch   PetscCall(ObjectiveRegularization(tao, x, &Jr, ctx));
419c4762a1bSJed Brown   *J = Jm + Jr;
420c4762a1bSJed Brown   PetscFunctionReturn(0);
421c4762a1bSJed Brown }
422c4762a1bSJed Brown 
423c4762a1bSJed Brown /* NORM_2 Case: FTFx - FTd + x
424c4762a1bSJed Brown  * NORM_1 Case: FTFx - FTd + x/(|x| + eps) */
4259371c9d4SSatish Balay static PetscErrorCode GradientComplete(Tao tao, Vec x, Vec V, void *ctx) {
426c4762a1bSJed Brown   UserCtx cntx = (UserCtx)ctx;
427c4762a1bSJed Brown 
428c4762a1bSJed Brown   PetscFunctionBegin;
4299566063dSJacob Faibussowitsch   PetscCall(GradientMisfit(tao, x, cntx->workRight[2], ctx));
4309566063dSJacob Faibussowitsch   PetscCall(GradientRegularization(tao, x, cntx->workRight[3], ctx));
4319566063dSJacob Faibussowitsch   PetscCall(VecWAXPY(V, 1, cntx->workRight[2], cntx->workRight[3]));
432c4762a1bSJed Brown   PetscFunctionReturn(0);
433c4762a1bSJed Brown }
434c4762a1bSJed Brown 
435c4762a1bSJed Brown /* NORM_2 Case: diag(mu) + FTF
436c4762a1bSJed Brown  * NORM_1 Case: diag(mu* 1/sqrt(x_i^2 + eps) * (1 - x_i^2/ABS(x_i^2+eps))) + FTF  */
4379371c9d4SSatish Balay static PetscErrorCode HessianComplete(Tao tao, Vec x, Mat H, Mat Hpre, void *ctx) {
438c4762a1bSJed Brown   Mat tempH;
439c4762a1bSJed Brown 
440c4762a1bSJed Brown   PetscFunctionBegin;
4419566063dSJacob Faibussowitsch   PetscCall(MatDuplicate(H, MAT_SHARE_NONZERO_PATTERN, &tempH));
4429566063dSJacob Faibussowitsch   PetscCall(HessianMisfit(tao, x, H, H, ctx));
4439566063dSJacob Faibussowitsch   PetscCall(HessianRegularization(tao, x, tempH, tempH, ctx));
4449566063dSJacob Faibussowitsch   PetscCall(MatAXPY(H, 1., tempH, DIFFERENT_NONZERO_PATTERN));
445*48a46eb9SPierre Jolivet   if (Hpre != H) PetscCall(MatCopy(H, Hpre, DIFFERENT_NONZERO_PATTERN));
4469566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&tempH));
447c4762a1bSJed Brown   PetscFunctionReturn(0);
448c4762a1bSJed Brown }
449c4762a1bSJed Brown 
4509371c9d4SSatish Balay static PetscErrorCode TaoSolveADMM(UserCtx ctx, Vec x) {
451c4762a1bSJed Brown   PetscInt  i;
452c4762a1bSJed Brown   PetscReal u_norm, r_norm, s_norm, primal, dual, x_norm, z_norm;
453c4762a1bSJed Brown   Tao       tao1, tao2;
454c4762a1bSJed Brown   Vec       xk, z, u, diff, zold, zdiff, temp;
455c4762a1bSJed Brown   PetscReal mu;
456c4762a1bSJed Brown 
457c4762a1bSJed Brown   PetscFunctionBegin;
458c4762a1bSJed Brown   xk    = ctx->workRight[4];
459c4762a1bSJed Brown   z     = ctx->workRight[5];
460c4762a1bSJed Brown   u     = ctx->workRight[6];
461c4762a1bSJed Brown   diff  = ctx->workRight[7];
462c4762a1bSJed Brown   zold  = ctx->workRight[8];
463c4762a1bSJed Brown   zdiff = ctx->workRight[9];
464c4762a1bSJed Brown   temp  = ctx->workRight[11];
465c4762a1bSJed Brown   mu    = ctx->mu;
4669566063dSJacob Faibussowitsch   PetscCall(VecSet(u, 0.));
4679566063dSJacob Faibussowitsch   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao1));
4689566063dSJacob Faibussowitsch   PetscCall(TaoSetType(tao1, TAONLS));
4699566063dSJacob Faibussowitsch   PetscCall(TaoSetObjective(tao1, ObjectiveMisfitADMM, (void *)ctx));
4709566063dSJacob Faibussowitsch   PetscCall(TaoSetGradient(tao1, NULL, GradientMisfitADMM, (void *)ctx));
4719566063dSJacob Faibussowitsch   PetscCall(TaoSetHessian(tao1, ctx->Hm, ctx->Hm, HessianMisfitADMM, (void *)ctx));
4729566063dSJacob Faibussowitsch   PetscCall(VecSet(xk, 0.));
4739566063dSJacob Faibussowitsch   PetscCall(TaoSetSolution(tao1, xk));
4749566063dSJacob Faibussowitsch   PetscCall(TaoSetOptionsPrefix(tao1, "misfit_"));
4759566063dSJacob Faibussowitsch   PetscCall(TaoSetFromOptions(tao1));
4769566063dSJacob Faibussowitsch   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao2));
477c4762a1bSJed Brown   if (ctx->p == NORM_2) {
4789566063dSJacob Faibussowitsch     PetscCall(TaoSetType(tao2, TAONLS));
4799566063dSJacob Faibussowitsch     PetscCall(TaoSetObjective(tao2, ObjectiveRegularizationADMM, (void *)ctx));
4809566063dSJacob Faibussowitsch     PetscCall(TaoSetGradient(tao2, NULL, GradientRegularizationADMM, (void *)ctx));
4819566063dSJacob Faibussowitsch     PetscCall(TaoSetHessian(tao2, ctx->Hr, ctx->Hr, HessianRegularizationADMM, (void *)ctx));
482c4762a1bSJed Brown   }
4839566063dSJacob Faibussowitsch   PetscCall(VecSet(z, 0.));
4849566063dSJacob Faibussowitsch   PetscCall(TaoSetSolution(tao2, z));
4859566063dSJacob Faibussowitsch   PetscCall(TaoSetOptionsPrefix(tao2, "reg_"));
4869566063dSJacob Faibussowitsch   PetscCall(TaoSetFromOptions(tao2));
487c4762a1bSJed Brown 
488c4762a1bSJed Brown   for (i = 0; i < ctx->iter; i++) {
4899566063dSJacob Faibussowitsch     PetscCall(VecCopy(z, zold));
4909566063dSJacob Faibussowitsch     PetscCall(TaoSolve(tao1)); /* Updates xk */
491c4762a1bSJed Brown     if (ctx->p == NORM_1) {
4929566063dSJacob Faibussowitsch       PetscCall(VecWAXPY(temp, 1., xk, u));
4939566063dSJacob Faibussowitsch       PetscCall(TaoSoftThreshold(temp, -ctx->alpha / mu, ctx->alpha / mu, z));
494c4762a1bSJed Brown     } else {
4959566063dSJacob Faibussowitsch       PetscCall(TaoSolve(tao2)); /* Update zk */
496c4762a1bSJed Brown     }
497c4762a1bSJed Brown     /* u = u + xk -z */
4989566063dSJacob Faibussowitsch     PetscCall(VecAXPBYPCZ(u, 1., -1., 1., xk, z));
499c4762a1bSJed Brown     /* r_norm : norm(x-z) */
5009566063dSJacob Faibussowitsch     PetscCall(VecWAXPY(diff, -1., z, xk));
5019566063dSJacob Faibussowitsch     PetscCall(VecNorm(diff, NORM_2, &r_norm));
502c4762a1bSJed Brown     /* s_norm : norm(-mu(z-zold)) */
5039566063dSJacob Faibussowitsch     PetscCall(VecWAXPY(zdiff, -1., zold, z));
5049566063dSJacob Faibussowitsch     PetscCall(VecNorm(zdiff, NORM_2, &s_norm));
505c4762a1bSJed Brown     s_norm = s_norm * mu;
506c4762a1bSJed Brown     /* primal : sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z))*/
5079566063dSJacob Faibussowitsch     PetscCall(VecNorm(xk, NORM_2, &x_norm));
5089566063dSJacob Faibussowitsch     PetscCall(VecNorm(z, NORM_2, &z_norm));
509c4762a1bSJed Brown     primal = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * PetscMax(x_norm, z_norm);
510c4762a1bSJed Brown     /* Duality : sqrt(n)*ABSTOL + RELTOL*norm(mu*u)*/
5119566063dSJacob Faibussowitsch     PetscCall(VecNorm(u, NORM_2, &u_norm));
512c4762a1bSJed Brown     dual = PetscSqrtReal(ctx->n) * ctx->abstol + ctx->reltol * u_norm * mu;
51363a3b9bcSJacob Faibussowitsch     PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao1), "Iter %" PetscInt_FMT " : ||x-z||: %g, mu*||z-zold||: %g\n", i, (double)r_norm, (double)s_norm));
514c4762a1bSJed Brown     if (r_norm < primal && s_norm < dual) break;
515c4762a1bSJed Brown   }
5169566063dSJacob Faibussowitsch   PetscCall(VecCopy(xk, x));
5179566063dSJacob Faibussowitsch   PetscCall(TaoDestroy(&tao1));
5189566063dSJacob Faibussowitsch   PetscCall(TaoDestroy(&tao2));
519c4762a1bSJed Brown   PetscFunctionReturn(0);
520c4762a1bSJed Brown }
521c4762a1bSJed Brown 
522c4762a1bSJed Brown /* Second order Taylor remainder convergence test */
5239371c9d4SSatish Balay static PetscErrorCode TaylorTest(UserCtx ctx, Tao tao, Vec x, PetscReal *C) {
524c4762a1bSJed Brown   PetscReal  h, J, temp;
525c4762a1bSJed Brown   PetscInt   i, j;
526c4762a1bSJed Brown   PetscInt   numValues;
527c4762a1bSJed Brown   PetscReal  Jx, Jxhat_comp, Jxhat_pred;
528c4762a1bSJed Brown   PetscReal *Js, *hs;
529c4762a1bSJed Brown   PetscReal  gdotdx;
530c4762a1bSJed Brown   PetscReal  minrate = PETSC_MAX_REAL;
531c4762a1bSJed Brown   MPI_Comm   comm    = PetscObjectComm((PetscObject)x);
532c4762a1bSJed Brown   Vec        g, dx, xhat;
533c4762a1bSJed Brown 
534c4762a1bSJed Brown   PetscFunctionBegin;
5359566063dSJacob Faibussowitsch   PetscCall(VecDuplicate(x, &g));
5369566063dSJacob Faibussowitsch   PetscCall(VecDuplicate(x, &xhat));
537c4762a1bSJed Brown   /* choose a perturbation direction */
5389566063dSJacob Faibussowitsch   PetscCall(VecDuplicate(x, &dx));
5399566063dSJacob Faibussowitsch   PetscCall(VecSetRandom(dx, ctx->rctx));
540c4762a1bSJed Brown   /* evaluate objective at x: J(x) */
5419566063dSJacob Faibussowitsch   PetscCall(TaoComputeObjective(tao, x, &Jx));
542c4762a1bSJed Brown   /* evaluate gradient at x, save in vector g */
5439566063dSJacob Faibussowitsch   PetscCall(TaoComputeGradient(tao, x, g));
5449566063dSJacob Faibussowitsch   PetscCall(VecDot(g, dx, &gdotdx));
545c4762a1bSJed Brown 
546c4762a1bSJed Brown   for (numValues = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor) numValues++;
5479566063dSJacob Faibussowitsch   PetscCall(PetscCalloc2(numValues, &Js, numValues, &hs));
548c4762a1bSJed Brown   for (i = 0, h = ctx->hStart; h >= ctx->hMin; h *= ctx->hFactor, i++) {
5499566063dSJacob Faibussowitsch     PetscCall(VecWAXPY(xhat, h, dx, x));
5509566063dSJacob Faibussowitsch     PetscCall(TaoComputeObjective(tao, xhat, &Jxhat_comp));
551c4762a1bSJed Brown     /* J(\hat(x)) \approx J(x) + g^T (xhat - x) = J(x) + h * g^T dx */
552c4762a1bSJed Brown     Jxhat_pred = Jx + h * gdotdx;
553c4762a1bSJed Brown     /* Vector to dJdm scalar? Dot?*/
554c4762a1bSJed Brown     J          = PetscAbsReal(Jxhat_comp - Jxhat_pred);
5559566063dSJacob Faibussowitsch     PetscCall(PetscPrintf(comm, "J(xhat): %g, predicted: %g, diff %g\n", (double)Jxhat_comp, (double)Jxhat_pred, (double)J));
556c4762a1bSJed Brown     Js[i] = J;
557c4762a1bSJed Brown     hs[i] = h;
558c4762a1bSJed Brown   }
559c4762a1bSJed Brown   for (j = 1; j < numValues; j++) {
560c4762a1bSJed Brown     temp = PetscLogReal(Js[j] / Js[j - 1]) / PetscLogReal(hs[j] / hs[j - 1]);
56163a3b9bcSJacob Faibussowitsch     PetscCall(PetscPrintf(comm, "Convergence rate step %" PetscInt_FMT ": %g\n", j - 1, (double)temp));
562c4762a1bSJed Brown     minrate = PetscMin(minrate, temp);
563c4762a1bSJed Brown   }
564c4762a1bSJed Brown   /* If O is not ~2, then the test is wrong */
5659566063dSJacob Faibussowitsch   PetscCall(PetscFree2(Js, hs));
566c4762a1bSJed Brown   *C = minrate;
5679566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&dx));
5689566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&xhat));
5699566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&g));
570c4762a1bSJed Brown   PetscFunctionReturn(0);
571c4762a1bSJed Brown }
572c4762a1bSJed Brown 
5739371c9d4SSatish Balay int main(int argc, char **argv) {
574c4762a1bSJed Brown   UserCtx ctx;
575c4762a1bSJed Brown   Tao     tao;
576c4762a1bSJed Brown   Vec     x;
577c4762a1bSJed Brown   Mat     H;
578c4762a1bSJed Brown 
579327415f7SBarry Smith   PetscFunctionBeginUser;
5809566063dSJacob Faibussowitsch   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
5819566063dSJacob Faibussowitsch   PetscCall(PetscNew(&ctx));
5829566063dSJacob Faibussowitsch   PetscCall(ConfigureContext(ctx));
583a82e8c82SStefano Zampini   /* Define two functions that could pass as objectives to TaoSetObjective(): one
584c4762a1bSJed Brown    * for the misfit component, and one for the regularization component */
585c4762a1bSJed Brown   /* ObjectiveMisfit() and ObjectiveRegularization() */
586c4762a1bSJed Brown 
587c4762a1bSJed Brown   /* Define a single function that calls both components adds them together: the complete objective,
588c4762a1bSJed Brown    * in the absence of a Tao implementation that handles separability */
589c4762a1bSJed Brown   /* ObjectiveComplete() */
5909566063dSJacob Faibussowitsch   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao));
5919566063dSJacob Faibussowitsch   PetscCall(TaoSetType(tao, TAONM));
5929566063dSJacob Faibussowitsch   PetscCall(TaoSetObjective(tao, ObjectiveComplete, (void *)ctx));
5939566063dSJacob Faibussowitsch   PetscCall(TaoSetGradient(tao, NULL, GradientComplete, (void *)ctx));
5949566063dSJacob Faibussowitsch   PetscCall(MatDuplicate(ctx->W, MAT_SHARE_NONZERO_PATTERN, &H));
5959566063dSJacob Faibussowitsch   PetscCall(TaoSetHessian(tao, H, H, HessianComplete, (void *)ctx));
5969566063dSJacob Faibussowitsch   PetscCall(MatCreateVecs(ctx->F, NULL, &x));
5979566063dSJacob Faibussowitsch   PetscCall(VecSet(x, 0.));
5989566063dSJacob Faibussowitsch   PetscCall(TaoSetSolution(tao, x));
5999566063dSJacob Faibussowitsch   PetscCall(TaoSetFromOptions(tao));
6001baa6e33SBarry Smith   if (ctx->use_admm) PetscCall(TaoSolveADMM(ctx, x));
6011baa6e33SBarry Smith   else PetscCall(TaoSolve(tao));
602c4762a1bSJed Brown   /* examine solution */
6039566063dSJacob Faibussowitsch   PetscCall(VecViewFromOptions(x, NULL, "-view_sol"));
604c4762a1bSJed Brown   if (ctx->taylor) {
605c4762a1bSJed Brown     PetscReal rate;
6069566063dSJacob Faibussowitsch     PetscCall(TaylorTest(ctx, tao, x, &rate));
607c4762a1bSJed Brown   }
6089566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&H));
6099566063dSJacob Faibussowitsch   PetscCall(TaoDestroy(&tao));
6109566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&x));
6119566063dSJacob Faibussowitsch   PetscCall(DestroyContext(&ctx));
6129566063dSJacob Faibussowitsch   PetscCall(PetscFinalize());
613b122ec5aSJacob Faibussowitsch   return 0;
614c4762a1bSJed Brown }
615c4762a1bSJed Brown 
616c4762a1bSJed Brown /*TEST
617c4762a1bSJed Brown 
618c4762a1bSJed Brown   build:
619c4762a1bSJed Brown     requires: !complex
620c4762a1bSJed Brown 
621c4762a1bSJed Brown   test:
622c4762a1bSJed Brown     suffix: 0
623c4762a1bSJed Brown     args:
624c4762a1bSJed Brown 
625c4762a1bSJed Brown   test:
626c4762a1bSJed Brown     suffix: l1_1
627c4762a1bSJed Brown     args: -p 1 -tao_type lmvm -alpha 1. -epsilon 1.e-7 -m 64 -n 64 -view_sol -matrix_format 1
628c4762a1bSJed Brown 
629c4762a1bSJed Brown   test:
630c4762a1bSJed Brown     suffix: hessian_1
631c5f5e425SStefano Zampini     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nls
632c4762a1bSJed Brown 
633c4762a1bSJed Brown   test:
634c4762a1bSJed Brown     suffix: hessian_2
635c5f5e425SStefano Zampini     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nls
636c4762a1bSJed Brown 
637c4762a1bSJed Brown   test:
638c4762a1bSJed Brown     suffix: nm_1
639c4762a1bSJed Brown     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type nm -tao_max_it 50
640c4762a1bSJed Brown 
641c4762a1bSJed Brown   test:
642c4762a1bSJed Brown     suffix: nm_2
643c4762a1bSJed Brown     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type nm -tao_max_it 50
644c4762a1bSJed Brown 
645c4762a1bSJed Brown   test:
646c4762a1bSJed Brown     suffix: lmvm_1
647c4762a1bSJed Brown     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -tao_type lmvm -tao_max_it 40
648c4762a1bSJed Brown 
649c4762a1bSJed Brown   test:
650c4762a1bSJed Brown     suffix: lmvm_2
651c4762a1bSJed Brown     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -tao_type lmvm -tao_max_it 15
652c4762a1bSJed Brown 
653c4762a1bSJed Brown   test:
654c4762a1bSJed Brown     suffix: soft_threshold_admm_1
655c4762a1bSJed Brown     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm
656c4762a1bSJed Brown 
657c4762a1bSJed Brown   test:
658c4762a1bSJed Brown     suffix: hessian_admm_1
659c4762a1bSJed Brown     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nls -misfit_tao_type nls
660c4762a1bSJed Brown 
661c4762a1bSJed Brown   test:
662c4762a1bSJed Brown     suffix: hessian_admm_2
663c4762a1bSJed Brown     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nls -misfit_tao_type nls
664c4762a1bSJed Brown 
665c4762a1bSJed Brown   test:
666c4762a1bSJed Brown     suffix: nm_admm_1
667c4762a1bSJed Brown     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type nm -misfit_tao_type nm
668c4762a1bSJed Brown 
669c4762a1bSJed Brown   test:
670c4762a1bSJed Brown     suffix: nm_admm_2
671c4762a1bSJed Brown     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type nm -misfit_tao_type nm -iter 7
672c4762a1bSJed Brown 
673c4762a1bSJed Brown   test:
674c4762a1bSJed Brown     suffix: lmvm_admm_1
675c4762a1bSJed Brown     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 1 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
676c4762a1bSJed Brown 
677c4762a1bSJed Brown   test:
678c4762a1bSJed Brown     suffix: lmvm_admm_2
679c4762a1bSJed Brown     args: -matrix_format 1 -m 100 -n 100 -tao_monitor -p 2 -use_admm -reg_tao_type lmvm -misfit_tao_type lmvm
680c4762a1bSJed Brown 
681c4762a1bSJed Brown TEST*/
682