15808f684SSatish Balayfrom petsc4py import PETSc 253022affSStefano Zampiniimport unittest, numpy 35808f684SSatish Balayfrom sys import getrefcount 45808f684SSatish Balay# -------------------------------------------------------------------- 55808f684SSatish Balay 65808f684SSatish Balayclass Matrix(object): 75808f684SSatish Balay 85808f684SSatish Balay def __init__(self): 95808f684SSatish Balay pass 105808f684SSatish Balay 115808f684SSatish Balay def create(self, mat): 125808f684SSatish Balay pass 135808f684SSatish Balay 145808f684SSatish Balay def destroy(self, mat): 155808f684SSatish Balay pass 165808f684SSatish Balay 175808f684SSatish Balayclass Identity(Matrix): 185808f684SSatish Balay 195808f684SSatish Balay def mult(self, mat, x, y): 205808f684SSatish Balay x.copy(y) 215808f684SSatish Balay 225808f684SSatish Balay def getDiagonal(self, mat, vd): 235808f684SSatish Balay vd.set(1) 245808f684SSatish Balay 25*ee6c7c31SStefano Zampini def productSetFromOptions(self, mat, producttype, A, B, C): 26*ee6c7c31SStefano Zampini return True 27*ee6c7c31SStefano Zampini 28*ee6c7c31SStefano Zampini def productSymbolic(self, mat, product, producttype, A, B, C): 29*ee6c7c31SStefano Zampini if producttype == 'AB': 30*ee6c7c31SStefano Zampini if mat is A: # product = identity * B 31*ee6c7c31SStefano Zampini product.setType(B.getType()) 32*ee6c7c31SStefano Zampini product.setSizes(B.getSizes()) 33*ee6c7c31SStefano Zampini product.setUp() 34*ee6c7c31SStefano Zampini product.assemble() 35*ee6c7c31SStefano Zampini B.copy(product) 36*ee6c7c31SStefano Zampini elif mat is B: # product = A * identity 37*ee6c7c31SStefano Zampini product.setType(A.getType()) 38*ee6c7c31SStefano Zampini product.setSizes(A.getSizes()) 39*ee6c7c31SStefano Zampini product.setUp() 40*ee6c7c31SStefano Zampini product.assemble() 41*ee6c7c31SStefano Zampini A.copy(product) 42*ee6c7c31SStefano Zampini else: 43*ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 44*ee6c7c31SStefano Zampini elif producttype == 'AtB': 45*ee6c7c31SStefano Zampini if mat is A: # product = identity^T * B 46*ee6c7c31SStefano Zampini product.setType(B.getType()) 47*ee6c7c31SStefano Zampini product.setSizes(B.getSizes()) 48*ee6c7c31SStefano Zampini product.setUp() 49*ee6c7c31SStefano Zampini product.assemble() 50*ee6c7c31SStefano Zampini B.copy(product) 51*ee6c7c31SStefano Zampini elif mat is B: # product = A^T * identity 52*ee6c7c31SStefano Zampini tmp = PETSc.Mat() 53*ee6c7c31SStefano Zampini A.transpose(tmp) 54*ee6c7c31SStefano Zampini product.setType(tmp.getType()) 55*ee6c7c31SStefano Zampini product.setSizes(tmp.getSizes()) 56*ee6c7c31SStefano Zampini product.setUp() 57*ee6c7c31SStefano Zampini product.assemble() 58*ee6c7c31SStefano Zampini tmp.copy(product) 59*ee6c7c31SStefano Zampini else: 60*ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 61*ee6c7c31SStefano Zampini elif producttype == 'ABt': 62*ee6c7c31SStefano Zampini if mat is A: # product = identity * B^T 63*ee6c7c31SStefano Zampini tmp = PETSc.Mat() 64*ee6c7c31SStefano Zampini B.transpose(tmp) 65*ee6c7c31SStefano Zampini product.setType(tmp.getType()) 66*ee6c7c31SStefano Zampini product.setSizes(tmp.getSizes()) 67*ee6c7c31SStefano Zampini product.setUp() 68*ee6c7c31SStefano Zampini product.assemble() 69*ee6c7c31SStefano Zampini tmp.copy(product) 70*ee6c7c31SStefano Zampini elif mat is B: # product = A * identity^T 71*ee6c7c31SStefano Zampini product.setType(A.getType()) 72*ee6c7c31SStefano Zampini product.setSizes(A.getSizes()) 73*ee6c7c31SStefano Zampini product.setUp() 74*ee6c7c31SStefano Zampini product.assemble() 75*ee6c7c31SStefano Zampini A.copy(product) 76*ee6c7c31SStefano Zampini else: 77*ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 78*ee6c7c31SStefano Zampini elif producttype == 'PtAP': 79*ee6c7c31SStefano Zampini if mat is A: # product = P^T * identity * P 80*ee6c7c31SStefano Zampini self.tmp = PETSc.Mat() 81*ee6c7c31SStefano Zampini B.transposeMatMult(B, self.tmp) 82*ee6c7c31SStefano Zampini product.setType(self.tmp.getType()) 83*ee6c7c31SStefano Zampini product.setSizes(self.tmp.getSizes()) 84*ee6c7c31SStefano Zampini product.setUp() 85*ee6c7c31SStefano Zampini product.assemble() 86*ee6c7c31SStefano Zampini self.tmp.copy(product) 87*ee6c7c31SStefano Zampini elif mat is B: # product = identity^T * A * identity 88*ee6c7c31SStefano Zampini product.setType(A.getType()) 89*ee6c7c31SStefano Zampini product.setSizes(A.getSizes()) 90*ee6c7c31SStefano Zampini product.setUp() 91*ee6c7c31SStefano Zampini product.assemble() 92*ee6c7c31SStefano Zampini A.copy(product) 93*ee6c7c31SStefano Zampini else: 94*ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 95*ee6c7c31SStefano Zampini elif producttype == 'RARt': 96*ee6c7c31SStefano Zampini if mat is A: # product = R * identity * R^t 97*ee6c7c31SStefano Zampini self.tmp = PETSc.Mat() 98*ee6c7c31SStefano Zampini B.matTransposeMult(B, self.tmp) 99*ee6c7c31SStefano Zampini product.setType(self.tmp.getType()) 100*ee6c7c31SStefano Zampini product.setSizes(self.tmp.getSizes()) 101*ee6c7c31SStefano Zampini product.setUp() 102*ee6c7c31SStefano Zampini product.assemble() 103*ee6c7c31SStefano Zampini self.tmp.copy(product) 104*ee6c7c31SStefano Zampini elif mat is B: # product = identity * A * identity^T 105*ee6c7c31SStefano Zampini product.setType(A.getType()) 106*ee6c7c31SStefano Zampini product.setSizes(A.getSizes()) 107*ee6c7c31SStefano Zampini product.setUp() 108*ee6c7c31SStefano Zampini product.assemble() 109*ee6c7c31SStefano Zampini A.copy(product) 110*ee6c7c31SStefano Zampini else: 111*ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 112*ee6c7c31SStefano Zampini elif producttype == 'ABC': 113*ee6c7c31SStefano Zampini if mat is A: # product = identity * B * C 114*ee6c7c31SStefano Zampini self.tmp = PETSc.Mat() 115*ee6c7c31SStefano Zampini B.matMult(C, self.tmp) 116*ee6c7c31SStefano Zampini product.setType(self.tmp.getType()) 117*ee6c7c31SStefano Zampini product.setSizes(self.tmp.getSizes()) 118*ee6c7c31SStefano Zampini product.setUp() 119*ee6c7c31SStefano Zampini product.assemble() 120*ee6c7c31SStefano Zampini self.tmp.copy(product) 121*ee6c7c31SStefano Zampini elif mat is B: # product = A * identity * C 122*ee6c7c31SStefano Zampini self.tmp = PETSc.Mat() 123*ee6c7c31SStefano Zampini A.matMult(C, self.tmp) 124*ee6c7c31SStefano Zampini product.setType(self.tmp.getType()) 125*ee6c7c31SStefano Zampini product.setSizes(self.tmp.getSizes()) 126*ee6c7c31SStefano Zampini product.setUp() 127*ee6c7c31SStefano Zampini product.assemble() 128*ee6c7c31SStefano Zampini self.tmp.copy(product) 129*ee6c7c31SStefano Zampini elif mat is C: # product = A * B * identity 130*ee6c7c31SStefano Zampini self.tmp = PETSc.Mat() 131*ee6c7c31SStefano Zampini A.matMult(B, self.tmp) 132*ee6c7c31SStefano Zampini product.setType(self.tmp.getType()) 133*ee6c7c31SStefano Zampini product.setSizes(self.tmp.getSizes()) 134*ee6c7c31SStefano Zampini product.setUp() 135*ee6c7c31SStefano Zampini product.assemble() 136*ee6c7c31SStefano Zampini self.tmp.copy(product) 137*ee6c7c31SStefano Zampini else: 138*ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 139*ee6c7c31SStefano Zampini else: 140*ee6c7c31SStefano Zampini raise RuntimeError('Product {} not implemented'.format(producttype)) 141*ee6c7c31SStefano Zampini product.zeroEntries() 142*ee6c7c31SStefano Zampini 143*ee6c7c31SStefano Zampini def productNumeric(self, mat, product, producttype, A, B, C): 144*ee6c7c31SStefano Zampini if producttype == 'AB': 145*ee6c7c31SStefano Zampini if mat is A: # product = identity * B 146*ee6c7c31SStefano Zampini B.copy(product, structure=True) 147*ee6c7c31SStefano Zampini elif mat is B: # product = A * identity 148*ee6c7c31SStefano Zampini A.copy(product, structure=True) 149*ee6c7c31SStefano Zampini else: 150*ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 151*ee6c7c31SStefano Zampini elif producttype == 'AtB': 152*ee6c7c31SStefano Zampini if mat is A: # product = identity^T * B 153*ee6c7c31SStefano Zampini B.copy(product, structure=True) 154*ee6c7c31SStefano Zampini elif mat is B: # product = A^T * identity 155*ee6c7c31SStefano Zampini A.transpose(product) 156*ee6c7c31SStefano Zampini else: 157*ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 158*ee6c7c31SStefano Zampini elif producttype == 'ABt': 159*ee6c7c31SStefano Zampini if mat is A: # product = identity * B^T 160*ee6c7c31SStefano Zampini B.transpose(product) 161*ee6c7c31SStefano Zampini elif mat is B: # product = A * identity^T 162*ee6c7c31SStefano Zampini A.copy(product, structure=True) 163*ee6c7c31SStefano Zampini else: 164*ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 165*ee6c7c31SStefano Zampini elif producttype == 'PtAP': 166*ee6c7c31SStefano Zampini if mat is A: # product = P^T * identity * P 167*ee6c7c31SStefano Zampini B.transposeMatMult(B, self.tmp) 168*ee6c7c31SStefano Zampini self.tmp.copy(product, structure=True) 169*ee6c7c31SStefano Zampini elif mat is B: # product = identity^T * A * identity 170*ee6c7c31SStefano Zampini A.copy(product, structure=True) 171*ee6c7c31SStefano Zampini else: 172*ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 173*ee6c7c31SStefano Zampini elif producttype == 'RARt': 174*ee6c7c31SStefano Zampini if mat is A: # product = R * identity * R^t 175*ee6c7c31SStefano Zampini B.matTransposeMult(B, self.tmp) 176*ee6c7c31SStefano Zampini self.tmp.copy(product, structure=True) 177*ee6c7c31SStefano Zampini elif mat is B: # product = identity * A * identity^T 178*ee6c7c31SStefano Zampini A.copy(product, structure=True) 179*ee6c7c31SStefano Zampini else: 180*ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 181*ee6c7c31SStefano Zampini elif producttype == 'ABC': 182*ee6c7c31SStefano Zampini if mat is A: # product = identity * B * C 183*ee6c7c31SStefano Zampini B.matMult(C, self.tmp) 184*ee6c7c31SStefano Zampini self.tmp.copy(product, structure=True) 185*ee6c7c31SStefano Zampini elif mat is B: # product = A * identity * C 186*ee6c7c31SStefano Zampini A.matMult(C, self.tmp) 187*ee6c7c31SStefano Zampini self.tmp.copy(product, structure=True) 188*ee6c7c31SStefano Zampini elif mat is C: # product = A * B * identity 189*ee6c7c31SStefano Zampini A.matMult(B, self.tmp) 190*ee6c7c31SStefano Zampini self.tmp.copy(product, structure=True) 191*ee6c7c31SStefano Zampini else: 192*ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 193*ee6c7c31SStefano Zampini else: 194*ee6c7c31SStefano Zampini raise RuntimeError('Product {} not implemented'.format(producttype)) 195*ee6c7c31SStefano Zampini 1965808f684SSatish Balayclass Diagonal(Matrix): 1975808f684SSatish Balay 1985808f684SSatish Balay def create(self, mat): 1995808f684SSatish Balay super(Diagonal,self).create(mat) 2005808f684SSatish Balay mat.setUp() 2015808f684SSatish Balay self.D = mat.createVecLeft() 2025808f684SSatish Balay 2035808f684SSatish Balay def destroy(self, mat): 2045808f684SSatish Balay self.D.destroy() 2055808f684SSatish Balay super(Diagonal,self).destroy(mat) 2065808f684SSatish Balay 2075808f684SSatish Balay def scale(self, mat, a): 2085808f684SSatish Balay self.D.scale(a) 2095808f684SSatish Balay 2105808f684SSatish Balay def shift(self, mat, a): 2115808f684SSatish Balay self.D.shift(a) 2125808f684SSatish Balay 2135808f684SSatish Balay def zeroEntries(self, mat): 2145808f684SSatish Balay self.D.zeroEntries() 2155808f684SSatish Balay 2165808f684SSatish Balay def mult(self, mat, x, y): 2175808f684SSatish Balay y.pointwiseMult(x, self.D) 2185808f684SSatish Balay 2195808f684SSatish Balay def getDiagonal(self, mat, vd): 2205808f684SSatish Balay self.D.copy(vd) 2215808f684SSatish Balay 2225808f684SSatish Balay def setDiagonal(self, mat, vd, im): 2235808f684SSatish Balay if isinstance (im, bool): 2245808f684SSatish Balay addv = im 2255808f684SSatish Balay if addv: 2265808f684SSatish Balay self.D.axpy(1, vd) 2275808f684SSatish Balay else: 2285808f684SSatish Balay vd.copy(self.D) 2295808f684SSatish Balay elif im == PETSc.InsertMode.INSERT_VALUES: 2305808f684SSatish Balay vd.copy(self.D) 2315808f684SSatish Balay elif im == PETSc.InsertMode.ADD_VALUES: 2325808f684SSatish Balay self.D.axpy(1, vd) 2335808f684SSatish Balay else: 2345808f684SSatish Balay raise ValueError('wrong InsertMode %d'% im) 2355808f684SSatish Balay 2365808f684SSatish Balay def diagonalScale(self, mat, vl, vr): 2375808f684SSatish Balay if vl: self.D.pointwiseMult(self.D, vl) 2385808f684SSatish Balay if vr: self.D.pointwiseMult(self.D, vr) 2395808f684SSatish Balay 2405808f684SSatish Balay# -------------------------------------------------------------------- 2415808f684SSatish Balay 2425808f684SSatish Balayclass TestMatrix(unittest.TestCase): 2435808f684SSatish Balay 2445808f684SSatish Balay COMM = PETSc.COMM_WORLD 2455808f684SSatish Balay PYMOD = __name__ 2465808f684SSatish Balay PYCLS = 'Matrix' 2475808f684SSatish Balay 2485808f684SSatish Balay def _getCtx(self): 2495808f684SSatish Balay return self.A.getPythonContext() 2505808f684SSatish Balay 2515808f684SSatish Balay def setUp(self): 2525808f684SSatish Balay N = self.N = 10 2535808f684SSatish Balay self.A = PETSc.Mat() 2545808f684SSatish Balay if 0: # command line way 2555808f684SSatish Balay self.A.create(self.COMM) 2565808f684SSatish Balay self.A.setSizes([N,N]) 2575808f684SSatish Balay self.A.setType('python') 2585808f684SSatish Balay OptDB = PETSc.Options(self.A) 2595808f684SSatish Balay OptDB['mat_python_type'] = '%s.%s' % (self.PYMOD,self.PYCLS) 2605808f684SSatish Balay self.A.setFromOptions() 2615808f684SSatish Balay self.A.setUp() 2625808f684SSatish Balay del OptDB['mat_python_type'] 2635808f684SSatish Balay self.assertTrue(self._getCtx() is not None) 2645808f684SSatish Balay else: # python way 2655808f684SSatish Balay context = globals()[self.PYCLS]() 2665808f684SSatish Balay self.A.createPython([N,N], context, comm=self.COMM) 2675808f684SSatish Balay self.A.setUp() 2685808f684SSatish Balay self.assertTrue(self._getCtx() is context) 2695808f684SSatish Balay self.assertEqual(getrefcount(context), 3) 2705808f684SSatish Balay del context 2715808f684SSatish Balay self.assertEqual(getrefcount(self._getCtx()), 2) 2725808f684SSatish Balay 2735808f684SSatish Balay def tearDown(self): 2745808f684SSatish Balay ctx = self.A.getPythonContext() 2755808f684SSatish Balay self.assertEqual(getrefcount(ctx), 3) 2765808f684SSatish Balay self.A.destroy() # XXX 2775808f684SSatish Balay self.A = None 2785808f684SSatish Balay self.assertEqual(getrefcount(ctx), 2) 2795808f684SSatish Balay #import gc,pprint; pprint.pprint(gc.get_referrers(ctx)) 2805808f684SSatish Balay 2815808f684SSatish Balay def testBasic(self): 2825808f684SSatish Balay ctx = self.A.getPythonContext() 2835808f684SSatish Balay self.assertTrue(self._getCtx() is ctx) 2845808f684SSatish Balay self.assertEqual(getrefcount(ctx), 3) 2855808f684SSatish Balay 2865808f684SSatish Balay def testZeroEntries(self): 2875808f684SSatish Balay f = lambda : self.A.zeroEntries() 2885808f684SSatish Balay self.assertRaises(Exception, f) 2895808f684SSatish Balay 2905808f684SSatish Balay def testMult(self): 2915808f684SSatish Balay x, y = self.A.createVecs() 2925808f684SSatish Balay f = lambda : self.A.mult(x, y) 2935808f684SSatish Balay self.assertRaises(Exception, f) 2945808f684SSatish Balay 2955808f684SSatish Balay def testMultTranspose(self): 2965808f684SSatish Balay x, y = self.A.createVecs() 2975808f684SSatish Balay f = lambda : self.A.multTranspose(x, y) 2985808f684SSatish Balay self.assertRaises(Exception, f) 2995808f684SSatish Balay 3005808f684SSatish Balay def testGetDiagonal(self): 3015808f684SSatish Balay d = self.A.createVecLeft() 3025808f684SSatish Balay f = lambda : self.A.getDiagonal(d) 3035808f684SSatish Balay self.assertRaises(Exception, f) 3045808f684SSatish Balay 3055808f684SSatish Balay def testSetDiagonal(self): 3065808f684SSatish Balay d = self.A.createVecLeft() 3075808f684SSatish Balay f = lambda : self.A.setDiagonal(d) 3085808f684SSatish Balay self.assertRaises(Exception, f) 3095808f684SSatish Balay 3105808f684SSatish Balay def testDiagonalScale(self): 3115808f684SSatish Balay x, y = self.A.createVecs() 3125808f684SSatish Balay f = lambda : self.A.diagonalScale(x, y) 3135808f684SSatish Balay self.assertRaises(Exception, f) 3145808f684SSatish Balay 3158c2316a8SJeremy Tillay 3165808f684SSatish Balayclass TestIdentity(TestMatrix): 3175808f684SSatish Balay 3185808f684SSatish Balay PYCLS = 'Identity' 3195808f684SSatish Balay 3205808f684SSatish Balay def testMult(self): 3215808f684SSatish Balay x, y = self.A.createVecs() 3225808f684SSatish Balay x.setRandom() 3235808f684SSatish Balay self.A.mult(x,y) 3245808f684SSatish Balay self.assertTrue(y.equal(x)) 3255808f684SSatish Balay 3265808f684SSatish Balay def testMultTransposeSymmKnown(self): 3275808f684SSatish Balay x, y = self.A.createVecs() 3285808f684SSatish Balay x.setRandom() 3295808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True) 3305808f684SSatish Balay self.A.multTranspose(x,y) 3315808f684SSatish Balay self.assertTrue(y.equal(x)) 3325808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False) 3335808f684SSatish Balay f = lambda : self.A.multTranspose(x, y) 3345808f684SSatish Balay self.assertRaises(Exception, f) 3355808f684SSatish Balay 3365808f684SSatish Balay def testMultTransposeNewMeth(self): 3375808f684SSatish Balay x, y = self.A.createVecs() 3385808f684SSatish Balay x.setRandom() 3395808f684SSatish Balay AA = self.A.getPythonContext() 3405808f684SSatish Balay AA.multTranspose = AA.mult 3415808f684SSatish Balay self.A.multTranspose(x,y) 3425808f684SSatish Balay del AA.multTranspose 3435808f684SSatish Balay self.assertTrue(y.equal(x)) 3445808f684SSatish Balay 3455808f684SSatish Balay def testGetDiagonal(self): 3465808f684SSatish Balay d = self.A.createVecLeft() 3475808f684SSatish Balay o = d.duplicate() 3485808f684SSatish Balay o.set(1) 3495808f684SSatish Balay self.A.getDiagonal(d) 3505808f684SSatish Balay self.assertTrue(o.equal(d)) 3515808f684SSatish Balay 352*ee6c7c31SStefano Zampini def testMatMat(self): 353*ee6c7c31SStefano Zampini R = PETSc.Random().create(self.COMM) 354*ee6c7c31SStefano Zampini R.setFromOptions() 355*ee6c7c31SStefano Zampini A = PETSc.Mat().create(self.COMM) 356*ee6c7c31SStefano Zampini A.setSizes(self.A.getSizes()) 357*ee6c7c31SStefano Zampini A.setType(PETSc.Mat.Type.AIJ) 358*ee6c7c31SStefano Zampini A.setUp() 359*ee6c7c31SStefano Zampini A.setRandom(R) 360*ee6c7c31SStefano Zampini B = PETSc.Mat().create(self.COMM) 361*ee6c7c31SStefano Zampini B.setSizes(self.A.getSizes()) 362*ee6c7c31SStefano Zampini B.setType(PETSc.Mat.Type.AIJ) 363*ee6c7c31SStefano Zampini B.setUp() 364*ee6c7c31SStefano Zampini B.setRandom(R) 365*ee6c7c31SStefano Zampini I = PETSc.Mat().create(self.COMM) 366*ee6c7c31SStefano Zampini I.setSizes(self.A.getSizes()) 367*ee6c7c31SStefano Zampini I.setType(PETSc.Mat.Type.AIJ) 368*ee6c7c31SStefano Zampini I.setUp() 369*ee6c7c31SStefano Zampini I.assemble() 370*ee6c7c31SStefano Zampini I.shift(1.) 371*ee6c7c31SStefano Zampini 372*ee6c7c31SStefano Zampini self.assertTrue(self.A.matMult(A).equal(I.matMult(A))) 373*ee6c7c31SStefano Zampini self.assertTrue(A.matMult(self.A).equal(A.matMult(I))) 374*ee6c7c31SStefano Zampini if self.A.getComm().Get_size() == 1: 375*ee6c7c31SStefano Zampini self.assertTrue(self.A.matTransposeMult(A).equal(I.matTransposeMult(A))) 376*ee6c7c31SStefano Zampini self.assertTrue(A.matTransposeMult(self.A).equal(A.matTransposeMult(I))) 377*ee6c7c31SStefano Zampini self.assertTrue(self.A.transposeMatMult(A).equal(I.transposeMatMult(A))) 378*ee6c7c31SStefano Zampini self.assertTrue(A.transposeMatMult(self.A).equal(A.transposeMatMult(I))) 379*ee6c7c31SStefano Zampini self.assertAlmostEqual((self.A.ptap(A) - I.ptap(A)).norm(), 0.0, places=5) 380*ee6c7c31SStefano Zampini self.assertAlmostEqual((A.ptap(self.A) - A.ptap(I)).norm(), 0.0, places=5) 381*ee6c7c31SStefano Zampini if self.A.getComm().Get_size() == 1: 382*ee6c7c31SStefano Zampini self.assertAlmostEqual((self.A.rart(A) - I.rart(A)).norm(), 0.0, places=5) 383*ee6c7c31SStefano Zampini self.assertAlmostEqual((A.rart(self.A) - A.rart(I)).norm(), 0.0, places=5) 384*ee6c7c31SStefano Zampini self.assertAlmostEqual((self.A.matMatMult(A,B)-I.matMatMult(A,B)).norm(), 0.0, places=5) 385*ee6c7c31SStefano Zampini self.assertAlmostEqual((A.matMatMult(self.A,B)-A.matMatMult(I,B)).norm(), 0.0, places=5) 386*ee6c7c31SStefano Zampini self.assertAlmostEqual((A.matMatMult(B,self.A)-A.matMatMult(B,I)).norm(), 0.0, places=5) 387*ee6c7c31SStefano Zampini 38853022affSStefano Zampini def testH2Opus(self): 38953022affSStefano Zampini if not PETSc.Sys.hasExternalPackage("h2opus"): 39053022affSStefano Zampini return 391*ee6c7c31SStefano Zampini if self.A.getComm().Get_size() > 1: 392*ee6c7c31SStefano Zampini return 39353022affSStefano Zampini h = PETSc.Mat() 39453022affSStefano Zampini 39553022affSStefano Zampini # need transpose operation for norm estimation 39653022affSStefano Zampini AA = self.A.getPythonContext() 39753022affSStefano Zampini AA.multTranspose = AA.mult 39853022affSStefano Zampini 39953022affSStefano Zampini # without coordinates 40053022affSStefano Zampini h.createH2OpusFromMat(self.A,leafsize=2) 40153022affSStefano Zampini h.assemble() 40253022affSStefano Zampini h.destroy() 40353022affSStefano Zampini 40453022affSStefano Zampini # with coordinates 405e0aaf7daSStefano Zampini coords = numpy.linspace((1,2,3),(10,20,30),self.A.getSize()[0],dtype=PETSc.RealType) 40653022affSStefano Zampini h.createH2OpusFromMat(self.A,coords,leafsize=2) 40753022affSStefano Zampini h.assemble() 40853022affSStefano Zampini h.destroy() 40953022affSStefano Zampini 41053022affSStefano Zampini del AA.multTranspose 4115808f684SSatish Balay 4125808f684SSatish Balayclass TestDiagonal(TestMatrix): 4135808f684SSatish Balay 4145808f684SSatish Balay PYCLS = 'Diagonal' 4155808f684SSatish Balay 4165808f684SSatish Balay def setUp(self): 4175808f684SSatish Balay super(TestDiagonal, self).setUp() 4185808f684SSatish Balay D = self.A.createVecLeft() 4195808f684SSatish Balay s, e = D.getOwnershipRange() 4205808f684SSatish Balay for i in range(s, e): 4215808f684SSatish Balay D[i] = i+1 4225808f684SSatish Balay D.assemble() 4235808f684SSatish Balay self.A.setDiagonal(D) 4245808f684SSatish Balay 4255808f684SSatish Balay 4265808f684SSatish Balay def testZeroEntries(self): 4275808f684SSatish Balay self.A.zeroEntries() 4285808f684SSatish Balay D = self._getCtx().D 4295808f684SSatish Balay self.assertEqual(D.norm(), 0) 4305808f684SSatish Balay 4315808f684SSatish Balay def testMult(self): 4325808f684SSatish Balay x, y = self.A.createVecs() 4335808f684SSatish Balay x.set(1) 4345808f684SSatish Balay self.A.mult(x,y) 4355808f684SSatish Balay self.assertTrue(y.equal(self._getCtx().D)) 4365808f684SSatish Balay 4375808f684SSatish Balay def testMultTransposeSymmKnown(self): 4385808f684SSatish Balay x, y = self.A.createVecs() 4395808f684SSatish Balay x.set(1) 4405808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True) 4415808f684SSatish Balay self.A.multTranspose(x,y) 4425808f684SSatish Balay self.assertTrue(y.equal(self._getCtx().D)) 4435808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False) 4445808f684SSatish Balay f = lambda : self.A.multTranspose(x, y) 4455808f684SSatish Balay self.assertRaises(Exception, f) 4465808f684SSatish Balay 4475808f684SSatish Balay def testMultTransposeNewMeth(self): 4485808f684SSatish Balay x, y = self.A.createVecs() 4495808f684SSatish Balay x.set(1) 4505808f684SSatish Balay AA = self.A.getPythonContext() 4515808f684SSatish Balay AA.multTranspose = AA.mult 4525808f684SSatish Balay self.A.multTranspose(x,y) 4535808f684SSatish Balay del AA.multTranspose 4545808f684SSatish Balay self.assertTrue(y.equal(self._getCtx().D)) 4555808f684SSatish Balay 4565808f684SSatish Balay def testGetDiagonal(self): 4575808f684SSatish Balay d = self.A.createVecLeft() 4585808f684SSatish Balay self.A.getDiagonal(d) 4595808f684SSatish Balay self.assertTrue(d.equal(self._getCtx().D)) 4605808f684SSatish Balay 4615808f684SSatish Balay def testSetDiagonal(self): 4625808f684SSatish Balay d = self.A.createVecLeft() 4635808f684SSatish Balay d.setRandom() 4645808f684SSatish Balay self.A.setDiagonal(d) 4655808f684SSatish Balay self.assertTrue(d.equal(self._getCtx().D)) 4665808f684SSatish Balay 4675808f684SSatish Balay def testDiagonalScale(self): 4685808f684SSatish Balay x, y = self.A.createVecs() 4695808f684SSatish Balay x.set(2) 4705808f684SSatish Balay y.set(3) 4715808f684SSatish Balay old = self._getCtx().D.copy() 4725808f684SSatish Balay self.A.diagonalScale(x, y) 4735808f684SSatish Balay D = self._getCtx().D 4745808f684SSatish Balay self.assertTrue(D.equal(old*6)) 4755808f684SSatish Balay 4765808f684SSatish Balay def testCreateTranspose(self): 4775808f684SSatish Balay A = self.A 4785808f684SSatish Balay A.setOption(PETSc.Mat.Option.SYMMETRIC, True) 4795808f684SSatish Balay AT = PETSc.Mat().createTranspose(A) 4805808f684SSatish Balay x, y = A.createVecs() 4815808f684SSatish Balay xt, yt = AT.createVecs() 4825808f684SSatish Balay # 4835808f684SSatish Balay y.setRandom() 4845808f684SSatish Balay A.multTranspose(y, x) 4855808f684SSatish Balay y.copy(xt) 4865808f684SSatish Balay AT.mult(xt, yt) 4875808f684SSatish Balay self.assertTrue(yt.equal(x)) 4885808f684SSatish Balay # 4895808f684SSatish Balay x.setRandom() 4905808f684SSatish Balay A.mult(x, y) 4915808f684SSatish Balay x.copy(yt) 4925808f684SSatish Balay AT.multTranspose(yt, xt) 4935808f684SSatish Balay self.assertTrue(xt.equal(y)) 4945808f684SSatish Balay del A 4955808f684SSatish Balay 4968c2316a8SJeremy Tillay 4975808f684SSatish Balay# -------------------------------------------------------------------- 4985808f684SSatish Balay 4995808f684SSatish Balayif __name__ == '__main__': 5005808f684SSatish Balay unittest.main() 501