15808f684SSatish Balayfrom petsc4py import PETSc 253022affSStefano Zampiniimport unittest, numpy 35808f684SSatish Balayfrom sys import getrefcount 45808f684SSatish Balay# -------------------------------------------------------------------- 55808f684SSatish Balay 65808f684SSatish Balayclass Matrix(object): 75808f684SSatish Balay 85808f684SSatish Balay def __init__(self): 95808f684SSatish Balay pass 105808f684SSatish Balay 115808f684SSatish Balay def create(self, mat): 125808f684SSatish Balay pass 135808f684SSatish Balay 145808f684SSatish Balay def destroy(self, mat): 155808f684SSatish Balay pass 165808f684SSatish Balay 1722fceea1SStefano Zampiniclass ScaledIdentity(Matrix): 1822fceea1SStefano Zampini 1922fceea1SStefano Zampini s = 2.0 2022fceea1SStefano Zampini 2122fceea1SStefano Zampini def scale(self, mat, s): 2222fceea1SStefano Zampini self.s *= s 2322fceea1SStefano Zampini 2422fceea1SStefano Zampini def shift(self, mat, s): 2522fceea1SStefano Zampini self.s += s 265808f684SSatish Balay 275808f684SSatish Balay def mult(self, mat, x, y): 285808f684SSatish Balay x.copy(y) 2922fceea1SStefano Zampini y.scale(self.s) 305808f684SSatish Balay 31e124b1b1SStefano Zampini def duplicate(self, mat, op): 32e124b1b1SStefano Zampini dmat = PETSc.Mat() 33e124b1b1SStefano Zampini dctx = ScaledIdentity() 34e124b1b1SStefano Zampini dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm()) 35e124b1b1SStefano Zampini if op == PETSc.Mat.DuplicateOption.COPY_VALUES: 36e124b1b1SStefano Zampini dctx.s = self.s 37e124b1b1SStefano Zampini dmat.setUp() 38e124b1b1SStefano Zampini return dmat 39e124b1b1SStefano Zampini 405808f684SSatish Balay def getDiagonal(self, mat, vd): 4122fceea1SStefano Zampini vd.set(self.s) 425808f684SSatish Balay 43ee6c7c31SStefano Zampini def productSetFromOptions(self, mat, producttype, A, B, C): 44ee6c7c31SStefano Zampini return True 45ee6c7c31SStefano Zampini 46ee6c7c31SStefano Zampini def productSymbolic(self, mat, product, producttype, A, B, C): 47ee6c7c31SStefano Zampini if producttype == 'AB': 48ee6c7c31SStefano Zampini if mat is A: # product = identity * B 49ee6c7c31SStefano Zampini product.setType(B.getType()) 50ee6c7c31SStefano Zampini product.setSizes(B.getSizes()) 51ee6c7c31SStefano Zampini product.setUp() 52ee6c7c31SStefano Zampini product.assemble() 53ee6c7c31SStefano Zampini B.copy(product) 54ee6c7c31SStefano Zampini elif mat is B: # product = A * identity 55ee6c7c31SStefano Zampini product.setType(A.getType()) 56ee6c7c31SStefano Zampini product.setSizes(A.getSizes()) 57ee6c7c31SStefano Zampini product.setUp() 58ee6c7c31SStefano Zampini product.assemble() 59ee6c7c31SStefano Zampini A.copy(product) 60ee6c7c31SStefano Zampini else: 61ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 62ee6c7c31SStefano Zampini elif producttype == 'AtB': 63ee6c7c31SStefano Zampini if mat is A: # product = identity^T * B 64ee6c7c31SStefano Zampini product.setType(B.getType()) 65ee6c7c31SStefano Zampini product.setSizes(B.getSizes()) 66ee6c7c31SStefano Zampini product.setUp() 67ee6c7c31SStefano Zampini product.assemble() 68ee6c7c31SStefano Zampini B.copy(product) 69ee6c7c31SStefano Zampini elif mat is B: # product = A^T * identity 70ee6c7c31SStefano Zampini tmp = PETSc.Mat() 71ee6c7c31SStefano Zampini A.transpose(tmp) 72ee6c7c31SStefano Zampini product.setType(tmp.getType()) 73ee6c7c31SStefano Zampini product.setSizes(tmp.getSizes()) 74ee6c7c31SStefano Zampini product.setUp() 75ee6c7c31SStefano Zampini product.assemble() 76ee6c7c31SStefano Zampini tmp.copy(product) 77ee6c7c31SStefano Zampini else: 78ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 79ee6c7c31SStefano Zampini elif producttype == 'ABt': 80ee6c7c31SStefano Zampini if mat is A: # product = identity * B^T 81ee6c7c31SStefano Zampini tmp = PETSc.Mat() 82ee6c7c31SStefano Zampini B.transpose(tmp) 83ee6c7c31SStefano Zampini product.setType(tmp.getType()) 84ee6c7c31SStefano Zampini product.setSizes(tmp.getSizes()) 85ee6c7c31SStefano Zampini product.setUp() 86ee6c7c31SStefano Zampini product.assemble() 87ee6c7c31SStefano Zampini tmp.copy(product) 88ee6c7c31SStefano Zampini elif mat is B: # product = A * identity^T 89ee6c7c31SStefano Zampini product.setType(A.getType()) 90ee6c7c31SStefano Zampini product.setSizes(A.getSizes()) 91ee6c7c31SStefano Zampini product.setUp() 92ee6c7c31SStefano Zampini product.assemble() 93ee6c7c31SStefano Zampini A.copy(product) 94ee6c7c31SStefano Zampini else: 95ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 96ee6c7c31SStefano Zampini elif producttype == 'PtAP': 97ee6c7c31SStefano Zampini if mat is A: # product = P^T * identity * P 98ee6c7c31SStefano Zampini self.tmp = PETSc.Mat() 99ee6c7c31SStefano Zampini B.transposeMatMult(B, self.tmp) 100ee6c7c31SStefano Zampini product.setType(self.tmp.getType()) 101ee6c7c31SStefano Zampini product.setSizes(self.tmp.getSizes()) 102ee6c7c31SStefano Zampini product.setUp() 103ee6c7c31SStefano Zampini product.assemble() 104ee6c7c31SStefano Zampini self.tmp.copy(product) 105ee6c7c31SStefano Zampini elif mat is B: # product = identity^T * A * identity 106ee6c7c31SStefano Zampini product.setType(A.getType()) 107ee6c7c31SStefano Zampini product.setSizes(A.getSizes()) 108ee6c7c31SStefano Zampini product.setUp() 109ee6c7c31SStefano Zampini product.assemble() 110ee6c7c31SStefano Zampini A.copy(product) 111ee6c7c31SStefano Zampini else: 112ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 113ee6c7c31SStefano Zampini elif producttype == 'RARt': 114ee6c7c31SStefano Zampini if mat is A: # product = R * identity * R^t 115ee6c7c31SStefano Zampini self.tmp = PETSc.Mat() 116ee6c7c31SStefano Zampini B.matTransposeMult(B, self.tmp) 117ee6c7c31SStefano Zampini product.setType(self.tmp.getType()) 118ee6c7c31SStefano Zampini product.setSizes(self.tmp.getSizes()) 119ee6c7c31SStefano Zampini product.setUp() 120ee6c7c31SStefano Zampini product.assemble() 121ee6c7c31SStefano Zampini self.tmp.copy(product) 122ee6c7c31SStefano Zampini elif mat is B: # product = identity * A * identity^T 123ee6c7c31SStefano Zampini product.setType(A.getType()) 124ee6c7c31SStefano Zampini product.setSizes(A.getSizes()) 125ee6c7c31SStefano Zampini product.setUp() 126ee6c7c31SStefano Zampini product.assemble() 127ee6c7c31SStefano Zampini A.copy(product) 128ee6c7c31SStefano Zampini else: 129ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 130ee6c7c31SStefano Zampini elif producttype == 'ABC': 131ee6c7c31SStefano Zampini if mat is A: # product = identity * B * C 132ee6c7c31SStefano Zampini self.tmp = PETSc.Mat() 133ee6c7c31SStefano Zampini B.matMult(C, self.tmp) 134ee6c7c31SStefano Zampini product.setType(self.tmp.getType()) 135ee6c7c31SStefano Zampini product.setSizes(self.tmp.getSizes()) 136ee6c7c31SStefano Zampini product.setUp() 137ee6c7c31SStefano Zampini product.assemble() 138ee6c7c31SStefano Zampini self.tmp.copy(product) 139ee6c7c31SStefano Zampini elif mat is B: # product = A * identity * C 140ee6c7c31SStefano Zampini self.tmp = PETSc.Mat() 141ee6c7c31SStefano Zampini A.matMult(C, self.tmp) 142ee6c7c31SStefano Zampini product.setType(self.tmp.getType()) 143ee6c7c31SStefano Zampini product.setSizes(self.tmp.getSizes()) 144ee6c7c31SStefano Zampini product.setUp() 145ee6c7c31SStefano Zampini product.assemble() 146ee6c7c31SStefano Zampini self.tmp.copy(product) 147ee6c7c31SStefano Zampini elif mat is C: # product = A * B * identity 148ee6c7c31SStefano Zampini self.tmp = PETSc.Mat() 149ee6c7c31SStefano Zampini A.matMult(B, self.tmp) 150ee6c7c31SStefano Zampini product.setType(self.tmp.getType()) 151ee6c7c31SStefano Zampini product.setSizes(self.tmp.getSizes()) 152ee6c7c31SStefano Zampini product.setUp() 153ee6c7c31SStefano Zampini product.assemble() 154ee6c7c31SStefano Zampini self.tmp.copy(product) 155ee6c7c31SStefano Zampini else: 156ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 157ee6c7c31SStefano Zampini else: 158ee6c7c31SStefano Zampini raise RuntimeError('Product {} not implemented'.format(producttype)) 159ee6c7c31SStefano Zampini product.zeroEntries() 160ee6c7c31SStefano Zampini 161ee6c7c31SStefano Zampini def productNumeric(self, mat, product, producttype, A, B, C): 162ee6c7c31SStefano Zampini if producttype == 'AB': 163ee6c7c31SStefano Zampini if mat is A: # product = identity * B 164ee6c7c31SStefano Zampini B.copy(product, structure=True) 165ee6c7c31SStefano Zampini elif mat is B: # product = A * identity 166ee6c7c31SStefano Zampini A.copy(product, structure=True) 167ee6c7c31SStefano Zampini else: 168ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 16922fceea1SStefano Zampini product.scale(self.s) 170ee6c7c31SStefano Zampini elif producttype == 'AtB': 171ee6c7c31SStefano Zampini if mat is A: # product = identity^T * B 172ee6c7c31SStefano Zampini B.copy(product, structure=True) 173ee6c7c31SStefano Zampini elif mat is B: # product = A^T * identity 174ee6c7c31SStefano Zampini A.transpose(product) 175ee6c7c31SStefano Zampini else: 176ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 17722fceea1SStefano Zampini product.scale(self.s) 178ee6c7c31SStefano Zampini elif producttype == 'ABt': 179ee6c7c31SStefano Zampini if mat is A: # product = identity * B^T 180ee6c7c31SStefano Zampini B.transpose(product) 181ee6c7c31SStefano Zampini elif mat is B: # product = A * identity^T 182ee6c7c31SStefano Zampini A.copy(product, structure=True) 183ee6c7c31SStefano Zampini else: 184ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 18522fceea1SStefano Zampini product.scale(self.s) 186ee6c7c31SStefano Zampini elif producttype == 'PtAP': 187ee6c7c31SStefano Zampini if mat is A: # product = P^T * identity * P 188ee6c7c31SStefano Zampini B.transposeMatMult(B, self.tmp) 189ee6c7c31SStefano Zampini self.tmp.copy(product, structure=True) 19022fceea1SStefano Zampini product.scale(self.s) 191ee6c7c31SStefano Zampini elif mat is B: # product = identity^T * A * identity 192ee6c7c31SStefano Zampini A.copy(product, structure=True) 19322fceea1SStefano Zampini product.scale(self.s**2) 194ee6c7c31SStefano Zampini else: 195ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 196ee6c7c31SStefano Zampini elif producttype == 'RARt': 197ee6c7c31SStefano Zampini if mat is A: # product = R * identity * R^t 198ee6c7c31SStefano Zampini B.matTransposeMult(B, self.tmp) 199ee6c7c31SStefano Zampini self.tmp.copy(product, structure=True) 20022fceea1SStefano Zampini product.scale(self.s) 201ee6c7c31SStefano Zampini elif mat is B: # product = identity * A * identity^T 202ee6c7c31SStefano Zampini A.copy(product, structure=True) 20322fceea1SStefano Zampini product.scale(self.s**2) 204ee6c7c31SStefano Zampini else: 205ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 206ee6c7c31SStefano Zampini elif producttype == 'ABC': 207ee6c7c31SStefano Zampini if mat is A: # product = identity * B * C 208ee6c7c31SStefano Zampini B.matMult(C, self.tmp) 209ee6c7c31SStefano Zampini self.tmp.copy(product, structure=True) 210ee6c7c31SStefano Zampini elif mat is B: # product = A * identity * C 211ee6c7c31SStefano Zampini A.matMult(C, self.tmp) 212ee6c7c31SStefano Zampini self.tmp.copy(product, structure=True) 213ee6c7c31SStefano Zampini elif mat is C: # product = A * B * identity 214ee6c7c31SStefano Zampini A.matMult(B, self.tmp) 215ee6c7c31SStefano Zampini self.tmp.copy(product, structure=True) 216ee6c7c31SStefano Zampini else: 217ee6c7c31SStefano Zampini raise RuntimeError('wrong configuration') 21822fceea1SStefano Zampini product.scale(self.s) 219ee6c7c31SStefano Zampini else: 220ee6c7c31SStefano Zampini raise RuntimeError('Product {} not implemented'.format(producttype)) 221ee6c7c31SStefano Zampini 2225808f684SSatish Balayclass Diagonal(Matrix): 2235808f684SSatish Balay 2245808f684SSatish Balay def create(self, mat): 2255808f684SSatish Balay super(Diagonal,self).create(mat) 2265808f684SSatish Balay mat.setUp() 2275808f684SSatish Balay self.D = mat.createVecLeft() 2285808f684SSatish Balay 2295808f684SSatish Balay def destroy(self, mat): 2305808f684SSatish Balay self.D.destroy() 2315808f684SSatish Balay super(Diagonal,self).destroy(mat) 2325808f684SSatish Balay 2335808f684SSatish Balay def scale(self, mat, a): 2345808f684SSatish Balay self.D.scale(a) 2355808f684SSatish Balay 2365808f684SSatish Balay def shift(self, mat, a): 2375808f684SSatish Balay self.D.shift(a) 2385808f684SSatish Balay 2395808f684SSatish Balay def zeroEntries(self, mat): 2405808f684SSatish Balay self.D.zeroEntries() 2415808f684SSatish Balay 2425808f684SSatish Balay def mult(self, mat, x, y): 2435808f684SSatish Balay y.pointwiseMult(x, self.D) 2445808f684SSatish Balay 245e124b1b1SStefano Zampini def duplicate(self, mat, op): 246e124b1b1SStefano Zampini dmat = PETSc.Mat() 247e124b1b1SStefano Zampini dctx = Diagonal() 248e124b1b1SStefano Zampini dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm()) 249e124b1b1SStefano Zampini dctx.D = self.D.duplicate() 250e124b1b1SStefano Zampini if op == PETSc.Mat.DuplicateOption.COPY_VALUES: 251e124b1b1SStefano Zampini self.D.copy(dctx.D) 252e124b1b1SStefano Zampini dmat.setUp() 253e124b1b1SStefano Zampini return dmat 254e124b1b1SStefano Zampini 2555808f684SSatish Balay def getDiagonal(self, mat, vd): 2565808f684SSatish Balay self.D.copy(vd) 2575808f684SSatish Balay 2585808f684SSatish Balay def setDiagonal(self, mat, vd, im): 2595808f684SSatish Balay if isinstance (im, bool): 2605808f684SSatish Balay addv = im 2615808f684SSatish Balay if addv: 2625808f684SSatish Balay self.D.axpy(1, vd) 2635808f684SSatish Balay else: 2645808f684SSatish Balay vd.copy(self.D) 2655808f684SSatish Balay elif im == PETSc.InsertMode.INSERT_VALUES: 2665808f684SSatish Balay vd.copy(self.D) 2675808f684SSatish Balay elif im == PETSc.InsertMode.ADD_VALUES: 2685808f684SSatish Balay self.D.axpy(1, vd) 2695808f684SSatish Balay else: 2705808f684SSatish Balay raise ValueError('wrong InsertMode %d'% im) 2715808f684SSatish Balay 2725808f684SSatish Balay def diagonalScale(self, mat, vl, vr): 2735808f684SSatish Balay if vl: self.D.pointwiseMult(self.D, vl) 2745808f684SSatish Balay if vr: self.D.pointwiseMult(self.D, vr) 2755808f684SSatish Balay 2765808f684SSatish Balay# -------------------------------------------------------------------- 2775808f684SSatish Balay 2785808f684SSatish Balayclass TestMatrix(unittest.TestCase): 2795808f684SSatish Balay 2805808f684SSatish Balay COMM = PETSc.COMM_WORLD 2815808f684SSatish Balay PYMOD = __name__ 2825808f684SSatish Balay PYCLS = 'Matrix' 2835808f684SSatish Balay 2845808f684SSatish Balay def _getCtx(self): 2855808f684SSatish Balay return self.A.getPythonContext() 2865808f684SSatish Balay 2875808f684SSatish Balay def setUp(self): 288*300d917bSStefano Zampini N = self.N = 13 2895808f684SSatish Balay self.A = PETSc.Mat() 2905808f684SSatish Balay if 0: # command line way 2915808f684SSatish Balay self.A.create(self.COMM) 2925808f684SSatish Balay self.A.setSizes([N,N]) 2935808f684SSatish Balay self.A.setType('python') 2945808f684SSatish Balay OptDB = PETSc.Options(self.A) 2955808f684SSatish Balay OptDB['mat_python_type'] = '%s.%s' % (self.PYMOD,self.PYCLS) 2965808f684SSatish Balay self.A.setFromOptions() 2975808f684SSatish Balay self.A.setUp() 2985808f684SSatish Balay del OptDB['mat_python_type'] 2995808f684SSatish Balay self.assertTrue(self._getCtx() is not None) 3005808f684SSatish Balay else: # python way 3015808f684SSatish Balay context = globals()[self.PYCLS]() 3025808f684SSatish Balay self.A.createPython([N,N], context, comm=self.COMM) 3035808f684SSatish Balay self.A.setUp() 3045808f684SSatish Balay self.assertTrue(self._getCtx() is context) 3055808f684SSatish Balay self.assertEqual(getrefcount(context), 3) 3065808f684SSatish Balay del context 3075808f684SSatish Balay self.assertEqual(getrefcount(self._getCtx()), 2) 3085808f684SSatish Balay 3095808f684SSatish Balay def tearDown(self): 3105808f684SSatish Balay ctx = self.A.getPythonContext() 3115808f684SSatish Balay self.assertEqual(getrefcount(ctx), 3) 3125808f684SSatish Balay self.A.destroy() # XXX 3135808f684SSatish Balay self.A = None 3145808f684SSatish Balay self.assertEqual(getrefcount(ctx), 2) 3155808f684SSatish Balay #import gc,pprint; pprint.pprint(gc.get_referrers(ctx)) 3165808f684SSatish Balay 3175808f684SSatish Balay def testBasic(self): 3185808f684SSatish Balay ctx = self.A.getPythonContext() 3195808f684SSatish Balay self.assertTrue(self._getCtx() is ctx) 3205808f684SSatish Balay self.assertEqual(getrefcount(ctx), 3) 3215808f684SSatish Balay 3225808f684SSatish Balay def testZeroEntries(self): 3235808f684SSatish Balay f = lambda : self.A.zeroEntries() 3245808f684SSatish Balay self.assertRaises(Exception, f) 3255808f684SSatish Balay 3265808f684SSatish Balay def testMult(self): 3275808f684SSatish Balay x, y = self.A.createVecs() 3285808f684SSatish Balay f = lambda : self.A.mult(x, y) 3295808f684SSatish Balay self.assertRaises(Exception, f) 3305808f684SSatish Balay 3315808f684SSatish Balay def testMultTranspose(self): 3325808f684SSatish Balay x, y = self.A.createVecs() 3335808f684SSatish Balay f = lambda : self.A.multTranspose(x, y) 3345808f684SSatish Balay self.assertRaises(Exception, f) 3355808f684SSatish Balay 3365808f684SSatish Balay def testGetDiagonal(self): 3375808f684SSatish Balay d = self.A.createVecLeft() 3385808f684SSatish Balay f = lambda : self.A.getDiagonal(d) 3395808f684SSatish Balay self.assertRaises(Exception, f) 3405808f684SSatish Balay 3415808f684SSatish Balay def testSetDiagonal(self): 3425808f684SSatish Balay d = self.A.createVecLeft() 3435808f684SSatish Balay f = lambda : self.A.setDiagonal(d) 3445808f684SSatish Balay self.assertRaises(Exception, f) 3455808f684SSatish Balay 3465808f684SSatish Balay def testDiagonalScale(self): 3475808f684SSatish Balay x, y = self.A.createVecs() 3485808f684SSatish Balay f = lambda : self.A.diagonalScale(x, y) 3495808f684SSatish Balay self.assertRaises(Exception, f) 3505808f684SSatish Balay 351e124b1b1SStefano Zampini def testDuplicate(self): 352e124b1b1SStefano Zampini f1 = lambda : self.A.duplicate(x, True) 353e124b1b1SStefano Zampini f2 = lambda : self.A.duplicate(x, False) 354e124b1b1SStefano Zampini self.assertRaises(Exception, f1) 355e124b1b1SStefano Zampini self.assertRaises(Exception, f2) 356e124b1b1SStefano Zampini 3571cebabd4SStefano Zampini def testSetVecType(self): 3581cebabd4SStefano Zampini self.A.setVecType('mpi') 3591cebabd4SStefano Zampini self.assertTrue('mpi' == self.A.getVecType()) 3601cebabd4SStefano Zampini 361*300d917bSStefano Zampini def testH2Opus(self): 362*300d917bSStefano Zampini if not PETSc.Sys.hasExternalPackage("h2opus"): 363*300d917bSStefano Zampini return 364*300d917bSStefano Zampini if self.A.getComm().Get_size() > 1: 365*300d917bSStefano Zampini return 366*300d917bSStefano Zampini h = PETSc.Mat() 367*300d917bSStefano Zampini 368*300d917bSStefano Zampini # need matrix vector and its transpose for norm estimation 369*300d917bSStefano Zampini AA = self.A.getPythonContext() 370*300d917bSStefano Zampini if not hasattr(AA,'mult'): 371*300d917bSStefano Zampini return 372*300d917bSStefano Zampini AA.multTranspose = AA.mult 373*300d917bSStefano Zampini 374*300d917bSStefano Zampini # without coordinates 375*300d917bSStefano Zampini h.createH2OpusFromMat(self.A,leafsize=2) 376*300d917bSStefano Zampini h.assemble() 377*300d917bSStefano Zampini h.destroy() 378*300d917bSStefano Zampini 379*300d917bSStefano Zampini # with coordinates 380*300d917bSStefano Zampini coords = numpy.linspace((1,2,3),(10,20,30),self.A.getSize()[0],dtype=PETSc.RealType) 381*300d917bSStefano Zampini h.createH2OpusFromMat(self.A,coords,leafsize=2) 382*300d917bSStefano Zampini h.assemble() 383*300d917bSStefano Zampini 384*300d917bSStefano Zampini # test API 385*300d917bSStefano Zampini h.H2OpusOrthogonalize() 386*300d917bSStefano Zampini h.H2OpusCompress(1.e-1) 387*300d917bSStefano Zampini 388*300d917bSStefano Zampini # Low-rank update 389*300d917bSStefano Zampini U = PETSc.Mat() 390*300d917bSStefano Zampini U.createDense([h.getSizes()[0],3],comm=h.getComm()) 391*300d917bSStefano Zampini U.setUp() 392*300d917bSStefano Zampini U.setRandom() 393*300d917bSStefano Zampini 394*300d917bSStefano Zampini he = PETSc.Mat() 395*300d917bSStefano Zampini h.convert('dense',he) 396*300d917bSStefano Zampini he.axpy(1.0, U.matTransposeMult(U)) 397*300d917bSStefano Zampini 398*300d917bSStefano Zampini h.H2OpusLowRankUpdate(U) 399*300d917bSStefano Zampini self.assertTrue(he.equal(h)) 400*300d917bSStefano Zampini 401*300d917bSStefano Zampini 402*300d917bSStefano Zampini h.destroy() 403*300d917bSStefano Zampini 404*300d917bSStefano Zampini del AA.multTranspose 405*300d917bSStefano Zampini 406*300d917bSStefano Zampini 40722fceea1SStefano Zampiniclass TestScaledIdentity(TestMatrix): 4085808f684SSatish Balay 40922fceea1SStefano Zampini PYCLS = 'ScaledIdentity' 4105808f684SSatish Balay 4115808f684SSatish Balay def testMult(self): 41222fceea1SStefano Zampini s = self._getCtx().s 4135808f684SSatish Balay x, y = self.A.createVecs() 4145808f684SSatish Balay x.setRandom() 4155808f684SSatish Balay self.A.mult(x,y) 41622fceea1SStefano Zampini self.assertTrue(y.equal(s*x)) 4175808f684SSatish Balay 4185808f684SSatish Balay def testMultTransposeSymmKnown(self): 41922fceea1SStefano Zampini s = self._getCtx().s 4205808f684SSatish Balay x, y = self.A.createVecs() 4215808f684SSatish Balay x.setRandom() 4225808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True) 4235808f684SSatish Balay self.A.multTranspose(x,y) 42422fceea1SStefano Zampini self.assertTrue(y.equal(s*x)) 4255808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False) 4265808f684SSatish Balay f = lambda : self.A.multTranspose(x, y) 4275808f684SSatish Balay self.assertRaises(Exception, f) 4285808f684SSatish Balay 4295808f684SSatish Balay def testMultTransposeNewMeth(self): 43022fceea1SStefano Zampini s = self._getCtx().s 4315808f684SSatish Balay x, y = self.A.createVecs() 4325808f684SSatish Balay x.setRandom() 4335808f684SSatish Balay AA = self.A.getPythonContext() 4345808f684SSatish Balay AA.multTranspose = AA.mult 4355808f684SSatish Balay self.A.multTranspose(x,y) 4365808f684SSatish Balay del AA.multTranspose 43722fceea1SStefano Zampini self.assertTrue(y.equal(s*x)) 4385808f684SSatish Balay 4395808f684SSatish Balay def testGetDiagonal(self): 44022fceea1SStefano Zampini s = self._getCtx().s 4415808f684SSatish Balay d = self.A.createVecLeft() 4425808f684SSatish Balay o = d.duplicate() 44322fceea1SStefano Zampini o.set(s) 4445808f684SSatish Balay self.A.getDiagonal(d) 4455808f684SSatish Balay self.assertTrue(o.equal(d)) 4465808f684SSatish Balay 447e124b1b1SStefano Zampini def testDuplicate(self): 448e124b1b1SStefano Zampini B = self.A.duplicate(False) 449e124b1b1SStefano Zampini self.assertTrue(B.getPythonContext().s == 2) 450e124b1b1SStefano Zampini B = self.A.duplicate(True) 451e124b1b1SStefano Zampini self.assertTrue(B.getPythonContext().s == self.A.getPythonContext().s) 452e124b1b1SStefano Zampini 453ee6c7c31SStefano Zampini def testMatMat(self): 45422fceea1SStefano Zampini s = self._getCtx().s 455ee6c7c31SStefano Zampini R = PETSc.Random().create(self.COMM) 456ee6c7c31SStefano Zampini R.setFromOptions() 457ee6c7c31SStefano Zampini A = PETSc.Mat().create(self.COMM) 458ee6c7c31SStefano Zampini A.setSizes(self.A.getSizes()) 459ee6c7c31SStefano Zampini A.setType(PETSc.Mat.Type.AIJ) 460ee6c7c31SStefano Zampini A.setUp() 461ee6c7c31SStefano Zampini A.setRandom(R) 462ee6c7c31SStefano Zampini B = PETSc.Mat().create(self.COMM) 463ee6c7c31SStefano Zampini B.setSizes(self.A.getSizes()) 464ee6c7c31SStefano Zampini B.setType(PETSc.Mat.Type.AIJ) 465ee6c7c31SStefano Zampini B.setUp() 466ee6c7c31SStefano Zampini B.setRandom(R) 467ee6c7c31SStefano Zampini I = PETSc.Mat().create(self.COMM) 468ee6c7c31SStefano Zampini I.setSizes(self.A.getSizes()) 469ee6c7c31SStefano Zampini I.setType(PETSc.Mat.Type.AIJ) 470ee6c7c31SStefano Zampini I.setUp() 471ee6c7c31SStefano Zampini I.assemble() 47222fceea1SStefano Zampini I.shift(s) 473ee6c7c31SStefano Zampini 474ee6c7c31SStefano Zampini self.assertTrue(self.A.matMult(A).equal(I.matMult(A))) 475ee6c7c31SStefano Zampini self.assertTrue(A.matMult(self.A).equal(A.matMult(I))) 476ee6c7c31SStefano Zampini if self.A.getComm().Get_size() == 1: 477ee6c7c31SStefano Zampini self.assertTrue(self.A.matTransposeMult(A).equal(I.matTransposeMult(A))) 478ee6c7c31SStefano Zampini self.assertTrue(A.matTransposeMult(self.A).equal(A.matTransposeMult(I))) 479ee6c7c31SStefano Zampini self.assertTrue(self.A.transposeMatMult(A).equal(I.transposeMatMult(A))) 480ee6c7c31SStefano Zampini self.assertTrue(A.transposeMatMult(self.A).equal(A.transposeMatMult(I))) 481ee6c7c31SStefano Zampini self.assertAlmostEqual((self.A.ptap(A) - I.ptap(A)).norm(), 0.0, places=5) 482ee6c7c31SStefano Zampini self.assertAlmostEqual((A.ptap(self.A) - A.ptap(I)).norm(), 0.0, places=5) 483ee6c7c31SStefano Zampini if self.A.getComm().Get_size() == 1: 484ee6c7c31SStefano Zampini self.assertAlmostEqual((self.A.rart(A) - I.rart(A)).norm(), 0.0, places=5) 485ee6c7c31SStefano Zampini self.assertAlmostEqual((A.rart(self.A) - A.rart(I)).norm(), 0.0, places=5) 486ee6c7c31SStefano Zampini self.assertAlmostEqual((self.A.matMatMult(A,B)-I.matMatMult(A,B)).norm(), 0.0, places=5) 487ee6c7c31SStefano Zampini self.assertAlmostEqual((A.matMatMult(self.A,B)-A.matMatMult(I,B)).norm(), 0.0, places=5) 488ee6c7c31SStefano Zampini self.assertAlmostEqual((A.matMatMult(B,self.A)-A.matMatMult(B,I)).norm(), 0.0, places=5) 489ee6c7c31SStefano Zampini 49022fceea1SStefano Zampini def testShift(self): 49122fceea1SStefano Zampini sold = self._getCtx().s 49222fceea1SStefano Zampini self.A.shift(-0.5) 49322fceea1SStefano Zampini s = self._getCtx().s 49422fceea1SStefano Zampini self.assertTrue(s == sold - 0.5) 49522fceea1SStefano Zampini 49622fceea1SStefano Zampini def testScale(self): 49722fceea1SStefano Zampini sold = self._getCtx().s 49822fceea1SStefano Zampini self.A.scale(-0.5) 49922fceea1SStefano Zampini s = self._getCtx().s 50022fceea1SStefano Zampini self.assertTrue(s == sold * -0.5) 50122fceea1SStefano Zampini 5025808f684SSatish Balayclass TestDiagonal(TestMatrix): 5035808f684SSatish Balay 5045808f684SSatish Balay PYCLS = 'Diagonal' 5055808f684SSatish Balay 5065808f684SSatish Balay def setUp(self): 5075808f684SSatish Balay super(TestDiagonal, self).setUp() 5085808f684SSatish Balay D = self.A.createVecLeft() 5095808f684SSatish Balay s, e = D.getOwnershipRange() 5105808f684SSatish Balay for i in range(s, e): 5115808f684SSatish Balay D[i] = i+1 5125808f684SSatish Balay D.assemble() 5135808f684SSatish Balay self.A.setDiagonal(D) 5145808f684SSatish Balay 5155808f684SSatish Balay def testZeroEntries(self): 5165808f684SSatish Balay self.A.zeroEntries() 5175808f684SSatish Balay D = self._getCtx().D 5185808f684SSatish Balay self.assertEqual(D.norm(), 0) 5195808f684SSatish Balay 5205808f684SSatish Balay def testMult(self): 5215808f684SSatish Balay x, y = self.A.createVecs() 5225808f684SSatish Balay x.set(1) 5235808f684SSatish Balay self.A.mult(x,y) 5245808f684SSatish Balay self.assertTrue(y.equal(self._getCtx().D)) 5255808f684SSatish Balay 5265808f684SSatish Balay def testMultTransposeSymmKnown(self): 5275808f684SSatish Balay x, y = self.A.createVecs() 5285808f684SSatish Balay x.set(1) 5295808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True) 5305808f684SSatish Balay self.A.multTranspose(x,y) 5315808f684SSatish Balay self.assertTrue(y.equal(self._getCtx().D)) 5325808f684SSatish Balay self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False) 5335808f684SSatish Balay f = lambda : self.A.multTranspose(x, y) 5345808f684SSatish Balay self.assertRaises(Exception, f) 5355808f684SSatish Balay 5365808f684SSatish Balay def testMultTransposeNewMeth(self): 5375808f684SSatish Balay x, y = self.A.createVecs() 5385808f684SSatish Balay x.set(1) 5395808f684SSatish Balay AA = self.A.getPythonContext() 5405808f684SSatish Balay AA.multTranspose = AA.mult 5415808f684SSatish Balay self.A.multTranspose(x,y) 5425808f684SSatish Balay del AA.multTranspose 5435808f684SSatish Balay self.assertTrue(y.equal(self._getCtx().D)) 5445808f684SSatish Balay 545e124b1b1SStefano Zampini def testDuplicate(self): 546e124b1b1SStefano Zampini B = self.A.duplicate(False) 547e124b1b1SStefano Zampini B = self.A.duplicate(True) 548e124b1b1SStefano Zampini self.assertTrue(B.getPythonContext().D.equal(self.A.getPythonContext().D)) 549e124b1b1SStefano Zampini 5505808f684SSatish Balay def testGetDiagonal(self): 5515808f684SSatish Balay d = self.A.createVecLeft() 5525808f684SSatish Balay self.A.getDiagonal(d) 5535808f684SSatish Balay self.assertTrue(d.equal(self._getCtx().D)) 5545808f684SSatish Balay 5555808f684SSatish Balay def testSetDiagonal(self): 5565808f684SSatish Balay d = self.A.createVecLeft() 5575808f684SSatish Balay d.setRandom() 5585808f684SSatish Balay self.A.setDiagonal(d) 5595808f684SSatish Balay self.assertTrue(d.equal(self._getCtx().D)) 5605808f684SSatish Balay 5615808f684SSatish Balay def testDiagonalScale(self): 5625808f684SSatish Balay x, y = self.A.createVecs() 5635808f684SSatish Balay x.set(2) 5645808f684SSatish Balay y.set(3) 5655808f684SSatish Balay old = self._getCtx().D.copy() 5665808f684SSatish Balay self.A.diagonalScale(x, y) 5675808f684SSatish Balay D = self._getCtx().D 5685808f684SSatish Balay self.assertTrue(D.equal(old*6)) 5695808f684SSatish Balay 5705808f684SSatish Balay def testCreateTranspose(self): 5715808f684SSatish Balay A = self.A 5725808f684SSatish Balay A.setOption(PETSc.Mat.Option.SYMMETRIC, True) 5735808f684SSatish Balay AT = PETSc.Mat().createTranspose(A) 5745808f684SSatish Balay x, y = A.createVecs() 5755808f684SSatish Balay xt, yt = AT.createVecs() 5765808f684SSatish Balay # 5775808f684SSatish Balay y.setRandom() 5785808f684SSatish Balay A.multTranspose(y, x) 5795808f684SSatish Balay y.copy(xt) 5805808f684SSatish Balay AT.mult(xt, yt) 5815808f684SSatish Balay self.assertTrue(yt.equal(x)) 5825808f684SSatish Balay # 5835808f684SSatish Balay x.setRandom() 5845808f684SSatish Balay A.mult(x, y) 5855808f684SSatish Balay x.copy(yt) 5865808f684SSatish Balay AT.multTranspose(yt, xt) 5875808f684SSatish Balay self.assertTrue(xt.equal(y)) 5885808f684SSatish Balay del A 5895808f684SSatish Balay 5908af18dd8SStefano Zampini def testConvert(self): 5918af18dd8SStefano Zampini self.assertTrue(self.A.convert(PETSc.Mat.Type.AIJ,PETSc.Mat()).equal(self.A)) 5928af18dd8SStefano Zampini self.assertTrue(self.A.convert(PETSc.Mat.Type.BAIJ,PETSc.Mat()).equal(self.A)) 5938af18dd8SStefano Zampini self.assertTrue(self.A.convert(PETSc.Mat.Type.SBAIJ,PETSc.Mat()).equal(self.A)) 5948af18dd8SStefano Zampini self.assertTrue(self.A.convert(PETSc.Mat.Type.DENSE,PETSc.Mat()).equal(self.A)) 5958c2316a8SJeremy Tillay 59622fceea1SStefano Zampini def testShift(self): 59722fceea1SStefano Zampini old = self._getCtx().D.copy() 59822fceea1SStefano Zampini self.A.shift(-0.5) 59922fceea1SStefano Zampini D = self._getCtx().D 60022fceea1SStefano Zampini self.assertTrue(D.equal(old-0.5)) 60122fceea1SStefano Zampini 60222fceea1SStefano Zampini def testScale(self): 60322fceea1SStefano Zampini old = self._getCtx().D.copy() 60422fceea1SStefano Zampini self.A.scale(-0.5) 60522fceea1SStefano Zampini D = self._getCtx().D 60622fceea1SStefano Zampini self.assertTrue(D.equal(-0.5*old)) 60722fceea1SStefano Zampini 60822fceea1SStefano Zampini 6095808f684SSatish Balay# -------------------------------------------------------------------- 6105808f684SSatish Balay 6115808f684SSatish Balayif __name__ == '__main__': 6125808f684SSatish Balay unittest.main() 613