xref: /petsc/src/binding/petsc4py/test/test_pc_py.py (revision 6f33641175f69f1db294cc9ba81c3f4ad4f81d49)
15808f684SSatish Balay# --------------------------------------------------------------------
25808f684SSatish Balay
35808f684SSatish Balayfrom petsc4py import PETSc
45808f684SSatish Balayimport unittest
55808f684SSatish Balayfrom sys import getrefcount
65808f684SSatish Balay
75808f684SSatish Balay# --------------------------------------------------------------------
85808f684SSatish Balay
9*6f336411SStefano Zampini
10*6f336411SStefano Zampiniclass BaseMyPC:
115808f684SSatish Balay    def setup(self, pc):
125808f684SSatish Balay        pass
13*6f336411SStefano Zampini
145808f684SSatish Balay    def reset(self, pc):
155808f684SSatish Balay        pass
16*6f336411SStefano Zampini
175808f684SSatish Balay    def apply(self, pc, x, y):
185808f684SSatish Balay        raise NotImplementedError
19*6f336411SStefano Zampini
205808f684SSatish Balay    def applyT(self, pc, x, y):
215808f684SSatish Balay        self.apply(pc, x, y)
22*6f336411SStefano Zampini
235808f684SSatish Balay    def applyS(self, pc, x, y):
245808f684SSatish Balay        self.apply(pc, x, y)
25*6f336411SStefano Zampini
265808f684SSatish Balay    def applySL(self, pc, x, y):
275808f684SSatish Balay        self.applyS(pc, x, y)
28*6f336411SStefano Zampini
295808f684SSatish Balay    def applySR(self, pc, x, y):
305808f684SSatish Balay        self.applyS(pc, x, y)
31*6f336411SStefano Zampini
325808f684SSatish Balay    def applyRich(self, pc, x, y, w, tols):
335808f684SSatish Balay        self.apply(pc, x, y)
34*6f336411SStefano Zampini
35bda5c5f8SStefano Zampini    def applyM(self, pc, x, y):
36bda5c5f8SStefano Zampini        raise NotImplementedError
375808f684SSatish Balay
38*6f336411SStefano Zampini
395808f684SSatish Balayclass MyPCNone(BaseMyPC):
405808f684SSatish Balay    def apply(self, pc, x, y):
415808f684SSatish Balay        x.copy(y)
42*6f336411SStefano Zampini
43bda5c5f8SStefano Zampini    def applyM(self, pc, x, y):
44bda5c5f8SStefano Zampini        x.copy(y)
455808f684SSatish Balay
46*6f336411SStefano Zampini
475808f684SSatish Balayclass MyPCJacobi(BaseMyPC):
485808f684SSatish Balay    def setup(self, pc):
495808f684SSatish Balay        A, P = pc.getOperators()
505808f684SSatish Balay        self.diag = P.getDiagonal()
515808f684SSatish Balay        self.diag.reciprocal()
52*6f336411SStefano Zampini
535808f684SSatish Balay    def reset(self, pc):
545808f684SSatish Balay        self.diag.destroy()
555808f684SSatish Balay        del self.diag
56*6f336411SStefano Zampini
575808f684SSatish Balay    def apply(self, pc, x, y):
585808f684SSatish Balay        y.pointwiseMult(self.diag, x)
59*6f336411SStefano Zampini
605808f684SSatish Balay    def applyS(self, pc, x, y):
615808f684SSatish Balay        self.diag.copy(y)
625808f684SSatish Balay        y.sqrtabs()
635808f684SSatish Balay        y.pointwiseMult(y, x)
64*6f336411SStefano Zampini
65bda5c5f8SStefano Zampini    def applyM(self, pc, x, y):
66bda5c5f8SStefano Zampini        x.copy(y)
67bda5c5f8SStefano Zampini        y.diagonalScale(L=self.diag)
685808f684SSatish Balay
695808f684SSatish Balay
70*6f336411SStefano Zampiniclass PC_PYTHON_CLASS:
715808f684SSatish Balay    def __init__(self):
725808f684SSatish Balay        self.impl = None
735808f684SSatish Balay        self.log = {}
74*6f336411SStefano Zampini
755808f684SSatish Balay    def _log(self, method, *args):
765808f684SSatish Balay        self.log.setdefault(method, 0)
775808f684SSatish Balay        self.log[method] += 1
78*6f336411SStefano Zampini
795808f684SSatish Balay    def create(self, pc):
805808f684SSatish Balay        self._log('create', pc)
81*6f336411SStefano Zampini
825808f684SSatish Balay    def destroy(self, pc):
835808f684SSatish Balay        self._log('destroy')
845808f684SSatish Balay        self.impl = None
85*6f336411SStefano Zampini
865808f684SSatish Balay    def reset(self, pc):
875808f684SSatish Balay        self._log('reset', pc)
88*6f336411SStefano Zampini
895808f684SSatish Balay    def view(self, pc, vw):
905808f684SSatish Balay        self._log('view', pc, vw)
91*6f336411SStefano Zampini
925808f684SSatish Balay    def setFromOptions(self, pc):
935808f684SSatish Balay        self._log('setFromOptions', pc)
945808f684SSatish Balay        OptDB = PETSc.Options(pc)
955808f684SSatish Balay        impl = OptDB.getString('impl', 'MyPCNone')
965808f684SSatish Balay        klass = globals()[impl]
975808f684SSatish Balay        self.impl = klass()
98*6f336411SStefano Zampini
995808f684SSatish Balay    def setUp(self, pc):
1005808f684SSatish Balay        self._log('setUp', pc)
1015808f684SSatish Balay        self.impl.setup(pc)
102*6f336411SStefano Zampini
1035808f684SSatish Balay    def preSolve(self, pc, ksp, b, x):
1045808f684SSatish Balay        self._log('preSolve', pc, ksp, b, x)
105*6f336411SStefano Zampini
1065808f684SSatish Balay    def postSolve(self, pc, ksp, b, x):
1075808f684SSatish Balay        self._log('postSolve', pc, ksp, b, x)
108*6f336411SStefano Zampini
1095808f684SSatish Balay    def apply(self, pc, x, y):
1105808f684SSatish Balay        self._log('apply', pc, x, y)
1115808f684SSatish Balay        self.impl.apply(pc, x, y)
112*6f336411SStefano Zampini
1135808f684SSatish Balay    def applySymmetricLeft(self, pc, x, y):
1145808f684SSatish Balay        self._log('applySymmetricLeft', pc, x, y)
1155808f684SSatish Balay        self.impl.applySL(pc, x, y)
116*6f336411SStefano Zampini
1175808f684SSatish Balay    def applySymmetricRight(self, pc, x, y):
1185808f684SSatish Balay        self._log('applySymmetricRight', pc, x, y)
1195808f684SSatish Balay        self.impl.applySR(pc, x, y)
120*6f336411SStefano Zampini
1215808f684SSatish Balay    def applyTranspose(self, pc, x, y):
1225808f684SSatish Balay        self._log('applyTranspose', pc, x, y)
1235808f684SSatish Balay        self.impl.applyT(pc, x, y)
124*6f336411SStefano Zampini
125bda5c5f8SStefano Zampini    def matApply(self, pc, x, y):
126bda5c5f8SStefano Zampini        self._log('matApply', pc, x, y)
127bda5c5f8SStefano Zampini        self.impl.applyM(pc, x, y)
128*6f336411SStefano Zampini
1295808f684SSatish Balay    def applyRichardson(self, pc, x, y, w, tols):
1305808f684SSatish Balay        self._log('applyRichardson', pc, x, y, w, tols)
1315808f684SSatish Balay        self.impl.applyRich(pc, x, y, w, tols)
1325808f684SSatish Balay
1335808f684SSatish Balay
1345808f684SSatish Balayclass TestPCPYTHON(unittest.TestCase):
1355808f684SSatish Balay    PC_TYPE = PETSc.PC.Type.PYTHON
1365808f684SSatish Balay    PC_PREFIX = 'test-'
1375808f684SSatish Balay
1385808f684SSatish Balay    def setUp(self):
1395808f684SSatish Balay        pc = self.pc = PETSc.PC()
1405808f684SSatish Balay        pc.create(PETSc.COMM_SELF)
1415808f684SSatish Balay        pc.setType(self.PC_TYPE)
1425808f684SSatish Balay        module = __name__
1435808f684SSatish Balay        factory = 'PC_PYTHON_CLASS'
1445808f684SSatish Balay        self.pc.prefix = self.PC_PREFIX
1455808f684SSatish Balay        OptDB = PETSc.Options(self.pc)
146*6f336411SStefano Zampini        self.assertTrue(OptDB.prefix == self.pc.prefix)
147*6f336411SStefano Zampini        OptDB['pc_python_type'] = f'{module}.{factory}'
1485808f684SSatish Balay        self.pc.setFromOptions()
1495808f684SSatish Balay        del OptDB['pc_python_type']
150*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['create'] == 1)
151*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['setFromOptions'] == 1)
1525808f684SSatish Balay        ctx = self._getCtx()
1535808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
1545808f684SSatish Balay
155ebead697SStefano Zampini    def testGetType(self):
156ebead697SStefano Zampini        ctx = self.pc.getPythonContext()
157*6f336411SStefano Zampini        pytype = f'{ctx.__module__}.{type(ctx).__name__}'
158ebead697SStefano Zampini        self.assertTrue(self.pc.getPythonType() == pytype)
159ebead697SStefano Zampini
1605808f684SSatish Balay    def tearDown(self):
1615808f684SSatish Balay        ctx = self._getCtx()
1625808f684SSatish Balay        self.pc.destroy()  # XXX
1635808f684SSatish Balay        self.pc = None
16462e5d2d2SJDBetteridge        PETSc.garbage_cleanup()
165*6f336411SStefano Zampini        self.assertTrue(ctx.log['destroy'] == 1)
1665808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 2)
1675808f684SSatish Balay
1685808f684SSatish Balay    def _prepare(self):
1695808f684SSatish Balay        A = PETSc.Mat().createAIJ([3, 3], comm=PETSc.COMM_SELF)
1705808f684SSatish Balay        A.setUp()
1715808f684SSatish Balay        A.assemble()
1725808f684SSatish Balay        A.shift(10)
1735808f684SSatish Balay        x, y = A.createVecs()
1745808f684SSatish Balay        x.setRandom()
1755808f684SSatish Balay        self.pc.setOperators(A, A)
176bda5c5f8SStefano Zampini        X = PETSc.Mat().createDense([3, 5], comm=PETSc.COMM_SELF).setUp()
177bda5c5f8SStefano Zampini        X.assemble()
178bda5c5f8SStefano Zampini        Y = PETSc.Mat().createDense([3, 5], comm=PETSc.COMM_SELF).setUp()
179bda5c5f8SStefano Zampini        Y.assemble()
180*6f336411SStefano Zampini        self.assertTrue((A, A) == self.pc.getOperators())
181bda5c5f8SStefano Zampini        return A, x, y, X, Y
1825808f684SSatish Balay
1835808f684SSatish Balay    def _getCtx(self):
1845808f684SSatish Balay        return self.pc.getPythonContext()
1855808f684SSatish Balay
1865808f684SSatish Balay    def _applyMeth(self, meth):
187bda5c5f8SStefano Zampini        A, x, y, X, Y = self._prepare()
188bda5c5f8SStefano Zampini        if meth == 'matApply':
189bda5c5f8SStefano Zampini            getattr(self.pc, meth)(X, Y)
190bda5c5f8SStefano Zampini            x.copy(y)
191bda5c5f8SStefano Zampini        else:
1925808f684SSatish Balay            getattr(self.pc, meth)(x, y)
193bda5c5f8SStefano Zampini            X.copy(Y)
1945808f684SSatish Balay        if 'reset' not in self._getCtx().log:
195*6f336411SStefano Zampini            self.assertTrue(self._getCtx().log['setUp'] == 1)
196*6f336411SStefano Zampini            self.assertTrue(self._getCtx().log[meth] == 1)
1975808f684SSatish Balay        else:
1985808f684SSatish Balay            nreset = self._getCtx().log['reset']
1995808f684SSatish Balay            nsetup = self._getCtx().log['setUp']
2005808f684SSatish Balay            nmeth = self._getCtx().log[meth]
201*6f336411SStefano Zampini            self.assertTrue(nreset == nsetup)
202*6f336411SStefano Zampini            self.assertTrue(nreset == nmeth)
2035808f684SSatish Balay        if isinstance(self._getCtx().impl, MyPCNone):
2045808f684SSatish Balay            self.assertTrue(y.equal(x))
205bda5c5f8SStefano Zampini            self.assertTrue(Y.equal(X))
206*6f336411SStefano Zampini
2075808f684SSatish Balay    def testApply(self):
2085808f684SSatish Balay        self._applyMeth('apply')
209*6f336411SStefano Zampini
2105808f684SSatish Balay    def testApplySymmetricLeft(self):
2115808f684SSatish Balay        self._applyMeth('applySymmetricLeft')
212*6f336411SStefano Zampini
2135808f684SSatish Balay    def testApplySymmetricRight(self):
2145808f684SSatish Balay        self._applyMeth('applySymmetricRight')
215*6f336411SStefano Zampini
2165808f684SSatish Balay    def testApplyTranspose(self):
2175808f684SSatish Balay        self._applyMeth('applyTranspose')
218*6f336411SStefano Zampini
219bda5c5f8SStefano Zampini    def testApplyMat(self):
220bda5c5f8SStefano Zampini        self._applyMeth('matApply')
221*6f336411SStefano Zampini
2225808f684SSatish Balay    ## def testApplyRichardson(self):
2235808f684SSatish Balay    ##     x, y = self._prepare()
2245808f684SSatish Balay    ##     w = x.duplicate()
2255808f684SSatish Balay    ##     tols = 0,0,0,0
2265808f684SSatish Balay    ##     self.pc.applyRichardson(x,y,w,tols)
2275808f684SSatish Balay    ##     assert self._getCtx().log['setUp'] == 1
2285808f684SSatish Balay    ##     assert self._getCtx().log['applyRichardson'] == 1
2295808f684SSatish Balay
2305808f684SSatish Balay    ## def testView(self):
2315808f684SSatish Balay    ##     vw = PETSc.ViewerString(100, self.pc.comm)
2325808f684SSatish Balay    ##     self.pc.view(vw)
2335808f684SSatish Balay    ##     s = vw.getString()
2345808f684SSatish Balay    ##     assert 'python' in s
2355808f684SSatish Balay    ##     module = __name__
2365808f684SSatish Balay    ##     factory = 'self._getCtx()'
2375808f684SSatish Balay    ##     assert '.'.join([module, factory]) in s
2385808f684SSatish Balay
2395808f684SSatish Balay    def testResetAndApply(self):
2405808f684SSatish Balay        self.pc.reset()
2415808f684SSatish Balay        self.testApply()
2425808f684SSatish Balay        self.pc.reset()
2435808f684SSatish Balay        self.testApply()
2445808f684SSatish Balay        self.pc.reset()
2455808f684SSatish Balay
2465808f684SSatish Balay    def testKSPSolve(self):
247bda5c5f8SStefano Zampini        A, x, y, _, _ = self._prepare()
2485808f684SSatish Balay        ksp = PETSc.KSP().create(self.pc.comm)
2495808f684SSatish Balay        ksp.setType(PETSc.KSP.Type.PREONLY)
250*6f336411SStefano Zampini        self.assertTrue(self.pc.getRefCount() == 1)
2515808f684SSatish Balay        ksp.setPC(self.pc)
252*6f336411SStefano Zampini        self.assertTrue(self.pc.getRefCount() == 2)
2535808f684SSatish Balay        # normal ksp solve, twice
2545808f684SSatish Balay        ksp.solve(x, y)
255*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['setUp'] == 1)
256*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['apply'] == 1)
257*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['preSolve'] == 1)
258*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['postSolve'] == 1)
2595808f684SSatish Balay        ksp.solve(x, y)
260*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['setUp'] == 1)
261*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['apply'] == 2)
262*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['preSolve'] == 2)
263*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['postSolve'] == 2)
2645808f684SSatish Balay        # transpose ksp solve, twice
2655808f684SSatish Balay        ksp.solveTranspose(x, y)
266*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['setUp'] == 1)
267*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['applyTranspose'] == 1)
2685808f684SSatish Balay        ksp.solveTranspose(x, y)
269*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['setUp'] == 1)
270*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['applyTranspose'] == 2)
2715808f684SSatish Balay        del ksp  # ksp.destroy()
27262e5d2d2SJDBetteridge        PETSc.garbage_cleanup()
27362e5d2d2SJDBetteridge        self.assertEqual(self.pc.getRefCount(), 1)
2745808f684SSatish Balay
2755808f684SSatish Balay    def testGetSetContext(self):
2765808f684SSatish Balay        ctx = self.pc.getPythonContext()
2775808f684SSatish Balay        self.pc.setPythonContext(ctx)
2785808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
2795808f684SSatish Balay        del ctx
2805808f684SSatish Balay
2815808f684SSatish Balay
2825808f684SSatish Balayclass TestPCPYTHON2(TestPCPYTHON):
2835808f684SSatish Balay    def setUp(self):
2845808f684SSatish Balay        OptDB = PETSc.Options(self.PC_PREFIX)
2855808f684SSatish Balay        OptDB['impl'] = 'MyPCJacobi'
286*6f336411SStefano Zampini        super().setUp()
2875808f684SSatish Balay        clsname = type(self._getCtx().impl).__name__
288*6f336411SStefano Zampini        self.assertTrue(clsname == OptDB['impl'])
2895808f684SSatish Balay        del OptDB['impl']
2905808f684SSatish Balay
291*6f336411SStefano Zampini
2925808f684SSatish Balayclass TestPCPYTHON3(TestPCPYTHON):
2935808f684SSatish Balay    def setUp(self):
2945808f684SSatish Balay        pc = self.pc = PETSc.PC()
2955808f684SSatish Balay        ctx = PC_PYTHON_CLASS()
2965808f684SSatish Balay        pc.createPython(ctx, comm=PETSc.COMM_SELF)
2975808f684SSatish Balay        self.pc.prefix = self.PC_PREFIX
2985808f684SSatish Balay        self.pc.setFromOptions()
299*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['create'] == 1)
300*6f336411SStefano Zampini        self.assertTrue(self._getCtx().log['setFromOptions'] == 1)
301*6f336411SStefano Zampini
3025808f684SSatish Balay
3035808f684SSatish Balayclass TestPCPYTHON4(TestPCPYTHON3):
3045808f684SSatish Balay    def setUp(self):
3055808f684SSatish Balay        OptDB = PETSc.Options(self.PC_PREFIX)
3065808f684SSatish Balay        OptDB['impl'] = 'MyPCJacobi'
307*6f336411SStefano Zampini        super().setUp()
3085808f684SSatish Balay        clsname = type(self._getCtx().impl).__name__
309*6f336411SStefano Zampini        self.assertTrue(clsname == OptDB['impl'])
3105808f684SSatish Balay        del OptDB['impl']
3115808f684SSatish Balay
312*6f336411SStefano Zampini
3135808f684SSatish Balay# --------------------------------------------------------------------
3145808f684SSatish Balay
3155808f684SSatish Balayif __name__ == '__main__':
3165808f684SSatish Balay    unittest.main()
317