xref: /petsc/src/binding/petsc4py/test/test_mat_py.py (revision 5808f68492579297331054bd8ff190489c3b8c20)
1*5808f684SSatish Balayfrom petsc4py import PETSc
2*5808f684SSatish Balayimport unittest
3*5808f684SSatish Balayfrom sys import getrefcount
4*5808f684SSatish Balay
5*5808f684SSatish Balay# --------------------------------------------------------------------
6*5808f684SSatish Balay
7*5808f684SSatish Balayclass Matrix(object):
8*5808f684SSatish Balay
9*5808f684SSatish Balay    def __init__(self):
10*5808f684SSatish Balay        pass
11*5808f684SSatish Balay
12*5808f684SSatish Balay    def create(self, mat):
13*5808f684SSatish Balay        pass
14*5808f684SSatish Balay
15*5808f684SSatish Balay    def destroy(self, mat):
16*5808f684SSatish Balay        pass
17*5808f684SSatish Balay
18*5808f684SSatish Balayclass Identity(Matrix):
19*5808f684SSatish Balay
20*5808f684SSatish Balay    def mult(self, mat, x, y):
21*5808f684SSatish Balay        x.copy(y)
22*5808f684SSatish Balay
23*5808f684SSatish Balay    def getDiagonal(self, mat, vd):
24*5808f684SSatish Balay        vd.set(1)
25*5808f684SSatish Balay
26*5808f684SSatish Balayclass Diagonal(Matrix):
27*5808f684SSatish Balay
28*5808f684SSatish Balay    def create(self, mat):
29*5808f684SSatish Balay        super(Diagonal,self).create(mat)
30*5808f684SSatish Balay        mat.setUp()
31*5808f684SSatish Balay        self.D = mat.createVecLeft()
32*5808f684SSatish Balay
33*5808f684SSatish Balay    def destroy(self, mat):
34*5808f684SSatish Balay        self.D.destroy()
35*5808f684SSatish Balay        super(Diagonal,self).destroy(mat)
36*5808f684SSatish Balay
37*5808f684SSatish Balay    def scale(self, mat, a):
38*5808f684SSatish Balay        self.D.scale(a)
39*5808f684SSatish Balay
40*5808f684SSatish Balay    def shift(self, mat, a):
41*5808f684SSatish Balay        self.D.shift(a)
42*5808f684SSatish Balay
43*5808f684SSatish Balay    def zeroEntries(self, mat):
44*5808f684SSatish Balay        self.D.zeroEntries()
45*5808f684SSatish Balay
46*5808f684SSatish Balay    def mult(self, mat, x, y):
47*5808f684SSatish Balay        y.pointwiseMult(x, self.D)
48*5808f684SSatish Balay
49*5808f684SSatish Balay    def getDiagonal(self, mat, vd):
50*5808f684SSatish Balay        self.D.copy(vd)
51*5808f684SSatish Balay
52*5808f684SSatish Balay    def setDiagonal(self, mat, vd, im):
53*5808f684SSatish Balay        if isinstance (im, bool):
54*5808f684SSatish Balay            addv = im
55*5808f684SSatish Balay            if addv:
56*5808f684SSatish Balay                self.D.axpy(1, vd)
57*5808f684SSatish Balay            else:
58*5808f684SSatish Balay                vd.copy(self.D)
59*5808f684SSatish Balay        elif im == PETSc.InsertMode.INSERT_VALUES:
60*5808f684SSatish Balay            vd.copy(self.D)
61*5808f684SSatish Balay        elif im == PETSc.InsertMode.ADD_VALUES:
62*5808f684SSatish Balay            self.D.axpy(1, vd)
63*5808f684SSatish Balay        else:
64*5808f684SSatish Balay            raise ValueError('wrong InsertMode %d'% im)
65*5808f684SSatish Balay
66*5808f684SSatish Balay    def diagonalScale(self, mat, vl, vr):
67*5808f684SSatish Balay        if vl: self.D.pointwiseMult(self.D, vl)
68*5808f684SSatish Balay        if vr: self.D.pointwiseMult(self.D, vr)
69*5808f684SSatish Balay
70*5808f684SSatish Balay# --------------------------------------------------------------------
71*5808f684SSatish Balay
72*5808f684SSatish Balayclass TestMatrix(unittest.TestCase):
73*5808f684SSatish Balay
74*5808f684SSatish Balay    COMM = PETSc.COMM_WORLD
75*5808f684SSatish Balay    PYMOD = __name__
76*5808f684SSatish Balay    PYCLS = 'Matrix'
77*5808f684SSatish Balay
78*5808f684SSatish Balay    def _getCtx(self):
79*5808f684SSatish Balay        return self.A.getPythonContext()
80*5808f684SSatish Balay
81*5808f684SSatish Balay    def setUp(self):
82*5808f684SSatish Balay        N = self.N = 10
83*5808f684SSatish Balay        self.A = PETSc.Mat()
84*5808f684SSatish Balay        if 0: # command line way
85*5808f684SSatish Balay            self.A.create(self.COMM)
86*5808f684SSatish Balay            self.A.setSizes([N,N])
87*5808f684SSatish Balay            self.A.setType('python')
88*5808f684SSatish Balay            OptDB = PETSc.Options(self.A)
89*5808f684SSatish Balay            OptDB['mat_python_type'] = '%s.%s' % (self.PYMOD,self.PYCLS)
90*5808f684SSatish Balay            self.A.setFromOptions()
91*5808f684SSatish Balay            self.A.setUp()
92*5808f684SSatish Balay            del OptDB['mat_python_type']
93*5808f684SSatish Balay            self.assertTrue(self._getCtx() is not None)
94*5808f684SSatish Balay        else: # python way
95*5808f684SSatish Balay            context = globals()[self.PYCLS]()
96*5808f684SSatish Balay            self.A.createPython([N,N], context, comm=self.COMM)
97*5808f684SSatish Balay            self.A.setUp()
98*5808f684SSatish Balay            self.assertTrue(self._getCtx() is context)
99*5808f684SSatish Balay            self.assertEqual(getrefcount(context), 3)
100*5808f684SSatish Balay            del context
101*5808f684SSatish Balay            self.assertEqual(getrefcount(self._getCtx()), 2)
102*5808f684SSatish Balay
103*5808f684SSatish Balay    def tearDown(self):
104*5808f684SSatish Balay        ctx = self.A.getPythonContext()
105*5808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
106*5808f684SSatish Balay        self.A.destroy() # XXX
107*5808f684SSatish Balay        self.A = None
108*5808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 2)
109*5808f684SSatish Balay        #import gc,pprint; pprint.pprint(gc.get_referrers(ctx))
110*5808f684SSatish Balay
111*5808f684SSatish Balay    def testBasic(self):
112*5808f684SSatish Balay        ctx = self.A.getPythonContext()
113*5808f684SSatish Balay        self.assertTrue(self._getCtx() is ctx)
114*5808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
115*5808f684SSatish Balay
116*5808f684SSatish Balay    def testZeroEntries(self):
117*5808f684SSatish Balay        f = lambda : self.A.zeroEntries()
118*5808f684SSatish Balay        self.assertRaises(Exception, f)
119*5808f684SSatish Balay
120*5808f684SSatish Balay    def testMult(self):
121*5808f684SSatish Balay        x, y = self.A.createVecs()
122*5808f684SSatish Balay        f = lambda : self.A.mult(x, y)
123*5808f684SSatish Balay        self.assertRaises(Exception, f)
124*5808f684SSatish Balay
125*5808f684SSatish Balay    def testMultTranspose(self):
126*5808f684SSatish Balay        x, y = self.A.createVecs()
127*5808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
128*5808f684SSatish Balay        self.assertRaises(Exception, f)
129*5808f684SSatish Balay
130*5808f684SSatish Balay    def testGetDiagonal(self):
131*5808f684SSatish Balay        d = self.A.createVecLeft()
132*5808f684SSatish Balay        f = lambda : self.A.getDiagonal(d)
133*5808f684SSatish Balay        self.assertRaises(Exception, f)
134*5808f684SSatish Balay
135*5808f684SSatish Balay    def testSetDiagonal(self):
136*5808f684SSatish Balay        d = self.A.createVecLeft()
137*5808f684SSatish Balay        f = lambda : self.A.setDiagonal(d)
138*5808f684SSatish Balay        self.assertRaises(Exception, f)
139*5808f684SSatish Balay
140*5808f684SSatish Balay    def testDiagonalScale(self):
141*5808f684SSatish Balay        x, y = self.A.createVecs()
142*5808f684SSatish Balay        f = lambda : self.A.diagonalScale(x, y)
143*5808f684SSatish Balay        self.assertRaises(Exception, f)
144*5808f684SSatish Balay
145*5808f684SSatish Balayclass TestIdentity(TestMatrix):
146*5808f684SSatish Balay
147*5808f684SSatish Balay    PYCLS = 'Identity'
148*5808f684SSatish Balay
149*5808f684SSatish Balay    def testMult(self):
150*5808f684SSatish Balay        x, y = self.A.createVecs()
151*5808f684SSatish Balay        x.setRandom()
152*5808f684SSatish Balay        self.A.mult(x,y)
153*5808f684SSatish Balay        self.assertTrue(y.equal(x))
154*5808f684SSatish Balay
155*5808f684SSatish Balay    def testMultTransposeSymmKnown(self):
156*5808f684SSatish Balay        x, y = self.A.createVecs()
157*5808f684SSatish Balay        x.setRandom()
158*5808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
159*5808f684SSatish Balay        self.A.multTranspose(x,y)
160*5808f684SSatish Balay        self.assertTrue(y.equal(x))
161*5808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
162*5808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
163*5808f684SSatish Balay        self.assertRaises(Exception, f)
164*5808f684SSatish Balay
165*5808f684SSatish Balay    def testMultTransposeNewMeth(self):
166*5808f684SSatish Balay        x, y = self.A.createVecs()
167*5808f684SSatish Balay        x.setRandom()
168*5808f684SSatish Balay        AA = self.A.getPythonContext()
169*5808f684SSatish Balay        AA.multTranspose = AA.mult
170*5808f684SSatish Balay        self.A.multTranspose(x,y)
171*5808f684SSatish Balay        del AA.multTranspose
172*5808f684SSatish Balay        self.assertTrue(y.equal(x))
173*5808f684SSatish Balay
174*5808f684SSatish Balay    def testGetDiagonal(self):
175*5808f684SSatish Balay        d = self.A.createVecLeft()
176*5808f684SSatish Balay        o = d.duplicate()
177*5808f684SSatish Balay        o.set(1)
178*5808f684SSatish Balay        self.A.getDiagonal(d)
179*5808f684SSatish Balay        self.assertTrue(o.equal(d))
180*5808f684SSatish Balay
181*5808f684SSatish Balay
182*5808f684SSatish Balayclass TestDiagonal(TestMatrix):
183*5808f684SSatish Balay
184*5808f684SSatish Balay    PYCLS = 'Diagonal'
185*5808f684SSatish Balay
186*5808f684SSatish Balay    def setUp(self):
187*5808f684SSatish Balay        super(TestDiagonal, self).setUp()
188*5808f684SSatish Balay        D = self.A.createVecLeft()
189*5808f684SSatish Balay        s, e = D.getOwnershipRange()
190*5808f684SSatish Balay        for i in range(s, e):
191*5808f684SSatish Balay            D[i] = i+1
192*5808f684SSatish Balay        D.assemble()
193*5808f684SSatish Balay        self.A.setDiagonal(D)
194*5808f684SSatish Balay
195*5808f684SSatish Balay
196*5808f684SSatish Balay    def testZeroEntries(self):
197*5808f684SSatish Balay        self.A.zeroEntries()
198*5808f684SSatish Balay        D = self._getCtx().D
199*5808f684SSatish Balay        self.assertEqual(D.norm(), 0)
200*5808f684SSatish Balay
201*5808f684SSatish Balay    def testMult(self):
202*5808f684SSatish Balay        x, y = self.A.createVecs()
203*5808f684SSatish Balay        x.set(1)
204*5808f684SSatish Balay        self.A.mult(x,y)
205*5808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
206*5808f684SSatish Balay
207*5808f684SSatish Balay    def testMultTransposeSymmKnown(self):
208*5808f684SSatish Balay        x, y = self.A.createVecs()
209*5808f684SSatish Balay        x.set(1)
210*5808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
211*5808f684SSatish Balay        self.A.multTranspose(x,y)
212*5808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
213*5808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
214*5808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
215*5808f684SSatish Balay        self.assertRaises(Exception, f)
216*5808f684SSatish Balay
217*5808f684SSatish Balay    def testMultTransposeNewMeth(self):
218*5808f684SSatish Balay        x, y = self.A.createVecs()
219*5808f684SSatish Balay        x.set(1)
220*5808f684SSatish Balay        AA = self.A.getPythonContext()
221*5808f684SSatish Balay        AA.multTranspose = AA.mult
222*5808f684SSatish Balay        self.A.multTranspose(x,y)
223*5808f684SSatish Balay        del AA.multTranspose
224*5808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
225*5808f684SSatish Balay
226*5808f684SSatish Balay    def testGetDiagonal(self):
227*5808f684SSatish Balay        d = self.A.createVecLeft()
228*5808f684SSatish Balay        self.A.getDiagonal(d)
229*5808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
230*5808f684SSatish Balay
231*5808f684SSatish Balay    def testSetDiagonal(self):
232*5808f684SSatish Balay        d = self.A.createVecLeft()
233*5808f684SSatish Balay        d.setRandom()
234*5808f684SSatish Balay        self.A.setDiagonal(d)
235*5808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
236*5808f684SSatish Balay
237*5808f684SSatish Balay    def testDiagonalScale(self):
238*5808f684SSatish Balay        x, y = self.A.createVecs()
239*5808f684SSatish Balay        x.set(2)
240*5808f684SSatish Balay        y.set(3)
241*5808f684SSatish Balay        old = self._getCtx().D.copy()
242*5808f684SSatish Balay        self.A.diagonalScale(x, y)
243*5808f684SSatish Balay        D = self._getCtx().D
244*5808f684SSatish Balay        self.assertTrue(D.equal(old*6))
245*5808f684SSatish Balay
246*5808f684SSatish Balay    def testCreateTranspose(self):
247*5808f684SSatish Balay        A = self.A
248*5808f684SSatish Balay        A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
249*5808f684SSatish Balay        AT = PETSc.Mat().createTranspose(A)
250*5808f684SSatish Balay        x, y = A.createVecs()
251*5808f684SSatish Balay        xt, yt = AT.createVecs()
252*5808f684SSatish Balay        #
253*5808f684SSatish Balay        y.setRandom()
254*5808f684SSatish Balay        A.multTranspose(y, x)
255*5808f684SSatish Balay        y.copy(xt)
256*5808f684SSatish Balay        AT.mult(xt, yt)
257*5808f684SSatish Balay        self.assertTrue(yt.equal(x))
258*5808f684SSatish Balay        #
259*5808f684SSatish Balay        x.setRandom()
260*5808f684SSatish Balay        A.mult(x, y)
261*5808f684SSatish Balay        x.copy(yt)
262*5808f684SSatish Balay        AT.multTranspose(yt, xt)
263*5808f684SSatish Balay        self.assertTrue(xt.equal(y))
264*5808f684SSatish Balay        del A
265*5808f684SSatish Balay
266*5808f684SSatish Balay# --------------------------------------------------------------------
267*5808f684SSatish Balay
268*5808f684SSatish Balayif __name__ == '__main__':
269*5808f684SSatish Balay    unittest.main()
270