1*55a74a43SLisandro Dalcin# Solves Heat equation on a periodic domain, using raw VecScatter 2*55a74a43SLisandro Dalcinfrom __future__ import division 3*55a74a43SLisandro Dalcinimport sys, petsc4py 4*55a74a43SLisandro Dalcinpetsc4py.init(sys.argv) 5*55a74a43SLisandro Dalcin 6*55a74a43SLisandro Dalcinfrom petsc4py import PETSc 7*55a74a43SLisandro Dalcinfrom mpi4py import MPI 8*55a74a43SLisandro Dalcinimport numpy 9*55a74a43SLisandro Dalcin 10*55a74a43SLisandro Dalcinclass Heat(object): 11*55a74a43SLisandro Dalcin def __init__(self,comm,N): 12*55a74a43SLisandro Dalcin self.comm = comm 13*55a74a43SLisandro Dalcin self.N = N # global problem size 14*55a74a43SLisandro Dalcin self.h = 1/N # grid spacing on unit interval 15*55a74a43SLisandro Dalcin self.n = N // comm.size + int(comm.rank < (N % comm.size)) # owned part of global problem 16*55a74a43SLisandro Dalcin self.start = comm.exscan(self.n) 17*55a74a43SLisandro Dalcin if comm.rank == 0: self.start = 0 18*55a74a43SLisandro Dalcin gindices = numpy.arange(self.start-1, self.start+self.n+1, dtype=PETSc.IntType) % N # periodic 19*55a74a43SLisandro Dalcin self.mat = PETSc.Mat().create(comm=comm) 20*55a74a43SLisandro Dalcin size = (self.n, self.N) # local and global sizes 21*55a74a43SLisandro Dalcin self.mat.setSizes((size,size)) 22*55a74a43SLisandro Dalcin self.mat.setFromOptions() 23*55a74a43SLisandro Dalcin self.mat.setPreallocationNNZ((3,1)) # Conservative preallocation for 3 "local" columns and one non-local 24*55a74a43SLisandro Dalcin 25*55a74a43SLisandro Dalcin # Allow matrix insertion using local indices [0:n+2] 26*55a74a43SLisandro Dalcin lgmap = PETSc.LGMap().create(list(gindices), comm=comm) 27*55a74a43SLisandro Dalcin self.mat.setLGMap(lgmap, lgmap) 28*55a74a43SLisandro Dalcin 29*55a74a43SLisandro Dalcin # Global and local vectors 30*55a74a43SLisandro Dalcin self.gvec = self.mat.createVecRight() 31*55a74a43SLisandro Dalcin self.lvec = PETSc.Vec().create(comm=PETSc.COMM_SELF) 32*55a74a43SLisandro Dalcin self.lvec.setSizes(self.n+2) 33*55a74a43SLisandro Dalcin self.lvec.setUp() 34*55a74a43SLisandro Dalcin # Configure scatter from global to local 35*55a74a43SLisandro Dalcin isg = PETSc.IS().createGeneral(list(gindices), comm=comm) 36*55a74a43SLisandro Dalcin self.g2l = PETSc.Scatter().create(self.gvec, isg, self.lvec, None) 37*55a74a43SLisandro Dalcin 38*55a74a43SLisandro Dalcin self.tozero, self.zvec = PETSc.Scatter.toZero(self.gvec) 39*55a74a43SLisandro Dalcin self.history = [] 40*55a74a43SLisandro Dalcin 41*55a74a43SLisandro Dalcin if False: # Print some diagnostics 42*55a74a43SLisandro Dalcin print('[%d] local size %d, global size %d, starting offset %d' % (comm.rank, self.n, self.N, self.start)) 43*55a74a43SLisandro Dalcin self.gvec.setArray(numpy.arange(self.start,self.start+self.n)) 44*55a74a43SLisandro Dalcin self.gvec.view() 45*55a74a43SLisandro Dalcin self.g2l.scatter(self.gvec, self.lvec, PETSc.InsertMode.INSERT) 46*55a74a43SLisandro Dalcin for rank in range(comm.size): 47*55a74a43SLisandro Dalcin if rank == comm.rank: 48*55a74a43SLisandro Dalcin print('Contents of local Vec on rank %d' % rank) 49*55a74a43SLisandro Dalcin self.lvec.view() 50*55a74a43SLisandro Dalcin comm.barrier() 51*55a74a43SLisandro Dalcin def evalSolution(self, t, x): 52*55a74a43SLisandro Dalcin assert t == 0.0, "only for t=0.0" 53*55a74a43SLisandro Dalcin coord = numpy.arange(self.start, self.start+self.n) / self.N 54*55a74a43SLisandro Dalcin x.setArray((numpy.abs(coord-0.5) < 0.1) * 1.0) 55*55a74a43SLisandro Dalcin def evalFunction(self, ts, t, x, xdot, f): 56*55a74a43SLisandro Dalcin self.g2l.scatter(x, self.lvec, PETSc.InsertMode.INSERT) # lvec is a work vector 57*55a74a43SLisandro Dalcin h = self.h 58*55a74a43SLisandro Dalcin with self.lvec as u, xdot as udot: 59*55a74a43SLisandro Dalcin f.setArray(udot*h + 2*u[1:-1]/h - u[:-2]/h - u[2:]/h) # Scale equation by volume element 60*55a74a43SLisandro Dalcin def evalJacobian(self, ts, t, x, xdot, a, A, B): 61*55a74a43SLisandro Dalcin h = self.h 62*55a74a43SLisandro Dalcin for i in range(self.n): 63*55a74a43SLisandro Dalcin lidx = i + 1 64*55a74a43SLisandro Dalcin gidx = self.start + i 65*55a74a43SLisandro Dalcin B.setValuesLocal([lidx], [lidx-1,lidx,lidx+1], [-1/h, a*h+2/h, -1/h]) 66*55a74a43SLisandro Dalcin B.assemble() 67*55a74a43SLisandro Dalcin if A != B: A.assemble() # If operator is different from preconditioning matrix 68*55a74a43SLisandro Dalcin return True # same nonzero pattern 69*55a74a43SLisandro Dalcin def monitor(self, ts, i, t, x): 70*55a74a43SLisandro Dalcin if self.history: 71*55a74a43SLisandro Dalcin lasti, lastt, lastx = self.history[-1] 72*55a74a43SLisandro Dalcin if i < lasti + 4 or t < lastt + 1e-4: return 73*55a74a43SLisandro Dalcin self.tozero.scatter(x, self.zvec, PETSc.InsertMode.INSERT) 74*55a74a43SLisandro Dalcin xx = self.zvec[:].tolist() 75*55a74a43SLisandro Dalcin self.history.append((i, t, xx)) 76*55a74a43SLisandro Dalcin def plotHistory(self): 77*55a74a43SLisandro Dalcin try: 78*55a74a43SLisandro Dalcin from matplotlib import pylab, rcParams 79*55a74a43SLisandro Dalcin except ImportError: 80*55a74a43SLisandro Dalcin print("matplotlib not available") 81*55a74a43SLisandro Dalcin raise SystemExit 82*55a74a43SLisandro Dalcin rcParams.update({'text.usetex':True, 'figure.figsize':(10,6)}) 83*55a74a43SLisandro Dalcin #rc('figure', figsize=(600,400)) 84*55a74a43SLisandro Dalcin pylab.title('Heat: TS \\texttt{%s}' % ts.getType()) 85*55a74a43SLisandro Dalcin x = numpy.arange(self.N) / self.N 86*55a74a43SLisandro Dalcin for i,t,u in self.history: 87*55a74a43SLisandro Dalcin pylab.plot(x, u, label='step=%d t=%8.2g'%(i,t)) 88*55a74a43SLisandro Dalcin pylab.xlabel('$x$') 89*55a74a43SLisandro Dalcin pylab.ylabel('$u$') 90*55a74a43SLisandro Dalcin pylab.legend(loc='upper right') 91*55a74a43SLisandro Dalcin pylab.savefig('heat-history.png') 92*55a74a43SLisandro Dalcin #pylab.show() 93*55a74a43SLisandro Dalcin 94*55a74a43SLisandro DalcinOptDB = PETSc.Options() 95*55a74a43SLisandro Dalcinode = Heat(MPI.COMM_WORLD, OptDB.getInt('n',100)) 96*55a74a43SLisandro Dalcin 97*55a74a43SLisandro Dalcinx = ode.gvec.duplicate() 98*55a74a43SLisandro Dalcinf = ode.gvec.duplicate() 99*55a74a43SLisandro Dalcin 100*55a74a43SLisandro Dalcints = PETSc.TS().create(comm=ode.comm) 101*55a74a43SLisandro Dalcints.setType(ts.Type.ROSW) # Rosenbrock-W. ARKIMEX is a nonlinearly implicit alternative. 102*55a74a43SLisandro Dalcin 103*55a74a43SLisandro Dalcints.setIFunction(ode.evalFunction, ode.gvec) 104*55a74a43SLisandro Dalcints.setIJacobian(ode.evalJacobian, ode.mat) 105*55a74a43SLisandro Dalcin 106*55a74a43SLisandro Dalcints.setMonitor(ode.monitor) 107*55a74a43SLisandro Dalcin 108*55a74a43SLisandro Dalcints.setTime(0.0) 109*55a74a43SLisandro Dalcints.setTimeStep(ode.h**2) 110*55a74a43SLisandro Dalcints.setMaxTime(1) 111*55a74a43SLisandro Dalcints.setMaxSteps(100) 112*55a74a43SLisandro Dalcints.setExactFinalTime(PETSc.TS.ExactFinalTime.INTERPOLATE) 113*55a74a43SLisandro Dalcints.setMaxSNESFailures(-1) # allow an unlimited number of failures (step will be rejected and retried) 114*55a74a43SLisandro Dalcin 115*55a74a43SLisandro Dalcinsnes = ts.getSNES() # Nonlinear solver 116*55a74a43SLisandro Dalcinsnes.setTolerances(max_it=10) # Stop nonlinear solve after 10 iterations (TS will retry with shorter step) 117*55a74a43SLisandro Dalcinksp = snes.getKSP() # Linear solver 118*55a74a43SLisandro Dalcinksp.setType(ksp.Type.CG) # Conjugate gradients 119*55a74a43SLisandro Dalcinpc = ksp.getPC() # Preconditioner 120*55a74a43SLisandro Dalcinif False: # Configure algebraic multigrid, could use run-time options instead 121*55a74a43SLisandro Dalcin pc.setType(pc.Type.GAMG) # PETSc's native AMG implementation, mostly based on smoothed aggregation 122*55a74a43SLisandro Dalcin OptDB['mg_coarse_pc_type'] = 'svd' # more specific multigrid options 123*55a74a43SLisandro Dalcin OptDB['mg_levels_pc_type'] = 'sor' 124*55a74a43SLisandro Dalcin 125*55a74a43SLisandro Dalcints.setFromOptions() # Apply run-time options, e.g. -ts_adapt_monitor -ts_type arkimex -snes_converged_reason 126*55a74a43SLisandro Dalcinode.evalSolution(0.0, x) 127*55a74a43SLisandro Dalcints.solve(x) 128*55a74a43SLisandro Dalcinif ode.comm.rank == 0: 129*55a74a43SLisandro Dalcin print('steps %d (%d rejected, %d SNES fails), nonlinear its %d, linear its %d' 130*55a74a43SLisandro Dalcin % (ts.getStepNumber(), ts.getStepRejections(), ts.getSNESFailures(), 131*55a74a43SLisandro Dalcin ts.getSNESIterations(), ts.getKSPIterations())) 132*55a74a43SLisandro Dalcin 133*55a74a43SLisandro Dalcinif OptDB.getBool('plot_history', True) and ode.comm.rank == 0: 134*55a74a43SLisandro Dalcin ode.plotHistory() 135