xref: /petsc/src/tao/unconstrained/tutorials/rosenbrock4.h (revision ec9796c4f88c0039d05d273918e252d002ea08fe)
1*ec9796c4SHansol Suh #pragma once
2*ec9796c4SHansol Suh 
3*ec9796c4SHansol Suh #include <petsctao.h>
4*ec9796c4SHansol Suh #include <petscsf.h>
5*ec9796c4SHansol Suh #include <petscdevice.h>
6*ec9796c4SHansol Suh #include <petscdevice_cupm.h>
7*ec9796c4SHansol Suh 
8*ec9796c4SHansol Suh /*
9*ec9796c4SHansol Suh    User-defined application context - contains data needed by the
10*ec9796c4SHansol Suh    application-provided call-back routines that evaluate the function,
11*ec9796c4SHansol Suh    gradient, and hessian.
12*ec9796c4SHansol Suh */
13*ec9796c4SHansol Suh 
14*ec9796c4SHansol Suh typedef struct _Rosenbrock {
15*ec9796c4SHansol Suh   PetscInt  bs; // each block of bs variables is one chained multidimensional rosenbrock problem
16*ec9796c4SHansol Suh   PetscInt  i_start, i_end;
17*ec9796c4SHansol Suh   PetscInt  c_start, c_end;
18*ec9796c4SHansol Suh   PetscReal alpha; // condition parameter
19*ec9796c4SHansol Suh } Rosenbrock;
20*ec9796c4SHansol Suh 
21*ec9796c4SHansol Suh typedef struct _AppCtx *AppCtx;
22*ec9796c4SHansol Suh struct _AppCtx {
23*ec9796c4SHansol Suh   MPI_Comm      comm;
24*ec9796c4SHansol Suh   PetscInt      n; /* dimension */
25*ec9796c4SHansol Suh   PetscInt      n_local;
26*ec9796c4SHansol Suh   PetscInt      n_local_comp;
27*ec9796c4SHansol Suh   Rosenbrock    problem;
28*ec9796c4SHansol Suh   Vec           Hvalues; /* vector for writing COO values of this MPI process */
29*ec9796c4SHansol Suh   Vec           gvalues; /* vector for writing gradient values of this mpi process */
30*ec9796c4SHansol Suh   Vec           fvector;
31*ec9796c4SHansol Suh   PetscSF       off_process_scatter;
32*ec9796c4SHansol Suh   PetscSF       gscatter;
33*ec9796c4SHansol Suh   Vec           off_process_values; /* buffer for off-process values if chained */
34*ec9796c4SHansol Suh   PetscBool     test_lmvm;
35*ec9796c4SHansol Suh   PetscLogEvent event_f, event_g, event_fg;
36*ec9796c4SHansol Suh };
37*ec9796c4SHansol Suh 
38*ec9796c4SHansol Suh /* -------------- User-defined routines ---------- */
39*ec9796c4SHansol Suh 
40*ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjective(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2)
41*ec9796c4SHansol Suh {
42*ec9796c4SHansol Suh   PetscScalar d = x_2 - x_1 * x_1;
43*ec9796c4SHansol Suh   PetscScalar e = 1.0 - x_1;
44*ec9796c4SHansol Suh   return alpha * d * d + e * e;
45*ec9796c4SHansol Suh }
46*ec9796c4SHansol Suh 
47*ec9796c4SHansol Suh static const PetscLogDouble RosenbrockObjectiveFlops = 7.0;
48*ec9796c4SHansol Suh 
49*ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2])
50*ec9796c4SHansol Suh {
51*ec9796c4SHansol Suh   PetscScalar d  = x_2 - x_1 * x_1;
52*ec9796c4SHansol Suh   PetscScalar e  = 1.0 - x_1;
53*ec9796c4SHansol Suh   PetscScalar g2 = alpha * d * 2.0;
54*ec9796c4SHansol Suh 
55*ec9796c4SHansol Suh   g[0] = -2.0 * x_1 * g2 - 2.0 * e;
56*ec9796c4SHansol Suh   g[1] = g2;
57*ec9796c4SHansol Suh }
58*ec9796c4SHansol Suh 
59*ec9796c4SHansol Suh static const PetscInt RosenbrockGradientFlops = 9.0;
60*ec9796c4SHansol Suh 
61*ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjectiveGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2])
62*ec9796c4SHansol Suh {
63*ec9796c4SHansol Suh   PetscScalar d  = x_2 - x_1 * x_1;
64*ec9796c4SHansol Suh   PetscScalar e  = 1.0 - x_1;
65*ec9796c4SHansol Suh   PetscScalar ad = alpha * d;
66*ec9796c4SHansol Suh   PetscScalar g2 = ad * 2.0;
67*ec9796c4SHansol Suh 
68*ec9796c4SHansol Suh   g[0] = -2.0 * x_1 * g2 - 2.0 * e;
69*ec9796c4SHansol Suh   g[1] = g2;
70*ec9796c4SHansol Suh   return ad * d + e * e;
71*ec9796c4SHansol Suh }
72*ec9796c4SHansol Suh 
73*ec9796c4SHansol Suh static const PetscLogDouble RosenbrockObjectiveGradientFlops = 12.0;
74*ec9796c4SHansol Suh 
75*ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockHessian(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar h[4])
76*ec9796c4SHansol Suh {
77*ec9796c4SHansol Suh   PetscScalar d  = x_2 - x_1 * x_1;
78*ec9796c4SHansol Suh   PetscScalar g2 = alpha * d * 2.0;
79*ec9796c4SHansol Suh   PetscScalar h2 = -4.0 * alpha * x_1;
80*ec9796c4SHansol Suh 
81*ec9796c4SHansol Suh   h[0] = -2.0 * (g2 + x_1 * h2) + 2.0;
82*ec9796c4SHansol Suh   h[1] = h[2] = h2;
83*ec9796c4SHansol Suh   h[3]        = 2.0 * alpha;
84*ec9796c4SHansol Suh }
85*ec9796c4SHansol Suh 
86*ec9796c4SHansol Suh static const PetscLogDouble RosenbrockHessianFlops = 11.0;
87*ec9796c4SHansol Suh 
88*ec9796c4SHansol Suh static PetscErrorCode AppCtxCreate(MPI_Comm comm, AppCtx *ctx)
89*ec9796c4SHansol Suh {
90*ec9796c4SHansol Suh   AppCtx             user;
91*ec9796c4SHansol Suh   PetscDeviceContext dctx;
92*ec9796c4SHansol Suh 
93*ec9796c4SHansol Suh   PetscFunctionBegin;
94*ec9796c4SHansol Suh   PetscCall(PetscNew(ctx));
95*ec9796c4SHansol Suh   user       = *ctx;
96*ec9796c4SHansol Suh   user->comm = PETSC_COMM_WORLD;
97*ec9796c4SHansol Suh 
98*ec9796c4SHansol Suh   /* Initialize problem parameters */
99*ec9796c4SHansol Suh   user->n             = 2;
100*ec9796c4SHansol Suh   user->problem.alpha = 99.0;
101*ec9796c4SHansol Suh   user->problem.bs    = 2; // bs = 2 is block Rosenbrock, bs = n is chained Rosenbrock
102*ec9796c4SHansol Suh   user->test_lmvm     = PETSC_FALSE;
103*ec9796c4SHansol Suh   /* Check for command line arguments to override defaults */
104*ec9796c4SHansol Suh   PetscOptionsBegin(user->comm, NULL, "Rosenbrock example", NULL);
105*ec9796c4SHansol Suh   PetscCall(PetscOptionsInt("-n", "Rosenbrock problem size", NULL, user->n, &user->n, NULL));
106*ec9796c4SHansol Suh   PetscCall(PetscOptionsInt("-bs", "Rosenbrock block size (2 <= bs <= n)", NULL, user->problem.bs, &user->problem.bs, NULL));
107*ec9796c4SHansol Suh   PetscCall(PetscOptionsReal("-alpha", "Rosenbrock off-diagonal coefficient", NULL, user->problem.alpha, &user->problem.alpha, NULL));
108*ec9796c4SHansol Suh   PetscCall(PetscOptionsBool("-test_lmvm", "Test LMVM solve againt LMVM mult", NULL, user->test_lmvm, &user->test_lmvm, NULL));
109*ec9796c4SHansol Suh   PetscOptionsEnd();
110*ec9796c4SHansol Suh   PetscCheck(user->problem.bs >= 1, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " is not bigger than 1", user->problem.bs);
111*ec9796c4SHansol Suh   PetscCheck((user->n % user->problem.bs) == 0, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " doest not divide problem size % " PetscInt_FMT, user->problem.bs, user->n);
112*ec9796c4SHansol Suh   PetscCall(PetscLogEventRegister("Rbock_Obj", TAO_CLASSID, &user->event_f));
113*ec9796c4SHansol Suh   PetscCall(PetscLogEventRegister("Rbock_Grad", TAO_CLASSID, &user->event_g));
114*ec9796c4SHansol Suh   PetscCall(PetscLogEventRegister("Rbock_ObjGrad", TAO_CLASSID, &user->event_fg));
115*ec9796c4SHansol Suh   PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
116*ec9796c4SHansol Suh   PetscCall(PetscDeviceContextSetUp(dctx));
117*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
118*ec9796c4SHansol Suh }
119*ec9796c4SHansol Suh 
120*ec9796c4SHansol Suh static PetscErrorCode AppCtxDestroy(AppCtx *ctx)
121*ec9796c4SHansol Suh {
122*ec9796c4SHansol Suh   AppCtx user;
123*ec9796c4SHansol Suh 
124*ec9796c4SHansol Suh   PetscFunctionBegin;
125*ec9796c4SHansol Suh   user = *ctx;
126*ec9796c4SHansol Suh   *ctx = NULL;
127*ec9796c4SHansol Suh   PetscCall(VecDestroy(&user->Hvalues));
128*ec9796c4SHansol Suh   PetscCall(VecDestroy(&user->gvalues));
129*ec9796c4SHansol Suh   PetscCall(VecDestroy(&user->fvector));
130*ec9796c4SHansol Suh   PetscCall(VecDestroy(&user->off_process_values));
131*ec9796c4SHansol Suh   PetscCall(PetscSFDestroy(&user->off_process_scatter));
132*ec9796c4SHansol Suh   PetscCall(PetscSFDestroy(&user->gscatter));
133*ec9796c4SHansol Suh   PetscCall(PetscFree(user));
134*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
135*ec9796c4SHansol Suh }
136*ec9796c4SHansol Suh 
137*ec9796c4SHansol Suh static PetscErrorCode CreateHessian(AppCtx user, Mat *Hessian)
138*ec9796c4SHansol Suh {
139*ec9796c4SHansol Suh   Mat         H;
140*ec9796c4SHansol Suh   PetscLayout layout;
141*ec9796c4SHansol Suh   PetscInt    i_start, i_end, n_local_comp, nnz_local;
142*ec9796c4SHansol Suh   PetscInt    c_start, c_end;
143*ec9796c4SHansol Suh   PetscInt   *coo_i;
144*ec9796c4SHansol Suh   PetscInt   *coo_j;
145*ec9796c4SHansol Suh   PetscInt    bs = user->problem.bs;
146*ec9796c4SHansol Suh   VecType     vec_type;
147*ec9796c4SHansol Suh 
148*ec9796c4SHansol Suh   PetscFunctionBegin;
149*ec9796c4SHansol Suh   /* Partition the optimization variables and the computations.
150*ec9796c4SHansol Suh      There are (bs - 1) contributions to the objective function for every (bs)
151*ec9796c4SHansol Suh      degrees of freedom. */
152*ec9796c4SHansol Suh   PetscCall(PetscLayoutCreateFromSizes(user->comm, PETSC_DECIDE, user->n, 1, &layout));
153*ec9796c4SHansol Suh   PetscCall(PetscLayoutSetUp(layout));
154*ec9796c4SHansol Suh   PetscCall(PetscLayoutGetRange(layout, &i_start, &i_end));
155*ec9796c4SHansol Suh   user->problem.i_start = i_start;
156*ec9796c4SHansol Suh   user->problem.i_end   = i_end;
157*ec9796c4SHansol Suh   user->n_local         = i_end - i_start;
158*ec9796c4SHansol Suh   user->problem.c_start = c_start = (i_start / bs) * (bs - 1) + (i_start % bs);
159*ec9796c4SHansol Suh   user->problem.c_end = c_end = (i_end / bs) * (bs - 1) + (i_end % bs);
160*ec9796c4SHansol Suh   user->n_local_comp = n_local_comp = c_end - c_start;
161*ec9796c4SHansol Suh 
162*ec9796c4SHansol Suh   PetscCall(MatCreate(user->comm, Hessian));
163*ec9796c4SHansol Suh   H = *Hessian;
164*ec9796c4SHansol Suh   PetscCall(MatSetLayouts(H, layout, layout));
165*ec9796c4SHansol Suh   PetscCall(PetscLayoutDestroy(&layout));
166*ec9796c4SHansol Suh   PetscCall(MatSetType(H, MATAIJ));
167*ec9796c4SHansol Suh   PetscCall(MatSetOption(H, MAT_HERMITIAN, PETSC_TRUE));
168*ec9796c4SHansol Suh   PetscCall(MatSetOption(H, MAT_SYMMETRIC, PETSC_TRUE));
169*ec9796c4SHansol Suh   PetscCall(MatSetOption(H, MAT_SYMMETRY_ETERNAL, PETSC_TRUE));
170*ec9796c4SHansol Suh   PetscCall(MatSetOption(H, MAT_STRUCTURALLY_SYMMETRIC, PETSC_TRUE));
171*ec9796c4SHansol Suh   PetscCall(MatSetOption(H, MAT_STRUCTURAL_SYMMETRY_ETERNAL, PETSC_TRUE));
172*ec9796c4SHansol Suh   PetscCall(MatSetFromOptions(H)); /* set from options so that we can change the underlying matrix type */
173*ec9796c4SHansol Suh 
174*ec9796c4SHansol Suh   nnz_local = n_local_comp * 4;
175*ec9796c4SHansol Suh   PetscCall(PetscMalloc2(nnz_local, &coo_i, nnz_local, &coo_j));
176*ec9796c4SHansol Suh   /* Instead of having one computation thread per row of the matrix,
177*ec9796c4SHansol Suh      this example uses one thread per contribution to the objective
178*ec9796c4SHansol Suh      function.  Each contribution to the objective function relates
179*ec9796c4SHansol Suh      two adjacent degrees of freedom, so each contribution to
180*ec9796c4SHansol Suh      the objective function adds a 2x2 block into the matrix.
181*ec9796c4SHansol Suh      We describe these 2x2 blocks in COO format. */
182*ec9796c4SHansol Suh   for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 4) {
183*ec9796c4SHansol Suh     PetscInt i = (c / (bs - 1)) * bs + c % (bs - 1);
184*ec9796c4SHansol Suh 
185*ec9796c4SHansol Suh     coo_i[k + 0] = i;
186*ec9796c4SHansol Suh     coo_i[k + 1] = i;
187*ec9796c4SHansol Suh     coo_i[k + 2] = i + 1;
188*ec9796c4SHansol Suh     coo_i[k + 3] = i + 1;
189*ec9796c4SHansol Suh 
190*ec9796c4SHansol Suh     coo_j[k + 0] = i;
191*ec9796c4SHansol Suh     coo_j[k + 1] = i + 1;
192*ec9796c4SHansol Suh     coo_j[k + 2] = i;
193*ec9796c4SHansol Suh     coo_j[k + 3] = i + 1;
194*ec9796c4SHansol Suh   }
195*ec9796c4SHansol Suh   PetscCall(MatSetPreallocationCOO(H, nnz_local, coo_i, coo_j));
196*ec9796c4SHansol Suh   PetscCall(PetscFree2(coo_i, coo_j));
197*ec9796c4SHansol Suh 
198*ec9796c4SHansol Suh   PetscCall(MatGetVecType(H, &vec_type));
199*ec9796c4SHansol Suh   PetscCall(VecCreate(user->comm, &user->Hvalues));
200*ec9796c4SHansol Suh   PetscCall(VecSetSizes(user->Hvalues, nnz_local, PETSC_DETERMINE));
201*ec9796c4SHansol Suh   PetscCall(VecSetType(user->Hvalues, vec_type));
202*ec9796c4SHansol Suh 
203*ec9796c4SHansol Suh   // vector to collect contributions to the objective
204*ec9796c4SHansol Suh   PetscCall(VecCreate(user->comm, &user->fvector));
205*ec9796c4SHansol Suh   PetscCall(VecSetSizes(user->fvector, user->n_local_comp, PETSC_DETERMINE));
206*ec9796c4SHansol Suh   PetscCall(VecSetType(user->fvector, vec_type));
207*ec9796c4SHansol Suh 
208*ec9796c4SHansol Suh   { /* If we are using a device (such as a GPU), run some computations that will
209*ec9796c4SHansol Suh        warm up its linear algebra runtime before the problem we actually want
210*ec9796c4SHansol Suh        to profile */
211*ec9796c4SHansol Suh 
212*ec9796c4SHansol Suh     PetscMemType       memtype;
213*ec9796c4SHansol Suh     const PetscScalar *a;
214*ec9796c4SHansol Suh 
215*ec9796c4SHansol Suh     PetscCall(VecGetArrayReadAndMemType(user->fvector, &a, &memtype));
216*ec9796c4SHansol Suh     PetscCall(VecRestoreArrayReadAndMemType(user->fvector, &a));
217*ec9796c4SHansol Suh 
218*ec9796c4SHansol Suh     if (memtype == PETSC_MEMTYPE_DEVICE) {
219*ec9796c4SHansol Suh       PetscLogStage      warmup;
220*ec9796c4SHansol Suh       Mat                A, AtA;
221*ec9796c4SHansol Suh       Vec                x, b;
222*ec9796c4SHansol Suh       PetscInt           warmup_size = 1000;
223*ec9796c4SHansol Suh       PetscDeviceContext dctx;
224*ec9796c4SHansol Suh 
225*ec9796c4SHansol Suh       PetscCall(PetscLogStageRegister("Device Warmup", &warmup));
226*ec9796c4SHansol Suh       PetscCall(PetscLogStageSetActive(warmup, PETSC_FALSE));
227*ec9796c4SHansol Suh 
228*ec9796c4SHansol Suh       PetscCall(PetscLogStagePush(warmup));
229*ec9796c4SHansol Suh       PetscCall(MatCreateDenseFromVecType(PETSC_COMM_SELF, vec_type, warmup_size, warmup_size, warmup_size, warmup_size, PETSC_DEFAULT, NULL, &A));
230*ec9796c4SHansol Suh       PetscCall(MatSetRandom(A, NULL));
231*ec9796c4SHansol Suh       PetscCall(MatCreateVecs(A, &x, &b));
232*ec9796c4SHansol Suh       PetscCall(VecSetRandom(x, NULL));
233*ec9796c4SHansol Suh 
234*ec9796c4SHansol Suh       PetscCall(MatMult(A, x, b));
235*ec9796c4SHansol Suh       PetscCall(MatTransposeMatMult(A, A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &AtA));
236*ec9796c4SHansol Suh       PetscCall(MatShift(AtA, (PetscScalar)warmup_size));
237*ec9796c4SHansol Suh       PetscCall(MatSetOption(AtA, MAT_SPD, PETSC_TRUE));
238*ec9796c4SHansol Suh       PetscCall(MatCholeskyFactor(AtA, NULL, NULL));
239*ec9796c4SHansol Suh       PetscCall(MatDestroy(&AtA));
240*ec9796c4SHansol Suh       PetscCall(VecDestroy(&b));
241*ec9796c4SHansol Suh       PetscCall(VecDestroy(&x));
242*ec9796c4SHansol Suh       PetscCall(MatDestroy(&A));
243*ec9796c4SHansol Suh       PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
244*ec9796c4SHansol Suh       PetscCall(PetscDeviceContextSynchronize(dctx));
245*ec9796c4SHansol Suh       PetscCall(PetscLogStagePop());
246*ec9796c4SHansol Suh     }
247*ec9796c4SHansol Suh   }
248*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
249*ec9796c4SHansol Suh }
250*ec9796c4SHansol Suh 
251*ec9796c4SHansol Suh static PetscErrorCode CreateVectors(AppCtx user, Mat H, Vec *solution, Vec *gradient)
252*ec9796c4SHansol Suh {
253*ec9796c4SHansol Suh   VecType     vec_type;
254*ec9796c4SHansol Suh   PetscInt    n_coo, *coo_i, i_start, i_end;
255*ec9796c4SHansol Suh   Vec         x;
256*ec9796c4SHansol Suh   PetscInt    n_recv;
257*ec9796c4SHansol Suh   PetscSFNode recv;
258*ec9796c4SHansol Suh   PetscLayout layout;
259*ec9796c4SHansol Suh   PetscInt    c_start = user->problem.c_start, c_end = user->problem.c_end, bs = user->problem.bs;
260*ec9796c4SHansol Suh 
261*ec9796c4SHansol Suh   PetscFunctionBegin;
262*ec9796c4SHansol Suh   PetscCall(MatCreateVecs(H, solution, gradient));
263*ec9796c4SHansol Suh   x = *solution;
264*ec9796c4SHansol Suh   PetscCall(VecGetOwnershipRange(x, &i_start, &i_end));
265*ec9796c4SHansol Suh   PetscCall(VecGetType(x, &vec_type));
266*ec9796c4SHansol Suh   // create scatter for communicating values
267*ec9796c4SHansol Suh   PetscCall(VecGetLayout(x, &layout));
268*ec9796c4SHansol Suh   n_recv = 0;
269*ec9796c4SHansol Suh   if (user->n_local_comp && i_end < user->n) {
270*ec9796c4SHansol Suh     PetscMPIInt rank;
271*ec9796c4SHansol Suh     PetscInt    index;
272*ec9796c4SHansol Suh 
273*ec9796c4SHansol Suh     n_recv = 1;
274*ec9796c4SHansol Suh     PetscCall(PetscLayoutFindOwnerIndex(layout, i_end, &rank, &index));
275*ec9796c4SHansol Suh     recv.rank  = rank;
276*ec9796c4SHansol Suh     recv.index = index;
277*ec9796c4SHansol Suh   }
278*ec9796c4SHansol Suh   PetscCall(PetscSFCreate(user->comm, &user->off_process_scatter));
279*ec9796c4SHansol Suh   PetscCall(PetscSFSetGraph(user->off_process_scatter, user->n_local, n_recv, NULL, PETSC_USE_POINTER, &recv, PETSC_COPY_VALUES));
280*ec9796c4SHansol Suh   PetscCall(VecCreate(user->comm, &user->off_process_values));
281*ec9796c4SHansol Suh   PetscCall(VecSetSizes(user->off_process_values, 1, PETSC_DETERMINE));
282*ec9796c4SHansol Suh   PetscCall(VecSetType(user->off_process_values, vec_type));
283*ec9796c4SHansol Suh   PetscCall(VecZeroEntries(user->off_process_values));
284*ec9796c4SHansol Suh 
285*ec9796c4SHansol Suh   // create COO data for writing the gradient
286*ec9796c4SHansol Suh   n_coo = user->n_local_comp * 2;
287*ec9796c4SHansol Suh   PetscCall(PetscMalloc1(n_coo, &coo_i));
288*ec9796c4SHansol Suh   for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 2) {
289*ec9796c4SHansol Suh     PetscInt i = (c / (bs - 1)) * bs + (c % (bs - 1));
290*ec9796c4SHansol Suh 
291*ec9796c4SHansol Suh     coo_i[k + 0] = i;
292*ec9796c4SHansol Suh     coo_i[k + 1] = i + 1;
293*ec9796c4SHansol Suh   }
294*ec9796c4SHansol Suh   PetscCall(PetscSFCreate(user->comm, &user->gscatter));
295*ec9796c4SHansol Suh   PetscCall(PetscSFSetGraphLayout(user->gscatter, layout, n_coo, NULL, PETSC_USE_POINTER, coo_i));
296*ec9796c4SHansol Suh   PetscCall(PetscSFSetUp(user->gscatter));
297*ec9796c4SHansol Suh   PetscCall(PetscFree(coo_i));
298*ec9796c4SHansol Suh   PetscCall(VecCreate(user->comm, &user->gvalues));
299*ec9796c4SHansol Suh   PetscCall(VecSetSizes(user->gvalues, n_coo, PETSC_DETERMINE));
300*ec9796c4SHansol Suh   PetscCall(VecSetType(user->gvalues, vec_type));
301*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
302*ec9796c4SHansol Suh }
303*ec9796c4SHansol Suh 
304*ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC)
305*ec9796c4SHansol Suh 
306*ec9796c4SHansol Suh   #if PetscDefined(USING_NVCC)
307*ec9796c4SHansol Suh typedef cudaStream_t cupmStream_t;
308*ec9796c4SHansol Suh     #define PetscCUPMLaunch(...) \
309*ec9796c4SHansol Suh       do { \
310*ec9796c4SHansol Suh         __VA_ARGS__; \
311*ec9796c4SHansol Suh         PetscCallCUDA(cudaGetLastError()); \
312*ec9796c4SHansol Suh       } while (0)
313*ec9796c4SHansol Suh   #elif PetscDefined(USING_HCC)
314*ec9796c4SHansol Suh     #define PetscCUPMLaunch(...) \
315*ec9796c4SHansol Suh       do { \
316*ec9796c4SHansol Suh         __VA_ARGS__; \
317*ec9796c4SHansol Suh         PetscCallHIP(hipGetLastError()); \
318*ec9796c4SHansol Suh       } while (0)
319*ec9796c4SHansol Suh typedef hipStream_t cupmStream_t;
320*ec9796c4SHansol Suh   #endif
321*ec9796c4SHansol Suh 
322*ec9796c4SHansol Suh // x: on-process optimization variables
323*ec9796c4SHansol Suh // o: buffer that contains the next optimization variable after the variables on this process
324*ec9796c4SHansol Suh template <typename T>
325*ec9796c4SHansol Suh PETSC_DEVICE_INLINE_DECL static void rosenbrock_for_loop(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], T &&func) noexcept
326*ec9796c4SHansol Suh {
327*ec9796c4SHansol Suh   PetscInt idx         = blockIdx.x * blockDim.x + threadIdx.x; // 1D grid
328*ec9796c4SHansol Suh   PetscInt num_threads = gridDim.x * blockDim.x;
329*ec9796c4SHansol Suh 
330*ec9796c4SHansol Suh   for (PetscInt c = r.c_start + idx, k = idx; c < r.c_end; c += num_threads, k += num_threads) {
331*ec9796c4SHansol Suh     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
332*ec9796c4SHansol Suh     PetscScalar x_a = x[i - r.i_start];
333*ec9796c4SHansol Suh     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
334*ec9796c4SHansol Suh 
335*ec9796c4SHansol Suh     func(k, x_a, x_b);
336*ec9796c4SHansol Suh   }
337*ec9796c4SHansol Suh   return;
338*ec9796c4SHansol Suh }
339*ec9796c4SHansol Suh 
340*ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockObjective_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[])
341*ec9796c4SHansol Suh {
342*ec9796c4SHansol Suh   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjective(r.alpha, x_a, x_b); });
343*ec9796c4SHansol Suh }
344*ec9796c4SHansol Suh 
345*ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
346*ec9796c4SHansol Suh {
347*ec9796c4SHansol Suh   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]); });
348*ec9796c4SHansol Suh }
349*ec9796c4SHansol Suh 
350*ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockObjectiveGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[])
351*ec9796c4SHansol Suh {
352*ec9796c4SHansol Suh   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]); });
353*ec9796c4SHansol Suh }
354*ec9796c4SHansol Suh 
355*ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockHessian_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
356*ec9796c4SHansol Suh {
357*ec9796c4SHansol Suh   rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]); });
358*ec9796c4SHansol Suh }
359*ec9796c4SHansol Suh 
360*ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjective_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[])
361*ec9796c4SHansol Suh {
362*ec9796c4SHansol Suh   PetscInt n_comp = r.c_end - r.c_start;
363*ec9796c4SHansol Suh 
364*ec9796c4SHansol Suh   PetscFunctionBegin;
365*ec9796c4SHansol Suh   if (n_comp) PetscCUPMLaunch(RosenbrockObjective_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec));
366*ec9796c4SHansol Suh   PetscCall(PetscLogGpuFlops(RosenbrockObjectiveFlops * n_comp));
367*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
368*ec9796c4SHansol Suh }
369*ec9796c4SHansol Suh 
370*ec9796c4SHansol Suh static PetscErrorCode RosenbrockGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
371*ec9796c4SHansol Suh {
372*ec9796c4SHansol Suh   PetscInt n_comp = r.c_end - r.c_start;
373*ec9796c4SHansol Suh 
374*ec9796c4SHansol Suh   PetscFunctionBegin;
375*ec9796c4SHansol Suh   if (n_comp) PetscCUPMLaunch(RosenbrockGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, g));
376*ec9796c4SHansol Suh   PetscCall(PetscLogGpuFlops(RosenbrockGradientFlops * n_comp));
377*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
378*ec9796c4SHansol Suh }
379*ec9796c4SHansol Suh 
380*ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjectiveGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[])
381*ec9796c4SHansol Suh {
382*ec9796c4SHansol Suh   PetscInt n_comp = r.c_end - r.c_start;
383*ec9796c4SHansol Suh 
384*ec9796c4SHansol Suh   PetscFunctionBegin;
385*ec9796c4SHansol Suh   if (n_comp) PetscCUPMLaunch(RosenbrockObjectiveGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec, g));
386*ec9796c4SHansol Suh   PetscCall(PetscLogGpuFlops(RosenbrockObjectiveGradientFlops * n_comp));
387*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
388*ec9796c4SHansol Suh }
389*ec9796c4SHansol Suh 
390*ec9796c4SHansol Suh static PetscErrorCode RosenbrockHessian_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
391*ec9796c4SHansol Suh {
392*ec9796c4SHansol Suh   PetscInt n_comp = r.c_end - r.c_start;
393*ec9796c4SHansol Suh 
394*ec9796c4SHansol Suh   PetscFunctionBegin;
395*ec9796c4SHansol Suh   if (n_comp) PetscCUPMLaunch(RosenbrockHessian_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, h));
396*ec9796c4SHansol Suh   PetscCall(PetscLogGpuFlops(RosenbrockHessianFlops * n_comp));
397*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
398*ec9796c4SHansol Suh }
399*ec9796c4SHansol Suh #endif
400*ec9796c4SHansol Suh 
401*ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjective_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f)
402*ec9796c4SHansol Suh {
403*ec9796c4SHansol Suh   PetscReal _f = 0.0;
404*ec9796c4SHansol Suh 
405*ec9796c4SHansol Suh   PetscFunctionBegin;
406*ec9796c4SHansol Suh   for (PetscInt c = r.c_start; c < r.c_end; c++) {
407*ec9796c4SHansol Suh     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
408*ec9796c4SHansol Suh     PetscScalar x_a = x[i - r.i_start];
409*ec9796c4SHansol Suh     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
410*ec9796c4SHansol Suh 
411*ec9796c4SHansol Suh     _f += RosenbrockObjective(r.alpha, x_a, x_b);
412*ec9796c4SHansol Suh   }
413*ec9796c4SHansol Suh   *f = _f;
414*ec9796c4SHansol Suh   PetscCall(PetscLogFlops((RosenbrockObjectiveFlops + 1.0) * (r.c_end - r.c_start)));
415*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
416*ec9796c4SHansol Suh }
417*ec9796c4SHansol Suh 
418*ec9796c4SHansol Suh static PetscErrorCode RosenbrockGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[])
419*ec9796c4SHansol Suh {
420*ec9796c4SHansol Suh   PetscFunctionBegin;
421*ec9796c4SHansol Suh   for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
422*ec9796c4SHansol Suh     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
423*ec9796c4SHansol Suh     PetscScalar x_a = x[i - r.i_start];
424*ec9796c4SHansol Suh     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
425*ec9796c4SHansol Suh 
426*ec9796c4SHansol Suh     RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]);
427*ec9796c4SHansol Suh   }
428*ec9796c4SHansol Suh   PetscCall(PetscLogFlops(RosenbrockGradientFlops * (r.c_end - r.c_start)));
429*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
430*ec9796c4SHansol Suh }
431*ec9796c4SHansol Suh 
432*ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjectiveGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f, PetscScalar g[])
433*ec9796c4SHansol Suh {
434*ec9796c4SHansol Suh   PetscReal _f = 0.0;
435*ec9796c4SHansol Suh 
436*ec9796c4SHansol Suh   PetscFunctionBegin;
437*ec9796c4SHansol Suh   for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
438*ec9796c4SHansol Suh     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
439*ec9796c4SHansol Suh     PetscScalar x_a = x[i - r.i_start];
440*ec9796c4SHansol Suh     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
441*ec9796c4SHansol Suh 
442*ec9796c4SHansol Suh     _f += RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]);
443*ec9796c4SHansol Suh   }
444*ec9796c4SHansol Suh   *f = _f;
445*ec9796c4SHansol Suh   PetscCall(PetscLogFlops(RosenbrockObjectiveGradientFlops * (r.c_end - r.c_start)));
446*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
447*ec9796c4SHansol Suh }
448*ec9796c4SHansol Suh 
449*ec9796c4SHansol Suh static PetscErrorCode RosenbrockHessian_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[])
450*ec9796c4SHansol Suh {
451*ec9796c4SHansol Suh   PetscFunctionBegin;
452*ec9796c4SHansol Suh   for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) {
453*ec9796c4SHansol Suh     PetscInt    i   = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1));
454*ec9796c4SHansol Suh     PetscScalar x_a = x[i - r.i_start];
455*ec9796c4SHansol Suh     PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0];
456*ec9796c4SHansol Suh 
457*ec9796c4SHansol Suh     RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]);
458*ec9796c4SHansol Suh   }
459*ec9796c4SHansol Suh   PetscCall(PetscLogFlops(RosenbrockHessianFlops * (r.c_end - r.c_start)));
460*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
461*ec9796c4SHansol Suh }
462*ec9796c4SHansol Suh 
463*ec9796c4SHansol Suh /* -------------------------------------------------------------------- */
464*ec9796c4SHansol Suh 
465*ec9796c4SHansol Suh static PetscErrorCode FormObjective(Tao tao, Vec X, PetscReal *f, void *ptr)
466*ec9796c4SHansol Suh {
467*ec9796c4SHansol Suh   AppCtx             user    = (AppCtx)ptr;
468*ec9796c4SHansol Suh   PetscReal          f_local = 0.0;
469*ec9796c4SHansol Suh   const PetscScalar *x;
470*ec9796c4SHansol Suh   const PetscScalar *o = NULL;
471*ec9796c4SHansol Suh   PetscMemType       memtype_x;
472*ec9796c4SHansol Suh 
473*ec9796c4SHansol Suh   PetscFunctionBeginUser;
474*ec9796c4SHansol Suh   PetscCall(PetscLogEventBegin(user->event_f, tao, NULL, NULL, NULL));
475*ec9796c4SHansol Suh   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
476*ec9796c4SHansol Suh   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
477*ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
478*ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
479*ec9796c4SHansol Suh   if (memtype_x == PETSC_MEMTYPE_HOST) {
480*ec9796c4SHansol Suh     PetscCall(RosenbrockObjective_Host(user->problem, x, o, &f_local));
481*ec9796c4SHansol Suh     PetscCallMPI(MPI_Allreduce(&f_local, f, 1, MPI_DOUBLE, MPI_SUM, user->comm));
482*ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC)
483*ec9796c4SHansol Suh   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
484*ec9796c4SHansol Suh     PetscScalar       *_fvec;
485*ec9796c4SHansol Suh     PetscScalar        f_scalar;
486*ec9796c4SHansol Suh     cupmStream_t      *stream;
487*ec9796c4SHansol Suh     PetscDeviceContext dctx;
488*ec9796c4SHansol Suh 
489*ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
490*ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
491*ec9796c4SHansol Suh     PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL));
492*ec9796c4SHansol Suh     PetscCall(RosenbrockObjective_Device(*stream, user->problem, x, o, _fvec));
493*ec9796c4SHansol Suh     PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec));
494*ec9796c4SHansol Suh     PetscCall(VecSum(user->fvector, &f_scalar));
495*ec9796c4SHansol Suh     *f = PetscRealPart(f_scalar);
496*ec9796c4SHansol Suh #endif
497*ec9796c4SHansol Suh   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsuported memtype %d", (int)memtype_x);
498*ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
499*ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
500*ec9796c4SHansol Suh   PetscCall(PetscLogEventEnd(user->event_f, tao, NULL, NULL, NULL));
501*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
502*ec9796c4SHansol Suh }
503*ec9796c4SHansol Suh 
504*ec9796c4SHansol Suh static PetscErrorCode FormGradient(Tao tao, Vec X, Vec G, void *ptr)
505*ec9796c4SHansol Suh {
506*ec9796c4SHansol Suh   AppCtx             user = (AppCtx)ptr;
507*ec9796c4SHansol Suh   PetscScalar       *g;
508*ec9796c4SHansol Suh   const PetscScalar *x;
509*ec9796c4SHansol Suh   const PetscScalar *o = NULL;
510*ec9796c4SHansol Suh   PetscMemType       memtype_x, memtype_g;
511*ec9796c4SHansol Suh 
512*ec9796c4SHansol Suh   PetscFunctionBeginUser;
513*ec9796c4SHansol Suh   PetscCall(PetscLogEventBegin(user->event_g, tao, NULL, NULL, NULL));
514*ec9796c4SHansol Suh   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
515*ec9796c4SHansol Suh   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
516*ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
517*ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
518*ec9796c4SHansol Suh   PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g));
519*ec9796c4SHansol Suh   PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype");
520*ec9796c4SHansol Suh   if (memtype_x == PETSC_MEMTYPE_HOST) {
521*ec9796c4SHansol Suh     PetscCall(RosenbrockGradient_Host(user->problem, x, o, g));
522*ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC)
523*ec9796c4SHansol Suh   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
524*ec9796c4SHansol Suh     cupmStream_t      *stream;
525*ec9796c4SHansol Suh     PetscDeviceContext dctx;
526*ec9796c4SHansol Suh 
527*ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
528*ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
529*ec9796c4SHansol Suh     PetscCall(RosenbrockGradient_Device(*stream, user->problem, x, o, g));
530*ec9796c4SHansol Suh #endif
531*ec9796c4SHansol Suh   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsuported memtype %d", (int)memtype_x);
532*ec9796c4SHansol Suh   PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g));
533*ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
534*ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
535*ec9796c4SHansol Suh   PetscCall(VecZeroEntries(G));
536*ec9796c4SHansol Suh   PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
537*ec9796c4SHansol Suh   PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
538*ec9796c4SHansol Suh   PetscCall(PetscLogEventEnd(user->event_g, tao, NULL, NULL, NULL));
539*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
540*ec9796c4SHansol Suh }
541*ec9796c4SHansol Suh 
542*ec9796c4SHansol Suh /*
543*ec9796c4SHansol Suh     FormObjectiveGradient - Evaluates the function, f(X), and gradient, G(X).
544*ec9796c4SHansol Suh 
545*ec9796c4SHansol Suh     Input Parameters:
546*ec9796c4SHansol Suh .   tao  - the Tao context
547*ec9796c4SHansol Suh .   X    - input vector
548*ec9796c4SHansol Suh .   ptr  - optional user-defined context, as set by TaoSetObjectiveGradient()
549*ec9796c4SHansol Suh 
550*ec9796c4SHansol Suh     Output Parameters:
551*ec9796c4SHansol Suh .   G - vector containing the newly evaluated gradient
552*ec9796c4SHansol Suh .   f - function value
553*ec9796c4SHansol Suh 
554*ec9796c4SHansol Suh     Note:
555*ec9796c4SHansol Suh     Some optimization methods ask for the function and the gradient evaluation
556*ec9796c4SHansol Suh     at the same time.  Evaluating both at once may be more efficient that
557*ec9796c4SHansol Suh     evaluating each separately.
558*ec9796c4SHansol Suh */
559*ec9796c4SHansol Suh static PetscErrorCode FormObjectiveGradient(Tao tao, Vec X, PetscReal *f, Vec G, void *ptr)
560*ec9796c4SHansol Suh {
561*ec9796c4SHansol Suh   AppCtx             user    = (AppCtx)ptr;
562*ec9796c4SHansol Suh   PetscReal          f_local = 0.0;
563*ec9796c4SHansol Suh   PetscScalar       *g;
564*ec9796c4SHansol Suh   const PetscScalar *x;
565*ec9796c4SHansol Suh   const PetscScalar *o = NULL;
566*ec9796c4SHansol Suh   PetscMemType       memtype_x, memtype_g;
567*ec9796c4SHansol Suh 
568*ec9796c4SHansol Suh   PetscFunctionBeginUser;
569*ec9796c4SHansol Suh   PetscCall(PetscLogEventBegin(user->event_fg, tao, NULL, NULL, NULL));
570*ec9796c4SHansol Suh   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
571*ec9796c4SHansol Suh   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
572*ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
573*ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
574*ec9796c4SHansol Suh   PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g));
575*ec9796c4SHansol Suh   PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype");
576*ec9796c4SHansol Suh   if (memtype_x == PETSC_MEMTYPE_HOST) {
577*ec9796c4SHansol Suh     PetscCall(RosenbrockObjectiveGradient_Host(user->problem, x, o, &f_local, g));
578*ec9796c4SHansol Suh     PetscCallMPI(MPI_Allreduce((void *)&f_local, (void *)f, 1, MPI_DOUBLE, MPI_SUM, PETSC_COMM_WORLD));
579*ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC)
580*ec9796c4SHansol Suh   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
581*ec9796c4SHansol Suh     PetscScalar       *_fvec;
582*ec9796c4SHansol Suh     PetscScalar        f_scalar;
583*ec9796c4SHansol Suh     cupmStream_t      *stream;
584*ec9796c4SHansol Suh     PetscDeviceContext dctx;
585*ec9796c4SHansol Suh 
586*ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
587*ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
588*ec9796c4SHansol Suh     PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL));
589*ec9796c4SHansol Suh     PetscCall(RosenbrockObjectiveGradient_Device(*stream, user->problem, x, o, _fvec, g));
590*ec9796c4SHansol Suh     PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec));
591*ec9796c4SHansol Suh     PetscCall(VecSum(user->fvector, &f_scalar));
592*ec9796c4SHansol Suh     *f = PetscRealPart(f_scalar);
593*ec9796c4SHansol Suh #endif
594*ec9796c4SHansol Suh   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsuported memtype %d", (int)memtype_x);
595*ec9796c4SHansol Suh 
596*ec9796c4SHansol Suh   PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g));
597*ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
598*ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
599*ec9796c4SHansol Suh   PetscCall(VecZeroEntries(G));
600*ec9796c4SHansol Suh   PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
601*ec9796c4SHansol Suh   PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE));
602*ec9796c4SHansol Suh   PetscCall(PetscLogEventEnd(user->event_fg, tao, NULL, NULL, NULL));
603*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
604*ec9796c4SHansol Suh }
605*ec9796c4SHansol Suh 
606*ec9796c4SHansol Suh /* ------------------------------------------------------------------- */
607*ec9796c4SHansol Suh /*
608*ec9796c4SHansol Suh    FormHessian - Evaluates Hessian matrix.
609*ec9796c4SHansol Suh 
610*ec9796c4SHansol Suh    Input Parameters:
611*ec9796c4SHansol Suh .  tao   - the Tao context
612*ec9796c4SHansol Suh .  x     - input vector
613*ec9796c4SHansol Suh .  ptr   - optional user-defined context, as set by TaoSetHessian()
614*ec9796c4SHansol Suh 
615*ec9796c4SHansol Suh    Output Parameters:
616*ec9796c4SHansol Suh .  H     - Hessian matrix
617*ec9796c4SHansol Suh 
618*ec9796c4SHansol Suh    Note:  Providing the Hessian may not be necessary.  Only some solvers
619*ec9796c4SHansol Suh    require this matrix.
620*ec9796c4SHansol Suh */
621*ec9796c4SHansol Suh static PetscErrorCode FormHessian(Tao tao, Vec X, Mat H, Mat Hpre, void *ptr)
622*ec9796c4SHansol Suh {
623*ec9796c4SHansol Suh   AppCtx             user = (AppCtx)ptr;
624*ec9796c4SHansol Suh   PetscScalar       *h;
625*ec9796c4SHansol Suh   const PetscScalar *x;
626*ec9796c4SHansol Suh   const PetscScalar *o = NULL;
627*ec9796c4SHansol Suh   PetscMemType       memtype_x, memtype_h;
628*ec9796c4SHansol Suh 
629*ec9796c4SHansol Suh   PetscFunctionBeginUser;
630*ec9796c4SHansol Suh   PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
631*ec9796c4SHansol Suh   PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD));
632*ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL));
633*ec9796c4SHansol Suh   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x));
634*ec9796c4SHansol Suh   PetscCall(VecGetArrayWriteAndMemType(user->Hvalues, &h, &memtype_h));
635*ec9796c4SHansol Suh   PetscAssert(memtype_x == memtype_h, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and hessian must have save memtype");
636*ec9796c4SHansol Suh   if (memtype_x == PETSC_MEMTYPE_HOST) {
637*ec9796c4SHansol Suh     PetscCall(RosenbrockHessian_Host(user->problem, x, o, h));
638*ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC)
639*ec9796c4SHansol Suh   } else if (memtype_x == PETSC_MEMTYPE_DEVICE) {
640*ec9796c4SHansol Suh     cupmStream_t      *stream;
641*ec9796c4SHansol Suh     PetscDeviceContext dctx;
642*ec9796c4SHansol Suh 
643*ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
644*ec9796c4SHansol Suh     PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
645*ec9796c4SHansol Suh     PetscCall(RosenbrockHessian_Device(*stream, user->problem, x, o, h));
646*ec9796c4SHansol Suh #endif
647*ec9796c4SHansol Suh   } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsuported memtype %d", (int)memtype_x);
648*ec9796c4SHansol Suh 
649*ec9796c4SHansol Suh   PetscCall(MatSetValuesCOO(H, h, INSERT_VALUES));
650*ec9796c4SHansol Suh   PetscCall(VecRestoreArrayWriteAndMemType(user->Hvalues, &h));
651*ec9796c4SHansol Suh 
652*ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(X, &x));
653*ec9796c4SHansol Suh   PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o));
654*ec9796c4SHansol Suh 
655*ec9796c4SHansol Suh   if (Hpre != H) PetscCall(MatCopy(H, Hpre, SAME_NONZERO_PATTERN));
656*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
657*ec9796c4SHansol Suh }
658*ec9796c4SHansol Suh 
659*ec9796c4SHansol Suh static PetscErrorCode TestLMVM(Tao tao)
660*ec9796c4SHansol Suh {
661*ec9796c4SHansol Suh   KSP       ksp;
662*ec9796c4SHansol Suh   PC        pc;
663*ec9796c4SHansol Suh   PetscBool is_lmvm;
664*ec9796c4SHansol Suh 
665*ec9796c4SHansol Suh   PetscFunctionBegin;
666*ec9796c4SHansol Suh   PetscCall(TaoGetKSP(tao, &ksp));
667*ec9796c4SHansol Suh   if (!ksp) PetscFunctionReturn(PETSC_SUCCESS);
668*ec9796c4SHansol Suh   PetscCall(KSPGetPC(ksp, &pc));
669*ec9796c4SHansol Suh   PetscCall(PetscObjectTypeCompare((PetscObject)pc, PCLMVM, &is_lmvm));
670*ec9796c4SHansol Suh   if (is_lmvm) {
671*ec9796c4SHansol Suh     Mat       M;
672*ec9796c4SHansol Suh     Vec       in, out, out2;
673*ec9796c4SHansol Suh     PetscReal mult_solve_dist;
674*ec9796c4SHansol Suh     Vec       x;
675*ec9796c4SHansol Suh 
676*ec9796c4SHansol Suh     PetscCall(PCLMVMGetMatLMVM(pc, &M));
677*ec9796c4SHansol Suh     PetscCall(TaoGetSolution(tao, &x));
678*ec9796c4SHansol Suh     PetscCall(VecDuplicate(x, &in));
679*ec9796c4SHansol Suh     PetscCall(VecDuplicate(x, &out));
680*ec9796c4SHansol Suh     PetscCall(VecDuplicate(x, &out2));
681*ec9796c4SHansol Suh     PetscCall(VecSetRandom(in, NULL));
682*ec9796c4SHansol Suh     PetscCall(MatMult(M, in, out));
683*ec9796c4SHansol Suh     PetscCall(MatSolve(M, out, out2));
684*ec9796c4SHansol Suh 
685*ec9796c4SHansol Suh     PetscCall(VecAXPY(out2, -1.0, in));
686*ec9796c4SHansol Suh     PetscCall(VecNorm(out2, NORM_2, &mult_solve_dist));
687*ec9796c4SHansol Suh     if (mult_solve_dist < 1.e-11) {
688*ec9796c4SHansol Suh       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-11\n"));
689*ec9796c4SHansol Suh     } else if (mult_solve_dist < 1.e-6) {
690*ec9796c4SHansol Suh       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-6\n"));
691*ec9796c4SHansol Suh     } else {
692*ec9796c4SHansol Suh       PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve is not small: %e\n", (double)mult_solve_dist));
693*ec9796c4SHansol Suh     }
694*ec9796c4SHansol Suh     PetscCall(VecDestroy(&in));
695*ec9796c4SHansol Suh     PetscCall(VecDestroy(&out));
696*ec9796c4SHansol Suh     PetscCall(VecDestroy(&out2));
697*ec9796c4SHansol Suh   }
698*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
699*ec9796c4SHansol Suh }
700*ec9796c4SHansol Suh 
701*ec9796c4SHansol Suh static PetscErrorCode RosenbrockMain(void)
702*ec9796c4SHansol Suh {
703*ec9796c4SHansol Suh   Vec           x;    /* solution vector */
704*ec9796c4SHansol Suh   Vec           g;    /* gradient vector */
705*ec9796c4SHansol Suh   Mat           H;    /* Hessian matrix */
706*ec9796c4SHansol Suh   Tao           tao;  /* Tao solver context */
707*ec9796c4SHansol Suh   AppCtx        user; /* user-defined application context */
708*ec9796c4SHansol Suh   PetscLogStage solve;
709*ec9796c4SHansol Suh 
710*ec9796c4SHansol Suh   /* Initialize TAO and PETSc */
711*ec9796c4SHansol Suh   PetscFunctionBegin;
712*ec9796c4SHansol Suh   PetscCall(PetscLogStageRegister("Rosenbrock solve", &solve));
713*ec9796c4SHansol Suh 
714*ec9796c4SHansol Suh   PetscCall(AppCtxCreate(PETSC_COMM_WORLD, &user));
715*ec9796c4SHansol Suh   PetscCall(CreateHessian(user, &H));
716*ec9796c4SHansol Suh   PetscCall(CreateVectors(user, H, &x, &g));
717*ec9796c4SHansol Suh 
718*ec9796c4SHansol Suh   /* The TAO code begins here */
719*ec9796c4SHansol Suh 
720*ec9796c4SHansol Suh   PetscCall(TaoCreate(user->comm, &tao));
721*ec9796c4SHansol Suh   PetscCall(VecZeroEntries(x));
722*ec9796c4SHansol Suh   PetscCall(TaoSetSolution(tao, x));
723*ec9796c4SHansol Suh 
724*ec9796c4SHansol Suh   /* Set routines for function, gradient, hessian evaluation */
725*ec9796c4SHansol Suh   PetscCall(TaoSetObjective(tao, FormObjective, user));
726*ec9796c4SHansol Suh   PetscCall(TaoSetObjectiveAndGradient(tao, g, FormObjectiveGradient, user));
727*ec9796c4SHansol Suh   PetscCall(TaoSetGradient(tao, g, FormGradient, user));
728*ec9796c4SHansol Suh   PetscCall(TaoSetHessian(tao, H, H, FormHessian, user));
729*ec9796c4SHansol Suh 
730*ec9796c4SHansol Suh   PetscCall(TaoSetFromOptions(tao));
731*ec9796c4SHansol Suh 
732*ec9796c4SHansol Suh   /* SOLVE THE APPLICATION */
733*ec9796c4SHansol Suh   PetscCall(PetscLogStagePush(solve));
734*ec9796c4SHansol Suh   PetscCall(TaoSolve(tao));
735*ec9796c4SHansol Suh   PetscCall(PetscLogStagePop());
736*ec9796c4SHansol Suh 
737*ec9796c4SHansol Suh   if (user->test_lmvm) PetscCall(TestLMVM(tao));
738*ec9796c4SHansol Suh 
739*ec9796c4SHansol Suh   PetscCall(TaoDestroy(&tao));
740*ec9796c4SHansol Suh   PetscCall(VecDestroy(&g));
741*ec9796c4SHansol Suh   PetscCall(VecDestroy(&x));
742*ec9796c4SHansol Suh   PetscCall(MatDestroy(&H));
743*ec9796c4SHansol Suh   PetscCall(AppCtxDestroy(&user));
744*ec9796c4SHansol Suh   PetscFunctionReturn(PETSC_SUCCESS);
745*ec9796c4SHansol Suh }
746