1a82e8c82SStefano Zampiniimport unittest 2a82e8c82SStefano Zampinifrom petsc4py import PETSc 3a82e8c82SStefano Zampinifrom sys import getrefcount 4*6f336411SStefano Zampiniimport numpy 5*6f336411SStefano Zampini 6a82e8c82SStefano Zampini 7a82e8c82SStefano Zampini# -------------------------------------------------------------------- 8a82e8c82SStefano Zampiniclass Objective: 9a82e8c82SStefano Zampini def __call__(self, tao, x): 10a82e8c82SStefano Zampini return (x[0] - 1.0) ** 2 + (x[1] - 2.0) ** 2 11a82e8c82SStefano Zampini 12*6f336411SStefano Zampini 13a82e8c82SStefano Zampiniclass Gradient: 14a82e8c82SStefano Zampini def __call__(self, tao, x, g): 15a82e8c82SStefano Zampini g[0] = 2.0 * (x[0] - 1.0) 16a82e8c82SStefano Zampini g[1] = 2.0 * (x[1] - 2.0) 17a82e8c82SStefano Zampini g.assemble() 18a82e8c82SStefano Zampini 19*6f336411SStefano Zampini 20a82e8c82SStefano Zampiniclass MyTao: 21a82e8c82SStefano Zampini def __init__(self): 22a82e8c82SStefano Zampini self.log = {} 23a82e8c82SStefano Zampini 24a82e8c82SStefano Zampini def _log(self, method): 25a82e8c82SStefano Zampini self.log.setdefault(method, 0) 26a82e8c82SStefano Zampini self.log[method] += 1 27a82e8c82SStefano Zampini 28a82e8c82SStefano Zampini def create(self, tao): 29a82e8c82SStefano Zampini self._log('create') 30a82e8c82SStefano Zampini self.testvec = PETSc.Vec() 31a82e8c82SStefano Zampini 32a82e8c82SStefano Zampini def destroy(self, tao): 33a82e8c82SStefano Zampini self._log('destroy') 34a82e8c82SStefano Zampini self.testvec.destroy() 35a82e8c82SStefano Zampini 36a82e8c82SStefano Zampini def setFromOptions(self, tao): 37a82e8c82SStefano Zampini self._log('setFromOptions') 38a82e8c82SStefano Zampini 39a82e8c82SStefano Zampini def setUp(self, tao): 40a82e8c82SStefano Zampini self._log('setUp') 41a82e8c82SStefano Zampini self.testvec = tao.getSolution().duplicate() 42a82e8c82SStefano Zampini 43a82e8c82SStefano Zampini def solve(self, tao): 44a82e8c82SStefano Zampini self._log('solve') 45a82e8c82SStefano Zampini 46a82e8c82SStefano Zampini def step(self, tao, x, g, s): 47a82e8c82SStefano Zampini self._log('step') 48a82e8c82SStefano Zampini tao.computeGradient(x, g) 49a82e8c82SStefano Zampini g.copy(s) 50a82e8c82SStefano Zampini s.scale(-1.0) 51a82e8c82SStefano Zampini 52a82e8c82SStefano Zampini def preStep(self, tao): 53a82e8c82SStefano Zampini self._log('preStep') 54a82e8c82SStefano Zampini 55a82e8c82SStefano Zampini def postStep(self, tao): 56a82e8c82SStefano Zampini self._log('postStep') 57a82e8c82SStefano Zampini 58a82e8c82SStefano Zampini def monitor(self, tao): 59a82e8c82SStefano Zampini self._log('monitor') 60a82e8c82SStefano Zampini 61a82e8c82SStefano Zampini 62*6f336411SStefano Zampiniclass TestTaoPython(unittest.TestCase): 63a82e8c82SStefano Zampini def setUp(self): 64a82e8c82SStefano Zampini self.tao = PETSc.TAO() 65a82e8c82SStefano Zampini self.tao.createPython(MyTao(), comm=PETSc.COMM_SELF) 66a82e8c82SStefano Zampini ctx = self.tao.getPythonContext() 67a82e8c82SStefano Zampini self.assertEqual(getrefcount(ctx), 3) 68a82e8c82SStefano Zampini self.assertEqual(ctx.log['create'], 1) 69a82e8c82SStefano Zampini self.nsolve = 0 70a82e8c82SStefano Zampini 71a82e8c82SStefano Zampini def tearDown(self): 72a82e8c82SStefano Zampini ctx = self.tao.getPythonContext() 73ebead697SStefano Zampini self.assertEqual(getrefcount(ctx), 3) 74a82e8c82SStefano Zampini self.assertTrue('destroy' not in ctx.log) 75a82e8c82SStefano Zampini self.tao.destroy() 76a82e8c82SStefano Zampini self.tao = None 7762e5d2d2SJDBetteridge PETSc.garbage_cleanup() 78a82e8c82SStefano Zampini self.assertEqual(ctx.log['destroy'], 1) 79a82e8c82SStefano Zampini self.assertEqual(getrefcount(ctx), 2) 80a82e8c82SStefano Zampini 81ebead697SStefano Zampini def testGetType(self): 82ebead697SStefano Zampini ctx = self.tao.getPythonContext() 83*6f336411SStefano Zampini pytype = f'{ctx.__module__}.{type(ctx).__name__}' 84ebead697SStefano Zampini self.assertTrue(self.tao.getPythonType() == pytype) 85ebead697SStefano Zampini 86a82e8c82SStefano Zampini def testSolve(self): 87a82e8c82SStefano Zampini tao = self.tao 88a82e8c82SStefano Zampini ctx = tao.getPythonContext() 89a82e8c82SStefano Zampini x = PETSc.Vec().create(tao.getComm()) 90a82e8c82SStefano Zampini x.setType('standard') 91a82e8c82SStefano Zampini x.setSizes(2) 92a82e8c82SStefano Zampini y1 = x.duplicate() 93a82e8c82SStefano Zampini y2 = x.duplicate() 94a82e8c82SStefano Zampini tao.setObjective(Objective()) 95a82e8c82SStefano Zampini tao.setGradient(Gradient(), None) 96a82e8c82SStefano Zampini tao.setMonitor(ctx.monitor) 97a82e8c82SStefano Zampini tao.setFromOptions() 98a82e8c82SStefano Zampini tao.setMaximumIterations(3) 99a82e8c82SStefano Zampini tao.setSolution(x) 100a82e8c82SStefano Zampini 101a82e8c82SStefano Zampini # Call the solve method of MyTAO 102a82e8c82SStefano Zampini x.set(0.5) 103a82e8c82SStefano Zampini tao.solve() 104a82e8c82SStefano Zampini n = tao.getIterationNumber() 105a82e8c82SStefano Zampini self.assertTrue(n == 0) 106a82e8c82SStefano Zampini 107a82e8c82SStefano Zampini # Call the default solve method and use step of MyTAO 108a82e8c82SStefano Zampini ctx.solve = None 109a82e8c82SStefano Zampini x.set(0.5) 110a82e8c82SStefano Zampini tao.solve() 111a82e8c82SStefano Zampini n = tao.getIterationNumber() 112a82e8c82SStefano Zampini self.assertTrue(n == 3) 113a82e8c82SStefano Zampini x.copy(y1) 114a82e8c82SStefano Zampini 115a82e8c82SStefano Zampini # Call the default solve method with the default step method 116a82e8c82SStefano Zampini ctx.step = None 117a82e8c82SStefano Zampini x.set(0.5) 118a82e8c82SStefano Zampini tao.solve() 119a82e8c82SStefano Zampini n = tao.getIterationNumber() 120a82e8c82SStefano Zampini self.assertTrue(n == 3) 121a82e8c82SStefano Zampini x.copy(y2) 122a82e8c82SStefano Zampini 123a82e8c82SStefano Zampini self.assertTrue(y1.equal(y2)) 124a82e8c82SStefano Zampini self.assertTrue(ctx.log['monitor'] == 2 * (n + 1)) 125a82e8c82SStefano Zampini self.assertTrue(ctx.log['preStep'] == 2 * n) 126a82e8c82SStefano Zampini self.assertTrue(ctx.log['postStep'] == 2 * n) 127a82e8c82SStefano Zampini self.assertTrue(ctx.log['solve'] == 1) 128a82e8c82SStefano Zampini self.assertTrue(ctx.log['setUp'] == 1) 129a82e8c82SStefano Zampini self.assertTrue(ctx.log['setFromOptions'] == 1) 130a82e8c82SStefano Zampini self.assertTrue(ctx.log['step'] == n) 131ebead697SStefano Zampini tao.cancelMonitor() 132a82e8c82SStefano Zampini 133*6f336411SStefano Zampini 134a82e8c82SStefano Zampini# -------------------------------------------------------------------- 135a82e8c82SStefano Zampini 136a82e8c82SStefano Zampiniif numpy.iscomplexobj(PETSc.ScalarType()): 137a82e8c82SStefano Zampini del TestTaoPython 138a82e8c82SStefano Zampini 139a82e8c82SStefano Zampiniif __name__ == '__main__': 140a82e8c82SStefano Zampini unittest.main() 141a82e8c82SStefano Zampini 142a82e8c82SStefano Zampini# -------------------------------------------------------------------- 143