1 #include <../src/tao/unconstrained/impls/neldermead/neldermead.h> 2 #include <petscvec.h> 3 4 /*------------------------------------------------------------*/ 5 static PetscErrorCode NelderMeadSort(TAO_NelderMead *nm) 6 { 7 PetscReal *values = nm->f_values; 8 PetscInt *indices = nm->indices; 9 PetscInt dim = nm->N + 1; 10 PetscInt i, j, index; 11 PetscReal val; 12 13 PetscFunctionBegin; 14 for (i = 1; i < dim; i++) { 15 index = indices[i]; 16 val = values[index]; 17 for (j = i - 1; j >= 0 && values[indices[j]] > val; j--) indices[j + 1] = indices[j]; 18 indices[j + 1] = index; 19 } 20 PetscFunctionReturn(PETSC_SUCCESS); 21 } 22 23 /*------------------------------------------------------------*/ 24 static PetscErrorCode NelderMeadReplace(TAO_NelderMead *nm, PetscInt index, Vec Xmu, PetscReal f) 25 { 26 PetscFunctionBegin; 27 /* Add new vector's fraction of average */ 28 PetscCall(VecAXPY(nm->Xbar, nm->oneOverN, Xmu)); 29 PetscCall(VecCopy(Xmu, nm->simplex[index])); 30 nm->f_values[index] = f; 31 32 PetscCall(NelderMeadSort(nm)); 33 34 /* Subtract last vector from average */ 35 PetscCall(VecAXPY(nm->Xbar, -nm->oneOverN, nm->simplex[nm->indices[nm->N]])); 36 PetscFunctionReturn(PETSC_SUCCESS); 37 } 38 39 /* ---------------------------------------------------------- */ 40 static PetscErrorCode TaoSetUp_NM(Tao tao) 41 { 42 TAO_NelderMead *nm = (TAO_NelderMead *)tao->data; 43 PetscInt n; 44 45 PetscFunctionBegin; 46 PetscCall(VecGetSize(tao->solution, &n)); 47 nm->N = n; 48 nm->oneOverN = 1.0 / n; 49 PetscCall(VecDuplicateVecs(tao->solution, nm->N + 1, &nm->simplex)); 50 PetscCall(PetscMalloc1(nm->N + 1, &nm->f_values)); 51 PetscCall(PetscMalloc1(nm->N + 1, &nm->indices)); 52 PetscCall(VecDuplicate(tao->solution, &nm->Xbar)); 53 PetscCall(VecDuplicate(tao->solution, &nm->Xmur)); 54 PetscCall(VecDuplicate(tao->solution, &nm->Xmue)); 55 PetscCall(VecDuplicate(tao->solution, &nm->Xmuc)); 56 57 tao->gradient = NULL; 58 tao->step = 0; 59 PetscFunctionReturn(PETSC_SUCCESS); 60 } 61 62 /* ---------------------------------------------------------- */ 63 static PetscErrorCode TaoDestroy_NM(Tao tao) 64 { 65 TAO_NelderMead *nm = (TAO_NelderMead *)tao->data; 66 67 PetscFunctionBegin; 68 if (tao->setupcalled) { 69 PetscCall(VecDestroyVecs(nm->N + 1, &nm->simplex)); 70 PetscCall(VecDestroy(&nm->Xmuc)); 71 PetscCall(VecDestroy(&nm->Xmue)); 72 PetscCall(VecDestroy(&nm->Xmur)); 73 PetscCall(VecDestroy(&nm->Xbar)); 74 } 75 PetscCall(PetscFree(nm->indices)); 76 PetscCall(PetscFree(nm->f_values)); 77 PetscCall(PetscFree(tao->data)); 78 PetscFunctionReturn(PETSC_SUCCESS); 79 } 80 81 /*------------------------------------------------------------*/ 82 static PetscErrorCode TaoSetFromOptions_NM(Tao tao, PetscOptionItems PetscOptionsObject) 83 { 84 TAO_NelderMead *nm = (TAO_NelderMead *)tao->data; 85 86 PetscFunctionBegin; 87 PetscOptionsHeadBegin(PetscOptionsObject, "Nelder-Mead options"); 88 PetscCall(PetscOptionsDeprecated("-tao_nm_lamda", "-tao_nm_lambda", "3.18.4", NULL)); 89 PetscCall(PetscOptionsReal("-tao_nm_lambda", "initial step length", "", nm->lambda, &nm->lambda, NULL)); 90 PetscCall(PetscOptionsReal("-tao_nm_mu", "mu", "", nm->mu_oc, &nm->mu_oc, NULL)); 91 nm->mu_ic = -nm->mu_oc; 92 nm->mu_r = nm->mu_oc * 2.0; 93 nm->mu_e = nm->mu_oc * 4.0; 94 PetscOptionsHeadEnd(); 95 PetscFunctionReturn(PETSC_SUCCESS); 96 } 97 98 /*------------------------------------------------------------*/ 99 static PetscErrorCode TaoView_NM(Tao tao, PetscViewer viewer) 100 { 101 TAO_NelderMead *nm = (TAO_NelderMead *)tao->data; 102 PetscBool isascii; 103 104 PetscFunctionBegin; 105 PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii)); 106 if (isascii) { 107 PetscCall(PetscViewerASCIIPushTab(viewer)); 108 PetscCall(PetscViewerASCIIPrintf(viewer, "expansions: %" PetscInt_FMT "\n", nm->nexpand)); 109 PetscCall(PetscViewerASCIIPrintf(viewer, "reflections: %" PetscInt_FMT "\n", nm->nreflect)); 110 PetscCall(PetscViewerASCIIPrintf(viewer, "inside contractions: %" PetscInt_FMT "\n", nm->nincontract)); 111 PetscCall(PetscViewerASCIIPrintf(viewer, "outside contractionss: %" PetscInt_FMT "\n", nm->noutcontract)); 112 PetscCall(PetscViewerASCIIPrintf(viewer, "Shrink steps: %" PetscInt_FMT "\n", nm->nshrink)); 113 PetscCall(PetscViewerASCIIPopTab(viewer)); 114 } 115 PetscFunctionReturn(PETSC_SUCCESS); 116 } 117 118 /*------------------------------------------------------------*/ 119 static PetscErrorCode TaoSolve_NM(Tao tao) 120 { 121 TAO_NelderMead *nm = (TAO_NelderMead *)tao->data; 122 PetscReal *x; 123 PetscInt i; 124 Vec Xmur = nm->Xmur, Xmue = nm->Xmue, Xmuc = nm->Xmuc, Xbar = nm->Xbar; 125 PetscReal fr, fe, fc; 126 PetscInt shrink; 127 PetscInt low, high; 128 129 PetscFunctionBegin; 130 nm->nshrink = 0; 131 nm->nreflect = 0; 132 nm->nincontract = 0; 133 nm->noutcontract = 0; 134 nm->nexpand = 0; 135 136 if (tao->XL || tao->XU || tao->ops->computebounds) PetscCall(PetscInfo(tao, "WARNING: Variable bounds have been set but will be ignored by NelderMead algorithm\n")); 137 138 PetscCall(VecCopy(tao->solution, nm->simplex[0])); 139 PetscCall(TaoComputeObjective(tao, nm->simplex[0], &nm->f_values[0])); 140 nm->indices[0] = 0; 141 for (i = 1; i < nm->N + 1; i++) { 142 PetscCall(VecCopy(tao->solution, nm->simplex[i])); 143 PetscCall(VecGetOwnershipRange(nm->simplex[i], &low, &high)); 144 if (i - 1 >= low && i - 1 < high) { 145 PetscCall(VecGetArray(nm->simplex[i], &x)); 146 x[i - 1 - low] += nm->lambda; 147 PetscCall(VecRestoreArray(nm->simplex[i], &x)); 148 } 149 150 PetscCall(TaoComputeObjective(tao, nm->simplex[i], &nm->f_values[i])); 151 nm->indices[i] = i; 152 } 153 154 /* Xbar = (Sum of all simplex vectors - worst vector)/N */ 155 PetscCall(NelderMeadSort(nm)); 156 PetscCall(VecSet(Xbar, 0.0)); 157 for (i = 0; i < nm->N; i++) PetscCall(VecAXPY(Xbar, 1.0, nm->simplex[nm->indices[i]])); 158 PetscCall(VecScale(Xbar, nm->oneOverN)); 159 tao->reason = TAO_CONTINUE_ITERATING; 160 while (1) { 161 /* Call general purpose update function */ 162 PetscTryTypeMethod(tao, update, tao->niter, tao->user_update); 163 ++tao->niter; 164 shrink = 0; 165 PetscCall(VecCopy(nm->simplex[nm->indices[0]], tao->solution)); 166 PetscCall(TaoLogConvergenceHistory(tao, nm->f_values[nm->indices[0]], nm->f_values[nm->indices[nm->N]] - nm->f_values[nm->indices[0]], 0.0, tao->ksp_its)); 167 PetscCall(TaoMonitor(tao, tao->niter, nm->f_values[nm->indices[0]], nm->f_values[nm->indices[nm->N]] - nm->f_values[nm->indices[0]], 0.0, 1.0)); 168 PetscUseTypeMethod(tao, convergencetest, tao->cnvP); 169 if (tao->reason != TAO_CONTINUE_ITERATING) break; 170 171 /* x(mu) = (1 + mu)Xbar - mu*X_N+1 */ 172 PetscCall(VecAXPBYPCZ(Xmur, 1 + nm->mu_r, -nm->mu_r, 0, Xbar, nm->simplex[nm->indices[nm->N]])); 173 PetscCall(TaoComputeObjective(tao, Xmur, &fr)); 174 175 if (nm->f_values[nm->indices[0]] <= fr && fr < nm->f_values[nm->indices[nm->N - 1]]) { 176 /* reflect */ 177 nm->nreflect++; 178 PetscCall(PetscInfo(0, "Reflect\n")); 179 PetscCall(NelderMeadReplace(nm, nm->indices[nm->N], Xmur, fr)); 180 } else if (fr < nm->f_values[nm->indices[0]]) { 181 /* expand */ 182 nm->nexpand++; 183 PetscCall(PetscInfo(0, "Expand\n")); 184 PetscCall(VecAXPBYPCZ(Xmue, 1 + nm->mu_e, -nm->mu_e, 0, Xbar, nm->simplex[nm->indices[nm->N]])); 185 PetscCall(TaoComputeObjective(tao, Xmue, &fe)); 186 if (fe < fr) { 187 PetscCall(NelderMeadReplace(nm, nm->indices[nm->N], Xmue, fe)); 188 } else { 189 PetscCall(NelderMeadReplace(nm, nm->indices[nm->N], Xmur, fr)); 190 } 191 } else if (nm->f_values[nm->indices[nm->N - 1]] <= fr && fr < nm->f_values[nm->indices[nm->N]]) { 192 /* outside contraction */ 193 nm->noutcontract++; 194 PetscCall(PetscInfo(0, "Outside Contraction\n")); 195 PetscCall(VecAXPBYPCZ(Xmuc, 1 + nm->mu_oc, -nm->mu_oc, 0, Xbar, nm->simplex[nm->indices[nm->N]])); 196 197 PetscCall(TaoComputeObjective(tao, Xmuc, &fc)); 198 if (fc <= fr) PetscCall(NelderMeadReplace(nm, nm->indices[nm->N], Xmuc, fc)); 199 else shrink = 1; 200 } else { 201 /* inside contraction */ 202 nm->nincontract++; 203 PetscCall(PetscInfo(0, "Inside Contraction\n")); 204 PetscCall(VecAXPBYPCZ(Xmuc, 1 + nm->mu_ic, -nm->mu_ic, 0, Xbar, nm->simplex[nm->indices[nm->N]])); 205 PetscCall(TaoComputeObjective(tao, Xmuc, &fc)); 206 if (fc < nm->f_values[nm->indices[nm->N]]) PetscCall(NelderMeadReplace(nm, nm->indices[nm->N], Xmuc, fc)); 207 else shrink = 1; 208 } 209 210 if (shrink) { 211 nm->nshrink++; 212 PetscCall(PetscInfo(0, "Shrink\n")); 213 214 for (i = 1; i < nm->N + 1; i++) { 215 PetscCall(VecAXPBY(nm->simplex[nm->indices[i]], 1.5, -0.5, nm->simplex[nm->indices[0]])); 216 PetscCall(TaoComputeObjective(tao, nm->simplex[nm->indices[i]], &nm->f_values[nm->indices[i]])); 217 } 218 PetscCall(VecAXPBY(Xbar, 1.5 * nm->oneOverN, -0.5, nm->simplex[nm->indices[0]])); 219 220 /* Add last vector's fraction of average */ 221 PetscCall(VecAXPY(Xbar, nm->oneOverN, nm->simplex[nm->indices[nm->N]])); 222 PetscCall(NelderMeadSort(nm)); 223 /* Subtract new last vector from average */ 224 PetscCall(VecAXPY(Xbar, -nm->oneOverN, nm->simplex[nm->indices[nm->N]])); 225 } 226 } 227 PetscFunctionReturn(PETSC_SUCCESS); 228 } 229 230 /* ---------------------------------------------------------- */ 231 /*MC 232 TAONM - Nelder-Mead solver for derivative free, unconstrained minimization 233 234 Options Database Keys: 235 + -tao_nm_lambda - initial step length 236 - -tao_nm_mu - expansion/contraction factor 237 238 Level: beginner 239 M*/ 240 241 PETSC_EXTERN PetscErrorCode TaoCreate_NM(Tao tao) 242 { 243 TAO_NelderMead *nm; 244 245 PetscFunctionBegin; 246 PetscCall(PetscNew(&nm)); 247 tao->data = (void *)nm; 248 249 tao->ops->setup = TaoSetUp_NM; 250 tao->ops->solve = TaoSolve_NM; 251 tao->ops->view = TaoView_NM; 252 tao->ops->setfromoptions = TaoSetFromOptions_NM; 253 tao->ops->destroy = TaoDestroy_NM; 254 255 /* Override default settings (unless already changed) */ 256 PetscCall(TaoParametersInitialize(tao)); 257 PetscObjectParameterSetDefault(tao, max_it, 2000); 258 PetscObjectParameterSetDefault(tao, max_funcs, 4000); 259 260 nm->simplex = NULL; 261 nm->lambda = 1; 262 263 nm->mu_ic = -0.5; 264 nm->mu_oc = 0.5; 265 nm->mu_r = 1.0; 266 nm->mu_e = 2.0; 267 PetscFunctionReturn(PETSC_SUCCESS); 268 } 269