1*5808f684SSatish Balayfrom petsc4py import PETSc 2*5808f684SSatish Balayimport unittest 3*5808f684SSatish Balayfrom sys import getrefcount 4*5808f684SSatish Balay 5*5808f684SSatish Balay# -------------------------------------------------------------------- 6*5808f684SSatish Balay 7*5808f684SSatish Balayclass Matrix(object): 8*5808f684SSatish Balay 9*5808f684SSatish Balay def __init__(self): 10*5808f684SSatish Balay pass 11*5808f684SSatish Balay 12*5808f684SSatish Balay def create(self, mat): 13*5808f684SSatish Balay pass 14*5808f684SSatish Balay 15*5808f684SSatish Balay def destroy(self, mat): 16*5808f684SSatish Balay pass 17*5808f684SSatish Balay 18*5808f684SSatish Balayclass Identity(Matrix): 19*5808f684SSatish Balay 20*5808f684SSatish Balay def mult(self, mat, x, y): 21*5808f684SSatish Balay x.copy(y) 22*5808f684SSatish Balay 23*5808f684SSatish Balay def getDiagonal(self, mat, vd): 24*5808f684SSatish Balay vd.set(1) 25*5808f684SSatish Balay 26*5808f684SSatish Balayclass Diagonal(Matrix): 27*5808f684SSatish Balay 28*5808f684SSatish Balay def create(self, mat): 29*5808f684SSatish Balay super(Diagonal,self).create(mat) 30*5808f684SSatish Balay mat.setUp() 31*5808f684SSatish Balay self.D = mat.createVecLeft() 32*5808f684SSatish Balay 33*5808f684SSatish Balay def destroy(self, mat): 34*5808f684SSatish Balay self.D.destroy() 35*5808f684SSatish Balay super(Diagonal,self).destroy(mat) 36*5808f684SSatish Balay 37*5808f684SSatish Balay def scale(self, mat, a): 38*5808f684SSatish Balay self.D.scale(a) 39*5808f684SSatish Balay 40*5808f684SSatish Balay def shift(self, mat, a): 41*5808f684SSatish Balay self.D.shift(a) 42*5808f684SSatish Balay 43*5808f684SSatish Balay def zeroEntries(self, mat): 44*5808f684SSatish Balay self.D.zeroEntries() 45*5808f684SSatish Balay 46*5808f684SSatish Balay def mult(self, mat, x, y): 47*5808f684SSatish Balay y.pointwiseMult(x, self.D) 48*5808f684SSatish Balay 49*5808f684SSatish Balay def getDiagonal(self, mat, vd): 50*5808f684SSatish Balay self.D.copy(vd) 51*5808f684SSatish Balay 52*5808f684SSatish Balay def setDiagonal(self, mat, vd, im): 53*5808f684SSatish Balay if isinstance (im, bool): 54*5808f684SSatish Balay addv = im 55*5808f684SSatish Balay if addv: 56*5808f684SSatish Balay self.D.axpy(1, vd) 57*5808f684SSatish Balay else: 58*5808f684SSatish Balay vd.copy(self.D) 59*5808f684SSatish Balay elif im == PETSc.InsertMode.INSERT_VALUES: 60*5808f684SSatish Balay vd.copy(self.D) 61*5808f684SSatish Balay elif im == PETSc.InsertMode.ADD_VALUES: 62*5808f684SSatish Balay self.D.axpy(1, vd) 63*5808f684SSatish Balay else: 64*5808f684SSatish Balay raise ValueError('wrong InsertMode %d'% im) 65*5808f684SSatish Balay 66*5808f684SSatish Balay def diagonalScale(self, mat, vl, vr): 67*5808f684SSatish Balay if vl: self.D.pointwiseMult(self.D, vl) 68*5808f684SSatish Balay if vr: self.D.pointwiseMult(self.D, vr) 69*5808f684SSatish Balay 70*5808f684SSatish Balay# -------------------------------------------------------------------- 71*5808f684SSatish Balay 72*5808f684SSatish Balayclass TestMatrix(unittest.TestCase): 73*5808f684SSatish Balay 74*5808f684SSatish Balay COMM = PETSc.COMM_WORLD 75*5808f684SSatish Balay PYMOD = __name__ 76*5808f684SSatish Balay PYCLS = 'Matrix' 77*5808f684SSatish Balay 78*5808f684SSatish Balay def _getCtx(self): 79*5808f684SSatish Balay return self.A.getPythonContext() 80*5808f684SSatish Balay 81*5808f684SSatish Balay def setUp(self): 82*5808f684SSatish Balay N = self.N = 10 83*5808f684SSatish Balay self.A = PETSc.Mat() 84*5808f684SSatish Balay if 0: # command line way 85*5808f684SSatish Balay self.A.create(self.COMM) 86*5808f684SSatish Balay self.A.setSizes([N,N]) 87*5808f684SSatish Balay self.A.setType('python') 88*5808f684SSatish Balay OptDB = PETSc.Options(self.A) 89*5808f684SSatish Balay OptDB['mat_python_type'] = '%s.%s' % (self.PYMOD,self.PYCLS) 90*5808f684SSatish Balay self.A.setFromOptions() 91*5808f684SSatish Balay self.A.setUp() 92*5808f684SSatish Balay del OptDB['mat_python_type'] 93*5808f684SSatish Balay self.assertTrue(self._getCtx() is not None) 94*5808f684SSatish Balay else: # python way 95*5808f684SSatish Balay context = globals()[self.PYCLS]() 96*5808f684SSatish Balay self.A.createPython([N,N], context, comm=self.COMM) 97*5808f684SSatish Balay self.A.setUp() 98*5808f684SSatish Balay self.assertTrue(self._getCtx() is context) 99*5808f684SSatish Balay self.assertEqual(getrefcount(context), 3) 100*5808f684SSatish Balay del context 101*5808f684SSatish Balay self.assertEqual(getrefcount(self._getCtx()), 2) 102*5808f684SSatish Balay 103*5808f684SSatish Balay def tearDown(self): 104*5808f684SSatish Balay ctx = self.A.getPythonContext() 105*5808f684SSatish Balay self.assertEqual(getrefcount(ctx), 3) 106*5808f684SSatish Balay self.A.destroy() # XXX 107*5808f684SSatish Balay self.A = None 108*5808f684SSatish Balay self.assertEqual(getrefcount(ctx), 2) 109*5808f684SSatish Balay #import gc,pprint; pprint.pprint(gc.get_referrers(ctx)) 110*5808f684SSatish Balay 111*5808f684SSatish Balay def testBasic(self): 112*5808f684SSatish Balay ctx = self.A.getPythonContext() 113*5808f684SSatish Balay self.assertTrue(self._getCtx() is ctx) 114*5808f684SSatish Balay self.assertEqual(getrefcount(ctx), 3) 115*5808f684SSatish Balay 116*5808f684SSatish Balay def testZeroEntries(self): 117*5808f684SSatish Balay f = lambda : self.A.zeroEntries() 118*5808f684SSatish Balay self.assertRaises(Exception, f) 119*5808f684SSatish Balay 120*5808f684SSatish Balay def testMult(self): 121*5808f684SSatish Balay x, y = self.A.createVecs() 122*5808f684SSatish Balay f = lambda : self.A.mult(x, y) 123*5808f684SSatish Balay self.assertRaises(Exception, f) 124*5808f684SSatish Balay 125*5808f684SSatish Balay def testMultTranspose(self): 126*5808f684SSatish Balay x, y = self.A.createVecs() 127*5808f684SSatish Balay f = lambda : self.A.multTranspose(x, y) 128*5808f684SSatish Balay self.assertRaises(Exception, f) 129*5808f684SSatish Balay 130*5808f684SSatish Balay def testGetDiagonal(self): 131*5808f684SSatish Balay d = self.A.createVecLeft() 132*5808f684SSatish Balay f = lambda : self.A.getDiagonal(d) 133*5808f684SSatish Balay self.assertRaises(Exception, f) 134*5808f684SSatish Balay 135*5808f684SSatish Balay def testSetDiagonal(self): 136*5808f684SSatish Balay d = self.A.createVecLeft() 137*5808f684SSatish Balay f = lambda : self.A.setDiagonal(d) 138*5808f684SSatish Balay self.assertRaises(Exception, f) 139*5808f684SSatish Balay 140*5808f684SSatish Balay def testDiagonalScale(self): 141*5808f684SSatish Balay x, y = self.A.createVecs() 142*5808f684SSatish Balay f = lambda : self.A.diagonalScale(x, y) 143*5808f684SSatish Balay self.assertRaises(Exception, f) 144*5808f684SSatish Balay 145*5808f684SSatish Balayclass TestIdentity(TestMatrix): 146*5808f684SSatish Balay 147*5808f684SSatish Balay PYCLS = 'Identity' 148*5808f684SSatish Balay 149*5808f684SSatish Balay def testMult(self): 150*5808f684SSatish Balay x, y = self.A.createVecs() 151*5808f684SSatish Balay x.setRandom() 152*5808f684SSatish Balay self.A.mult(x,y) 153*5808f684SSatish Balay self.assertTrue(y.equal(x)) 154*5808f684SSatish Balay 155*5808f684SSatish Balay def testMultTransposeSymmKnown(self): 156*5808f684SSatish Balay x, y = self.A.createVecs() 157*5808f684SSatish Balay x.setRandom() 158*5808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True) 159*5808f684SSatish Balay self.A.multTranspose(x,y) 160*5808f684SSatish Balay self.assertTrue(y.equal(x)) 161*5808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False) 162*5808f684SSatish Balay f = lambda : self.A.multTranspose(x, y) 163*5808f684SSatish Balay self.assertRaises(Exception, f) 164*5808f684SSatish Balay 165*5808f684SSatish Balay def testMultTransposeNewMeth(self): 166*5808f684SSatish Balay x, y = self.A.createVecs() 167*5808f684SSatish Balay x.setRandom() 168*5808f684SSatish Balay AA = self.A.getPythonContext() 169*5808f684SSatish Balay AA.multTranspose = AA.mult 170*5808f684SSatish Balay self.A.multTranspose(x,y) 171*5808f684SSatish Balay del AA.multTranspose 172*5808f684SSatish Balay self.assertTrue(y.equal(x)) 173*5808f684SSatish Balay 174*5808f684SSatish Balay def testGetDiagonal(self): 175*5808f684SSatish Balay d = self.A.createVecLeft() 176*5808f684SSatish Balay o = d.duplicate() 177*5808f684SSatish Balay o.set(1) 178*5808f684SSatish Balay self.A.getDiagonal(d) 179*5808f684SSatish Balay self.assertTrue(o.equal(d)) 180*5808f684SSatish Balay 181*5808f684SSatish Balay 182*5808f684SSatish Balayclass TestDiagonal(TestMatrix): 183*5808f684SSatish Balay 184*5808f684SSatish Balay PYCLS = 'Diagonal' 185*5808f684SSatish Balay 186*5808f684SSatish Balay def setUp(self): 187*5808f684SSatish Balay super(TestDiagonal, self).setUp() 188*5808f684SSatish Balay D = self.A.createVecLeft() 189*5808f684SSatish Balay s, e = D.getOwnershipRange() 190*5808f684SSatish Balay for i in range(s, e): 191*5808f684SSatish Balay D[i] = i+1 192*5808f684SSatish Balay D.assemble() 193*5808f684SSatish Balay self.A.setDiagonal(D) 194*5808f684SSatish Balay 195*5808f684SSatish Balay 196*5808f684SSatish Balay def testZeroEntries(self): 197*5808f684SSatish Balay self.A.zeroEntries() 198*5808f684SSatish Balay D = self._getCtx().D 199*5808f684SSatish Balay self.assertEqual(D.norm(), 0) 200*5808f684SSatish Balay 201*5808f684SSatish Balay def testMult(self): 202*5808f684SSatish Balay x, y = self.A.createVecs() 203*5808f684SSatish Balay x.set(1) 204*5808f684SSatish Balay self.A.mult(x,y) 205*5808f684SSatish Balay self.assertTrue(y.equal(self._getCtx().D)) 206*5808f684SSatish Balay 207*5808f684SSatish Balay def testMultTransposeSymmKnown(self): 208*5808f684SSatish Balay x, y = self.A.createVecs() 209*5808f684SSatish Balay x.set(1) 210*5808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True) 211*5808f684SSatish Balay self.A.multTranspose(x,y) 212*5808f684SSatish Balay self.assertTrue(y.equal(self._getCtx().D)) 213*5808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False) 214*5808f684SSatish Balay f = lambda : self.A.multTranspose(x, y) 215*5808f684SSatish Balay self.assertRaises(Exception, f) 216*5808f684SSatish Balay 217*5808f684SSatish Balay def testMultTransposeNewMeth(self): 218*5808f684SSatish Balay x, y = self.A.createVecs() 219*5808f684SSatish Balay x.set(1) 220*5808f684SSatish Balay AA = self.A.getPythonContext() 221*5808f684SSatish Balay AA.multTranspose = AA.mult 222*5808f684SSatish Balay self.A.multTranspose(x,y) 223*5808f684SSatish Balay del AA.multTranspose 224*5808f684SSatish Balay self.assertTrue(y.equal(self._getCtx().D)) 225*5808f684SSatish Balay 226*5808f684SSatish Balay def testGetDiagonal(self): 227*5808f684SSatish Balay d = self.A.createVecLeft() 228*5808f684SSatish Balay self.A.getDiagonal(d) 229*5808f684SSatish Balay self.assertTrue(d.equal(self._getCtx().D)) 230*5808f684SSatish Balay 231*5808f684SSatish Balay def testSetDiagonal(self): 232*5808f684SSatish Balay d = self.A.createVecLeft() 233*5808f684SSatish Balay d.setRandom() 234*5808f684SSatish Balay self.A.setDiagonal(d) 235*5808f684SSatish Balay self.assertTrue(d.equal(self._getCtx().D)) 236*5808f684SSatish Balay 237*5808f684SSatish Balay def testDiagonalScale(self): 238*5808f684SSatish Balay x, y = self.A.createVecs() 239*5808f684SSatish Balay x.set(2) 240*5808f684SSatish Balay y.set(3) 241*5808f684SSatish Balay old = self._getCtx().D.copy() 242*5808f684SSatish Balay self.A.diagonalScale(x, y) 243*5808f684SSatish Balay D = self._getCtx().D 244*5808f684SSatish Balay self.assertTrue(D.equal(old*6)) 245*5808f684SSatish Balay 246*5808f684SSatish Balay def testCreateTranspose(self): 247*5808f684SSatish Balay A = self.A 248*5808f684SSatish Balay A.setOption(PETSc.Mat.Option.SYMMETRIC, True) 249*5808f684SSatish Balay AT = PETSc.Mat().createTranspose(A) 250*5808f684SSatish Balay x, y = A.createVecs() 251*5808f684SSatish Balay xt, yt = AT.createVecs() 252*5808f684SSatish Balay # 253*5808f684SSatish Balay y.setRandom() 254*5808f684SSatish Balay A.multTranspose(y, x) 255*5808f684SSatish Balay y.copy(xt) 256*5808f684SSatish Balay AT.mult(xt, yt) 257*5808f684SSatish Balay self.assertTrue(yt.equal(x)) 258*5808f684SSatish Balay # 259*5808f684SSatish Balay x.setRandom() 260*5808f684SSatish Balay A.mult(x, y) 261*5808f684SSatish Balay x.copy(yt) 262*5808f684SSatish Balay AT.multTranspose(yt, xt) 263*5808f684SSatish Balay self.assertTrue(xt.equal(y)) 264*5808f684SSatish Balay del A 265*5808f684SSatish Balay 266*5808f684SSatish Balay# -------------------------------------------------------------------- 267*5808f684SSatish Balay 268*5808f684SSatish Balayif __name__ == '__main__': 269*5808f684SSatish Balay unittest.main() 270