xref: /petsc/src/binding/petsc4py/test/test_tao_py.py (revision a82e8c82ed9474375a7f877f23dfa96948657643)
1*a82e8c82SStefano Zampiniimport unittest
2*a82e8c82SStefano Zampinifrom petsc4py import PETSc
3*a82e8c82SStefano Zampinifrom sys import getrefcount
4*a82e8c82SStefano Zampiniimport gc
5*a82e8c82SStefano Zampini
6*a82e8c82SStefano Zampini# --------------------------------------------------------------------
7*a82e8c82SStefano Zampiniclass Objective:
8*a82e8c82SStefano Zampini    def __call__(self, tao, x):
9*a82e8c82SStefano Zampini        return (x[0] - 1.0)**2 + (x[1] - 2.0)**2
10*a82e8c82SStefano Zampini
11*a82e8c82SStefano Zampiniclass Gradient:
12*a82e8c82SStefano Zampini    def __call__(self, tao, x, g):
13*a82e8c82SStefano Zampini        g[0] = 2.0*(x[0] - 1.0)
14*a82e8c82SStefano Zampini        g[1] = 2.0*(x[1] - 2.0)
15*a82e8c82SStefano Zampini        g.assemble()
16*a82e8c82SStefano Zampini
17*a82e8c82SStefano Zampiniclass MyTao:
18*a82e8c82SStefano Zampini    def __init__(self):
19*a82e8c82SStefano Zampini        self.log = {}
20*a82e8c82SStefano Zampini
21*a82e8c82SStefano Zampini    def _log(self, method):
22*a82e8c82SStefano Zampini        self.log.setdefault(method, 0)
23*a82e8c82SStefano Zampini        self.log[method] += 1
24*a82e8c82SStefano Zampini
25*a82e8c82SStefano Zampini    def create(self, tao):
26*a82e8c82SStefano Zampini        self._log('create')
27*a82e8c82SStefano Zampini        self.testvec = PETSc.Vec()
28*a82e8c82SStefano Zampini
29*a82e8c82SStefano Zampini    def destroy(self, tao):
30*a82e8c82SStefano Zampini        self._log('destroy')
31*a82e8c82SStefano Zampini        self.testvec.destroy()
32*a82e8c82SStefano Zampini
33*a82e8c82SStefano Zampini    def setFromOptions(self, tao):
34*a82e8c82SStefano Zampini        self._log('setFromOptions')
35*a82e8c82SStefano Zampini
36*a82e8c82SStefano Zampini    def setUp(self, tao):
37*a82e8c82SStefano Zampini        self._log('setUp')
38*a82e8c82SStefano Zampini        self.testvec = tao.getSolution().duplicate()
39*a82e8c82SStefano Zampini
40*a82e8c82SStefano Zampini    def solve(self, tao):
41*a82e8c82SStefano Zampini        self._log('solve')
42*a82e8c82SStefano Zampini
43*a82e8c82SStefano Zampini    def step(self, tao, x, g, s):
44*a82e8c82SStefano Zampini        self._log('step')
45*a82e8c82SStefano Zampini        tao.computeGradient(x,g)
46*a82e8c82SStefano Zampini        g.copy(s)
47*a82e8c82SStefano Zampini        s.scale(-1.0)
48*a82e8c82SStefano Zampini
49*a82e8c82SStefano Zampini    def preStep(self, tao):
50*a82e8c82SStefano Zampini        self._log('preStep')
51*a82e8c82SStefano Zampini
52*a82e8c82SStefano Zampini    def postStep(self, tao):
53*a82e8c82SStefano Zampini        self._log('postStep')
54*a82e8c82SStefano Zampini
55*a82e8c82SStefano Zampini    def monitor(self, tao):
56*a82e8c82SStefano Zampini        self._log('monitor')
57*a82e8c82SStefano Zampini
58*a82e8c82SStefano Zampiniclass TestTaoPython(unittest.TestCase):
59*a82e8c82SStefano Zampini
60*a82e8c82SStefano Zampini    def setUp(self):
61*a82e8c82SStefano Zampini        self.tao = PETSc.TAO()
62*a82e8c82SStefano Zampini        self.tao.createPython(MyTao(), comm=PETSc.COMM_SELF)
63*a82e8c82SStefano Zampini        ctx = self.tao.getPythonContext()
64*a82e8c82SStefano Zampini        self.assertEqual(getrefcount(ctx),  3)
65*a82e8c82SStefano Zampini        self.assertEqual(ctx.log['create'], 1)
66*a82e8c82SStefano Zampini        self.nsolve = 0
67*a82e8c82SStefano Zampini
68*a82e8c82SStefano Zampini    def tearDown(self):
69*a82e8c82SStefano Zampini        ctx = self.tao.getPythonContext()
70*a82e8c82SStefano Zampini        self.assertEqual(getrefcount(ctx), 4)
71*a82e8c82SStefano Zampini        self.assertTrue('destroy' not in ctx.log)
72*a82e8c82SStefano Zampini        self.tao.destroy()
73*a82e8c82SStefano Zampini        self.tao = None
74*a82e8c82SStefano Zampini        self.assertEqual(ctx.log['destroy'], 1)
75*a82e8c82SStefano Zampini        self.assertEqual(getrefcount(ctx),   2)
76*a82e8c82SStefano Zampini
77*a82e8c82SStefano Zampini    def testSolve(self):
78*a82e8c82SStefano Zampini        tao = self.tao
79*a82e8c82SStefano Zampini        ctx = tao.getPythonContext()
80*a82e8c82SStefano Zampini        x = PETSc.Vec().create(tao.getComm())
81*a82e8c82SStefano Zampini        x.setType('standard')
82*a82e8c82SStefano Zampini        x.setSizes(2)
83*a82e8c82SStefano Zampini        y1 = x.duplicate()
84*a82e8c82SStefano Zampini        y2 = x.duplicate()
85*a82e8c82SStefano Zampini        tao.setObjective(Objective())
86*a82e8c82SStefano Zampini        tao.setGradient(Gradient(),None)
87*a82e8c82SStefano Zampini        tao.setMonitor(ctx.monitor)
88*a82e8c82SStefano Zampini        tao.setFromOptions()
89*a82e8c82SStefano Zampini        tao.setMaximumIterations(3)
90*a82e8c82SStefano Zampini        tao.setSolution(x)
91*a82e8c82SStefano Zampini
92*a82e8c82SStefano Zampini        # Call the solve method of MyTAO
93*a82e8c82SStefano Zampini        x.set(0.5)
94*a82e8c82SStefano Zampini        tao.solve()
95*a82e8c82SStefano Zampini        n = tao.getIterationNumber()
96*a82e8c82SStefano Zampini        self.assertTrue(n == 0)
97*a82e8c82SStefano Zampini
98*a82e8c82SStefano Zampini        # Call the default solve method and use step of MyTAO
99*a82e8c82SStefano Zampini        ctx.solve = None
100*a82e8c82SStefano Zampini        x.set(0.5)
101*a82e8c82SStefano Zampini        tao.solve()
102*a82e8c82SStefano Zampini        n = tao.getIterationNumber()
103*a82e8c82SStefano Zampini        self.assertTrue(n == 3)
104*a82e8c82SStefano Zampini        x.copy(y1)
105*a82e8c82SStefano Zampini
106*a82e8c82SStefano Zampini        # Call the default solve method with the default step method
107*a82e8c82SStefano Zampini        ctx.step = None
108*a82e8c82SStefano Zampini        x.set(0.5)
109*a82e8c82SStefano Zampini        tao.solve()
110*a82e8c82SStefano Zampini        n = tao.getIterationNumber()
111*a82e8c82SStefano Zampini        self.assertTrue(n == 3)
112*a82e8c82SStefano Zampini        x.copy(y2)
113*a82e8c82SStefano Zampini
114*a82e8c82SStefano Zampini        self.assertTrue(y1.equal(y2))
115*a82e8c82SStefano Zampini        self.assertTrue(ctx.log['monitor'] == 2*(n+1))
116*a82e8c82SStefano Zampini        self.assertTrue(ctx.log['preStep'] == 2*n)
117*a82e8c82SStefano Zampini        self.assertTrue(ctx.log['postStep'] == 2*n)
118*a82e8c82SStefano Zampini        self.assertTrue(ctx.log['solve'] == 1)
119*a82e8c82SStefano Zampini        self.assertTrue(ctx.log['setUp'] == 1)
120*a82e8c82SStefano Zampini        self.assertTrue(ctx.log['setFromOptions'] == 1)
121*a82e8c82SStefano Zampini        self.assertTrue(ctx.log['step'] == n)
122*a82e8c82SStefano Zampini
123*a82e8c82SStefano Zampini# --------------------------------------------------------------------
124*a82e8c82SStefano Zampini
125*a82e8c82SStefano Zampiniimport numpy
126*a82e8c82SStefano Zampiniif numpy.iscomplexobj(PETSc.ScalarType()):
127*a82e8c82SStefano Zampini    del TestTaoPython
128*a82e8c82SStefano Zampini
129*a82e8c82SStefano Zampiniif __name__ == '__main__':
130*a82e8c82SStefano Zampini    unittest.main()
131*a82e8c82SStefano Zampini
132*a82e8c82SStefano Zampini# --------------------------------------------------------------------
133