xref: /petsc/src/binding/petsc4py/test/test_mat_py.py (revision 300d917ba61fc16924780ffe9b581fa5705848da)
15808f684SSatish Balayfrom petsc4py import PETSc
253022affSStefano Zampiniimport unittest, numpy
35808f684SSatish Balayfrom sys import getrefcount
45808f684SSatish Balay# --------------------------------------------------------------------
55808f684SSatish Balay
65808f684SSatish Balayclass Matrix(object):
75808f684SSatish Balay
85808f684SSatish Balay    def __init__(self):
95808f684SSatish Balay        pass
105808f684SSatish Balay
115808f684SSatish Balay    def create(self, mat):
125808f684SSatish Balay        pass
135808f684SSatish Balay
145808f684SSatish Balay    def destroy(self, mat):
155808f684SSatish Balay        pass
165808f684SSatish Balay
1722fceea1SStefano Zampiniclass ScaledIdentity(Matrix):
1822fceea1SStefano Zampini
1922fceea1SStefano Zampini    s = 2.0
2022fceea1SStefano Zampini
2122fceea1SStefano Zampini    def scale(self, mat, s):
2222fceea1SStefano Zampini        self.s *= s
2322fceea1SStefano Zampini
2422fceea1SStefano Zampini    def shift(self, mat, s):
2522fceea1SStefano Zampini        self.s += s
265808f684SSatish Balay
275808f684SSatish Balay    def mult(self, mat, x, y):
285808f684SSatish Balay        x.copy(y)
2922fceea1SStefano Zampini        y.scale(self.s)
305808f684SSatish Balay
31e124b1b1SStefano Zampini    def duplicate(self, mat, op):
32e124b1b1SStefano Zampini        dmat = PETSc.Mat()
33e124b1b1SStefano Zampini        dctx = ScaledIdentity()
34e124b1b1SStefano Zampini        dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm())
35e124b1b1SStefano Zampini        if op == PETSc.Mat.DuplicateOption.COPY_VALUES:
36e124b1b1SStefano Zampini          dctx.s = self.s
37e124b1b1SStefano Zampini          dmat.setUp()
38e124b1b1SStefano Zampini        return dmat
39e124b1b1SStefano Zampini
405808f684SSatish Balay    def getDiagonal(self, mat, vd):
4122fceea1SStefano Zampini        vd.set(self.s)
425808f684SSatish Balay
43ee6c7c31SStefano Zampini    def productSetFromOptions(self, mat, producttype, A, B, C):
44ee6c7c31SStefano Zampini        return True
45ee6c7c31SStefano Zampini
46ee6c7c31SStefano Zampini    def productSymbolic(self, mat, product, producttype, A, B, C):
47ee6c7c31SStefano Zampini        if producttype == 'AB':
48ee6c7c31SStefano Zampini            if mat is A: # product = identity * B
49ee6c7c31SStefano Zampini                product.setType(B.getType())
50ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
51ee6c7c31SStefano Zampini                product.setUp()
52ee6c7c31SStefano Zampini                product.assemble()
53ee6c7c31SStefano Zampini                B.copy(product)
54ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity
55ee6c7c31SStefano Zampini                product.setType(A.getType())
56ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
57ee6c7c31SStefano Zampini                product.setUp()
58ee6c7c31SStefano Zampini                product.assemble()
59ee6c7c31SStefano Zampini                A.copy(product)
60ee6c7c31SStefano Zampini            else:
61ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
62ee6c7c31SStefano Zampini        elif producttype == 'AtB':
63ee6c7c31SStefano Zampini            if mat is A: # product = identity^T * B
64ee6c7c31SStefano Zampini                product.setType(B.getType())
65ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
66ee6c7c31SStefano Zampini                product.setUp()
67ee6c7c31SStefano Zampini                product.assemble()
68ee6c7c31SStefano Zampini                B.copy(product)
69ee6c7c31SStefano Zampini            elif mat is B: # product = A^T * identity
70ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
71ee6c7c31SStefano Zampini                A.transpose(tmp)
72ee6c7c31SStefano Zampini                product.setType(tmp.getType())
73ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
74ee6c7c31SStefano Zampini                product.setUp()
75ee6c7c31SStefano Zampini                product.assemble()
76ee6c7c31SStefano Zampini                tmp.copy(product)
77ee6c7c31SStefano Zampini            else:
78ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
79ee6c7c31SStefano Zampini        elif producttype == 'ABt':
80ee6c7c31SStefano Zampini            if mat is A: # product = identity * B^T
81ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
82ee6c7c31SStefano Zampini                B.transpose(tmp)
83ee6c7c31SStefano Zampini                product.setType(tmp.getType())
84ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
85ee6c7c31SStefano Zampini                product.setUp()
86ee6c7c31SStefano Zampini                product.assemble()
87ee6c7c31SStefano Zampini                tmp.copy(product)
88ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity^T
89ee6c7c31SStefano Zampini                product.setType(A.getType())
90ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
91ee6c7c31SStefano Zampini                product.setUp()
92ee6c7c31SStefano Zampini                product.assemble()
93ee6c7c31SStefano Zampini                A.copy(product)
94ee6c7c31SStefano Zampini            else:
95ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
96ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
97ee6c7c31SStefano Zampini            if mat is A: # product = P^T * identity * P
98ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
99ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
100ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
101ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
102ee6c7c31SStefano Zampini                product.setUp()
103ee6c7c31SStefano Zampini                product.assemble()
104ee6c7c31SStefano Zampini                self.tmp.copy(product)
105ee6c7c31SStefano Zampini            elif mat is B: # product = identity^T * A * identity
106ee6c7c31SStefano Zampini                product.setType(A.getType())
107ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
108ee6c7c31SStefano Zampini                product.setUp()
109ee6c7c31SStefano Zampini                product.assemble()
110ee6c7c31SStefano Zampini                A.copy(product)
111ee6c7c31SStefano Zampini            else:
112ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
113ee6c7c31SStefano Zampini        elif producttype == 'RARt':
114ee6c7c31SStefano Zampini            if mat is A: # product = R * identity * R^t
115ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
116ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
117ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
118ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
119ee6c7c31SStefano Zampini                product.setUp()
120ee6c7c31SStefano Zampini                product.assemble()
121ee6c7c31SStefano Zampini                self.tmp.copy(product)
122ee6c7c31SStefano Zampini            elif mat is B: # product = identity * A * identity^T
123ee6c7c31SStefano Zampini                product.setType(A.getType())
124ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
125ee6c7c31SStefano Zampini                product.setUp()
126ee6c7c31SStefano Zampini                product.assemble()
127ee6c7c31SStefano Zampini                A.copy(product)
128ee6c7c31SStefano Zampini            else:
129ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
130ee6c7c31SStefano Zampini        elif producttype == 'ABC':
131ee6c7c31SStefano Zampini            if mat is A: # product = identity * B * C
132ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
133ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
134ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
135ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
136ee6c7c31SStefano Zampini                product.setUp()
137ee6c7c31SStefano Zampini                product.assemble()
138ee6c7c31SStefano Zampini                self.tmp.copy(product)
139ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity * C
140ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
141ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
142ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
143ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
144ee6c7c31SStefano Zampini                product.setUp()
145ee6c7c31SStefano Zampini                product.assemble()
146ee6c7c31SStefano Zampini                self.tmp.copy(product)
147ee6c7c31SStefano Zampini            elif mat is C: # product = A * B * identity
148ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
149ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
150ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
151ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
152ee6c7c31SStefano Zampini                product.setUp()
153ee6c7c31SStefano Zampini                product.assemble()
154ee6c7c31SStefano Zampini                self.tmp.copy(product)
155ee6c7c31SStefano Zampini            else:
156ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
157ee6c7c31SStefano Zampini        else:
158ee6c7c31SStefano Zampini            raise RuntimeError('Product {} not implemented'.format(producttype))
159ee6c7c31SStefano Zampini        product.zeroEntries()
160ee6c7c31SStefano Zampini
161ee6c7c31SStefano Zampini    def productNumeric(self, mat, product, producttype, A, B, C):
162ee6c7c31SStefano Zampini        if producttype == 'AB':
163ee6c7c31SStefano Zampini            if mat is A: # product = identity * B
164ee6c7c31SStefano Zampini                B.copy(product, structure=True)
165ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity
166ee6c7c31SStefano Zampini                A.copy(product, structure=True)
167ee6c7c31SStefano Zampini            else:
168ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
16922fceea1SStefano Zampini            product.scale(self.s)
170ee6c7c31SStefano Zampini        elif producttype == 'AtB':
171ee6c7c31SStefano Zampini            if mat is A: # product = identity^T * B
172ee6c7c31SStefano Zampini                B.copy(product, structure=True)
173ee6c7c31SStefano Zampini            elif mat is B: # product = A^T * identity
174ee6c7c31SStefano Zampini                A.transpose(product)
175ee6c7c31SStefano Zampini            else:
176ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
17722fceea1SStefano Zampini            product.scale(self.s)
178ee6c7c31SStefano Zampini        elif producttype == 'ABt':
179ee6c7c31SStefano Zampini            if mat is A: # product = identity * B^T
180ee6c7c31SStefano Zampini                B.transpose(product)
181ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity^T
182ee6c7c31SStefano Zampini                A.copy(product, structure=True)
183ee6c7c31SStefano Zampini            else:
184ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
18522fceea1SStefano Zampini            product.scale(self.s)
186ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
187ee6c7c31SStefano Zampini            if mat is A: # product = P^T * identity * P
188ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
189ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
19022fceea1SStefano Zampini                product.scale(self.s)
191ee6c7c31SStefano Zampini            elif mat is B: # product = identity^T * A * identity
192ee6c7c31SStefano Zampini                A.copy(product, structure=True)
19322fceea1SStefano Zampini                product.scale(self.s**2)
194ee6c7c31SStefano Zampini            else:
195ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
196ee6c7c31SStefano Zampini        elif producttype == 'RARt':
197ee6c7c31SStefano Zampini            if mat is A: # product = R * identity * R^t
198ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
199ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
20022fceea1SStefano Zampini                product.scale(self.s)
201ee6c7c31SStefano Zampini            elif mat is B: # product = identity * A * identity^T
202ee6c7c31SStefano Zampini                A.copy(product, structure=True)
20322fceea1SStefano Zampini                product.scale(self.s**2)
204ee6c7c31SStefano Zampini            else:
205ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
206ee6c7c31SStefano Zampini        elif producttype == 'ABC':
207ee6c7c31SStefano Zampini            if mat is A: # product = identity * B * C
208ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
209ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
210ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity * C
211ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
212ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
213ee6c7c31SStefano Zampini            elif mat is C: # product = A * B * identity
214ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
215ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
216ee6c7c31SStefano Zampini            else:
217ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
21822fceea1SStefano Zampini            product.scale(self.s)
219ee6c7c31SStefano Zampini        else:
220ee6c7c31SStefano Zampini            raise RuntimeError('Product {} not implemented'.format(producttype))
221ee6c7c31SStefano Zampini
2225808f684SSatish Balayclass Diagonal(Matrix):
2235808f684SSatish Balay
2245808f684SSatish Balay    def create(self, mat):
2255808f684SSatish Balay        super(Diagonal,self).create(mat)
2265808f684SSatish Balay        mat.setUp()
2275808f684SSatish Balay        self.D = mat.createVecLeft()
2285808f684SSatish Balay
2295808f684SSatish Balay    def destroy(self, mat):
2305808f684SSatish Balay        self.D.destroy()
2315808f684SSatish Balay        super(Diagonal,self).destroy(mat)
2325808f684SSatish Balay
2335808f684SSatish Balay    def scale(self, mat, a):
2345808f684SSatish Balay        self.D.scale(a)
2355808f684SSatish Balay
2365808f684SSatish Balay    def shift(self, mat, a):
2375808f684SSatish Balay        self.D.shift(a)
2385808f684SSatish Balay
2395808f684SSatish Balay    def zeroEntries(self, mat):
2405808f684SSatish Balay        self.D.zeroEntries()
2415808f684SSatish Balay
2425808f684SSatish Balay    def mult(self, mat, x, y):
2435808f684SSatish Balay        y.pointwiseMult(x, self.D)
2445808f684SSatish Balay
245e124b1b1SStefano Zampini    def duplicate(self, mat, op):
246e124b1b1SStefano Zampini        dmat = PETSc.Mat()
247e124b1b1SStefano Zampini        dctx = Diagonal()
248e124b1b1SStefano Zampini        dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm())
249e124b1b1SStefano Zampini        dctx.D = self.D.duplicate()
250e124b1b1SStefano Zampini        if op == PETSc.Mat.DuplicateOption.COPY_VALUES:
251e124b1b1SStefano Zampini          self.D.copy(dctx.D)
252e124b1b1SStefano Zampini          dmat.setUp()
253e124b1b1SStefano Zampini        return dmat
254e124b1b1SStefano Zampini
2555808f684SSatish Balay    def getDiagonal(self, mat, vd):
2565808f684SSatish Balay        self.D.copy(vd)
2575808f684SSatish Balay
2585808f684SSatish Balay    def setDiagonal(self, mat, vd, im):
2595808f684SSatish Balay        if isinstance (im, bool):
2605808f684SSatish Balay            addv = im
2615808f684SSatish Balay            if addv:
2625808f684SSatish Balay                self.D.axpy(1, vd)
2635808f684SSatish Balay            else:
2645808f684SSatish Balay                vd.copy(self.D)
2655808f684SSatish Balay        elif im == PETSc.InsertMode.INSERT_VALUES:
2665808f684SSatish Balay            vd.copy(self.D)
2675808f684SSatish Balay        elif im == PETSc.InsertMode.ADD_VALUES:
2685808f684SSatish Balay            self.D.axpy(1, vd)
2695808f684SSatish Balay        else:
2705808f684SSatish Balay            raise ValueError('wrong InsertMode %d'% im)
2715808f684SSatish Balay
2725808f684SSatish Balay    def diagonalScale(self, mat, vl, vr):
2735808f684SSatish Balay        if vl: self.D.pointwiseMult(self.D, vl)
2745808f684SSatish Balay        if vr: self.D.pointwiseMult(self.D, vr)
2755808f684SSatish Balay
2765808f684SSatish Balay# --------------------------------------------------------------------
2775808f684SSatish Balay
2785808f684SSatish Balayclass TestMatrix(unittest.TestCase):
2795808f684SSatish Balay
2805808f684SSatish Balay    COMM = PETSc.COMM_WORLD
2815808f684SSatish Balay    PYMOD = __name__
2825808f684SSatish Balay    PYCLS = 'Matrix'
2835808f684SSatish Balay
2845808f684SSatish Balay    def _getCtx(self):
2855808f684SSatish Balay        return self.A.getPythonContext()
2865808f684SSatish Balay
2875808f684SSatish Balay    def setUp(self):
288*300d917bSStefano Zampini        N = self.N = 13
2895808f684SSatish Balay        self.A = PETSc.Mat()
2905808f684SSatish Balay        if 0: # command line way
2915808f684SSatish Balay            self.A.create(self.COMM)
2925808f684SSatish Balay            self.A.setSizes([N,N])
2935808f684SSatish Balay            self.A.setType('python')
2945808f684SSatish Balay            OptDB = PETSc.Options(self.A)
2955808f684SSatish Balay            OptDB['mat_python_type'] = '%s.%s' % (self.PYMOD,self.PYCLS)
2965808f684SSatish Balay            self.A.setFromOptions()
2975808f684SSatish Balay            self.A.setUp()
2985808f684SSatish Balay            del OptDB['mat_python_type']
2995808f684SSatish Balay            self.assertTrue(self._getCtx() is not None)
3005808f684SSatish Balay        else: # python way
3015808f684SSatish Balay            context = globals()[self.PYCLS]()
3025808f684SSatish Balay            self.A.createPython([N,N], context, comm=self.COMM)
3035808f684SSatish Balay            self.A.setUp()
3045808f684SSatish Balay            self.assertTrue(self._getCtx() is context)
3055808f684SSatish Balay            self.assertEqual(getrefcount(context), 3)
3065808f684SSatish Balay            del context
3075808f684SSatish Balay            self.assertEqual(getrefcount(self._getCtx()), 2)
3085808f684SSatish Balay
3095808f684SSatish Balay    def tearDown(self):
3105808f684SSatish Balay        ctx = self.A.getPythonContext()
3115808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
3125808f684SSatish Balay        self.A.destroy() # XXX
3135808f684SSatish Balay        self.A = None
3145808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 2)
3155808f684SSatish Balay        #import gc,pprint; pprint.pprint(gc.get_referrers(ctx))
3165808f684SSatish Balay
3175808f684SSatish Balay    def testBasic(self):
3185808f684SSatish Balay        ctx = self.A.getPythonContext()
3195808f684SSatish Balay        self.assertTrue(self._getCtx() is ctx)
3205808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
3215808f684SSatish Balay
3225808f684SSatish Balay    def testZeroEntries(self):
3235808f684SSatish Balay        f = lambda : self.A.zeroEntries()
3245808f684SSatish Balay        self.assertRaises(Exception, f)
3255808f684SSatish Balay
3265808f684SSatish Balay    def testMult(self):
3275808f684SSatish Balay        x, y = self.A.createVecs()
3285808f684SSatish Balay        f = lambda : self.A.mult(x, y)
3295808f684SSatish Balay        self.assertRaises(Exception, f)
3305808f684SSatish Balay
3315808f684SSatish Balay    def testMultTranspose(self):
3325808f684SSatish Balay        x, y = self.A.createVecs()
3335808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
3345808f684SSatish Balay        self.assertRaises(Exception, f)
3355808f684SSatish Balay
3365808f684SSatish Balay    def testGetDiagonal(self):
3375808f684SSatish Balay        d = self.A.createVecLeft()
3385808f684SSatish Balay        f = lambda : self.A.getDiagonal(d)
3395808f684SSatish Balay        self.assertRaises(Exception, f)
3405808f684SSatish Balay
3415808f684SSatish Balay    def testSetDiagonal(self):
3425808f684SSatish Balay        d = self.A.createVecLeft()
3435808f684SSatish Balay        f = lambda : self.A.setDiagonal(d)
3445808f684SSatish Balay        self.assertRaises(Exception, f)
3455808f684SSatish Balay
3465808f684SSatish Balay    def testDiagonalScale(self):
3475808f684SSatish Balay        x, y = self.A.createVecs()
3485808f684SSatish Balay        f = lambda : self.A.diagonalScale(x, y)
3495808f684SSatish Balay        self.assertRaises(Exception, f)
3505808f684SSatish Balay
351e124b1b1SStefano Zampini    def testDuplicate(self):
352e124b1b1SStefano Zampini        f1 = lambda : self.A.duplicate(x, True)
353e124b1b1SStefano Zampini        f2 = lambda : self.A.duplicate(x, False)
354e124b1b1SStefano Zampini        self.assertRaises(Exception, f1)
355e124b1b1SStefano Zampini        self.assertRaises(Exception, f2)
356e124b1b1SStefano Zampini
3571cebabd4SStefano Zampini    def testSetVecType(self):
3581cebabd4SStefano Zampini        self.A.setVecType('mpi')
3591cebabd4SStefano Zampini        self.assertTrue('mpi' == self.A.getVecType())
3601cebabd4SStefano Zampini
361*300d917bSStefano Zampini    def testH2Opus(self):
362*300d917bSStefano Zampini        if not PETSc.Sys.hasExternalPackage("h2opus"):
363*300d917bSStefano Zampini            return
364*300d917bSStefano Zampini        if self.A.getComm().Get_size() > 1:
365*300d917bSStefano Zampini            return
366*300d917bSStefano Zampini        h = PETSc.Mat()
367*300d917bSStefano Zampini
368*300d917bSStefano Zampini        # need matrix vector and its transpose for norm estimation
369*300d917bSStefano Zampini        AA = self.A.getPythonContext()
370*300d917bSStefano Zampini        if not hasattr(AA,'mult'):
371*300d917bSStefano Zampini            return
372*300d917bSStefano Zampini        AA.multTranspose = AA.mult
373*300d917bSStefano Zampini
374*300d917bSStefano Zampini        # without coordinates
375*300d917bSStefano Zampini        h.createH2OpusFromMat(self.A,leafsize=2)
376*300d917bSStefano Zampini        h.assemble()
377*300d917bSStefano Zampini        h.destroy()
378*300d917bSStefano Zampini
379*300d917bSStefano Zampini        # with coordinates
380*300d917bSStefano Zampini        coords = numpy.linspace((1,2,3),(10,20,30),self.A.getSize()[0],dtype=PETSc.RealType)
381*300d917bSStefano Zampini        h.createH2OpusFromMat(self.A,coords,leafsize=2)
382*300d917bSStefano Zampini        h.assemble()
383*300d917bSStefano Zampini
384*300d917bSStefano Zampini        # test API
385*300d917bSStefano Zampini        h.H2OpusOrthogonalize()
386*300d917bSStefano Zampini        h.H2OpusCompress(1.e-1)
387*300d917bSStefano Zampini
388*300d917bSStefano Zampini        # Low-rank update
389*300d917bSStefano Zampini        U = PETSc.Mat()
390*300d917bSStefano Zampini        U.createDense([h.getSizes()[0],3],comm=h.getComm())
391*300d917bSStefano Zampini        U.setUp()
392*300d917bSStefano Zampini        U.setRandom()
393*300d917bSStefano Zampini
394*300d917bSStefano Zampini        he = PETSc.Mat()
395*300d917bSStefano Zampini        h.convert('dense',he)
396*300d917bSStefano Zampini        he.axpy(1.0, U.matTransposeMult(U))
397*300d917bSStefano Zampini
398*300d917bSStefano Zampini        h.H2OpusLowRankUpdate(U)
399*300d917bSStefano Zampini        self.assertTrue(he.equal(h))
400*300d917bSStefano Zampini
401*300d917bSStefano Zampini
402*300d917bSStefano Zampini        h.destroy()
403*300d917bSStefano Zampini
404*300d917bSStefano Zampini        del AA.multTranspose
405*300d917bSStefano Zampini
406*300d917bSStefano Zampini
40722fceea1SStefano Zampiniclass TestScaledIdentity(TestMatrix):
4085808f684SSatish Balay
40922fceea1SStefano Zampini    PYCLS = 'ScaledIdentity'
4105808f684SSatish Balay
4115808f684SSatish Balay    def testMult(self):
41222fceea1SStefano Zampini        s = self._getCtx().s
4135808f684SSatish Balay        x, y = self.A.createVecs()
4145808f684SSatish Balay        x.setRandom()
4155808f684SSatish Balay        self.A.mult(x,y)
41622fceea1SStefano Zampini        self.assertTrue(y.equal(s*x))
4175808f684SSatish Balay
4185808f684SSatish Balay    def testMultTransposeSymmKnown(self):
41922fceea1SStefano Zampini        s = self._getCtx().s
4205808f684SSatish Balay        x, y = self.A.createVecs()
4215808f684SSatish Balay        x.setRandom()
4225808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
4235808f684SSatish Balay        self.A.multTranspose(x,y)
42422fceea1SStefano Zampini        self.assertTrue(y.equal(s*x))
4255808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
4265808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
4275808f684SSatish Balay        self.assertRaises(Exception, f)
4285808f684SSatish Balay
4295808f684SSatish Balay    def testMultTransposeNewMeth(self):
43022fceea1SStefano Zampini        s = self._getCtx().s
4315808f684SSatish Balay        x, y = self.A.createVecs()
4325808f684SSatish Balay        x.setRandom()
4335808f684SSatish Balay        AA = self.A.getPythonContext()
4345808f684SSatish Balay        AA.multTranspose = AA.mult
4355808f684SSatish Balay        self.A.multTranspose(x,y)
4365808f684SSatish Balay        del AA.multTranspose
43722fceea1SStefano Zampini        self.assertTrue(y.equal(s*x))
4385808f684SSatish Balay
4395808f684SSatish Balay    def testGetDiagonal(self):
44022fceea1SStefano Zampini        s = self._getCtx().s
4415808f684SSatish Balay        d = self.A.createVecLeft()
4425808f684SSatish Balay        o = d.duplicate()
44322fceea1SStefano Zampini        o.set(s)
4445808f684SSatish Balay        self.A.getDiagonal(d)
4455808f684SSatish Balay        self.assertTrue(o.equal(d))
4465808f684SSatish Balay
447e124b1b1SStefano Zampini    def testDuplicate(self):
448e124b1b1SStefano Zampini        B = self.A.duplicate(False)
449e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().s == 2)
450e124b1b1SStefano Zampini        B = self.A.duplicate(True)
451e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().s == self.A.getPythonContext().s)
452e124b1b1SStefano Zampini
453ee6c7c31SStefano Zampini    def testMatMat(self):
45422fceea1SStefano Zampini        s = self._getCtx().s
455ee6c7c31SStefano Zampini        R = PETSc.Random().create(self.COMM)
456ee6c7c31SStefano Zampini        R.setFromOptions()
457ee6c7c31SStefano Zampini        A = PETSc.Mat().create(self.COMM)
458ee6c7c31SStefano Zampini        A.setSizes(self.A.getSizes())
459ee6c7c31SStefano Zampini        A.setType(PETSc.Mat.Type.AIJ)
460ee6c7c31SStefano Zampini        A.setUp()
461ee6c7c31SStefano Zampini        A.setRandom(R)
462ee6c7c31SStefano Zampini        B = PETSc.Mat().create(self.COMM)
463ee6c7c31SStefano Zampini        B.setSizes(self.A.getSizes())
464ee6c7c31SStefano Zampini        B.setType(PETSc.Mat.Type.AIJ)
465ee6c7c31SStefano Zampini        B.setUp()
466ee6c7c31SStefano Zampini        B.setRandom(R)
467ee6c7c31SStefano Zampini        I = PETSc.Mat().create(self.COMM)
468ee6c7c31SStefano Zampini        I.setSizes(self.A.getSizes())
469ee6c7c31SStefano Zampini        I.setType(PETSc.Mat.Type.AIJ)
470ee6c7c31SStefano Zampini        I.setUp()
471ee6c7c31SStefano Zampini        I.assemble()
47222fceea1SStefano Zampini        I.shift(s)
473ee6c7c31SStefano Zampini
474ee6c7c31SStefano Zampini        self.assertTrue(self.A.matMult(A).equal(I.matMult(A)))
475ee6c7c31SStefano Zampini        self.assertTrue(A.matMult(self.A).equal(A.matMult(I)))
476ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
477ee6c7c31SStefano Zampini            self.assertTrue(self.A.matTransposeMult(A).equal(I.matTransposeMult(A)))
478ee6c7c31SStefano Zampini            self.assertTrue(A.matTransposeMult(self.A).equal(A.matTransposeMult(I)))
479ee6c7c31SStefano Zampini        self.assertTrue(self.A.transposeMatMult(A).equal(I.transposeMatMult(A)))
480ee6c7c31SStefano Zampini        self.assertTrue(A.transposeMatMult(self.A).equal(A.transposeMatMult(I)))
481ee6c7c31SStefano Zampini        self.assertAlmostEqual((self.A.ptap(A) - I.ptap(A)).norm(), 0.0, places=5)
482ee6c7c31SStefano Zampini        self.assertAlmostEqual((A.ptap(self.A) - A.ptap(I)).norm(), 0.0, places=5)
483ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
484ee6c7c31SStefano Zampini            self.assertAlmostEqual((self.A.rart(A) - I.rart(A)).norm(), 0.0, places=5)
485ee6c7c31SStefano Zampini            self.assertAlmostEqual((A.rart(self.A) - A.rart(I)).norm(), 0.0, places=5)
486ee6c7c31SStefano Zampini        self.assertAlmostEqual((self.A.matMatMult(A,B)-I.matMatMult(A,B)).norm(), 0.0, places=5)
487ee6c7c31SStefano Zampini        self.assertAlmostEqual((A.matMatMult(self.A,B)-A.matMatMult(I,B)).norm(), 0.0, places=5)
488ee6c7c31SStefano Zampini        self.assertAlmostEqual((A.matMatMult(B,self.A)-A.matMatMult(B,I)).norm(), 0.0, places=5)
489ee6c7c31SStefano Zampini
49022fceea1SStefano Zampini    def testShift(self):
49122fceea1SStefano Zampini        sold = self._getCtx().s
49222fceea1SStefano Zampini        self.A.shift(-0.5)
49322fceea1SStefano Zampini        s = self._getCtx().s
49422fceea1SStefano Zampini        self.assertTrue(s == sold - 0.5)
49522fceea1SStefano Zampini
49622fceea1SStefano Zampini    def testScale(self):
49722fceea1SStefano Zampini        sold = self._getCtx().s
49822fceea1SStefano Zampini        self.A.scale(-0.5)
49922fceea1SStefano Zampini        s = self._getCtx().s
50022fceea1SStefano Zampini        self.assertTrue(s == sold * -0.5)
50122fceea1SStefano Zampini
5025808f684SSatish Balayclass TestDiagonal(TestMatrix):
5035808f684SSatish Balay
5045808f684SSatish Balay    PYCLS = 'Diagonal'
5055808f684SSatish Balay
5065808f684SSatish Balay    def setUp(self):
5075808f684SSatish Balay        super(TestDiagonal, self).setUp()
5085808f684SSatish Balay        D = self.A.createVecLeft()
5095808f684SSatish Balay        s, e = D.getOwnershipRange()
5105808f684SSatish Balay        for i in range(s, e):
5115808f684SSatish Balay            D[i] = i+1
5125808f684SSatish Balay        D.assemble()
5135808f684SSatish Balay        self.A.setDiagonal(D)
5145808f684SSatish Balay
5155808f684SSatish Balay    def testZeroEntries(self):
5165808f684SSatish Balay        self.A.zeroEntries()
5175808f684SSatish Balay        D = self._getCtx().D
5185808f684SSatish Balay        self.assertEqual(D.norm(), 0)
5195808f684SSatish Balay
5205808f684SSatish Balay    def testMult(self):
5215808f684SSatish Balay        x, y = self.A.createVecs()
5225808f684SSatish Balay        x.set(1)
5235808f684SSatish Balay        self.A.mult(x,y)
5245808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5255808f684SSatish Balay
5265808f684SSatish Balay    def testMultTransposeSymmKnown(self):
5275808f684SSatish Balay        x, y = self.A.createVecs()
5285808f684SSatish Balay        x.set(1)
5295808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
5305808f684SSatish Balay        self.A.multTranspose(x,y)
5315808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5325808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
5335808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
5345808f684SSatish Balay        self.assertRaises(Exception, f)
5355808f684SSatish Balay
5365808f684SSatish Balay    def testMultTransposeNewMeth(self):
5375808f684SSatish Balay        x, y = self.A.createVecs()
5385808f684SSatish Balay        x.set(1)
5395808f684SSatish Balay        AA = self.A.getPythonContext()
5405808f684SSatish Balay        AA.multTranspose = AA.mult
5415808f684SSatish Balay        self.A.multTranspose(x,y)
5425808f684SSatish Balay        del AA.multTranspose
5435808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5445808f684SSatish Balay
545e124b1b1SStefano Zampini    def testDuplicate(self):
546e124b1b1SStefano Zampini        B = self.A.duplicate(False)
547e124b1b1SStefano Zampini        B = self.A.duplicate(True)
548e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().D.equal(self.A.getPythonContext().D))
549e124b1b1SStefano Zampini
5505808f684SSatish Balay    def testGetDiagonal(self):
5515808f684SSatish Balay        d = self.A.createVecLeft()
5525808f684SSatish Balay        self.A.getDiagonal(d)
5535808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
5545808f684SSatish Balay
5555808f684SSatish Balay    def testSetDiagonal(self):
5565808f684SSatish Balay        d = self.A.createVecLeft()
5575808f684SSatish Balay        d.setRandom()
5585808f684SSatish Balay        self.A.setDiagonal(d)
5595808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
5605808f684SSatish Balay
5615808f684SSatish Balay    def testDiagonalScale(self):
5625808f684SSatish Balay        x, y = self.A.createVecs()
5635808f684SSatish Balay        x.set(2)
5645808f684SSatish Balay        y.set(3)
5655808f684SSatish Balay        old = self._getCtx().D.copy()
5665808f684SSatish Balay        self.A.diagonalScale(x, y)
5675808f684SSatish Balay        D = self._getCtx().D
5685808f684SSatish Balay        self.assertTrue(D.equal(old*6))
5695808f684SSatish Balay
5705808f684SSatish Balay    def testCreateTranspose(self):
5715808f684SSatish Balay        A = self.A
5725808f684SSatish Balay        A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
5735808f684SSatish Balay        AT = PETSc.Mat().createTranspose(A)
5745808f684SSatish Balay        x, y = A.createVecs()
5755808f684SSatish Balay        xt, yt = AT.createVecs()
5765808f684SSatish Balay        #
5775808f684SSatish Balay        y.setRandom()
5785808f684SSatish Balay        A.multTranspose(y, x)
5795808f684SSatish Balay        y.copy(xt)
5805808f684SSatish Balay        AT.mult(xt, yt)
5815808f684SSatish Balay        self.assertTrue(yt.equal(x))
5825808f684SSatish Balay        #
5835808f684SSatish Balay        x.setRandom()
5845808f684SSatish Balay        A.mult(x, y)
5855808f684SSatish Balay        x.copy(yt)
5865808f684SSatish Balay        AT.multTranspose(yt, xt)
5875808f684SSatish Balay        self.assertTrue(xt.equal(y))
5885808f684SSatish Balay        del A
5895808f684SSatish Balay
5908af18dd8SStefano Zampini    def testConvert(self):
5918af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.AIJ,PETSc.Mat()).equal(self.A))
5928af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.BAIJ,PETSc.Mat()).equal(self.A))
5938af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.SBAIJ,PETSc.Mat()).equal(self.A))
5948af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.DENSE,PETSc.Mat()).equal(self.A))
5958c2316a8SJeremy Tillay
59622fceea1SStefano Zampini    def testShift(self):
59722fceea1SStefano Zampini        old = self._getCtx().D.copy()
59822fceea1SStefano Zampini        self.A.shift(-0.5)
59922fceea1SStefano Zampini        D = self._getCtx().D
60022fceea1SStefano Zampini        self.assertTrue(D.equal(old-0.5))
60122fceea1SStefano Zampini
60222fceea1SStefano Zampini    def testScale(self):
60322fceea1SStefano Zampini        old = self._getCtx().D.copy()
60422fceea1SStefano Zampini        self.A.scale(-0.5)
60522fceea1SStefano Zampini        D = self._getCtx().D
60622fceea1SStefano Zampini        self.assertTrue(D.equal(-0.5*old))
60722fceea1SStefano Zampini
60822fceea1SStefano Zampini
6095808f684SSatish Balay# --------------------------------------------------------------------
6105808f684SSatish Balay
6115808f684SSatish Balayif __name__ == '__main__':
6125808f684SSatish Balay    unittest.main()
613