xref: /petsc/src/binding/petsc4py/test/test_tao.py (revision 6f33641175f69f1db294cc9ba81c3f4ad4f81d49)
15808f684SSatish Balay# --------------------------------------------------------------------
25808f684SSatish Balay
35808f684SSatish Balayfrom petsc4py import PETSc
45808f684SSatish Balayimport unittest
5*6f336411SStefano Zampiniimport numpy
6*6f336411SStefano Zampini
75808f684SSatish Balay
85808f684SSatish Balay# --------------------------------------------------------------------
95971bccaSStefano Zampiniclass Objective:
105971bccaSStefano Zampini    def __call__(self, tao, x):
115971bccaSStefano Zampini        return (x[0] - 2.0) ** 2 + (x[1] - 2.0) ** 2 - 2.0 * (x[0] + x[1])
125971bccaSStefano Zampini
13*6f336411SStefano Zampini
145971bccaSStefano Zampiniclass Gradient:
155971bccaSStefano Zampini    def __call__(self, tao, x, g):
165971bccaSStefano Zampini        g[0] = 2.0 * (x[0] - 2.0) - 2.0
175971bccaSStefano Zampini        g[1] = 2.0 * (x[1] - 2.0) - 2.0
185971bccaSStefano Zampini        g.assemble()
195971bccaSStefano Zampini
20*6f336411SStefano Zampini
215971bccaSStefano Zampiniclass EqConstraints:
225971bccaSStefano Zampini    def __call__(self, tao, x, c):
235971bccaSStefano Zampini        c[0] = x[0] ** 2 + x[1] - 2.0
245971bccaSStefano Zampini        c.assemble()
255971bccaSStefano Zampini
26*6f336411SStefano Zampini
275971bccaSStefano Zampiniclass EqJacobian:
285971bccaSStefano Zampini    def __call__(self, tao, x, J, P):
295971bccaSStefano Zampini        P[0, 0] = 2.0 * x[0]
305971bccaSStefano Zampini        P[0, 1] = 1.0
315971bccaSStefano Zampini        P.assemble()
32*6f336411SStefano Zampini        if J != P:
33*6f336411SStefano Zampini            J.assemble()
345808f684SSatish Balay
355808f684SSatish Balay
36*6f336411SStefano Zampiniclass BaseTestTAO:
375808f684SSatish Balay    COMM = None
385808f684SSatish Balay
395808f684SSatish Balay    def setUp(self):
405808f684SSatish Balay        self.tao = PETSc.TAO().create(comm=self.COMM)
415808f684SSatish Balay
425808f684SSatish Balay    def tearDown(self):
435808f684SSatish Balay        self.tao = None
4462e5d2d2SJDBetteridge        PETSc.garbage_cleanup()
455808f684SSatish Balay
465808f684SSatish Balay    def testSetRoutinesToNone(self):
475808f684SSatish Balay        tao = self.tao
485808f684SSatish Balay        objective, gradient, objgrad = None, None, None
495808f684SSatish Balay        constraint, varbounds = None, None
505808f684SSatish Balay        hessian, jacobian = None, None
515808f684SSatish Balay        tao.setObjective(objective)
52a82e8c82SStefano Zampini        tao.setGradient(gradient, None)
535808f684SSatish Balay        tao.setVariableBounds(varbounds)
54a82e8c82SStefano Zampini        tao.setObjectiveGradient(objgrad, None)
555808f684SSatish Balay        tao.setConstraints(constraint)
565808f684SSatish Balay        tao.setHessian(hessian)
575808f684SSatish Balay        tao.setJacobian(jacobian)
585808f684SSatish Balay
595808f684SSatish Balay    def testGetVecsAndMats(self):
605808f684SSatish Balay        tao = self.tao
615808f684SSatish Balay        x = tao.getSolution()
62a82e8c82SStefano Zampini        (g, _) = tao.getGradient()
63*6f336411SStefano Zampini        low, up = tao.getVariableBounds()
645808f684SSatish Balay        r = None  # tao.getConstraintVec()
655808f684SSatish Balay        H, HP = None, None  # tao.getHessianMat()
665808f684SSatish Balay        J, JP = None, None  # tao.getJacobianMat()
67*6f336411SStefano Zampini        for o in [
68*6f336411SStefano Zampini            x,
69*6f336411SStefano Zampini            g,
70*6f336411SStefano Zampini            r,
71*6f336411SStefano Zampini            low,
72*6f336411SStefano Zampini            up,
73*6f336411SStefano Zampini            H,
74*6f336411SStefano Zampini            HP,
75*6f336411SStefano Zampini            J,
76*6f336411SStefano Zampini            JP,
77*6f336411SStefano Zampini        ]:
785808f684SSatish Balay            self.assertFalse(o)
795808f684SSatish Balay
805808f684SSatish Balay    def testGetKSP(self):
815808f684SSatish Balay        ksp = self.tao.getKSP()
825808f684SSatish Balay        self.assertFalse(ksp)
835808f684SSatish Balay
845971bccaSStefano Zampini    def testEqualityConstraints(self):
855971bccaSStefano Zampini        if self.tao.getComm().Get_size() > 1:
865971bccaSStefano Zampini            return
875971bccaSStefano Zampini        tao = self.tao
885971bccaSStefano Zampini
895971bccaSStefano Zampini        x = PETSc.Vec().create(tao.getComm())
905971bccaSStefano Zampini        x.setType('standard')
915971bccaSStefano Zampini        x.setSizes(2)
925971bccaSStefano Zampini        c = PETSc.Vec().create(tao.getComm())
935971bccaSStefano Zampini        c.setSizes(1)
945971bccaSStefano Zampini        c.setType(x.getType())
955971bccaSStefano Zampini        J = PETSc.Mat().create(tao.getComm())
965971bccaSStefano Zampini        J.setSizes([1, 2])
975971bccaSStefano Zampini        J.setType(PETSc.Mat.Type.DENSE)
985971bccaSStefano Zampini        J.setUp()
995971bccaSStefano Zampini
1005971bccaSStefano Zampini        tao.setObjective(Objective())
101a82e8c82SStefano Zampini        tao.setGradient(Gradient(), None)
1025971bccaSStefano Zampini        tao.setEqualityConstraints(EqConstraints(), c)
1035971bccaSStefano Zampini        tao.setJacobianEquality(EqJacobian(), J, J)
104a82e8c82SStefano Zampini        tao.setSolution(x)
1055971bccaSStefano Zampini        tao.setType(PETSc.TAO.Type.ALMM)
106*6f336411SStefano Zampini        tao.setTolerances(gatol=1.0e-4)
1075971bccaSStefano Zampini        tao.setFromOptions()
1085971bccaSStefano Zampini        tao.solve()
1095971bccaSStefano Zampini        self.assertAlmostEqual(abs(x[0] ** 2 + x[1] - 2.0), 0.0, places=4)
1105971bccaSStefano Zampini
111d6e07cdcSHong Zhang    def testBNCG(self):
112d6e07cdcSHong Zhang        if self.tao.getComm().Get_size() > 1:
113d6e07cdcSHong Zhang            return
114d6e07cdcSHong Zhang        tao = self.tao
115d6e07cdcSHong Zhang
116d6e07cdcSHong Zhang        x = PETSc.Vec().create(tao.getComm())
117d6e07cdcSHong Zhang        x.setType('standard')
118d6e07cdcSHong Zhang        x.setSizes(2)
119d6e07cdcSHong Zhang        xl = PETSc.Vec().create(tao.getComm())
120d6e07cdcSHong Zhang        xl.setType('standard')
121d6e07cdcSHong Zhang        xl.setSizes(2)
122d6e07cdcSHong Zhang        xl.set(0.0)
123d6e07cdcSHong Zhang        xu = PETSc.Vec().create(tao.getComm())
124d6e07cdcSHong Zhang        xu.setType('standard')
125d6e07cdcSHong Zhang        xu.setSizes(2)
126d6e07cdcSHong Zhang        xu.set(2.0)
127d6e07cdcSHong Zhang        tao.setVariableBounds((xl, xu))
128d6e07cdcSHong Zhang        tao.setObjective(Objective())
129d6e07cdcSHong Zhang        tao.setGradient(Gradient(), None)
130d6e07cdcSHong Zhang        tao.setSolution(x)
131d6e07cdcSHong Zhang        tao.setType(PETSc.TAO.Type.BNCG)
132*6f336411SStefano Zampini        tao.setTolerances(gatol=1.0e-4)
133d6e07cdcSHong Zhang        ls = tao.getLineSearch()
134d6e07cdcSHong Zhang        ls.setType(PETSc.TAOLineSearch.Type.UNIT)
135d6e07cdcSHong Zhang        tao.setFromOptions()
136d6e07cdcSHong Zhang        tao.solve()
137d6e07cdcSHong Zhang        self.assertAlmostEqual(x[0], 2.0, places=4)
138d6e07cdcSHong Zhang        self.assertAlmostEqual(x[1], 2.0, places=4)
139d6e07cdcSHong Zhang
140*6f336411SStefano Zampini
1415808f684SSatish Balay# --------------------------------------------------------------------
1425808f684SSatish Balay
143*6f336411SStefano Zampini
1445808f684SSatish Balayclass TestTAOSelf(BaseTestTAO, unittest.TestCase):
1455808f684SSatish Balay    COMM = PETSc.COMM_SELF
1465808f684SSatish Balay
147*6f336411SStefano Zampini
1485808f684SSatish Balayclass TestTAOWorld(BaseTestTAO, unittest.TestCase):
1495808f684SSatish Balay    COMM = PETSc.COMM_WORLD
1505808f684SSatish Balay
151*6f336411SStefano Zampini
1525808f684SSatish Balay# --------------------------------------------------------------------
1535808f684SSatish Balay
154*6f336411SStefano Zampini
1555808f684SSatish Balayif numpy.iscomplexobj(PETSc.ScalarType()):
1565808f684SSatish Balay    del BaseTestTAO
1575808f684SSatish Balay    del TestTAOSelf
1585808f684SSatish Balay    del TestTAOWorld
1595808f684SSatish Balay
1605808f684SSatish Balayif __name__ == '__main__':
1615808f684SSatish Balay    unittest.main()
162