xref: /petsc/src/binding/petsc4py/demo/legacy/ode/heat.py (revision 697771379c1747b47000b38c62729b9fba7ac359)
155a74a43SLisandro Dalcin# Solves Heat equation on a periodic domain, using raw VecScatter
2*69777137SStefano Zampiniimport sys
3*69777137SStefano Zampiniimport petsc4py
4*69777137SStefano Zampini
555a74a43SLisandro Dalcinpetsc4py.init(sys.argv)
655a74a43SLisandro Dalcin
755a74a43SLisandro Dalcinfrom petsc4py import PETSc
855a74a43SLisandro Dalcinfrom mpi4py import MPI
955a74a43SLisandro Dalcinimport numpy
1055a74a43SLisandro Dalcin
11*69777137SStefano Zampini
12*69777137SStefano Zampiniclass Heat:
1355a74a43SLisandro Dalcin    def __init__(self, comm, N):
1455a74a43SLisandro Dalcin        self.comm = comm
1555a74a43SLisandro Dalcin        self.N = N  # global problem size
1655a74a43SLisandro Dalcin        self.h = 1 / N  # grid spacing on unit interval
17*69777137SStefano Zampini        self.n = N // comm.size + int(
18*69777137SStefano Zampini            comm.rank < (N % comm.size)
19*69777137SStefano Zampini        )  # owned part of global problem
2055a74a43SLisandro Dalcin        self.start = comm.exscan(self.n)
21*69777137SStefano Zampini        if comm.rank == 0:
22*69777137SStefano Zampini            self.start = 0
23*69777137SStefano Zampini        gindices = (
24*69777137SStefano Zampini            numpy.arange(self.start - 1, self.start + self.n + 1, dtype=PETSc.IntType)
25*69777137SStefano Zampini            % N
26*69777137SStefano Zampini        )  # periodic
2755a74a43SLisandro Dalcin        self.mat = PETSc.Mat().create(comm=comm)
2855a74a43SLisandro Dalcin        size = (self.n, self.N)  # local and global sizes
2955a74a43SLisandro Dalcin        self.mat.setSizes((size, size))
3055a74a43SLisandro Dalcin        self.mat.setFromOptions()
31*69777137SStefano Zampini        self.mat.setPreallocationNNZ(
32*69777137SStefano Zampini            (3, 1)
33*69777137SStefano Zampini        )  # Conservative preallocation for 3 "local" columns and one non-local
3455a74a43SLisandro Dalcin
3555a74a43SLisandro Dalcin        # Allow matrix insertion using local indices [0:n+2]
3655a74a43SLisandro Dalcin        lgmap = PETSc.LGMap().create(list(gindices), comm=comm)
3755a74a43SLisandro Dalcin        self.mat.setLGMap(lgmap, lgmap)
3855a74a43SLisandro Dalcin
3955a74a43SLisandro Dalcin        # Global and local vectors
4055a74a43SLisandro Dalcin        self.gvec = self.mat.createVecRight()
4155a74a43SLisandro Dalcin        self.lvec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
4255a74a43SLisandro Dalcin        self.lvec.setSizes(self.n + 2)
4355a74a43SLisandro Dalcin        self.lvec.setUp()
4455a74a43SLisandro Dalcin        # Configure scatter from global to local
4555a74a43SLisandro Dalcin        isg = PETSc.IS().createGeneral(list(gindices), comm=comm)
4655a74a43SLisandro Dalcin        self.g2l = PETSc.Scatter().create(self.gvec, isg, self.lvec, None)
4755a74a43SLisandro Dalcin
4855a74a43SLisandro Dalcin        self.tozero, self.zvec = PETSc.Scatter.toZero(self.gvec)
4955a74a43SLisandro Dalcin        self.history = []
5055a74a43SLisandro Dalcin
5155a74a43SLisandro Dalcin        if False:  # Print some diagnostics
52*69777137SStefano Zampini            print(
53*69777137SStefano Zampini                '[%d] local size %d, global size %d, starting offset %d'
54*69777137SStefano Zampini                % (comm.rank, self.n, self.N, self.start)
55*69777137SStefano Zampini            )
5655a74a43SLisandro Dalcin            self.gvec.setArray(numpy.arange(self.start, self.start + self.n))
5755a74a43SLisandro Dalcin            self.gvec.view()
5855a74a43SLisandro Dalcin            self.g2l.scatter(self.gvec, self.lvec, PETSc.InsertMode.INSERT)
5955a74a43SLisandro Dalcin            for rank in range(comm.size):
6055a74a43SLisandro Dalcin                if rank == comm.rank:
6155a74a43SLisandro Dalcin                    print('Contents of local Vec on rank %d' % rank)
6255a74a43SLisandro Dalcin                    self.lvec.view()
6355a74a43SLisandro Dalcin                comm.barrier()
64*69777137SStefano Zampini
6555a74a43SLisandro Dalcin    def evalSolution(self, t, x):
66*69777137SStefano Zampini        if t != 0.0:
67*69777137SStefano Zampini            raise ValueError('Only for t=0')
6855a74a43SLisandro Dalcin        coord = numpy.arange(self.start, self.start + self.n) / self.N
6955a74a43SLisandro Dalcin        x.setArray((numpy.abs(coord - 0.5) < 0.1) * 1.0)
70*69777137SStefano Zampini
7155a74a43SLisandro Dalcin    def evalFunction(self, ts, t, x, xdot, f):
7255a74a43SLisandro Dalcin        self.g2l.scatter(x, self.lvec, PETSc.InsertMode.INSERT)  # lvec is a work vector
7355a74a43SLisandro Dalcin        h = self.h
7455a74a43SLisandro Dalcin        with self.lvec as u, xdot as udot:
75*69777137SStefano Zampini            f.setArray(
76*69777137SStefano Zampini                udot * h + 2 * u[1:-1] / h - u[:-2] / h - u[2:] / h
77*69777137SStefano Zampini            )  # Scale equation by volume element
78*69777137SStefano Zampini
7955a74a43SLisandro Dalcin    def evalJacobian(self, ts, t, x, xdot, a, A, B):
8055a74a43SLisandro Dalcin        h = self.h
8155a74a43SLisandro Dalcin        for i in range(self.n):
8255a74a43SLisandro Dalcin            lidx = i + 1
83*69777137SStefano Zampini            B.setValuesLocal(
84*69777137SStefano Zampini                [lidx], [lidx - 1, lidx, lidx + 1], [-1 / h, a * h + 2 / h, -1 / h]
85*69777137SStefano Zampini            )
8655a74a43SLisandro Dalcin        B.assemble()
87*69777137SStefano Zampini        if A != B:
88*69777137SStefano Zampini            A.assemble()  # If operator is different from preconditioning matrix
89*69777137SStefano Zampini
9055a74a43SLisandro Dalcin    def monitor(self, ts, i, t, x):
9155a74a43SLisandro Dalcin        if self.history:
9255a74a43SLisandro Dalcin            lasti, lastt, lastx = self.history[-1]
93*69777137SStefano Zampini            if i < lasti + 4 or t < lastt + 1e-4:
94*69777137SStefano Zampini                return
9555a74a43SLisandro Dalcin        self.tozero.scatter(x, self.zvec, PETSc.InsertMode.INSERT)
9655a74a43SLisandro Dalcin        xx = self.zvec[:].tolist()
9755a74a43SLisandro Dalcin        self.history.append((i, t, xx))
98*69777137SStefano Zampini
9955a74a43SLisandro Dalcin    def plotHistory(self):
10055a74a43SLisandro Dalcin        try:
10155a74a43SLisandro Dalcin            from matplotlib import pylab, rcParams
10255a74a43SLisandro Dalcin        except ImportError:
103*69777137SStefano Zampini            return
10455a74a43SLisandro Dalcin        rcParams.update({'text.usetex': True, 'figure.figsize': (10, 6)})
10555a74a43SLisandro Dalcin        # rc('figure', figsize=(600,400))
10655a74a43SLisandro Dalcin        pylab.title('Heat: TS \\texttt{%s}' % ts.getType())
10755a74a43SLisandro Dalcin        x = numpy.arange(self.N) / self.N
10855a74a43SLisandro Dalcin        for i, t, u in self.history:
10955a74a43SLisandro Dalcin            pylab.plot(x, u, label='step=%d t=%8.2g' % (i, t))
11055a74a43SLisandro Dalcin        pylab.xlabel('$x$')
11155a74a43SLisandro Dalcin        pylab.ylabel('$u$')
11255a74a43SLisandro Dalcin        pylab.legend(loc='upper right')
11355a74a43SLisandro Dalcin        pylab.savefig('heat-history.png')
11455a74a43SLisandro Dalcin        # pylab.show()
11555a74a43SLisandro Dalcin
116*69777137SStefano Zampini
11755a74a43SLisandro DalcinOptDB = PETSc.Options()
11855a74a43SLisandro Dalcinode = Heat(MPI.COMM_WORLD, OptDB.getInt('n', 100))
11955a74a43SLisandro Dalcin
12055a74a43SLisandro Dalcinx = ode.gvec.duplicate()
12155a74a43SLisandro Dalcinf = ode.gvec.duplicate()
12255a74a43SLisandro Dalcin
12355a74a43SLisandro Dalcints = PETSc.TS().create(comm=ode.comm)
12455a74a43SLisandro Dalcints.setType(ts.Type.ROSW)  # Rosenbrock-W. ARKIMEX is a nonlinearly implicit alternative.
12555a74a43SLisandro Dalcin
12655a74a43SLisandro Dalcints.setIFunction(ode.evalFunction, ode.gvec)
12755a74a43SLisandro Dalcints.setIJacobian(ode.evalJacobian, ode.mat)
12855a74a43SLisandro Dalcin
12955a74a43SLisandro Dalcints.setMonitor(ode.monitor)
13055a74a43SLisandro Dalcin
13155a74a43SLisandro Dalcints.setTime(0.0)
13255a74a43SLisandro Dalcints.setTimeStep(ode.h**2)
13355a74a43SLisandro Dalcints.setMaxTime(1)
13455a74a43SLisandro Dalcints.setMaxSteps(100)
13555a74a43SLisandro Dalcints.setExactFinalTime(PETSc.TS.ExactFinalTime.INTERPOLATE)
136*69777137SStefano Zampinits.setMaxSNESFailures(
137*69777137SStefano Zampini    -1
138*69777137SStefano Zampini)  # allow an unlimited number of failures (step will be rejected and retried)
13955a74a43SLisandro Dalcin
14055a74a43SLisandro Dalcinsnes = ts.getSNES()  # Nonlinear solver
141*69777137SStefano Zampinisnes.setTolerances(
142*69777137SStefano Zampini    max_it=10
143*69777137SStefano Zampini)  # Stop nonlinear solve after 10 iterations (TS will retry with shorter step)
14455a74a43SLisandro Dalcinksp = snes.getKSP()  # Linear solver
14555a74a43SLisandro Dalcinksp.setType(ksp.Type.CG)  # Conjugate gradients
14655a74a43SLisandro Dalcinpc = ksp.getPC()  # Preconditioner
14755a74a43SLisandro Dalcinif False:  # Configure algebraic multigrid, could use run-time options instead
148*69777137SStefano Zampini    pc.setType(
149*69777137SStefano Zampini        pc.Type.GAMG
150*69777137SStefano Zampini    )  # PETSc's native AMG implementation, mostly based on smoothed aggregation
15155a74a43SLisandro Dalcin    OptDB['mg_coarse_pc_type'] = 'svd'  # more specific multigrid options
15255a74a43SLisandro Dalcin    OptDB['mg_levels_pc_type'] = 'sor'
15355a74a43SLisandro Dalcin
15455a74a43SLisandro Dalcints.setFromOptions()  # Apply run-time options, e.g. -ts_adapt_monitor -ts_type arkimex -snes_converged_reason
15555a74a43SLisandro Dalcinode.evalSolution(0.0, x)
15655a74a43SLisandro Dalcints.solve(x)
15755a74a43SLisandro Dalcinif ode.comm.rank == 0:
158*69777137SStefano Zampini    print(
159*69777137SStefano Zampini        'steps %d (%d rejected, %d SNES fails), nonlinear its %d, linear its %d'
160*69777137SStefano Zampini        % (
161*69777137SStefano Zampini            ts.getStepNumber(),
162*69777137SStefano Zampini            ts.getStepRejections(),
163*69777137SStefano Zampini            ts.getSNESFailures(),
164*69777137SStefano Zampini            ts.getSNESIterations(),
165*69777137SStefano Zampini            ts.getKSPIterations(),
166*69777137SStefano Zampini        )
167*69777137SStefano Zampini    )
16855a74a43SLisandro Dalcin
16955a74a43SLisandro Dalcinif OptDB.getBool('plot_history', True) and ode.comm.rank == 0:
17055a74a43SLisandro Dalcin    ode.plotHistory()
171