xref: /petsc/src/binding/petsc4py/test/test_tao.py (revision d6e07cdc171d2c54a390877dc4eb7bbacca9b380)
15808f684SSatish Balay# --------------------------------------------------------------------
25808f684SSatish Balay
35808f684SSatish Balayfrom petsc4py import PETSc
45808f684SSatish Balayimport unittest
55808f684SSatish Balay
65808f684SSatish Balay# --------------------------------------------------------------------
75971bccaSStefano Zampiniclass Objective:
85971bccaSStefano Zampini    def __call__(self, tao, x):
95971bccaSStefano Zampini        return (x[0] - 2.0)**2 + (x[1] - 2.0)**2 - 2.0*(x[0] + x[1])
105971bccaSStefano Zampini
115971bccaSStefano Zampiniclass Gradient:
125971bccaSStefano Zampini    def __call__(self, tao, x, g):
135971bccaSStefano Zampini        g[0] = 2.0*(x[0] - 2.0) - 2.0
145971bccaSStefano Zampini        g[1] = 2.0*(x[1] - 2.0) - 2.0
155971bccaSStefano Zampini        g.assemble()
165971bccaSStefano Zampini
175971bccaSStefano Zampiniclass EqConstraints:
185971bccaSStefano Zampini    def __call__(self, tao, x, c):
195971bccaSStefano Zampini        c[0] = x[0]**2 + x[1] - 2.0
205971bccaSStefano Zampini        c.assemble()
215971bccaSStefano Zampini
225971bccaSStefano Zampiniclass EqJacobian:
235971bccaSStefano Zampini    def __call__(self, tao, x, J, P):
245971bccaSStefano Zampini        P[0,0] = 2.0*x[0]
255971bccaSStefano Zampini        P[0,1] = 1.0
265971bccaSStefano Zampini        P.assemble()
275971bccaSStefano Zampini        if J != P: J.assemble()
285808f684SSatish Balay
295808f684SSatish Balayclass BaseTestTAO(object):
305808f684SSatish Balay
315808f684SSatish Balay    COMM = None
325808f684SSatish Balay
335808f684SSatish Balay    def setUp(self):
345808f684SSatish Balay        self.tao = PETSc.TAO().create(comm=self.COMM)
355808f684SSatish Balay
365808f684SSatish Balay    def tearDown(self):
375808f684SSatish Balay        self.tao = None
3862e5d2d2SJDBetteridge        PETSc.garbage_cleanup()
395808f684SSatish Balay
405808f684SSatish Balay    def testSetRoutinesToNone(self):
415808f684SSatish Balay        tao = self.tao
425808f684SSatish Balay        objective, gradient, objgrad = None, None, None
435808f684SSatish Balay        constraint, varbounds = None, None
445808f684SSatish Balay        hessian, jacobian = None, None
455808f684SSatish Balay        tao.setObjective(objective)
46a82e8c82SStefano Zampini        tao.setGradient(gradient,None)
475808f684SSatish Balay        tao.setVariableBounds(varbounds)
48a82e8c82SStefano Zampini        tao.setObjectiveGradient(objgrad,None)
495808f684SSatish Balay        tao.setConstraints(constraint)
505808f684SSatish Balay        tao.setHessian(hessian)
515808f684SSatish Balay        tao.setJacobian(jacobian)
525808f684SSatish Balay
535808f684SSatish Balay    def testGetVecsAndMats(self):
545808f684SSatish Balay        tao = self.tao
555808f684SSatish Balay        x = tao.getSolution()
56a82e8c82SStefano Zampini        (g, _) = tao.getGradient()
575808f684SSatish Balay        l, u = tao.getVariableBounds()
585808f684SSatish Balay        r = None#tao.getConstraintVec()
595808f684SSatish Balay        H, HP = None,None#tao.getHessianMat()
605808f684SSatish Balay        J, JP = None,None#tao.getJacobianMat()
615808f684SSatish Balay        for o in [x, g, r, l, u ,H, HP, J, JP,]:
625808f684SSatish Balay            self.assertFalse(o)
635808f684SSatish Balay
645808f684SSatish Balay    def testGetKSP(self):
655808f684SSatish Balay        ksp = self.tao.getKSP()
665808f684SSatish Balay        self.assertFalse(ksp)
675808f684SSatish Balay
685971bccaSStefano Zampini    def testEqualityConstraints(self):
695971bccaSStefano Zampini        if self.tao.getComm().Get_size() > 1:
705971bccaSStefano Zampini            return
715971bccaSStefano Zampini        tao = self.tao
725971bccaSStefano Zampini
735971bccaSStefano Zampini        x = PETSc.Vec().create(tao.getComm())
745971bccaSStefano Zampini        x.setType('standard')
755971bccaSStefano Zampini        x.setSizes(2)
765971bccaSStefano Zampini        c = PETSc.Vec().create(tao.getComm())
775971bccaSStefano Zampini        c.setSizes(1)
785971bccaSStefano Zampini        c.setType(x.getType())
795971bccaSStefano Zampini        J = PETSc.Mat().create(tao.getComm())
805971bccaSStefano Zampini        J.setSizes([1, 2])
815971bccaSStefano Zampini        J.setType(PETSc.Mat.Type.DENSE)
825971bccaSStefano Zampini        J.setUp()
835971bccaSStefano Zampini
845971bccaSStefano Zampini        tao.setObjective(Objective())
85a82e8c82SStefano Zampini        tao.setGradient(Gradient(),None)
865971bccaSStefano Zampini        tao.setEqualityConstraints(EqConstraints(),c)
875971bccaSStefano Zampini        tao.setJacobianEquality(EqJacobian(),J,J)
88a82e8c82SStefano Zampini        tao.setSolution(x)
895971bccaSStefano Zampini        tao.setType(PETSc.TAO.Type.ALMM)
905971bccaSStefano Zampini        tao.setTolerances(gatol=1.e-4)
915971bccaSStefano Zampini        tao.setFromOptions()
925971bccaSStefano Zampini        tao.solve()
935971bccaSStefano Zampini        self.assertAlmostEqual(abs(x[0]**2 + x[1] - 2.0), 0.0, places=4)
945971bccaSStefano Zampini
95*d6e07cdcSHong Zhang    def testBNCG(self):
96*d6e07cdcSHong Zhang        if self.tao.getComm().Get_size() > 1:
97*d6e07cdcSHong Zhang            return
98*d6e07cdcSHong Zhang        tao = self.tao
99*d6e07cdcSHong Zhang
100*d6e07cdcSHong Zhang        x = PETSc.Vec().create(tao.getComm())
101*d6e07cdcSHong Zhang        x.setType('standard')
102*d6e07cdcSHong Zhang        x.setSizes(2)
103*d6e07cdcSHong Zhang        xl = PETSc.Vec().create(tao.getComm())
104*d6e07cdcSHong Zhang        xl.setType('standard')
105*d6e07cdcSHong Zhang        xl.setSizes(2)
106*d6e07cdcSHong Zhang        xl.set(0.0)
107*d6e07cdcSHong Zhang        xu = PETSc.Vec().create(tao.getComm())
108*d6e07cdcSHong Zhang        xu.setType('standard')
109*d6e07cdcSHong Zhang        xu.setSizes(2)
110*d6e07cdcSHong Zhang        xu.set(2.0)
111*d6e07cdcSHong Zhang        tao.setVariableBounds((xl,xu))
112*d6e07cdcSHong Zhang        tao.setObjective(Objective())
113*d6e07cdcSHong Zhang        tao.setGradient(Gradient(),None)
114*d6e07cdcSHong Zhang        tao.setSolution(x)
115*d6e07cdcSHong Zhang        tao.setType(PETSc.TAO.Type.BNCG)
116*d6e07cdcSHong Zhang        tao.setTolerances(gatol=1.e-4)
117*d6e07cdcSHong Zhang        ls = tao.getLineSearch()
118*d6e07cdcSHong Zhang        ls.setType(PETSc.TAOLineSearch.Type.UNIT)
119*d6e07cdcSHong Zhang        tao.setFromOptions()
120*d6e07cdcSHong Zhang        tao.solve()
121*d6e07cdcSHong Zhang        self.assertAlmostEqual(x[0], 2.0, places=4)
122*d6e07cdcSHong Zhang        self.assertAlmostEqual(x[1], 2.0, places=4)
123*d6e07cdcSHong Zhang
1245808f684SSatish Balay# --------------------------------------------------------------------
1255808f684SSatish Balay
1265808f684SSatish Balayclass TestTAOSelf(BaseTestTAO, unittest.TestCase):
1275808f684SSatish Balay    COMM = PETSc.COMM_SELF
1285808f684SSatish Balay
1295808f684SSatish Balayclass TestTAOWorld(BaseTestTAO, unittest.TestCase):
1305808f684SSatish Balay    COMM = PETSc.COMM_WORLD
1315808f684SSatish Balay
1325808f684SSatish Balay# --------------------------------------------------------------------
1335808f684SSatish Balay
1345808f684SSatish Balayimport numpy
1355808f684SSatish Balayif numpy.iscomplexobj(PETSc.ScalarType()):
1365808f684SSatish Balay    del BaseTestTAO
1375808f684SSatish Balay    del TestTAOSelf
1385808f684SSatish Balay    del TestTAOWorld
1395808f684SSatish Balay
1405808f684SSatish Balayif __name__ == '__main__':
1415808f684SSatish Balay    unittest.main()
142