xref: /petsc/src/binding/petsc4py/test/test_tao_py.py (revision ebead697dbf761eb322f829370bbe90b3bd93fa3)
1a82e8c82SStefano Zampiniimport unittest
2a82e8c82SStefano Zampinifrom petsc4py import PETSc
3a82e8c82SStefano Zampinifrom sys import getrefcount
4a82e8c82SStefano Zampiniimport gc
5a82e8c82SStefano Zampini
6a82e8c82SStefano Zampini# --------------------------------------------------------------------
7a82e8c82SStefano Zampiniclass Objective:
8a82e8c82SStefano Zampini    def __call__(self, tao, x):
9a82e8c82SStefano Zampini        return (x[0] - 1.0)**2 + (x[1] - 2.0)**2
10a82e8c82SStefano Zampini
11a82e8c82SStefano Zampiniclass Gradient:
12a82e8c82SStefano Zampini    def __call__(self, tao, x, g):
13a82e8c82SStefano Zampini        g[0] = 2.0*(x[0] - 1.0)
14a82e8c82SStefano Zampini        g[1] = 2.0*(x[1] - 2.0)
15a82e8c82SStefano Zampini        g.assemble()
16a82e8c82SStefano Zampini
17a82e8c82SStefano Zampiniclass MyTao:
18a82e8c82SStefano Zampini    def __init__(self):
19a82e8c82SStefano Zampini        self.log = {}
20a82e8c82SStefano Zampini
21a82e8c82SStefano Zampini    def _log(self, method):
22a82e8c82SStefano Zampini        self.log.setdefault(method, 0)
23a82e8c82SStefano Zampini        self.log[method] += 1
24a82e8c82SStefano Zampini
25a82e8c82SStefano Zampini    def create(self, tao):
26a82e8c82SStefano Zampini        self._log('create')
27a82e8c82SStefano Zampini        self.testvec = PETSc.Vec()
28a82e8c82SStefano Zampini
29a82e8c82SStefano Zampini    def destroy(self, tao):
30a82e8c82SStefano Zampini        self._log('destroy')
31a82e8c82SStefano Zampini        self.testvec.destroy()
32a82e8c82SStefano Zampini
33a82e8c82SStefano Zampini    def setFromOptions(self, tao):
34a82e8c82SStefano Zampini        self._log('setFromOptions')
35a82e8c82SStefano Zampini
36a82e8c82SStefano Zampini    def setUp(self, tao):
37a82e8c82SStefano Zampini        self._log('setUp')
38a82e8c82SStefano Zampini        self.testvec = tao.getSolution().duplicate()
39a82e8c82SStefano Zampini
40a82e8c82SStefano Zampini    def solve(self, tao):
41a82e8c82SStefano Zampini        self._log('solve')
42a82e8c82SStefano Zampini
43a82e8c82SStefano Zampini    def step(self, tao, x, g, s):
44a82e8c82SStefano Zampini        self._log('step')
45a82e8c82SStefano Zampini        tao.computeGradient(x,g)
46a82e8c82SStefano Zampini        g.copy(s)
47a82e8c82SStefano Zampini        s.scale(-1.0)
48a82e8c82SStefano Zampini
49a82e8c82SStefano Zampini    def preStep(self, tao):
50a82e8c82SStefano Zampini        self._log('preStep')
51a82e8c82SStefano Zampini
52a82e8c82SStefano Zampini    def postStep(self, tao):
53a82e8c82SStefano Zampini        self._log('postStep')
54a82e8c82SStefano Zampini
55a82e8c82SStefano Zampini    def monitor(self, tao):
56a82e8c82SStefano Zampini        self._log('monitor')
57a82e8c82SStefano Zampini
58a82e8c82SStefano Zampiniclass TestTaoPython(unittest.TestCase):
59a82e8c82SStefano Zampini
60a82e8c82SStefano Zampini    def setUp(self):
61a82e8c82SStefano Zampini        self.tao = PETSc.TAO()
62a82e8c82SStefano Zampini        self.tao.createPython(MyTao(), comm=PETSc.COMM_SELF)
63a82e8c82SStefano Zampini        ctx = self.tao.getPythonContext()
64a82e8c82SStefano Zampini        self.assertEqual(getrefcount(ctx),  3)
65a82e8c82SStefano Zampini        self.assertEqual(ctx.log['create'], 1)
66a82e8c82SStefano Zampini        self.nsolve = 0
67a82e8c82SStefano Zampini
68a82e8c82SStefano Zampini    def tearDown(self):
69a82e8c82SStefano Zampini        ctx = self.tao.getPythonContext()
70*ebead697SStefano Zampini        self.assertEqual(getrefcount(ctx), 3)
71a82e8c82SStefano Zampini        self.assertTrue('destroy' not in ctx.log)
72a82e8c82SStefano Zampini        self.tao.destroy()
73a82e8c82SStefano Zampini        self.tao = None
74a82e8c82SStefano Zampini        self.assertEqual(ctx.log['destroy'], 1)
75a82e8c82SStefano Zampini        self.assertEqual(getrefcount(ctx),   2)
76a82e8c82SStefano Zampini
77*ebead697SStefano Zampini    def testGetType(self):
78*ebead697SStefano Zampini        ctx = self.tao.getPythonContext()
79*ebead697SStefano Zampini        pytype = "{0}.{1}".format(ctx.__module__, type(ctx).__name__)
80*ebead697SStefano Zampini        self.assertTrue(self.tao.getPythonType() == pytype)
81*ebead697SStefano Zampini
82a82e8c82SStefano Zampini    def testSolve(self):
83a82e8c82SStefano Zampini        tao = self.tao
84a82e8c82SStefano Zampini        ctx = tao.getPythonContext()
85a82e8c82SStefano Zampini        x = PETSc.Vec().create(tao.getComm())
86a82e8c82SStefano Zampini        x.setType('standard')
87a82e8c82SStefano Zampini        x.setSizes(2)
88a82e8c82SStefano Zampini        y1 = x.duplicate()
89a82e8c82SStefano Zampini        y2 = x.duplicate()
90a82e8c82SStefano Zampini        tao.setObjective(Objective())
91a82e8c82SStefano Zampini        tao.setGradient(Gradient(),None)
92a82e8c82SStefano Zampini        tao.setMonitor(ctx.monitor)
93a82e8c82SStefano Zampini        tao.setFromOptions()
94a82e8c82SStefano Zampini        tao.setMaximumIterations(3)
95a82e8c82SStefano Zampini        tao.setSolution(x)
96a82e8c82SStefano Zampini
97a82e8c82SStefano Zampini        # Call the solve method of MyTAO
98a82e8c82SStefano Zampini        x.set(0.5)
99a82e8c82SStefano Zampini        tao.solve()
100a82e8c82SStefano Zampini        n = tao.getIterationNumber()
101a82e8c82SStefano Zampini        self.assertTrue(n == 0)
102a82e8c82SStefano Zampini
103a82e8c82SStefano Zampini        # Call the default solve method and use step of MyTAO
104a82e8c82SStefano Zampini        ctx.solve = None
105a82e8c82SStefano Zampini        x.set(0.5)
106a82e8c82SStefano Zampini        tao.solve()
107a82e8c82SStefano Zampini        n = tao.getIterationNumber()
108a82e8c82SStefano Zampini        self.assertTrue(n == 3)
109a82e8c82SStefano Zampini        x.copy(y1)
110a82e8c82SStefano Zampini
111a82e8c82SStefano Zampini        # Call the default solve method with the default step method
112a82e8c82SStefano Zampini        ctx.step = None
113a82e8c82SStefano Zampini        x.set(0.5)
114a82e8c82SStefano Zampini        tao.solve()
115a82e8c82SStefano Zampini        n = tao.getIterationNumber()
116a82e8c82SStefano Zampini        self.assertTrue(n == 3)
117a82e8c82SStefano Zampini        x.copy(y2)
118a82e8c82SStefano Zampini
119a82e8c82SStefano Zampini        self.assertTrue(y1.equal(y2))
120a82e8c82SStefano Zampini        self.assertTrue(ctx.log['monitor'] == 2*(n+1))
121a82e8c82SStefano Zampini        self.assertTrue(ctx.log['preStep'] == 2*n)
122a82e8c82SStefano Zampini        self.assertTrue(ctx.log['postStep'] == 2*n)
123a82e8c82SStefano Zampini        self.assertTrue(ctx.log['solve'] == 1)
124a82e8c82SStefano Zampini        self.assertTrue(ctx.log['setUp'] == 1)
125a82e8c82SStefano Zampini        self.assertTrue(ctx.log['setFromOptions'] == 1)
126a82e8c82SStefano Zampini        self.assertTrue(ctx.log['step'] == n)
127*ebead697SStefano Zampini        tao.cancelMonitor()
128a82e8c82SStefano Zampini
129a82e8c82SStefano Zampini# --------------------------------------------------------------------
130a82e8c82SStefano Zampini
131a82e8c82SStefano Zampiniimport numpy
132a82e8c82SStefano Zampiniif numpy.iscomplexobj(PETSc.ScalarType()):
133a82e8c82SStefano Zampini    del TestTaoPython
134a82e8c82SStefano Zampini
135a82e8c82SStefano Zampiniif __name__ == '__main__':
136a82e8c82SStefano Zampini    unittest.main()
137a82e8c82SStefano Zampini
138a82e8c82SStefano Zampini# --------------------------------------------------------------------
139