xref: /petsc/src/ts/tutorials/ex23fwdadj.c (revision 48a46eb9bd028bec07ec0f396b1a3abb43f14558)
182ad9101SHong Zhang static char help[] = "A toy example for testing forward and adjoint sensitivity analysis of an implicit ODE with a paramerized mass matrice.\n";
282ad9101SHong Zhang 
382ad9101SHong Zhang /*
482ad9101SHong Zhang   This example solves the simple ODE
582ad9101SHong Zhang     c x' = b x, x(0) = a,
682ad9101SHong Zhang   whose analytical solution is x(T)=a*exp(b/c*T), and calculates the derivative of x(T) w.r.t. c (by default) or w.r.t. b (can be enabled with command line option -der 2).
782ad9101SHong Zhang 
882ad9101SHong Zhang */
982ad9101SHong Zhang 
1082ad9101SHong Zhang #include <petscts.h>
1182ad9101SHong Zhang 
1282ad9101SHong Zhang typedef struct _n_User *User;
1382ad9101SHong Zhang struct _n_User {
1482ad9101SHong Zhang   PetscReal a;
1582ad9101SHong Zhang   PetscReal b;
1682ad9101SHong Zhang   PetscReal c;
1782ad9101SHong Zhang   /* Sensitivity analysis support */
1882ad9101SHong Zhang   PetscInt  steps;
1982ad9101SHong Zhang   PetscReal ftime;
2082ad9101SHong Zhang   Mat       Jac;  /* Jacobian matrix */
2182ad9101SHong Zhang   Mat       Jacp; /* JacobianP matrix */
2282ad9101SHong Zhang   Vec       x;
2382ad9101SHong Zhang   Mat       sp;        /* forward sensitivity variables */
2482ad9101SHong Zhang   Vec       lambda[1]; /* adjoint sensitivity variables */
2582ad9101SHong Zhang   Vec       mup[1];    /* adjoint sensitivity variables */
2682ad9101SHong Zhang   PetscInt  der;
2782ad9101SHong Zhang };
2882ad9101SHong Zhang 
299371c9d4SSatish Balay static PetscErrorCode IFunction(TS ts, PetscReal t, Vec X, Vec Xdot, Vec F, void *ctx) {
3082ad9101SHong Zhang   User               user = (User)ctx;
3182ad9101SHong Zhang   const PetscScalar *x, *xdot;
3282ad9101SHong Zhang   PetscScalar       *f;
3382ad9101SHong Zhang 
3482ad9101SHong Zhang   PetscFunctionBeginUser;
359566063dSJacob Faibussowitsch   PetscCall(VecGetArrayRead(X, &x));
369566063dSJacob Faibussowitsch   PetscCall(VecGetArrayRead(Xdot, &xdot));
379566063dSJacob Faibussowitsch   PetscCall(VecGetArrayWrite(F, &f));
3882ad9101SHong Zhang   f[0] = user->c * xdot[0] - user->b * x[0];
399566063dSJacob Faibussowitsch   PetscCall(VecRestoreArrayRead(X, &x));
409566063dSJacob Faibussowitsch   PetscCall(VecRestoreArrayRead(Xdot, &xdot));
419566063dSJacob Faibussowitsch   PetscCall(VecRestoreArrayWrite(F, &f));
4282ad9101SHong Zhang   PetscFunctionReturn(0);
4382ad9101SHong Zhang }
4482ad9101SHong Zhang 
459371c9d4SSatish Balay static PetscErrorCode IJacobian(TS ts, PetscReal t, Vec X, Vec Xdot, PetscReal a, Mat A, Mat B, void *ctx) {
4682ad9101SHong Zhang   User               user     = (User)ctx;
4782ad9101SHong Zhang   PetscInt           rowcol[] = {0};
4882ad9101SHong Zhang   PetscScalar        J[1][1];
4982ad9101SHong Zhang   const PetscScalar *x;
5082ad9101SHong Zhang 
5182ad9101SHong Zhang   PetscFunctionBeginUser;
529566063dSJacob Faibussowitsch   PetscCall(VecGetArrayRead(X, &x));
5382ad9101SHong Zhang   J[0][0] = user->c * a - user->b * 1.0;
549566063dSJacob Faibussowitsch   PetscCall(MatSetValues(B, 1, rowcol, 1, rowcol, &J[0][0], INSERT_VALUES));
559566063dSJacob Faibussowitsch   PetscCall(VecRestoreArrayRead(X, &x));
5682ad9101SHong Zhang 
579566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
589566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
5982ad9101SHong Zhang   if (A != B) {
609566063dSJacob Faibussowitsch     PetscCall(MatAssemblyBegin(B, MAT_FINAL_ASSEMBLY));
619566063dSJacob Faibussowitsch     PetscCall(MatAssemblyEnd(B, MAT_FINAL_ASSEMBLY));
6282ad9101SHong Zhang   }
6382ad9101SHong Zhang   PetscFunctionReturn(0);
6482ad9101SHong Zhang }
6582ad9101SHong Zhang 
669371c9d4SSatish Balay static PetscErrorCode IJacobianP(TS ts, PetscReal t, Vec X, Vec Xdot, PetscReal shift, Mat A, void *ctx) {
6782ad9101SHong Zhang   User               user  = (User)ctx;
6882ad9101SHong Zhang   PetscInt           row[] = {0}, col[] = {0};
6982ad9101SHong Zhang   PetscScalar        J[1][1];
7082ad9101SHong Zhang   const PetscScalar *x, *xdot;
7182ad9101SHong Zhang   PetscReal          dt;
7282ad9101SHong Zhang 
7382ad9101SHong Zhang   PetscFunctionBeginUser;
749566063dSJacob Faibussowitsch   PetscCall(VecGetArrayRead(X, &x));
759566063dSJacob Faibussowitsch   PetscCall(VecGetArrayRead(Xdot, &xdot));
769566063dSJacob Faibussowitsch   PetscCall(TSGetTimeStep(ts, &dt));
7782ad9101SHong Zhang   if (user->der == 1) J[0][0] = xdot[0];
7882ad9101SHong Zhang   if (user->der == 2) J[0][0] = -x[0];
799566063dSJacob Faibussowitsch   PetscCall(MatSetValues(A, 1, row, 1, col, &J[0][0], INSERT_VALUES));
809566063dSJacob Faibussowitsch   PetscCall(VecRestoreArrayRead(X, &x));
8182ad9101SHong Zhang 
829566063dSJacob Faibussowitsch   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
839566063dSJacob Faibussowitsch   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
8482ad9101SHong Zhang   PetscFunctionReturn(0);
8582ad9101SHong Zhang }
8682ad9101SHong Zhang 
879371c9d4SSatish Balay int main(int argc, char **argv) {
8882ad9101SHong Zhang   TS             ts;
8982ad9101SHong Zhang   PetscScalar   *x_ptr;
9082ad9101SHong Zhang   PetscMPIInt    size;
9182ad9101SHong Zhang   struct _n_User user;
9282ad9101SHong Zhang   PetscInt       rows, cols;
9382ad9101SHong Zhang 
94327415f7SBarry Smith   PetscFunctionBeginUser;
959566063dSJacob Faibussowitsch   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
9682ad9101SHong Zhang 
979566063dSJacob Faibussowitsch   PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD, &size));
983c633725SBarry Smith   PetscCheck(size == 1, PETSC_COMM_WORLD, PETSC_ERR_WRONG_MPI_SIZE, "This is a uniprocessor example only!");
9982ad9101SHong Zhang 
10082ad9101SHong Zhang   user.a     = 2.0;
10182ad9101SHong Zhang   user.b     = 4.0;
10282ad9101SHong Zhang   user.c     = 3.0;
10382ad9101SHong Zhang   user.steps = 0;
10482ad9101SHong Zhang   user.ftime = 1.0;
10582ad9101SHong Zhang   user.der   = 1;
1069566063dSJacob Faibussowitsch   PetscCall(PetscOptionsGetInt(NULL, NULL, "-der", &user.der, NULL));
10782ad9101SHong Zhang 
10882ad9101SHong Zhang   rows = 1;
10982ad9101SHong Zhang   cols = 1;
1109566063dSJacob Faibussowitsch   PetscCall(MatCreate(PETSC_COMM_WORLD, &user.Jac));
1119566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(user.Jac, PETSC_DECIDE, PETSC_DECIDE, 1, 1));
1129566063dSJacob Faibussowitsch   PetscCall(MatSetFromOptions(user.Jac));
1139566063dSJacob Faibussowitsch   PetscCall(MatSetUp(user.Jac));
1149566063dSJacob Faibussowitsch   PetscCall(MatCreateVecs(user.Jac, &user.x, NULL));
11582ad9101SHong Zhang 
1169566063dSJacob Faibussowitsch   PetscCall(TSCreate(PETSC_COMM_WORLD, &ts));
1179566063dSJacob Faibussowitsch   PetscCall(TSSetType(ts, TSBEULER));
1189566063dSJacob Faibussowitsch   PetscCall(TSSetIFunction(ts, NULL, IFunction, &user));
1199566063dSJacob Faibussowitsch   PetscCall(TSSetIJacobian(ts, user.Jac, user.Jac, IJacobian, &user));
1209566063dSJacob Faibussowitsch   PetscCall(TSSetExactFinalTime(ts, TS_EXACTFINALTIME_MATCHSTEP));
1219566063dSJacob Faibussowitsch   PetscCall(TSSetMaxTime(ts, user.ftime));
12282ad9101SHong Zhang 
1239566063dSJacob Faibussowitsch   PetscCall(VecGetArrayWrite(user.x, &x_ptr));
12482ad9101SHong Zhang   x_ptr[0] = user.a;
1259566063dSJacob Faibussowitsch   PetscCall(VecRestoreArrayWrite(user.x, &x_ptr));
1269566063dSJacob Faibussowitsch   PetscCall(TSSetTimeStep(ts, 0.001));
12782ad9101SHong Zhang 
12882ad9101SHong Zhang   /* Set up forward sensitivity */
1299566063dSJacob Faibussowitsch   PetscCall(MatCreate(PETSC_COMM_WORLD, &user.Jacp));
1309566063dSJacob Faibussowitsch   PetscCall(MatSetSizes(user.Jacp, PETSC_DECIDE, PETSC_DECIDE, rows, cols));
1319566063dSJacob Faibussowitsch   PetscCall(MatSetFromOptions(user.Jacp));
1329566063dSJacob Faibussowitsch   PetscCall(MatSetUp(user.Jacp));
1339566063dSJacob Faibussowitsch   PetscCall(MatCreateDense(PETSC_COMM_WORLD, PETSC_DECIDE, PETSC_DECIDE, rows, cols, NULL, &user.sp));
1349566063dSJacob Faibussowitsch   PetscCall(MatZeroEntries(user.sp));
1359566063dSJacob Faibussowitsch   PetscCall(TSForwardSetSensitivities(ts, cols, user.sp));
1369566063dSJacob Faibussowitsch   PetscCall(TSSetIJacobianP(ts, user.Jacp, IJacobianP, &user));
13782ad9101SHong Zhang 
1389566063dSJacob Faibussowitsch   PetscCall(TSSetSaveTrajectory(ts));
1399566063dSJacob Faibussowitsch   PetscCall(TSSetFromOptions(ts));
14082ad9101SHong Zhang 
1419566063dSJacob Faibussowitsch   PetscCall(TSSolve(ts, user.x));
1429566063dSJacob Faibussowitsch   PetscCall(TSGetSolveTime(ts, &user.ftime));
1439566063dSJacob Faibussowitsch   PetscCall(TSGetStepNumber(ts, &user.steps));
1449566063dSJacob Faibussowitsch   PetscCall(VecGetArray(user.x, &x_ptr));
1459566063dSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n ode solution %g\n", (double)PetscRealPart(x_ptr[0])));
1469566063dSJacob Faibussowitsch   PetscCall(VecRestoreArray(user.x, &x_ptr));
14763a3b9bcSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n analytical solution %g\n", (double)(user.a * PetscExpReal(user.b / user.c * user.ftime))));
14882ad9101SHong Zhang 
149*48a46eb9SPierre Jolivet   if (user.der == 1) PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n analytical derivative w.r.t. c %g\n", (double)(-user.a * user.ftime * user.b / (user.c * user.c) * PetscExpReal(user.b / user.c * user.ftime))));
150*48a46eb9SPierre Jolivet   if (user.der == 2) PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n analytical derivative w.r.t. b %g\n", (double)(user.a * user.ftime / user.c * PetscExpReal(user.b / user.c * user.ftime))));
1519566063dSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n forward sensitivity:\n"));
1529566063dSJacob Faibussowitsch   PetscCall(MatView(user.sp, PETSC_VIEWER_STDOUT_WORLD));
15382ad9101SHong Zhang 
1549566063dSJacob Faibussowitsch   PetscCall(MatCreateVecs(user.Jac, &user.lambda[0], NULL));
15582ad9101SHong Zhang   /* Set initial conditions for the adjoint integration */
1569566063dSJacob Faibussowitsch   PetscCall(VecGetArrayWrite(user.lambda[0], &x_ptr));
15782ad9101SHong Zhang   x_ptr[0] = 1.0;
1589566063dSJacob Faibussowitsch   PetscCall(VecRestoreArrayWrite(user.lambda[0], &x_ptr));
1599566063dSJacob Faibussowitsch   PetscCall(MatCreateVecs(user.Jacp, &user.mup[0], NULL));
1609566063dSJacob Faibussowitsch   PetscCall(VecGetArrayWrite(user.mup[0], &x_ptr));
16182ad9101SHong Zhang   x_ptr[0] = 0.0;
1629566063dSJacob Faibussowitsch   PetscCall(VecRestoreArrayWrite(user.mup[0], &x_ptr));
16382ad9101SHong Zhang 
1649566063dSJacob Faibussowitsch   PetscCall(TSSetCostGradients(ts, 1, user.lambda, user.mup));
1659566063dSJacob Faibussowitsch   PetscCall(TSAdjointSolve(ts));
16682ad9101SHong Zhang 
1679566063dSJacob Faibussowitsch   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n adjoint sensitivity:\n"));
1689566063dSJacob Faibussowitsch   PetscCall(VecView(user.mup[0], PETSC_VIEWER_STDOUT_WORLD));
16982ad9101SHong Zhang 
1709566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&user.Jac));
1719566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&user.sp));
1729566063dSJacob Faibussowitsch   PetscCall(MatDestroy(&user.Jacp));
1739566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&user.x));
1749566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&user.lambda[0]));
1759566063dSJacob Faibussowitsch   PetscCall(VecDestroy(&user.mup[0]));
1769566063dSJacob Faibussowitsch   PetscCall(TSDestroy(&ts));
17782ad9101SHong Zhang 
1789566063dSJacob Faibussowitsch   PetscCall(PetscFinalize());
179b122ec5aSJacob Faibussowitsch   return 0;
18082ad9101SHong Zhang }
18182ad9101SHong Zhang 
18282ad9101SHong Zhang /*TEST
18382ad9101SHong Zhang 
18482ad9101SHong Zhang     test:
18582ad9101SHong Zhang       args: -ts_type beuler
18682ad9101SHong Zhang 
18782ad9101SHong Zhang     test:
18882ad9101SHong Zhang       suffix: 2
18982ad9101SHong Zhang       args: -ts_type cn
19082ad9101SHong Zhang       output_file: output/ex23fwdadj_1.out
19182ad9101SHong Zhang 
19282ad9101SHong Zhang TEST*/
193