1 #include <../src/tao/leastsquares/impls/brgn/brgn.h> 2 3 static PetscErrorCode GNHessianProd(Mat H, Vec in, Vec out) 4 { 5 TAO_BRGN *gn; 6 PetscErrorCode ierr; 7 8 PetscFunctionBegin; 9 ierr = MatShellGetContext(H, &gn);CHKERRQ(ierr); 10 ierr = MatMult(gn->subsolver->ls_jac, in, gn->r_work);CHKERRQ(ierr); 11 ierr = MatMultTranspose(gn->subsolver->ls_jac, gn->r_work, out);CHKERRQ(ierr); 12 ierr = VecAXPY(out, gn->lambda, in);CHKERRQ(ierr); 13 PetscFunctionReturn(0); 14 } 15 16 static PetscErrorCode GNObjectiveGradientEval(Tao tao, Vec X, PetscReal *fcn, Vec G, void *ptr) 17 { 18 TAO_BRGN *gn = (TAO_BRGN *)ptr; 19 PetscScalar xnorm2; 20 PetscErrorCode ierr; 21 22 PetscFunctionBegin; 23 ierr = TaoComputeResidual(tao, X, tao->ls_res);CHKERRQ(ierr); 24 ierr = VecDotBegin(tao->ls_res, tao->ls_res, fcn);CHKERRQ(ierr); 25 ierr = VecAXPBYPCZ(gn->x_work, 1.0, -1.0, 0.0, X, gn->x_old);CHKERRQ(ierr); 26 ierr = VecDotBegin(gn->x_work, gn->x_work, &xnorm2);CHKERRQ(ierr); 27 ierr = VecDotEnd(tao->ls_res, tao->ls_res, fcn);CHKERRQ(ierr); 28 ierr = VecDotEnd(gn->x_work, gn->x_work, &xnorm2);CHKERRQ(ierr); 29 *fcn = 0.5*(*fcn) + 0.5*gn->lambda*xnorm2; 30 31 ierr = TaoComputeResidualJacobian(tao, X, tao->ls_jac, tao->ls_jac_pre);CHKERRQ(ierr); 32 ierr = MatMultTranspose(tao->ls_jac, tao->ls_res, G);CHKERRQ(ierr); 33 ierr = VecAXPBYPCZ(G, gn->lambda, -gn->lambda, 1.0, X, gn->x_old);CHKERRQ(ierr); 34 PetscFunctionReturn(0); 35 } 36 37 static PetscErrorCode GNComputeHessian(Tao tao, Vec X, Mat H, Mat Hpre, void *ptr) 38 { 39 PetscErrorCode ierr; 40 41 PetscFunctionBegin; 42 ierr = TaoComputeResidualJacobian(tao, X, tao->ls_jac, tao->ls_jac_pre);CHKERRQ(ierr); 43 PetscFunctionReturn(0); 44 } 45 46 static PetscErrorCode GNHookFunction(Tao tao, PetscInt iter) 47 { 48 TAO_BRGN *gn = (TAO_BRGN *)tao->user_update; 49 PetscErrorCode ierr; 50 51 PetscFunctionBegin; 52 /* Update basic tao information from the subsolver */ 53 gn->parent->nfuncs = tao->nfuncs; 54 gn->parent->ngrads = tao->ngrads; 55 gn->parent->nfuncgrads = tao->nfuncgrads; 56 gn->parent->nhess = tao->nhess; 57 gn->parent->niter = tao->niter; 58 gn->parent->ksp_its = tao->ksp_its; 59 gn->parent->ksp_tot_its = tao->ksp_tot_its; 60 ierr = TaoGetConvergedReason(tao, &gn->parent->reason);CHKERRQ(ierr); 61 /* Update the solution vectors */ 62 if (iter == 0) { 63 ierr = VecSet(gn->x_old, 0.0);CHKERRQ(ierr); 64 } else { 65 ierr = VecCopy(tao->solution, gn->x_old);CHKERRQ(ierr); 66 ierr = VecCopy(tao->solution, gn->parent->solution);CHKERRQ(ierr); 67 } 68 /* Update the gradient */ 69 ierr = VecCopy(tao->gradient, gn->parent->gradient);CHKERRQ(ierr); 70 /* Call general purpose update function */ 71 if (gn->parent->ops->update) { 72 ierr = (*gn->parent->ops->update)(gn->parent, gn->parent->niter);CHKERRQ(ierr); 73 } 74 PetscFunctionReturn(0); 75 } 76 77 static PetscErrorCode TaoSolve_BRGN(Tao tao) 78 { 79 TAO_BRGN *gn = (TAO_BRGN *)tao->data; 80 PetscErrorCode ierr; 81 82 PetscFunctionBegin; 83 ierr = TaoSolve(gn->subsolver);CHKERRQ(ierr); 84 /* Update basic tao information from the subsolver */ 85 tao->nfuncs = gn->subsolver->nfuncs; 86 tao->ngrads = gn->subsolver->ngrads; 87 tao->nfuncgrads = gn->subsolver->nfuncgrads; 88 tao->nhess = gn->subsolver->nhess; 89 tao->niter = gn->subsolver->niter; 90 tao->ksp_its = gn->subsolver->ksp_its; 91 tao->ksp_tot_its = gn->subsolver->ksp_tot_its; 92 ierr = TaoGetConvergedReason(gn->subsolver, &tao->reason);CHKERRQ(ierr); 93 /* Update vectors */ 94 ierr = VecCopy(gn->subsolver->solution, tao->solution);CHKERRQ(ierr); 95 ierr = VecCopy(gn->subsolver->gradient, tao->gradient);CHKERRQ(ierr); 96 PetscFunctionReturn(0); 97 } 98 99 static PetscErrorCode TaoSetFromOptions_BRGN(PetscOptionItems *PetscOptionsObject,Tao tao) 100 { 101 TAO_BRGN *gn = (TAO_BRGN *)tao->data; 102 PetscErrorCode ierr; 103 104 PetscFunctionBegin; 105 ierr = PetscOptionsHead(PetscOptionsObject,"Gauss-Newton method for least-squares problems using Tikhonov regularization");CHKERRQ(ierr); 106 ierr = PetscOptionsReal("-tao_brgn_lambda", "Tikhonov regularization factor", "", gn->lambda, &gn->lambda, NULL);CHKERRQ(ierr); 107 ierr = PetscOptionsTail();CHKERRQ(ierr); 108 ierr = TaoSetFromOptions(gn->subsolver);CHKERRQ(ierr); 109 PetscFunctionReturn(0); 110 } 111 112 static PetscErrorCode TaoView_BRGN(Tao tao, PetscViewer viewer) 113 { 114 TAO_BRGN *gn = (TAO_BRGN *)tao->data; 115 PetscErrorCode ierr; 116 117 PetscFunctionBegin; 118 ierr = PetscViewerASCIIPushTab(viewer);CHKERRQ(ierr); 119 ierr = TaoView(gn->subsolver, viewer);CHKERRQ(ierr); 120 ierr = PetscViewerASCIIPopTab(viewer);CHKERRQ(ierr); 121 PetscFunctionReturn(0); 122 } 123 124 static PetscErrorCode TaoSetUp_BRGN(Tao tao) 125 { 126 TAO_BRGN *gn = (TAO_BRGN *)tao->data; 127 PetscErrorCode ierr; 128 PetscBool is_bnls, is_bntr, is_bntl; 129 PetscInt i, nx, Nx; 130 131 PetscFunctionBegin; 132 if (!tao->ls_res) SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ORDER, "TaoSetResidualRoutine() must be called before setup!"); 133 ierr = PetscObjectTypeCompare((PetscObject)gn->subsolver, TAOBNLS, &is_bnls);CHKERRQ(ierr); 134 ierr = PetscObjectTypeCompare((PetscObject)gn->subsolver, TAOBNTR, &is_bntr);CHKERRQ(ierr); 135 ierr = PetscObjectTypeCompare((PetscObject)gn->subsolver, TAOBNTL, &is_bntl);CHKERRQ(ierr); 136 if ((is_bnls || is_bntr || is_bntl) && !tao->ls_jac) SETERRQ(PetscObjectComm((PetscObject)tao), PETSC_ERR_ORDER, "TaoSetResidualJacobianRoutine() must be called before setup!"); 137 if (!tao->gradient){ 138 ierr = VecDuplicate(tao->solution, &tao->gradient);CHKERRQ(ierr); 139 } 140 if (!gn->x_work){ 141 ierr = VecDuplicate(tao->solution, &gn->x_work);CHKERRQ(ierr); 142 } 143 if (!gn->r_work){ 144 ierr = VecDuplicate(tao->ls_res, &gn->r_work);CHKERRQ(ierr); 145 } 146 if (!gn->x_old) { 147 ierr = VecDuplicate(tao->solution, &gn->x_old);CHKERRQ(ierr); 148 ierr = VecSet(gn->x_old, 0.0);CHKERRQ(ierr); 149 } 150 if (!tao->setupcalled) { 151 /* Hessian setup */ 152 ierr = VecGetLocalSize(tao->solution, &nx);CHKERRQ(ierr); 153 ierr = VecGetSize(tao->solution, &Nx);CHKERRQ(ierr); 154 ierr = MatSetSizes(gn->H, nx, nx, Nx, Nx);CHKERRQ(ierr); 155 ierr = MatSetType(gn->H, MATSHELL);CHKERRQ(ierr); 156 ierr = MatSetUp(gn->H);CHKERRQ(ierr); 157 ierr = MatShellSetOperation(gn->H, MATOP_MULT, (void (*)(void))GNHessianProd);CHKERRQ(ierr); 158 ierr = MatShellSetContext(gn->H, (void*)gn);CHKERRQ(ierr); 159 /* Subsolver setup */ 160 ierr = TaoSetUpdate(gn->subsolver, GNHookFunction, (void*)gn);CHKERRQ(ierr); 161 ierr = TaoSetInitialVector(gn->subsolver, tao->solution);CHKERRQ(ierr); 162 if (tao->bounded) { 163 ierr = TaoSetVariableBounds(gn->subsolver, tao->XL, tao->XU);CHKERRQ(ierr); 164 } 165 ierr = TaoSetResidualRoutine(gn->subsolver, tao->ls_res, tao->ops->computeresidual, tao->user_lsresP);CHKERRQ(ierr); 166 ierr = TaoSetJacobianResidualRoutine(gn->subsolver, tao->ls_jac, tao->ls_jac, tao->ops->computeresidualjacobian, tao->user_lsjacP);CHKERRQ(ierr); 167 ierr = TaoSetObjectiveAndGradientRoutine(gn->subsolver, GNObjectiveGradientEval, (void*)gn);CHKERRQ(ierr); 168 ierr = TaoSetHessianRoutine(gn->subsolver, gn->H, gn->H, GNComputeHessian, (void*)gn);CHKERRQ(ierr); 169 /* Propagate some options down */ 170 ierr = TaoSetTolerances(gn->subsolver, tao->gatol, tao->grtol, tao->gttol);CHKERRQ(ierr); 171 ierr = TaoSetMaximumIterations(gn->subsolver, tao->max_it);CHKERRQ(ierr); 172 ierr = TaoSetMaximumFunctionEvaluations(gn->subsolver, tao->max_funcs);CHKERRQ(ierr); 173 for (i=0; i<tao->numbermonitors; ++i) { 174 ierr = TaoSetMonitor(gn->subsolver, tao->monitor[i], tao->monitorcontext[i], tao->monitordestroy[i]);CHKERRQ(ierr); 175 ierr = PetscObjectReference((PetscObject)(tao->monitorcontext[i]));CHKERRQ(ierr); 176 } 177 ierr = TaoSetUp(gn->subsolver);CHKERRQ(ierr); 178 } 179 PetscFunctionReturn(0); 180 } 181 182 static PetscErrorCode TaoDestroy_BRGN(Tao tao) 183 { 184 TAO_BRGN *gn = (TAO_BRGN *)tao->data; 185 PetscErrorCode ierr; 186 187 PetscFunctionBegin; 188 if (tao->setupcalled) { 189 ierr = VecDestroy(&tao->gradient);CHKERRQ(ierr); 190 ierr = VecDestroy(&gn->x_work);CHKERRQ(ierr); 191 ierr = VecDestroy(&gn->r_work);CHKERRQ(ierr); 192 ierr = VecDestroy(&gn->x_old);CHKERRQ(ierr); 193 } 194 ierr = MatDestroy(&gn->H);CHKERRQ(ierr); 195 ierr = TaoDestroy(&gn->subsolver);CHKERRQ(ierr); 196 gn->parent = NULL; 197 ierr = PetscFree(tao->data);CHKERRQ(ierr); 198 PetscFunctionReturn(0); 199 } 200 201 PETSC_EXTERN PetscErrorCode TaoCreate_BRGN(Tao tao) 202 { 203 TAO_BRGN *gn; 204 PetscErrorCode ierr; 205 206 PetscFunctionBegin; 207 ierr = PetscNewLog(tao,&gn);CHKERRQ(ierr); 208 209 tao->ops->destroy = TaoDestroy_BRGN; 210 tao->ops->setup = TaoSetUp_BRGN; 211 tao->ops->setfromoptions = TaoSetFromOptions_BRGN; 212 tao->ops->view = TaoView_BRGN; 213 tao->ops->solve = TaoSolve_BRGN; 214 215 tao->data = (void*)gn; 216 gn->lambda = 1e-4; 217 gn->parent = tao; 218 219 ierr = MatCreate(PetscObjectComm((PetscObject)tao), &gn->H);CHKERRQ(ierr); 220 ierr = MatSetOptionsPrefix(gn->H, "tao_brgn_hessian_");CHKERRQ(ierr); 221 222 ierr = TaoCreate(PetscObjectComm((PetscObject)tao), &gn->subsolver);CHKERRQ(ierr); 223 ierr = TaoSetType(gn->subsolver, TAOBNLS);CHKERRQ(ierr); 224 ierr = TaoSetOptionsPrefix(gn->subsolver, "tao_brgn_subsolver_");CHKERRQ(ierr); 225 PetscFunctionReturn(0); 226 } 227 228 /*@C 229 TaoBRGNGetSubsolver - Get the pointer to the subsolver inside BRGN 230 231 Collective on Tao 232 233 Level: developer 234 235 Input Parameters: 236 + tao - the Tao solver context 237 - subsolver - the Tao sub-solver context 238 @*/ 239 PetscErrorCode TaoBRGNGetSubsolver(Tao tao, Tao *subsolver) 240 { 241 TAO_BRGN *gn = (TAO_BRGN *)tao->data; 242 243 PetscFunctionBegin; 244 *subsolver = gn->subsolver; 245 PetscFunctionReturn(0); 246 } 247 248 /*@C 249 TaoBRGNSetTikhonovLambda - Set the Tikhonov regularization factor for the Gauss-Newton least-squares algorithm 250 251 Collective on Tao 252 253 Level: developer 254 255 Input Parameters: 256 + tao - the Tao solver context 257 - lambda - Tikhonov regularization factor 258 @*/ 259 PetscErrorCode TaoBRGNSetTikhonovLambda(Tao tao, PetscReal lambda) 260 { 261 TAO_BRGN *gn = (TAO_BRGN *)tao->data; 262 263 PetscFunctionBegin; 264 gn->lambda = lambda; 265 PetscFunctionReturn(0); 266 } 267