xref: /petsc/src/binding/petsc4py/test/test_mat_py.py (revision 53022affac82b2fcec7b6432d0d3b2c8aa0487f8)
15808f684SSatish Balayfrom petsc4py import PETSc
2*53022affSStefano Zampiniimport unittest, numpy
35808f684SSatish Balayfrom sys import getrefcount
45808f684SSatish Balay# --------------------------------------------------------------------
55808f684SSatish Balay
65808f684SSatish Balayclass Matrix(object):
75808f684SSatish Balay
85808f684SSatish Balay    def __init__(self):
95808f684SSatish Balay        pass
105808f684SSatish Balay
115808f684SSatish Balay    def create(self, mat):
125808f684SSatish Balay        pass
135808f684SSatish Balay
145808f684SSatish Balay    def destroy(self, mat):
155808f684SSatish Balay        pass
165808f684SSatish Balay
175808f684SSatish Balayclass Identity(Matrix):
185808f684SSatish Balay
195808f684SSatish Balay    def mult(self, mat, x, y):
205808f684SSatish Balay        x.copy(y)
215808f684SSatish Balay
225808f684SSatish Balay    def getDiagonal(self, mat, vd):
235808f684SSatish Balay        vd.set(1)
245808f684SSatish Balay
255808f684SSatish Balayclass Diagonal(Matrix):
265808f684SSatish Balay
275808f684SSatish Balay    def create(self, mat):
285808f684SSatish Balay        super(Diagonal,self).create(mat)
295808f684SSatish Balay        mat.setUp()
305808f684SSatish Balay        self.D = mat.createVecLeft()
315808f684SSatish Balay
325808f684SSatish Balay    def destroy(self, mat):
335808f684SSatish Balay        self.D.destroy()
345808f684SSatish Balay        super(Diagonal,self).destroy(mat)
355808f684SSatish Balay
365808f684SSatish Balay    def scale(self, mat, a):
375808f684SSatish Balay        self.D.scale(a)
385808f684SSatish Balay
395808f684SSatish Balay    def shift(self, mat, a):
405808f684SSatish Balay        self.D.shift(a)
415808f684SSatish Balay
425808f684SSatish Balay    def zeroEntries(self, mat):
435808f684SSatish Balay        self.D.zeroEntries()
445808f684SSatish Balay
455808f684SSatish Balay    def mult(self, mat, x, y):
465808f684SSatish Balay        y.pointwiseMult(x, self.D)
475808f684SSatish Balay
485808f684SSatish Balay    def getDiagonal(self, mat, vd):
495808f684SSatish Balay        self.D.copy(vd)
505808f684SSatish Balay
515808f684SSatish Balay    def setDiagonal(self, mat, vd, im):
525808f684SSatish Balay        if isinstance (im, bool):
535808f684SSatish Balay            addv = im
545808f684SSatish Balay            if addv:
555808f684SSatish Balay                self.D.axpy(1, vd)
565808f684SSatish Balay            else:
575808f684SSatish Balay                vd.copy(self.D)
585808f684SSatish Balay        elif im == PETSc.InsertMode.INSERT_VALUES:
595808f684SSatish Balay            vd.copy(self.D)
605808f684SSatish Balay        elif im == PETSc.InsertMode.ADD_VALUES:
615808f684SSatish Balay            self.D.axpy(1, vd)
625808f684SSatish Balay        else:
635808f684SSatish Balay            raise ValueError('wrong InsertMode %d'% im)
645808f684SSatish Balay
655808f684SSatish Balay    def diagonalScale(self, mat, vl, vr):
665808f684SSatish Balay        if vl: self.D.pointwiseMult(self.D, vl)
675808f684SSatish Balay        if vr: self.D.pointwiseMult(self.D, vr)
685808f684SSatish Balay
695808f684SSatish Balay# --------------------------------------------------------------------
705808f684SSatish Balay
715808f684SSatish Balayclass TestMatrix(unittest.TestCase):
725808f684SSatish Balay
735808f684SSatish Balay    COMM = PETSc.COMM_WORLD
745808f684SSatish Balay    PYMOD = __name__
755808f684SSatish Balay    PYCLS = 'Matrix'
765808f684SSatish Balay
775808f684SSatish Balay    def _getCtx(self):
785808f684SSatish Balay        return self.A.getPythonContext()
795808f684SSatish Balay
805808f684SSatish Balay    def setUp(self):
815808f684SSatish Balay        N = self.N = 10
825808f684SSatish Balay        self.A = PETSc.Mat()
835808f684SSatish Balay        if 0: # command line way
845808f684SSatish Balay            self.A.create(self.COMM)
855808f684SSatish Balay            self.A.setSizes([N,N])
865808f684SSatish Balay            self.A.setType('python')
875808f684SSatish Balay            OptDB = PETSc.Options(self.A)
885808f684SSatish Balay            OptDB['mat_python_type'] = '%s.%s' % (self.PYMOD,self.PYCLS)
895808f684SSatish Balay            self.A.setFromOptions()
905808f684SSatish Balay            self.A.setUp()
915808f684SSatish Balay            del OptDB['mat_python_type']
925808f684SSatish Balay            self.assertTrue(self._getCtx() is not None)
935808f684SSatish Balay        else: # python way
945808f684SSatish Balay            context = globals()[self.PYCLS]()
955808f684SSatish Balay            self.A.createPython([N,N], context, comm=self.COMM)
965808f684SSatish Balay            self.A.setUp()
975808f684SSatish Balay            self.assertTrue(self._getCtx() is context)
985808f684SSatish Balay            self.assertEqual(getrefcount(context), 3)
995808f684SSatish Balay            del context
1005808f684SSatish Balay            self.assertEqual(getrefcount(self._getCtx()), 2)
1015808f684SSatish Balay
1025808f684SSatish Balay    def tearDown(self):
1035808f684SSatish Balay        ctx = self.A.getPythonContext()
1045808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
1055808f684SSatish Balay        self.A.destroy() # XXX
1065808f684SSatish Balay        self.A = None
1075808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 2)
1085808f684SSatish Balay        #import gc,pprint; pprint.pprint(gc.get_referrers(ctx))
1095808f684SSatish Balay
1105808f684SSatish Balay    def testBasic(self):
1115808f684SSatish Balay        ctx = self.A.getPythonContext()
1125808f684SSatish Balay        self.assertTrue(self._getCtx() is ctx)
1135808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
1145808f684SSatish Balay
1155808f684SSatish Balay    def testZeroEntries(self):
1165808f684SSatish Balay        f = lambda : self.A.zeroEntries()
1175808f684SSatish Balay        self.assertRaises(Exception, f)
1185808f684SSatish Balay
1195808f684SSatish Balay    def testMult(self):
1205808f684SSatish Balay        x, y = self.A.createVecs()
1215808f684SSatish Balay        f = lambda : self.A.mult(x, y)
1225808f684SSatish Balay        self.assertRaises(Exception, f)
1235808f684SSatish Balay
1245808f684SSatish Balay    def testMultTranspose(self):
1255808f684SSatish Balay        x, y = self.A.createVecs()
1265808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
1275808f684SSatish Balay        self.assertRaises(Exception, f)
1285808f684SSatish Balay
1295808f684SSatish Balay    def testGetDiagonal(self):
1305808f684SSatish Balay        d = self.A.createVecLeft()
1315808f684SSatish Balay        f = lambda : self.A.getDiagonal(d)
1325808f684SSatish Balay        self.assertRaises(Exception, f)
1335808f684SSatish Balay
1345808f684SSatish Balay    def testSetDiagonal(self):
1355808f684SSatish Balay        d = self.A.createVecLeft()
1365808f684SSatish Balay        f = lambda : self.A.setDiagonal(d)
1375808f684SSatish Balay        self.assertRaises(Exception, f)
1385808f684SSatish Balay
1395808f684SSatish Balay    def testDiagonalScale(self):
1405808f684SSatish Balay        x, y = self.A.createVecs()
1415808f684SSatish Balay        f = lambda : self.A.diagonalScale(x, y)
1425808f684SSatish Balay        self.assertRaises(Exception, f)
1435808f684SSatish Balay
1448c2316a8SJeremy Tillay
1455808f684SSatish Balayclass TestIdentity(TestMatrix):
1465808f684SSatish Balay
1475808f684SSatish Balay    PYCLS = 'Identity'
1485808f684SSatish Balay
1495808f684SSatish Balay    def testMult(self):
1505808f684SSatish Balay        x, y = self.A.createVecs()
1515808f684SSatish Balay        x.setRandom()
1525808f684SSatish Balay        self.A.mult(x,y)
1535808f684SSatish Balay        self.assertTrue(y.equal(x))
1545808f684SSatish Balay
1555808f684SSatish Balay    def testMultTransposeSymmKnown(self):
1565808f684SSatish Balay        x, y = self.A.createVecs()
1575808f684SSatish Balay        x.setRandom()
1585808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
1595808f684SSatish Balay        self.A.multTranspose(x,y)
1605808f684SSatish Balay        self.assertTrue(y.equal(x))
1615808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
1625808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
1635808f684SSatish Balay        self.assertRaises(Exception, f)
1645808f684SSatish Balay
1655808f684SSatish Balay    def testMultTransposeNewMeth(self):
1665808f684SSatish Balay        x, y = self.A.createVecs()
1675808f684SSatish Balay        x.setRandom()
1685808f684SSatish Balay        AA = self.A.getPythonContext()
1695808f684SSatish Balay        AA.multTranspose = AA.mult
1705808f684SSatish Balay        self.A.multTranspose(x,y)
1715808f684SSatish Balay        del AA.multTranspose
1725808f684SSatish Balay        self.assertTrue(y.equal(x))
1735808f684SSatish Balay
1745808f684SSatish Balay    def testGetDiagonal(self):
1755808f684SSatish Balay        d = self.A.createVecLeft()
1765808f684SSatish Balay        o = d.duplicate()
1775808f684SSatish Balay        o.set(1)
1785808f684SSatish Balay        self.A.getDiagonal(d)
1795808f684SSatish Balay        self.assertTrue(o.equal(d))
1805808f684SSatish Balay
181*53022affSStefano Zampini    def testH2Opus(self):
182*53022affSStefano Zampini        if not PETSc.Sys.hasExternalPackage("h2opus"):
183*53022affSStefano Zampini            return
184*53022affSStefano Zampini        h = PETSc.Mat()
185*53022affSStefano Zampini
186*53022affSStefano Zampini        # need transpose operation for norm estimation
187*53022affSStefano Zampini        AA = self.A.getPythonContext()
188*53022affSStefano Zampini        AA.multTranspose = AA.mult
189*53022affSStefano Zampini
190*53022affSStefano Zampini        # without coordinates
191*53022affSStefano Zampini        h.createH2OpusFromMat(self.A,leafsize=2)
192*53022affSStefano Zampini        h.assemble()
193*53022affSStefano Zampini        h.destroy()
194*53022affSStefano Zampini
195*53022affSStefano Zampini        # with coordinates
196*53022affSStefano Zampini        coords = numpy.linspace((1,2,3),(10,20,30),self.A.getSize()[0])
197*53022affSStefano Zampini        h.createH2OpusFromMat(self.A,coords,leafsize=2)
198*53022affSStefano Zampini        h.assemble()
199*53022affSStefano Zampini        h.destroy()
200*53022affSStefano Zampini
201*53022affSStefano Zampini        del AA.multTranspose
2025808f684SSatish Balay
2035808f684SSatish Balayclass TestDiagonal(TestMatrix):
2045808f684SSatish Balay
2055808f684SSatish Balay    PYCLS = 'Diagonal'
2065808f684SSatish Balay
2075808f684SSatish Balay    def setUp(self):
2085808f684SSatish Balay        super(TestDiagonal, self).setUp()
2095808f684SSatish Balay        D = self.A.createVecLeft()
2105808f684SSatish Balay        s, e = D.getOwnershipRange()
2115808f684SSatish Balay        for i in range(s, e):
2125808f684SSatish Balay            D[i] = i+1
2135808f684SSatish Balay        D.assemble()
2145808f684SSatish Balay        self.A.setDiagonal(D)
2155808f684SSatish Balay
2165808f684SSatish Balay
2175808f684SSatish Balay    def testZeroEntries(self):
2185808f684SSatish Balay        self.A.zeroEntries()
2195808f684SSatish Balay        D = self._getCtx().D
2205808f684SSatish Balay        self.assertEqual(D.norm(), 0)
2215808f684SSatish Balay
2225808f684SSatish Balay    def testMult(self):
2235808f684SSatish Balay        x, y = self.A.createVecs()
2245808f684SSatish Balay        x.set(1)
2255808f684SSatish Balay        self.A.mult(x,y)
2265808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
2275808f684SSatish Balay
2285808f684SSatish Balay    def testMultTransposeSymmKnown(self):
2295808f684SSatish Balay        x, y = self.A.createVecs()
2305808f684SSatish Balay        x.set(1)
2315808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
2325808f684SSatish Balay        self.A.multTranspose(x,y)
2335808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
2345808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
2355808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
2365808f684SSatish Balay        self.assertRaises(Exception, f)
2375808f684SSatish Balay
2385808f684SSatish Balay    def testMultTransposeNewMeth(self):
2395808f684SSatish Balay        x, y = self.A.createVecs()
2405808f684SSatish Balay        x.set(1)
2415808f684SSatish Balay        AA = self.A.getPythonContext()
2425808f684SSatish Balay        AA.multTranspose = AA.mult
2435808f684SSatish Balay        self.A.multTranspose(x,y)
2445808f684SSatish Balay        del AA.multTranspose
2455808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
2465808f684SSatish Balay
2475808f684SSatish Balay    def testGetDiagonal(self):
2485808f684SSatish Balay        d = self.A.createVecLeft()
2495808f684SSatish Balay        self.A.getDiagonal(d)
2505808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
2515808f684SSatish Balay
2525808f684SSatish Balay    def testSetDiagonal(self):
2535808f684SSatish Balay        d = self.A.createVecLeft()
2545808f684SSatish Balay        d.setRandom()
2555808f684SSatish Balay        self.A.setDiagonal(d)
2565808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
2575808f684SSatish Balay
2585808f684SSatish Balay    def testDiagonalScale(self):
2595808f684SSatish Balay        x, y = self.A.createVecs()
2605808f684SSatish Balay        x.set(2)
2615808f684SSatish Balay        y.set(3)
2625808f684SSatish Balay        old = self._getCtx().D.copy()
2635808f684SSatish Balay        self.A.diagonalScale(x, y)
2645808f684SSatish Balay        D = self._getCtx().D
2655808f684SSatish Balay        self.assertTrue(D.equal(old*6))
2665808f684SSatish Balay
2675808f684SSatish Balay    def testCreateTranspose(self):
2685808f684SSatish Balay        A = self.A
2695808f684SSatish Balay        A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
2705808f684SSatish Balay        AT = PETSc.Mat().createTranspose(A)
2715808f684SSatish Balay        x, y = A.createVecs()
2725808f684SSatish Balay        xt, yt = AT.createVecs()
2735808f684SSatish Balay        #
2745808f684SSatish Balay        y.setRandom()
2755808f684SSatish Balay        A.multTranspose(y, x)
2765808f684SSatish Balay        y.copy(xt)
2775808f684SSatish Balay        AT.mult(xt, yt)
2785808f684SSatish Balay        self.assertTrue(yt.equal(x))
2795808f684SSatish Balay        #
2805808f684SSatish Balay        x.setRandom()
2815808f684SSatish Balay        A.mult(x, y)
2825808f684SSatish Balay        x.copy(yt)
2835808f684SSatish Balay        AT.multTranspose(yt, xt)
2845808f684SSatish Balay        self.assertTrue(xt.equal(y))
2855808f684SSatish Balay        del A
2865808f684SSatish Balay
2878c2316a8SJeremy Tillay
2885808f684SSatish Balay# --------------------------------------------------------------------
2895808f684SSatish Balay
2905808f684SSatish Balayif __name__ == '__main__':
2915808f684SSatish Balay    unittest.main()
292