15808f684SSatish Balayfrom petsc4py import PETSc 25808f684SSatish Balayimport unittest 35808f684SSatish Balayfrom sys import getrefcount 45808f684SSatish Balay# -------------------------------------------------------------------- 55808f684SSatish Balay 65808f684SSatish Balayclass Matrix(object): 75808f684SSatish Balay 85808f684SSatish Balay def __init__(self): 95808f684SSatish Balay pass 105808f684SSatish Balay 115808f684SSatish Balay def create(self, mat): 125808f684SSatish Balay pass 135808f684SSatish Balay 145808f684SSatish Balay def destroy(self, mat): 155808f684SSatish Balay pass 165808f684SSatish Balay 175808f684SSatish Balayclass Identity(Matrix): 185808f684SSatish Balay 195808f684SSatish Balay def mult(self, mat, x, y): 205808f684SSatish Balay x.copy(y) 215808f684SSatish Balay 225808f684SSatish Balay def getDiagonal(self, mat, vd): 235808f684SSatish Balay vd.set(1) 245808f684SSatish Balay 255808f684SSatish Balayclass Diagonal(Matrix): 265808f684SSatish Balay 275808f684SSatish Balay def create(self, mat): 285808f684SSatish Balay super(Diagonal,self).create(mat) 295808f684SSatish Balay mat.setUp() 305808f684SSatish Balay self.D = mat.createVecLeft() 315808f684SSatish Balay 325808f684SSatish Balay def destroy(self, mat): 335808f684SSatish Balay self.D.destroy() 345808f684SSatish Balay super(Diagonal,self).destroy(mat) 355808f684SSatish Balay 365808f684SSatish Balay def scale(self, mat, a): 375808f684SSatish Balay self.D.scale(a) 385808f684SSatish Balay 395808f684SSatish Balay def shift(self, mat, a): 405808f684SSatish Balay self.D.shift(a) 415808f684SSatish Balay 425808f684SSatish Balay def zeroEntries(self, mat): 435808f684SSatish Balay self.D.zeroEntries() 445808f684SSatish Balay 455808f684SSatish Balay def mult(self, mat, x, y): 465808f684SSatish Balay y.pointwiseMult(x, self.D) 475808f684SSatish Balay 485808f684SSatish Balay def getDiagonal(self, mat, vd): 495808f684SSatish Balay self.D.copy(vd) 505808f684SSatish Balay 515808f684SSatish Balay def setDiagonal(self, mat, vd, im): 525808f684SSatish Balay if isinstance (im, bool): 535808f684SSatish Balay addv = im 545808f684SSatish Balay if addv: 555808f684SSatish Balay self.D.axpy(1, vd) 565808f684SSatish Balay else: 575808f684SSatish Balay vd.copy(self.D) 585808f684SSatish Balay elif im == PETSc.InsertMode.INSERT_VALUES: 595808f684SSatish Balay vd.copy(self.D) 605808f684SSatish Balay elif im == PETSc.InsertMode.ADD_VALUES: 615808f684SSatish Balay self.D.axpy(1, vd) 625808f684SSatish Balay else: 635808f684SSatish Balay raise ValueError('wrong InsertMode %d'% im) 645808f684SSatish Balay 655808f684SSatish Balay def diagonalScale(self, mat, vl, vr): 665808f684SSatish Balay if vl: self.D.pointwiseMult(self.D, vl) 675808f684SSatish Balay if vr: self.D.pointwiseMult(self.D, vr) 685808f684SSatish Balay 695808f684SSatish Balay# -------------------------------------------------------------------- 705808f684SSatish Balay 715808f684SSatish Balayclass TestMatrix(unittest.TestCase): 725808f684SSatish Balay 735808f684SSatish Balay COMM = PETSc.COMM_WORLD 745808f684SSatish Balay PYMOD = __name__ 755808f684SSatish Balay PYCLS = 'Matrix' 765808f684SSatish Balay 775808f684SSatish Balay def _getCtx(self): 785808f684SSatish Balay return self.A.getPythonContext() 795808f684SSatish Balay 805808f684SSatish Balay def setUp(self): 815808f684SSatish Balay N = self.N = 10 825808f684SSatish Balay self.A = PETSc.Mat() 835808f684SSatish Balay if 0: # command line way 845808f684SSatish Balay self.A.create(self.COMM) 855808f684SSatish Balay self.A.setSizes([N,N]) 865808f684SSatish Balay self.A.setType('python') 875808f684SSatish Balay OptDB = PETSc.Options(self.A) 885808f684SSatish Balay OptDB['mat_python_type'] = '%s.%s' % (self.PYMOD,self.PYCLS) 895808f684SSatish Balay self.A.setFromOptions() 905808f684SSatish Balay self.A.setUp() 915808f684SSatish Balay del OptDB['mat_python_type'] 925808f684SSatish Balay self.assertTrue(self._getCtx() is not None) 935808f684SSatish Balay else: # python way 945808f684SSatish Balay context = globals()[self.PYCLS]() 955808f684SSatish Balay self.A.createPython([N,N], context, comm=self.COMM) 965808f684SSatish Balay self.A.setUp() 975808f684SSatish Balay self.assertTrue(self._getCtx() is context) 985808f684SSatish Balay self.assertEqual(getrefcount(context), 3) 995808f684SSatish Balay del context 1005808f684SSatish Balay self.assertEqual(getrefcount(self._getCtx()), 2) 1015808f684SSatish Balay 1025808f684SSatish Balay def tearDown(self): 1035808f684SSatish Balay ctx = self.A.getPythonContext() 1045808f684SSatish Balay self.assertEqual(getrefcount(ctx), 3) 1055808f684SSatish Balay self.A.destroy() # XXX 1065808f684SSatish Balay self.A = None 1075808f684SSatish Balay self.assertEqual(getrefcount(ctx), 2) 1085808f684SSatish Balay #import gc,pprint; pprint.pprint(gc.get_referrers(ctx)) 1095808f684SSatish Balay 1105808f684SSatish Balay def testBasic(self): 1115808f684SSatish Balay ctx = self.A.getPythonContext() 1125808f684SSatish Balay self.assertTrue(self._getCtx() is ctx) 1135808f684SSatish Balay self.assertEqual(getrefcount(ctx), 3) 1145808f684SSatish Balay 1155808f684SSatish Balay def testZeroEntries(self): 1165808f684SSatish Balay f = lambda : self.A.zeroEntries() 1175808f684SSatish Balay self.assertRaises(Exception, f) 1185808f684SSatish Balay 1195808f684SSatish Balay def testMult(self): 1205808f684SSatish Balay x, y = self.A.createVecs() 1215808f684SSatish Balay f = lambda : self.A.mult(x, y) 1225808f684SSatish Balay self.assertRaises(Exception, f) 1235808f684SSatish Balay 1245808f684SSatish Balay def testMultTranspose(self): 1255808f684SSatish Balay x, y = self.A.createVecs() 1265808f684SSatish Balay f = lambda : self.A.multTranspose(x, y) 1275808f684SSatish Balay self.assertRaises(Exception, f) 1285808f684SSatish Balay 1295808f684SSatish Balay def testGetDiagonal(self): 1305808f684SSatish Balay d = self.A.createVecLeft() 1315808f684SSatish Balay f = lambda : self.A.getDiagonal(d) 1325808f684SSatish Balay self.assertRaises(Exception, f) 1335808f684SSatish Balay 1345808f684SSatish Balay def testSetDiagonal(self): 1355808f684SSatish Balay d = self.A.createVecLeft() 1365808f684SSatish Balay f = lambda : self.A.setDiagonal(d) 1375808f684SSatish Balay self.assertRaises(Exception, f) 1385808f684SSatish Balay 1395808f684SSatish Balay def testDiagonalScale(self): 1405808f684SSatish Balay x, y = self.A.createVecs() 1415808f684SSatish Balay f = lambda : self.A.diagonalScale(x, y) 1425808f684SSatish Balay self.assertRaises(Exception, f) 1435808f684SSatish Balay 144*8c2316a8SJeremy Tillay 1455808f684SSatish Balayclass TestIdentity(TestMatrix): 1465808f684SSatish Balay 1475808f684SSatish Balay PYCLS = 'Identity' 1485808f684SSatish Balay 1495808f684SSatish Balay def testMult(self): 1505808f684SSatish Balay x, y = self.A.createVecs() 1515808f684SSatish Balay x.setRandom() 1525808f684SSatish Balay self.A.mult(x,y) 1535808f684SSatish Balay self.assertTrue(y.equal(x)) 1545808f684SSatish Balay 1555808f684SSatish Balay def testMultTransposeSymmKnown(self): 1565808f684SSatish Balay x, y = self.A.createVecs() 1575808f684SSatish Balay x.setRandom() 1585808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True) 1595808f684SSatish Balay self.A.multTranspose(x,y) 1605808f684SSatish Balay self.assertTrue(y.equal(x)) 1615808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False) 1625808f684SSatish Balay f = lambda : self.A.multTranspose(x, y) 1635808f684SSatish Balay self.assertRaises(Exception, f) 1645808f684SSatish Balay 1655808f684SSatish Balay def testMultTransposeNewMeth(self): 1665808f684SSatish Balay x, y = self.A.createVecs() 1675808f684SSatish Balay x.setRandom() 1685808f684SSatish Balay AA = self.A.getPythonContext() 1695808f684SSatish Balay AA.multTranspose = AA.mult 1705808f684SSatish Balay self.A.multTranspose(x,y) 1715808f684SSatish Balay del AA.multTranspose 1725808f684SSatish Balay self.assertTrue(y.equal(x)) 1735808f684SSatish Balay 1745808f684SSatish Balay def testGetDiagonal(self): 1755808f684SSatish Balay d = self.A.createVecLeft() 1765808f684SSatish Balay o = d.duplicate() 1775808f684SSatish Balay o.set(1) 1785808f684SSatish Balay self.A.getDiagonal(d) 1795808f684SSatish Balay self.assertTrue(o.equal(d)) 1805808f684SSatish Balay 1815808f684SSatish Balay 1825808f684SSatish Balayclass TestDiagonal(TestMatrix): 1835808f684SSatish Balay 1845808f684SSatish Balay PYCLS = 'Diagonal' 1855808f684SSatish Balay 1865808f684SSatish Balay def setUp(self): 1875808f684SSatish Balay super(TestDiagonal, self).setUp() 1885808f684SSatish Balay D = self.A.createVecLeft() 1895808f684SSatish Balay s, e = D.getOwnershipRange() 1905808f684SSatish Balay for i in range(s, e): 1915808f684SSatish Balay D[i] = i+1 1925808f684SSatish Balay D.assemble() 1935808f684SSatish Balay self.A.setDiagonal(D) 1945808f684SSatish Balay 1955808f684SSatish Balay 1965808f684SSatish Balay def testZeroEntries(self): 1975808f684SSatish Balay self.A.zeroEntries() 1985808f684SSatish Balay D = self._getCtx().D 1995808f684SSatish Balay self.assertEqual(D.norm(), 0) 2005808f684SSatish Balay 2015808f684SSatish Balay def testMult(self): 2025808f684SSatish Balay x, y = self.A.createVecs() 2035808f684SSatish Balay x.set(1) 2045808f684SSatish Balay self.A.mult(x,y) 2055808f684SSatish Balay self.assertTrue(y.equal(self._getCtx().D)) 2065808f684SSatish Balay 2075808f684SSatish Balay def testMultTransposeSymmKnown(self): 2085808f684SSatish Balay x, y = self.A.createVecs() 2095808f684SSatish Balay x.set(1) 2105808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True) 2115808f684SSatish Balay self.A.multTranspose(x,y) 2125808f684SSatish Balay self.assertTrue(y.equal(self._getCtx().D)) 2135808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False) 2145808f684SSatish Balay f = lambda : self.A.multTranspose(x, y) 2155808f684SSatish Balay self.assertRaises(Exception, f) 2165808f684SSatish Balay 2175808f684SSatish Balay def testMultTransposeNewMeth(self): 2185808f684SSatish Balay x, y = self.A.createVecs() 2195808f684SSatish Balay x.set(1) 2205808f684SSatish Balay AA = self.A.getPythonContext() 2215808f684SSatish Balay AA.multTranspose = AA.mult 2225808f684SSatish Balay self.A.multTranspose(x,y) 2235808f684SSatish Balay del AA.multTranspose 2245808f684SSatish Balay self.assertTrue(y.equal(self._getCtx().D)) 2255808f684SSatish Balay 2265808f684SSatish Balay def testGetDiagonal(self): 2275808f684SSatish Balay d = self.A.createVecLeft() 2285808f684SSatish Balay self.A.getDiagonal(d) 2295808f684SSatish Balay self.assertTrue(d.equal(self._getCtx().D)) 2305808f684SSatish Balay 2315808f684SSatish Balay def testSetDiagonal(self): 2325808f684SSatish Balay d = self.A.createVecLeft() 2335808f684SSatish Balay d.setRandom() 2345808f684SSatish Balay self.A.setDiagonal(d) 2355808f684SSatish Balay self.assertTrue(d.equal(self._getCtx().D)) 2365808f684SSatish Balay 2375808f684SSatish Balay def testDiagonalScale(self): 2385808f684SSatish Balay x, y = self.A.createVecs() 2395808f684SSatish Balay x.set(2) 2405808f684SSatish Balay y.set(3) 2415808f684SSatish Balay old = self._getCtx().D.copy() 2425808f684SSatish Balay self.A.diagonalScale(x, y) 2435808f684SSatish Balay D = self._getCtx().D 2445808f684SSatish Balay self.assertTrue(D.equal(old*6)) 2455808f684SSatish Balay 2465808f684SSatish Balay def testCreateTranspose(self): 2475808f684SSatish Balay A = self.A 2485808f684SSatish Balay A.setOption(PETSc.Mat.Option.SYMMETRIC, True) 2495808f684SSatish Balay AT = PETSc.Mat().createTranspose(A) 2505808f684SSatish Balay x, y = A.createVecs() 2515808f684SSatish Balay xt, yt = AT.createVecs() 2525808f684SSatish Balay # 2535808f684SSatish Balay y.setRandom() 2545808f684SSatish Balay A.multTranspose(y, x) 2555808f684SSatish Balay y.copy(xt) 2565808f684SSatish Balay AT.mult(xt, yt) 2575808f684SSatish Balay self.assertTrue(yt.equal(x)) 2585808f684SSatish Balay # 2595808f684SSatish Balay x.setRandom() 2605808f684SSatish Balay A.mult(x, y) 2615808f684SSatish Balay x.copy(yt) 2625808f684SSatish Balay AT.multTranspose(yt, xt) 2635808f684SSatish Balay self.assertTrue(xt.equal(y)) 2645808f684SSatish Balay del A 2655808f684SSatish Balay 266*8c2316a8SJeremy Tillay 2675808f684SSatish Balay# -------------------------------------------------------------------- 2685808f684SSatish Balay 2695808f684SSatish Balayif __name__ == '__main__': 2705808f684SSatish Balay unittest.main() 271