xref: /petsc/src/binding/petsc4py/test/test_mat_py.py (revision ee6c7c3114175da046f35334ed12bfa2b0ac9657)
15808f684SSatish Balayfrom petsc4py import PETSc
253022affSStefano Zampiniimport unittest, numpy
35808f684SSatish Balayfrom sys import getrefcount
45808f684SSatish Balay# --------------------------------------------------------------------
55808f684SSatish Balay
65808f684SSatish Balayclass Matrix(object):
75808f684SSatish Balay
85808f684SSatish Balay    def __init__(self):
95808f684SSatish Balay        pass
105808f684SSatish Balay
115808f684SSatish Balay    def create(self, mat):
125808f684SSatish Balay        pass
135808f684SSatish Balay
145808f684SSatish Balay    def destroy(self, mat):
155808f684SSatish Balay        pass
165808f684SSatish Balay
175808f684SSatish Balayclass Identity(Matrix):
185808f684SSatish Balay
195808f684SSatish Balay    def mult(self, mat, x, y):
205808f684SSatish Balay        x.copy(y)
215808f684SSatish Balay
225808f684SSatish Balay    def getDiagonal(self, mat, vd):
235808f684SSatish Balay        vd.set(1)
245808f684SSatish Balay
25*ee6c7c31SStefano Zampini    def productSetFromOptions(self, mat, producttype, A, B, C):
26*ee6c7c31SStefano Zampini        return True
27*ee6c7c31SStefano Zampini
28*ee6c7c31SStefano Zampini    def productSymbolic(self, mat, product, producttype, A, B, C):
29*ee6c7c31SStefano Zampini        if producttype == 'AB':
30*ee6c7c31SStefano Zampini            if mat is A: # product = identity * B
31*ee6c7c31SStefano Zampini                product.setType(B.getType())
32*ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
33*ee6c7c31SStefano Zampini                product.setUp()
34*ee6c7c31SStefano Zampini                product.assemble()
35*ee6c7c31SStefano Zampini                B.copy(product)
36*ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity
37*ee6c7c31SStefano Zampini                product.setType(A.getType())
38*ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
39*ee6c7c31SStefano Zampini                product.setUp()
40*ee6c7c31SStefano Zampini                product.assemble()
41*ee6c7c31SStefano Zampini                A.copy(product)
42*ee6c7c31SStefano Zampini            else:
43*ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
44*ee6c7c31SStefano Zampini        elif producttype == 'AtB':
45*ee6c7c31SStefano Zampini            if mat is A: # product = identity^T * B
46*ee6c7c31SStefano Zampini                product.setType(B.getType())
47*ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
48*ee6c7c31SStefano Zampini                product.setUp()
49*ee6c7c31SStefano Zampini                product.assemble()
50*ee6c7c31SStefano Zampini                B.copy(product)
51*ee6c7c31SStefano Zampini            elif mat is B: # product = A^T * identity
52*ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
53*ee6c7c31SStefano Zampini                A.transpose(tmp)
54*ee6c7c31SStefano Zampini                product.setType(tmp.getType())
55*ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
56*ee6c7c31SStefano Zampini                product.setUp()
57*ee6c7c31SStefano Zampini                product.assemble()
58*ee6c7c31SStefano Zampini                tmp.copy(product)
59*ee6c7c31SStefano Zampini            else:
60*ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
61*ee6c7c31SStefano Zampini        elif producttype == 'ABt':
62*ee6c7c31SStefano Zampini            if mat is A: # product = identity * B^T
63*ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
64*ee6c7c31SStefano Zampini                B.transpose(tmp)
65*ee6c7c31SStefano Zampini                product.setType(tmp.getType())
66*ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
67*ee6c7c31SStefano Zampini                product.setUp()
68*ee6c7c31SStefano Zampini                product.assemble()
69*ee6c7c31SStefano Zampini                tmp.copy(product)
70*ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity^T
71*ee6c7c31SStefano Zampini                product.setType(A.getType())
72*ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
73*ee6c7c31SStefano Zampini                product.setUp()
74*ee6c7c31SStefano Zampini                product.assemble()
75*ee6c7c31SStefano Zampini                A.copy(product)
76*ee6c7c31SStefano Zampini            else:
77*ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
78*ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
79*ee6c7c31SStefano Zampini            if mat is A: # product = P^T * identity * P
80*ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
81*ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
82*ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
83*ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
84*ee6c7c31SStefano Zampini                product.setUp()
85*ee6c7c31SStefano Zampini                product.assemble()
86*ee6c7c31SStefano Zampini                self.tmp.copy(product)
87*ee6c7c31SStefano Zampini            elif mat is B: # product = identity^T * A * identity
88*ee6c7c31SStefano Zampini                product.setType(A.getType())
89*ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
90*ee6c7c31SStefano Zampini                product.setUp()
91*ee6c7c31SStefano Zampini                product.assemble()
92*ee6c7c31SStefano Zampini                A.copy(product)
93*ee6c7c31SStefano Zampini            else:
94*ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
95*ee6c7c31SStefano Zampini        elif producttype == 'RARt':
96*ee6c7c31SStefano Zampini            if mat is A: # product = R * identity * R^t
97*ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
98*ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
99*ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
100*ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
101*ee6c7c31SStefano Zampini                product.setUp()
102*ee6c7c31SStefano Zampini                product.assemble()
103*ee6c7c31SStefano Zampini                self.tmp.copy(product)
104*ee6c7c31SStefano Zampini            elif mat is B: # product = identity * A * identity^T
105*ee6c7c31SStefano Zampini                product.setType(A.getType())
106*ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
107*ee6c7c31SStefano Zampini                product.setUp()
108*ee6c7c31SStefano Zampini                product.assemble()
109*ee6c7c31SStefano Zampini                A.copy(product)
110*ee6c7c31SStefano Zampini            else:
111*ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
112*ee6c7c31SStefano Zampini        elif producttype == 'ABC':
113*ee6c7c31SStefano Zampini            if mat is A: # product = identity * B * C
114*ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
115*ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
116*ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
117*ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
118*ee6c7c31SStefano Zampini                product.setUp()
119*ee6c7c31SStefano Zampini                product.assemble()
120*ee6c7c31SStefano Zampini                self.tmp.copy(product)
121*ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity * C
122*ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
123*ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
124*ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
125*ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
126*ee6c7c31SStefano Zampini                product.setUp()
127*ee6c7c31SStefano Zampini                product.assemble()
128*ee6c7c31SStefano Zampini                self.tmp.copy(product)
129*ee6c7c31SStefano Zampini            elif mat is C: # product = A * B * identity
130*ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
131*ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
132*ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
133*ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
134*ee6c7c31SStefano Zampini                product.setUp()
135*ee6c7c31SStefano Zampini                product.assemble()
136*ee6c7c31SStefano Zampini                self.tmp.copy(product)
137*ee6c7c31SStefano Zampini            else:
138*ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
139*ee6c7c31SStefano Zampini        else:
140*ee6c7c31SStefano Zampini            raise RuntimeError('Product {} not implemented'.format(producttype))
141*ee6c7c31SStefano Zampini        product.zeroEntries()
142*ee6c7c31SStefano Zampini
143*ee6c7c31SStefano Zampini    def productNumeric(self, mat, product, producttype, A, B, C):
144*ee6c7c31SStefano Zampini        if producttype == 'AB':
145*ee6c7c31SStefano Zampini            if mat is A: # product = identity * B
146*ee6c7c31SStefano Zampini                B.copy(product, structure=True)
147*ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity
148*ee6c7c31SStefano Zampini                A.copy(product, structure=True)
149*ee6c7c31SStefano Zampini            else:
150*ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
151*ee6c7c31SStefano Zampini        elif producttype == 'AtB':
152*ee6c7c31SStefano Zampini            if mat is A: # product = identity^T * B
153*ee6c7c31SStefano Zampini                B.copy(product, structure=True)
154*ee6c7c31SStefano Zampini            elif mat is B: # product = A^T * identity
155*ee6c7c31SStefano Zampini                A.transpose(product)
156*ee6c7c31SStefano Zampini            else:
157*ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
158*ee6c7c31SStefano Zampini        elif producttype == 'ABt':
159*ee6c7c31SStefano Zampini            if mat is A: # product = identity * B^T
160*ee6c7c31SStefano Zampini                B.transpose(product)
161*ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity^T
162*ee6c7c31SStefano Zampini                A.copy(product, structure=True)
163*ee6c7c31SStefano Zampini            else:
164*ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
165*ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
166*ee6c7c31SStefano Zampini            if mat is A: # product = P^T * identity * P
167*ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
168*ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
169*ee6c7c31SStefano Zampini            elif mat is B: # product = identity^T * A * identity
170*ee6c7c31SStefano Zampini                A.copy(product, structure=True)
171*ee6c7c31SStefano Zampini            else:
172*ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
173*ee6c7c31SStefano Zampini        elif producttype == 'RARt':
174*ee6c7c31SStefano Zampini            if mat is A: # product = R * identity * R^t
175*ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
176*ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
177*ee6c7c31SStefano Zampini            elif mat is B: # product = identity * A * identity^T
178*ee6c7c31SStefano Zampini                A.copy(product, structure=True)
179*ee6c7c31SStefano Zampini            else:
180*ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
181*ee6c7c31SStefano Zampini        elif producttype == 'ABC':
182*ee6c7c31SStefano Zampini            if mat is A: # product = identity * B * C
183*ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
184*ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
185*ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity * C
186*ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
187*ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
188*ee6c7c31SStefano Zampini            elif mat is C: # product = A * B * identity
189*ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
190*ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
191*ee6c7c31SStefano Zampini            else:
192*ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
193*ee6c7c31SStefano Zampini        else:
194*ee6c7c31SStefano Zampini            raise RuntimeError('Product {} not implemented'.format(producttype))
195*ee6c7c31SStefano Zampini
1965808f684SSatish Balayclass Diagonal(Matrix):
1975808f684SSatish Balay
1985808f684SSatish Balay    def create(self, mat):
1995808f684SSatish Balay        super(Diagonal,self).create(mat)
2005808f684SSatish Balay        mat.setUp()
2015808f684SSatish Balay        self.D = mat.createVecLeft()
2025808f684SSatish Balay
2035808f684SSatish Balay    def destroy(self, mat):
2045808f684SSatish Balay        self.D.destroy()
2055808f684SSatish Balay        super(Diagonal,self).destroy(mat)
2065808f684SSatish Balay
2075808f684SSatish Balay    def scale(self, mat, a):
2085808f684SSatish Balay        self.D.scale(a)
2095808f684SSatish Balay
2105808f684SSatish Balay    def shift(self, mat, a):
2115808f684SSatish Balay        self.D.shift(a)
2125808f684SSatish Balay
2135808f684SSatish Balay    def zeroEntries(self, mat):
2145808f684SSatish Balay        self.D.zeroEntries()
2155808f684SSatish Balay
2165808f684SSatish Balay    def mult(self, mat, x, y):
2175808f684SSatish Balay        y.pointwiseMult(x, self.D)
2185808f684SSatish Balay
2195808f684SSatish Balay    def getDiagonal(self, mat, vd):
2205808f684SSatish Balay        self.D.copy(vd)
2215808f684SSatish Balay
2225808f684SSatish Balay    def setDiagonal(self, mat, vd, im):
2235808f684SSatish Balay        if isinstance (im, bool):
2245808f684SSatish Balay            addv = im
2255808f684SSatish Balay            if addv:
2265808f684SSatish Balay                self.D.axpy(1, vd)
2275808f684SSatish Balay            else:
2285808f684SSatish Balay                vd.copy(self.D)
2295808f684SSatish Balay        elif im == PETSc.InsertMode.INSERT_VALUES:
2305808f684SSatish Balay            vd.copy(self.D)
2315808f684SSatish Balay        elif im == PETSc.InsertMode.ADD_VALUES:
2325808f684SSatish Balay            self.D.axpy(1, vd)
2335808f684SSatish Balay        else:
2345808f684SSatish Balay            raise ValueError('wrong InsertMode %d'% im)
2355808f684SSatish Balay
2365808f684SSatish Balay    def diagonalScale(self, mat, vl, vr):
2375808f684SSatish Balay        if vl: self.D.pointwiseMult(self.D, vl)
2385808f684SSatish Balay        if vr: self.D.pointwiseMult(self.D, vr)
2395808f684SSatish Balay
2405808f684SSatish Balay# --------------------------------------------------------------------
2415808f684SSatish Balay
2425808f684SSatish Balayclass TestMatrix(unittest.TestCase):
2435808f684SSatish Balay
2445808f684SSatish Balay    COMM = PETSc.COMM_WORLD
2455808f684SSatish Balay    PYMOD = __name__
2465808f684SSatish Balay    PYCLS = 'Matrix'
2475808f684SSatish Balay
2485808f684SSatish Balay    def _getCtx(self):
2495808f684SSatish Balay        return self.A.getPythonContext()
2505808f684SSatish Balay
2515808f684SSatish Balay    def setUp(self):
2525808f684SSatish Balay        N = self.N = 10
2535808f684SSatish Balay        self.A = PETSc.Mat()
2545808f684SSatish Balay        if 0: # command line way
2555808f684SSatish Balay            self.A.create(self.COMM)
2565808f684SSatish Balay            self.A.setSizes([N,N])
2575808f684SSatish Balay            self.A.setType('python')
2585808f684SSatish Balay            OptDB = PETSc.Options(self.A)
2595808f684SSatish Balay            OptDB['mat_python_type'] = '%s.%s' % (self.PYMOD,self.PYCLS)
2605808f684SSatish Balay            self.A.setFromOptions()
2615808f684SSatish Balay            self.A.setUp()
2625808f684SSatish Balay            del OptDB['mat_python_type']
2635808f684SSatish Balay            self.assertTrue(self._getCtx() is not None)
2645808f684SSatish Balay        else: # python way
2655808f684SSatish Balay            context = globals()[self.PYCLS]()
2665808f684SSatish Balay            self.A.createPython([N,N], context, comm=self.COMM)
2675808f684SSatish Balay            self.A.setUp()
2685808f684SSatish Balay            self.assertTrue(self._getCtx() is context)
2695808f684SSatish Balay            self.assertEqual(getrefcount(context), 3)
2705808f684SSatish Balay            del context
2715808f684SSatish Balay            self.assertEqual(getrefcount(self._getCtx()), 2)
2725808f684SSatish Balay
2735808f684SSatish Balay    def tearDown(self):
2745808f684SSatish Balay        ctx = self.A.getPythonContext()
2755808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
2765808f684SSatish Balay        self.A.destroy() # XXX
2775808f684SSatish Balay        self.A = None
2785808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 2)
2795808f684SSatish Balay        #import gc,pprint; pprint.pprint(gc.get_referrers(ctx))
2805808f684SSatish Balay
2815808f684SSatish Balay    def testBasic(self):
2825808f684SSatish Balay        ctx = self.A.getPythonContext()
2835808f684SSatish Balay        self.assertTrue(self._getCtx() is ctx)
2845808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
2855808f684SSatish Balay
2865808f684SSatish Balay    def testZeroEntries(self):
2875808f684SSatish Balay        f = lambda : self.A.zeroEntries()
2885808f684SSatish Balay        self.assertRaises(Exception, f)
2895808f684SSatish Balay
2905808f684SSatish Balay    def testMult(self):
2915808f684SSatish Balay        x, y = self.A.createVecs()
2925808f684SSatish Balay        f = lambda : self.A.mult(x, y)
2935808f684SSatish Balay        self.assertRaises(Exception, f)
2945808f684SSatish Balay
2955808f684SSatish Balay    def testMultTranspose(self):
2965808f684SSatish Balay        x, y = self.A.createVecs()
2975808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
2985808f684SSatish Balay        self.assertRaises(Exception, f)
2995808f684SSatish Balay
3005808f684SSatish Balay    def testGetDiagonal(self):
3015808f684SSatish Balay        d = self.A.createVecLeft()
3025808f684SSatish Balay        f = lambda : self.A.getDiagonal(d)
3035808f684SSatish Balay        self.assertRaises(Exception, f)
3045808f684SSatish Balay
3055808f684SSatish Balay    def testSetDiagonal(self):
3065808f684SSatish Balay        d = self.A.createVecLeft()
3075808f684SSatish Balay        f = lambda : self.A.setDiagonal(d)
3085808f684SSatish Balay        self.assertRaises(Exception, f)
3095808f684SSatish Balay
3105808f684SSatish Balay    def testDiagonalScale(self):
3115808f684SSatish Balay        x, y = self.A.createVecs()
3125808f684SSatish Balay        f = lambda : self.A.diagonalScale(x, y)
3135808f684SSatish Balay        self.assertRaises(Exception, f)
3145808f684SSatish Balay
3158c2316a8SJeremy Tillay
3165808f684SSatish Balayclass TestIdentity(TestMatrix):
3175808f684SSatish Balay
3185808f684SSatish Balay    PYCLS = 'Identity'
3195808f684SSatish Balay
3205808f684SSatish Balay    def testMult(self):
3215808f684SSatish Balay        x, y = self.A.createVecs()
3225808f684SSatish Balay        x.setRandom()
3235808f684SSatish Balay        self.A.mult(x,y)
3245808f684SSatish Balay        self.assertTrue(y.equal(x))
3255808f684SSatish Balay
3265808f684SSatish Balay    def testMultTransposeSymmKnown(self):
3275808f684SSatish Balay        x, y = self.A.createVecs()
3285808f684SSatish Balay        x.setRandom()
3295808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
3305808f684SSatish Balay        self.A.multTranspose(x,y)
3315808f684SSatish Balay        self.assertTrue(y.equal(x))
3325808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
3335808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
3345808f684SSatish Balay        self.assertRaises(Exception, f)
3355808f684SSatish Balay
3365808f684SSatish Balay    def testMultTransposeNewMeth(self):
3375808f684SSatish Balay        x, y = self.A.createVecs()
3385808f684SSatish Balay        x.setRandom()
3395808f684SSatish Balay        AA = self.A.getPythonContext()
3405808f684SSatish Balay        AA.multTranspose = AA.mult
3415808f684SSatish Balay        self.A.multTranspose(x,y)
3425808f684SSatish Balay        del AA.multTranspose
3435808f684SSatish Balay        self.assertTrue(y.equal(x))
3445808f684SSatish Balay
3455808f684SSatish Balay    def testGetDiagonal(self):
3465808f684SSatish Balay        d = self.A.createVecLeft()
3475808f684SSatish Balay        o = d.duplicate()
3485808f684SSatish Balay        o.set(1)
3495808f684SSatish Balay        self.A.getDiagonal(d)
3505808f684SSatish Balay        self.assertTrue(o.equal(d))
3515808f684SSatish Balay
352*ee6c7c31SStefano Zampini    def testMatMat(self):
353*ee6c7c31SStefano Zampini        R = PETSc.Random().create(self.COMM)
354*ee6c7c31SStefano Zampini        R.setFromOptions()
355*ee6c7c31SStefano Zampini        A = PETSc.Mat().create(self.COMM)
356*ee6c7c31SStefano Zampini        A.setSizes(self.A.getSizes())
357*ee6c7c31SStefano Zampini        A.setType(PETSc.Mat.Type.AIJ)
358*ee6c7c31SStefano Zampini        A.setUp()
359*ee6c7c31SStefano Zampini        A.setRandom(R)
360*ee6c7c31SStefano Zampini        B = PETSc.Mat().create(self.COMM)
361*ee6c7c31SStefano Zampini        B.setSizes(self.A.getSizes())
362*ee6c7c31SStefano Zampini        B.setType(PETSc.Mat.Type.AIJ)
363*ee6c7c31SStefano Zampini        B.setUp()
364*ee6c7c31SStefano Zampini        B.setRandom(R)
365*ee6c7c31SStefano Zampini        I = PETSc.Mat().create(self.COMM)
366*ee6c7c31SStefano Zampini        I.setSizes(self.A.getSizes())
367*ee6c7c31SStefano Zampini        I.setType(PETSc.Mat.Type.AIJ)
368*ee6c7c31SStefano Zampini        I.setUp()
369*ee6c7c31SStefano Zampini        I.assemble()
370*ee6c7c31SStefano Zampini        I.shift(1.)
371*ee6c7c31SStefano Zampini
372*ee6c7c31SStefano Zampini        self.assertTrue(self.A.matMult(A).equal(I.matMult(A)))
373*ee6c7c31SStefano Zampini        self.assertTrue(A.matMult(self.A).equal(A.matMult(I)))
374*ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
375*ee6c7c31SStefano Zampini            self.assertTrue(self.A.matTransposeMult(A).equal(I.matTransposeMult(A)))
376*ee6c7c31SStefano Zampini            self.assertTrue(A.matTransposeMult(self.A).equal(A.matTransposeMult(I)))
377*ee6c7c31SStefano Zampini        self.assertTrue(self.A.transposeMatMult(A).equal(I.transposeMatMult(A)))
378*ee6c7c31SStefano Zampini        self.assertTrue(A.transposeMatMult(self.A).equal(A.transposeMatMult(I)))
379*ee6c7c31SStefano Zampini        self.assertAlmostEqual((self.A.ptap(A) - I.ptap(A)).norm(), 0.0, places=5)
380*ee6c7c31SStefano Zampini        self.assertAlmostEqual((A.ptap(self.A) - A.ptap(I)).norm(), 0.0, places=5)
381*ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
382*ee6c7c31SStefano Zampini            self.assertAlmostEqual((self.A.rart(A) - I.rart(A)).norm(), 0.0, places=5)
383*ee6c7c31SStefano Zampini            self.assertAlmostEqual((A.rart(self.A) - A.rart(I)).norm(), 0.0, places=5)
384*ee6c7c31SStefano Zampini        self.assertAlmostEqual((self.A.matMatMult(A,B)-I.matMatMult(A,B)).norm(), 0.0, places=5)
385*ee6c7c31SStefano Zampini        self.assertAlmostEqual((A.matMatMult(self.A,B)-A.matMatMult(I,B)).norm(), 0.0, places=5)
386*ee6c7c31SStefano Zampini        self.assertAlmostEqual((A.matMatMult(B,self.A)-A.matMatMult(B,I)).norm(), 0.0, places=5)
387*ee6c7c31SStefano Zampini
38853022affSStefano Zampini    def testH2Opus(self):
38953022affSStefano Zampini        if not PETSc.Sys.hasExternalPackage("h2opus"):
39053022affSStefano Zampini            return
391*ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() > 1:
392*ee6c7c31SStefano Zampini            return
39353022affSStefano Zampini        h = PETSc.Mat()
39453022affSStefano Zampini
39553022affSStefano Zampini        # need transpose operation for norm estimation
39653022affSStefano Zampini        AA = self.A.getPythonContext()
39753022affSStefano Zampini        AA.multTranspose = AA.mult
39853022affSStefano Zampini
39953022affSStefano Zampini        # without coordinates
40053022affSStefano Zampini        h.createH2OpusFromMat(self.A,leafsize=2)
40153022affSStefano Zampini        h.assemble()
40253022affSStefano Zampini        h.destroy()
40353022affSStefano Zampini
40453022affSStefano Zampini        # with coordinates
405e0aaf7daSStefano Zampini        coords = numpy.linspace((1,2,3),(10,20,30),self.A.getSize()[0],dtype=PETSc.RealType)
40653022affSStefano Zampini        h.createH2OpusFromMat(self.A,coords,leafsize=2)
40753022affSStefano Zampini        h.assemble()
40853022affSStefano Zampini        h.destroy()
40953022affSStefano Zampini
41053022affSStefano Zampini        del AA.multTranspose
4115808f684SSatish Balay
4125808f684SSatish Balayclass TestDiagonal(TestMatrix):
4135808f684SSatish Balay
4145808f684SSatish Balay    PYCLS = 'Diagonal'
4155808f684SSatish Balay
4165808f684SSatish Balay    def setUp(self):
4175808f684SSatish Balay        super(TestDiagonal, self).setUp()
4185808f684SSatish Balay        D = self.A.createVecLeft()
4195808f684SSatish Balay        s, e = D.getOwnershipRange()
4205808f684SSatish Balay        for i in range(s, e):
4215808f684SSatish Balay            D[i] = i+1
4225808f684SSatish Balay        D.assemble()
4235808f684SSatish Balay        self.A.setDiagonal(D)
4245808f684SSatish Balay
4255808f684SSatish Balay
4265808f684SSatish Balay    def testZeroEntries(self):
4275808f684SSatish Balay        self.A.zeroEntries()
4285808f684SSatish Balay        D = self._getCtx().D
4295808f684SSatish Balay        self.assertEqual(D.norm(), 0)
4305808f684SSatish Balay
4315808f684SSatish Balay    def testMult(self):
4325808f684SSatish Balay        x, y = self.A.createVecs()
4335808f684SSatish Balay        x.set(1)
4345808f684SSatish Balay        self.A.mult(x,y)
4355808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
4365808f684SSatish Balay
4375808f684SSatish Balay    def testMultTransposeSymmKnown(self):
4385808f684SSatish Balay        x, y = self.A.createVecs()
4395808f684SSatish Balay        x.set(1)
4405808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
4415808f684SSatish Balay        self.A.multTranspose(x,y)
4425808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
4435808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
4445808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
4455808f684SSatish Balay        self.assertRaises(Exception, f)
4465808f684SSatish Balay
4475808f684SSatish Balay    def testMultTransposeNewMeth(self):
4485808f684SSatish Balay        x, y = self.A.createVecs()
4495808f684SSatish Balay        x.set(1)
4505808f684SSatish Balay        AA = self.A.getPythonContext()
4515808f684SSatish Balay        AA.multTranspose = AA.mult
4525808f684SSatish Balay        self.A.multTranspose(x,y)
4535808f684SSatish Balay        del AA.multTranspose
4545808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
4555808f684SSatish Balay
4565808f684SSatish Balay    def testGetDiagonal(self):
4575808f684SSatish Balay        d = self.A.createVecLeft()
4585808f684SSatish Balay        self.A.getDiagonal(d)
4595808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
4605808f684SSatish Balay
4615808f684SSatish Balay    def testSetDiagonal(self):
4625808f684SSatish Balay        d = self.A.createVecLeft()
4635808f684SSatish Balay        d.setRandom()
4645808f684SSatish Balay        self.A.setDiagonal(d)
4655808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
4665808f684SSatish Balay
4675808f684SSatish Balay    def testDiagonalScale(self):
4685808f684SSatish Balay        x, y = self.A.createVecs()
4695808f684SSatish Balay        x.set(2)
4705808f684SSatish Balay        y.set(3)
4715808f684SSatish Balay        old = self._getCtx().D.copy()
4725808f684SSatish Balay        self.A.diagonalScale(x, y)
4735808f684SSatish Balay        D = self._getCtx().D
4745808f684SSatish Balay        self.assertTrue(D.equal(old*6))
4755808f684SSatish Balay
4765808f684SSatish Balay    def testCreateTranspose(self):
4775808f684SSatish Balay        A = self.A
4785808f684SSatish Balay        A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
4795808f684SSatish Balay        AT = PETSc.Mat().createTranspose(A)
4805808f684SSatish Balay        x, y = A.createVecs()
4815808f684SSatish Balay        xt, yt = AT.createVecs()
4825808f684SSatish Balay        #
4835808f684SSatish Balay        y.setRandom()
4845808f684SSatish Balay        A.multTranspose(y, x)
4855808f684SSatish Balay        y.copy(xt)
4865808f684SSatish Balay        AT.mult(xt, yt)
4875808f684SSatish Balay        self.assertTrue(yt.equal(x))
4885808f684SSatish Balay        #
4895808f684SSatish Balay        x.setRandom()
4905808f684SSatish Balay        A.mult(x, y)
4915808f684SSatish Balay        x.copy(yt)
4925808f684SSatish Balay        AT.multTranspose(yt, xt)
4935808f684SSatish Balay        self.assertTrue(xt.equal(y))
4945808f684SSatish Balay        del A
4955808f684SSatish Balay
4968c2316a8SJeremy Tillay
4975808f684SSatish Balay# --------------------------------------------------------------------
4985808f684SSatish Balay
4995808f684SSatish Balayif __name__ == '__main__':
5005808f684SSatish Balay    unittest.main()
501