1*a82e8c82SStefano Zampiniimport unittest 2*a82e8c82SStefano Zampinifrom petsc4py import PETSc 3*a82e8c82SStefano Zampinifrom sys import getrefcount 4*a82e8c82SStefano Zampiniimport gc 5*a82e8c82SStefano Zampini 6*a82e8c82SStefano Zampini# -------------------------------------------------------------------- 7*a82e8c82SStefano Zampiniclass Objective: 8*a82e8c82SStefano Zampini def __call__(self, tao, x): 9*a82e8c82SStefano Zampini return (x[0] - 1.0)**2 + (x[1] - 2.0)**2 10*a82e8c82SStefano Zampini 11*a82e8c82SStefano Zampiniclass Gradient: 12*a82e8c82SStefano Zampini def __call__(self, tao, x, g): 13*a82e8c82SStefano Zampini g[0] = 2.0*(x[0] - 1.0) 14*a82e8c82SStefano Zampini g[1] = 2.0*(x[1] - 2.0) 15*a82e8c82SStefano Zampini g.assemble() 16*a82e8c82SStefano Zampini 17*a82e8c82SStefano Zampiniclass MyTao: 18*a82e8c82SStefano Zampini def __init__(self): 19*a82e8c82SStefano Zampini self.log = {} 20*a82e8c82SStefano Zampini 21*a82e8c82SStefano Zampini def _log(self, method): 22*a82e8c82SStefano Zampini self.log.setdefault(method, 0) 23*a82e8c82SStefano Zampini self.log[method] += 1 24*a82e8c82SStefano Zampini 25*a82e8c82SStefano Zampini def create(self, tao): 26*a82e8c82SStefano Zampini self._log('create') 27*a82e8c82SStefano Zampini self.testvec = PETSc.Vec() 28*a82e8c82SStefano Zampini 29*a82e8c82SStefano Zampini def destroy(self, tao): 30*a82e8c82SStefano Zampini self._log('destroy') 31*a82e8c82SStefano Zampini self.testvec.destroy() 32*a82e8c82SStefano Zampini 33*a82e8c82SStefano Zampini def setFromOptions(self, tao): 34*a82e8c82SStefano Zampini self._log('setFromOptions') 35*a82e8c82SStefano Zampini 36*a82e8c82SStefano Zampini def setUp(self, tao): 37*a82e8c82SStefano Zampini self._log('setUp') 38*a82e8c82SStefano Zampini self.testvec = tao.getSolution().duplicate() 39*a82e8c82SStefano Zampini 40*a82e8c82SStefano Zampini def solve(self, tao): 41*a82e8c82SStefano Zampini self._log('solve') 42*a82e8c82SStefano Zampini 43*a82e8c82SStefano Zampini def step(self, tao, x, g, s): 44*a82e8c82SStefano Zampini self._log('step') 45*a82e8c82SStefano Zampini tao.computeGradient(x,g) 46*a82e8c82SStefano Zampini g.copy(s) 47*a82e8c82SStefano Zampini s.scale(-1.0) 48*a82e8c82SStefano Zampini 49*a82e8c82SStefano Zampini def preStep(self, tao): 50*a82e8c82SStefano Zampini self._log('preStep') 51*a82e8c82SStefano Zampini 52*a82e8c82SStefano Zampini def postStep(self, tao): 53*a82e8c82SStefano Zampini self._log('postStep') 54*a82e8c82SStefano Zampini 55*a82e8c82SStefano Zampini def monitor(self, tao): 56*a82e8c82SStefano Zampini self._log('monitor') 57*a82e8c82SStefano Zampini 58*a82e8c82SStefano Zampiniclass TestTaoPython(unittest.TestCase): 59*a82e8c82SStefano Zampini 60*a82e8c82SStefano Zampini def setUp(self): 61*a82e8c82SStefano Zampini self.tao = PETSc.TAO() 62*a82e8c82SStefano Zampini self.tao.createPython(MyTao(), comm=PETSc.COMM_SELF) 63*a82e8c82SStefano Zampini ctx = self.tao.getPythonContext() 64*a82e8c82SStefano Zampini self.assertEqual(getrefcount(ctx), 3) 65*a82e8c82SStefano Zampini self.assertEqual(ctx.log['create'], 1) 66*a82e8c82SStefano Zampini self.nsolve = 0 67*a82e8c82SStefano Zampini 68*a82e8c82SStefano Zampini def tearDown(self): 69*a82e8c82SStefano Zampini ctx = self.tao.getPythonContext() 70*a82e8c82SStefano Zampini self.assertEqual(getrefcount(ctx), 4) 71*a82e8c82SStefano Zampini self.assertTrue('destroy' not in ctx.log) 72*a82e8c82SStefano Zampini self.tao.destroy() 73*a82e8c82SStefano Zampini self.tao = None 74*a82e8c82SStefano Zampini self.assertEqual(ctx.log['destroy'], 1) 75*a82e8c82SStefano Zampini self.assertEqual(getrefcount(ctx), 2) 76*a82e8c82SStefano Zampini 77*a82e8c82SStefano Zampini def testSolve(self): 78*a82e8c82SStefano Zampini tao = self.tao 79*a82e8c82SStefano Zampini ctx = tao.getPythonContext() 80*a82e8c82SStefano Zampini x = PETSc.Vec().create(tao.getComm()) 81*a82e8c82SStefano Zampini x.setType('standard') 82*a82e8c82SStefano Zampini x.setSizes(2) 83*a82e8c82SStefano Zampini y1 = x.duplicate() 84*a82e8c82SStefano Zampini y2 = x.duplicate() 85*a82e8c82SStefano Zampini tao.setObjective(Objective()) 86*a82e8c82SStefano Zampini tao.setGradient(Gradient(),None) 87*a82e8c82SStefano Zampini tao.setMonitor(ctx.monitor) 88*a82e8c82SStefano Zampini tao.setFromOptions() 89*a82e8c82SStefano Zampini tao.setMaximumIterations(3) 90*a82e8c82SStefano Zampini tao.setSolution(x) 91*a82e8c82SStefano Zampini 92*a82e8c82SStefano Zampini # Call the solve method of MyTAO 93*a82e8c82SStefano Zampini x.set(0.5) 94*a82e8c82SStefano Zampini tao.solve() 95*a82e8c82SStefano Zampini n = tao.getIterationNumber() 96*a82e8c82SStefano Zampini self.assertTrue(n == 0) 97*a82e8c82SStefano Zampini 98*a82e8c82SStefano Zampini # Call the default solve method and use step of MyTAO 99*a82e8c82SStefano Zampini ctx.solve = None 100*a82e8c82SStefano Zampini x.set(0.5) 101*a82e8c82SStefano Zampini tao.solve() 102*a82e8c82SStefano Zampini n = tao.getIterationNumber() 103*a82e8c82SStefano Zampini self.assertTrue(n == 3) 104*a82e8c82SStefano Zampini x.copy(y1) 105*a82e8c82SStefano Zampini 106*a82e8c82SStefano Zampini # Call the default solve method with the default step method 107*a82e8c82SStefano Zampini ctx.step = None 108*a82e8c82SStefano Zampini x.set(0.5) 109*a82e8c82SStefano Zampini tao.solve() 110*a82e8c82SStefano Zampini n = tao.getIterationNumber() 111*a82e8c82SStefano Zampini self.assertTrue(n == 3) 112*a82e8c82SStefano Zampini x.copy(y2) 113*a82e8c82SStefano Zampini 114*a82e8c82SStefano Zampini self.assertTrue(y1.equal(y2)) 115*a82e8c82SStefano Zampini self.assertTrue(ctx.log['monitor'] == 2*(n+1)) 116*a82e8c82SStefano Zampini self.assertTrue(ctx.log['preStep'] == 2*n) 117*a82e8c82SStefano Zampini self.assertTrue(ctx.log['postStep'] == 2*n) 118*a82e8c82SStefano Zampini self.assertTrue(ctx.log['solve'] == 1) 119*a82e8c82SStefano Zampini self.assertTrue(ctx.log['setUp'] == 1) 120*a82e8c82SStefano Zampini self.assertTrue(ctx.log['setFromOptions'] == 1) 121*a82e8c82SStefano Zampini self.assertTrue(ctx.log['step'] == n) 122*a82e8c82SStefano Zampini 123*a82e8c82SStefano Zampini# -------------------------------------------------------------------- 124*a82e8c82SStefano Zampini 125*a82e8c82SStefano Zampiniimport numpy 126*a82e8c82SStefano Zampiniif numpy.iscomplexobj(PETSc.ScalarType()): 127*a82e8c82SStefano Zampini del TestTaoPython 128*a82e8c82SStefano Zampini 129*a82e8c82SStefano Zampiniif __name__ == '__main__': 130*a82e8c82SStefano Zampini unittest.main() 131*a82e8c82SStefano Zampini 132*a82e8c82SStefano Zampini# -------------------------------------------------------------------- 133