xref: /petsc/src/binding/petsc4py/test/test_mat_py.py (revision 26cec32642a02ddeaa9481e1a5bd50b8500ffeea)
15808f684SSatish Balayfrom petsc4py import PETSc
253022affSStefano Zampiniimport unittest, numpy
35808f684SSatish Balayfrom sys import getrefcount
45808f684SSatish Balay# --------------------------------------------------------------------
55808f684SSatish Balay
65808f684SSatish Balayclass Matrix(object):
75808f684SSatish Balay
85808f684SSatish Balay    def __init__(self):
95808f684SSatish Balay        pass
105808f684SSatish Balay
115808f684SSatish Balay    def create(self, mat):
125808f684SSatish Balay        pass
135808f684SSatish Balay
145808f684SSatish Balay    def destroy(self, mat):
155808f684SSatish Balay        pass
165808f684SSatish Balay
1722fceea1SStefano Zampiniclass ScaledIdentity(Matrix):
1822fceea1SStefano Zampini
1922fceea1SStefano Zampini    s = 2.0
2022fceea1SStefano Zampini
2122fceea1SStefano Zampini    def scale(self, mat, s):
2222fceea1SStefano Zampini        self.s *= s
2322fceea1SStefano Zampini
2422fceea1SStefano Zampini    def shift(self, mat, s):
2522fceea1SStefano Zampini        self.s += s
265808f684SSatish Balay
275808f684SSatish Balay    def mult(self, mat, x, y):
285808f684SSatish Balay        x.copy(y)
2922fceea1SStefano Zampini        y.scale(self.s)
305808f684SSatish Balay
31e124b1b1SStefano Zampini    def duplicate(self, mat, op):
32e124b1b1SStefano Zampini        dmat = PETSc.Mat()
33e124b1b1SStefano Zampini        dctx = ScaledIdentity()
34e124b1b1SStefano Zampini        dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm())
35e124b1b1SStefano Zampini        if op == PETSc.Mat.DuplicateOption.COPY_VALUES:
36e124b1b1SStefano Zampini          dctx.s = self.s
37e124b1b1SStefano Zampini          dmat.setUp()
38e124b1b1SStefano Zampini        return dmat
39e124b1b1SStefano Zampini
405808f684SSatish Balay    def getDiagonal(self, mat, vd):
4122fceea1SStefano Zampini        vd.set(self.s)
425808f684SSatish Balay
43ee6c7c31SStefano Zampini    def productSetFromOptions(self, mat, producttype, A, B, C):
44ee6c7c31SStefano Zampini        return True
45ee6c7c31SStefano Zampini
46ee6c7c31SStefano Zampini    def productSymbolic(self, mat, product, producttype, A, B, C):
47ee6c7c31SStefano Zampini        if producttype == 'AB':
48ee6c7c31SStefano Zampini            if mat is A: # product = identity * B
49ee6c7c31SStefano Zampini                product.setType(B.getType())
50ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
51ee6c7c31SStefano Zampini                product.setUp()
52ee6c7c31SStefano Zampini                product.assemble()
53ee6c7c31SStefano Zampini                B.copy(product)
54ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity
55ee6c7c31SStefano Zampini                product.setType(A.getType())
56ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
57ee6c7c31SStefano Zampini                product.setUp()
58ee6c7c31SStefano Zampini                product.assemble()
59ee6c7c31SStefano Zampini                A.copy(product)
60ee6c7c31SStefano Zampini            else:
61ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
62ee6c7c31SStefano Zampini        elif producttype == 'AtB':
63ee6c7c31SStefano Zampini            if mat is A: # product = identity^T * B
64ee6c7c31SStefano Zampini                product.setType(B.getType())
65ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
66ee6c7c31SStefano Zampini                product.setUp()
67ee6c7c31SStefano Zampini                product.assemble()
68ee6c7c31SStefano Zampini                B.copy(product)
69ee6c7c31SStefano Zampini            elif mat is B: # product = A^T * identity
70ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
71ee6c7c31SStefano Zampini                A.transpose(tmp)
72ee6c7c31SStefano Zampini                product.setType(tmp.getType())
73ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
74ee6c7c31SStefano Zampini                product.setUp()
75ee6c7c31SStefano Zampini                product.assemble()
76ee6c7c31SStefano Zampini                tmp.copy(product)
77ee6c7c31SStefano Zampini            else:
78ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
79ee6c7c31SStefano Zampini        elif producttype == 'ABt':
80ee6c7c31SStefano Zampini            if mat is A: # product = identity * B^T
81ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
82ee6c7c31SStefano Zampini                B.transpose(tmp)
83ee6c7c31SStefano Zampini                product.setType(tmp.getType())
84ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
85ee6c7c31SStefano Zampini                product.setUp()
86ee6c7c31SStefano Zampini                product.assemble()
87ee6c7c31SStefano Zampini                tmp.copy(product)
88ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity^T
89ee6c7c31SStefano Zampini                product.setType(A.getType())
90ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
91ee6c7c31SStefano Zampini                product.setUp()
92ee6c7c31SStefano Zampini                product.assemble()
93ee6c7c31SStefano Zampini                A.copy(product)
94ee6c7c31SStefano Zampini            else:
95ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
96ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
97ee6c7c31SStefano Zampini            if mat is A: # product = P^T * identity * P
98ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
99ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
100ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
101ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
102ee6c7c31SStefano Zampini                product.setUp()
103ee6c7c31SStefano Zampini                product.assemble()
104ee6c7c31SStefano Zampini                self.tmp.copy(product)
105ee6c7c31SStefano Zampini            elif mat is B: # product = identity^T * A * identity
106ee6c7c31SStefano Zampini                product.setType(A.getType())
107ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
108ee6c7c31SStefano Zampini                product.setUp()
109ee6c7c31SStefano Zampini                product.assemble()
110ee6c7c31SStefano Zampini                A.copy(product)
111ee6c7c31SStefano Zampini            else:
112ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
113ee6c7c31SStefano Zampini        elif producttype == 'RARt':
114ee6c7c31SStefano Zampini            if mat is A: # product = R * identity * R^t
115ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
116ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
117ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
118ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
119ee6c7c31SStefano Zampini                product.setUp()
120ee6c7c31SStefano Zampini                product.assemble()
121ee6c7c31SStefano Zampini                self.tmp.copy(product)
122ee6c7c31SStefano Zampini            elif mat is B: # product = identity * A * identity^T
123ee6c7c31SStefano Zampini                product.setType(A.getType())
124ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
125ee6c7c31SStefano Zampini                product.setUp()
126ee6c7c31SStefano Zampini                product.assemble()
127ee6c7c31SStefano Zampini                A.copy(product)
128ee6c7c31SStefano Zampini            else:
129ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
130ee6c7c31SStefano Zampini        elif producttype == 'ABC':
131ee6c7c31SStefano Zampini            if mat is A: # product = identity * B * C
132ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
133ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
134ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
135ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
136ee6c7c31SStefano Zampini                product.setUp()
137ee6c7c31SStefano Zampini                product.assemble()
138ee6c7c31SStefano Zampini                self.tmp.copy(product)
139ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity * C
140ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
141ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
142ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
143ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
144ee6c7c31SStefano Zampini                product.setUp()
145ee6c7c31SStefano Zampini                product.assemble()
146ee6c7c31SStefano Zampini                self.tmp.copy(product)
147ee6c7c31SStefano Zampini            elif mat is C: # product = A * B * identity
148ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
149ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
150ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
151ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
152ee6c7c31SStefano Zampini                product.setUp()
153ee6c7c31SStefano Zampini                product.assemble()
154ee6c7c31SStefano Zampini                self.tmp.copy(product)
155ee6c7c31SStefano Zampini            else:
156ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
157ee6c7c31SStefano Zampini        else:
158ee6c7c31SStefano Zampini            raise RuntimeError('Product {} not implemented'.format(producttype))
159ee6c7c31SStefano Zampini        product.zeroEntries()
160ee6c7c31SStefano Zampini
161ee6c7c31SStefano Zampini    def productNumeric(self, mat, product, producttype, A, B, C):
162ee6c7c31SStefano Zampini        if producttype == 'AB':
163ee6c7c31SStefano Zampini            if mat is A: # product = identity * B
164ee6c7c31SStefano Zampini                B.copy(product, structure=True)
165ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity
166ee6c7c31SStefano Zampini                A.copy(product, structure=True)
167ee6c7c31SStefano Zampini            else:
168ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
16922fceea1SStefano Zampini            product.scale(self.s)
170ee6c7c31SStefano Zampini        elif producttype == 'AtB':
171ee6c7c31SStefano Zampini            if mat is A: # product = identity^T * B
172ee6c7c31SStefano Zampini                B.copy(product, structure=True)
173ee6c7c31SStefano Zampini            elif mat is B: # product = A^T * identity
1747fb60732SBarry Smith                A.setTransposePrecursor(product)
175ee6c7c31SStefano Zampini                A.transpose(product)
176ee6c7c31SStefano Zampini            else:
177ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
17822fceea1SStefano Zampini            product.scale(self.s)
179ee6c7c31SStefano Zampini        elif producttype == 'ABt':
180ee6c7c31SStefano Zampini            if mat is A: # product = identity * B^T
1817fb60732SBarry Smith                B.setTransposePrecursor(product)
182ee6c7c31SStefano Zampini                B.transpose(product)
183ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity^T
184ee6c7c31SStefano Zampini                A.copy(product, structure=True)
185ee6c7c31SStefano Zampini            else:
186ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
18722fceea1SStefano Zampini            product.scale(self.s)
188ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
189ee6c7c31SStefano Zampini            if mat is A: # product = P^T * identity * P
190ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
191ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
19222fceea1SStefano Zampini                product.scale(self.s)
193ee6c7c31SStefano Zampini            elif mat is B: # product = identity^T * A * identity
194ee6c7c31SStefano Zampini                A.copy(product, structure=True)
19522fceea1SStefano Zampini                product.scale(self.s**2)
196ee6c7c31SStefano Zampini            else:
197ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
198ee6c7c31SStefano Zampini        elif producttype == 'RARt':
199ee6c7c31SStefano Zampini            if mat is A: # product = R * identity * R^t
200ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
201ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
20222fceea1SStefano Zampini                product.scale(self.s)
203ee6c7c31SStefano Zampini            elif mat is B: # product = identity * A * identity^T
204ee6c7c31SStefano Zampini                A.copy(product, structure=True)
20522fceea1SStefano Zampini                product.scale(self.s**2)
206ee6c7c31SStefano Zampini            else:
207ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
208ee6c7c31SStefano Zampini        elif producttype == 'ABC':
209ee6c7c31SStefano Zampini            if mat is A: # product = identity * B * C
210ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
211ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
212ee6c7c31SStefano Zampini            elif mat is B: # product = A * identity * C
213ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
214ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
215ee6c7c31SStefano Zampini            elif mat is C: # product = A * B * identity
216ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
217ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
218ee6c7c31SStefano Zampini            else:
219ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
22022fceea1SStefano Zampini            product.scale(self.s)
221ee6c7c31SStefano Zampini        else:
222ee6c7c31SStefano Zampini            raise RuntimeError('Product {} not implemented'.format(producttype))
223ee6c7c31SStefano Zampini
2245808f684SSatish Balayclass Diagonal(Matrix):
2255808f684SSatish Balay
2265808f684SSatish Balay    def create(self, mat):
2275808f684SSatish Balay        super(Diagonal,self).create(mat)
2285808f684SSatish Balay        mat.setUp()
2295808f684SSatish Balay        self.D = mat.createVecLeft()
2305808f684SSatish Balay
2315808f684SSatish Balay    def destroy(self, mat):
2325808f684SSatish Balay        self.D.destroy()
2335808f684SSatish Balay        super(Diagonal,self).destroy(mat)
2345808f684SSatish Balay
2355808f684SSatish Balay    def scale(self, mat, a):
2365808f684SSatish Balay        self.D.scale(a)
2375808f684SSatish Balay
2385808f684SSatish Balay    def shift(self, mat, a):
2395808f684SSatish Balay        self.D.shift(a)
2405808f684SSatish Balay
2415808f684SSatish Balay    def zeroEntries(self, mat):
2425808f684SSatish Balay        self.D.zeroEntries()
2435808f684SSatish Balay
2445808f684SSatish Balay    def mult(self, mat, x, y):
2455808f684SSatish Balay        y.pointwiseMult(x, self.D)
2465808f684SSatish Balay
247e124b1b1SStefano Zampini    def duplicate(self, mat, op):
248e124b1b1SStefano Zampini        dmat = PETSc.Mat()
249e124b1b1SStefano Zampini        dctx = Diagonal()
250e124b1b1SStefano Zampini        dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm())
251e124b1b1SStefano Zampini        dctx.D = self.D.duplicate()
252e124b1b1SStefano Zampini        if op == PETSc.Mat.DuplicateOption.COPY_VALUES:
253e124b1b1SStefano Zampini          self.D.copy(dctx.D)
254e124b1b1SStefano Zampini          dmat.setUp()
255e124b1b1SStefano Zampini        return dmat
256e124b1b1SStefano Zampini
2575808f684SSatish Balay    def getDiagonal(self, mat, vd):
2585808f684SSatish Balay        self.D.copy(vd)
2595808f684SSatish Balay
2605808f684SSatish Balay    def setDiagonal(self, mat, vd, im):
2615808f684SSatish Balay        if isinstance (im, bool):
2625808f684SSatish Balay            addv = im
2635808f684SSatish Balay            if addv:
2645808f684SSatish Balay                self.D.axpy(1, vd)
2655808f684SSatish Balay            else:
2665808f684SSatish Balay                vd.copy(self.D)
2675808f684SSatish Balay        elif im == PETSc.InsertMode.INSERT_VALUES:
2685808f684SSatish Balay            vd.copy(self.D)
2695808f684SSatish Balay        elif im == PETSc.InsertMode.ADD_VALUES:
2705808f684SSatish Balay            self.D.axpy(1, vd)
2715808f684SSatish Balay        else:
2725808f684SSatish Balay            raise ValueError('wrong InsertMode %d'% im)
2735808f684SSatish Balay
2745808f684SSatish Balay    def diagonalScale(self, mat, vl, vr):
2755808f684SSatish Balay        if vl: self.D.pointwiseMult(self.D, vl)
2765808f684SSatish Balay        if vr: self.D.pointwiseMult(self.D, vr)
2775808f684SSatish Balay
2785808f684SSatish Balay# --------------------------------------------------------------------
2795808f684SSatish Balay
2805808f684SSatish Balayclass TestMatrix(unittest.TestCase):
2815808f684SSatish Balay
2825808f684SSatish Balay    COMM = PETSc.COMM_WORLD
2835808f684SSatish Balay    PYMOD = __name__
2845808f684SSatish Balay    PYCLS = 'Matrix'
2855808f684SSatish Balay
2865808f684SSatish Balay    def _getCtx(self):
2875808f684SSatish Balay        return self.A.getPythonContext()
2885808f684SSatish Balay
2895808f684SSatish Balay    def setUp(self):
290300d917bSStefano Zampini        N = self.N = 13
2915808f684SSatish Balay        self.A = PETSc.Mat()
2925808f684SSatish Balay        if 0: # command line way
2935808f684SSatish Balay            self.A.create(self.COMM)
2945808f684SSatish Balay            self.A.setSizes([N,N])
2955808f684SSatish Balay            self.A.setType('python')
2965808f684SSatish Balay            OptDB = PETSc.Options(self.A)
2975808f684SSatish Balay            OptDB['mat_python_type'] = '%s.%s' % (self.PYMOD,self.PYCLS)
2985808f684SSatish Balay            self.A.setFromOptions()
2995808f684SSatish Balay            self.A.setUp()
3005808f684SSatish Balay            del OptDB['mat_python_type']
3015808f684SSatish Balay            self.assertTrue(self._getCtx() is not None)
3025808f684SSatish Balay        else: # python way
3035808f684SSatish Balay            context = globals()[self.PYCLS]()
3045808f684SSatish Balay            self.A.createPython([N,N], context, comm=self.COMM)
3055808f684SSatish Balay            self.A.setUp()
3065808f684SSatish Balay            self.assertTrue(self._getCtx() is context)
3075808f684SSatish Balay            self.assertEqual(getrefcount(context), 3)
3085808f684SSatish Balay            del context
3095808f684SSatish Balay            self.assertEqual(getrefcount(self._getCtx()), 2)
3105808f684SSatish Balay
3115808f684SSatish Balay    def tearDown(self):
3125808f684SSatish Balay        ctx = self.A.getPythonContext()
3135808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
3145808f684SSatish Balay        self.A.destroy() # XXX
3155808f684SSatish Balay        self.A = None
31662e5d2d2SJDBetteridge        PETSc.garbage_cleanup()
3175808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 2)
3185808f684SSatish Balay        #import gc,pprint; pprint.pprint(gc.get_referrers(ctx))
3195808f684SSatish Balay
3205808f684SSatish Balay    def testBasic(self):
3215808f684SSatish Balay        ctx = self.A.getPythonContext()
3225808f684SSatish Balay        self.assertTrue(self._getCtx() is ctx)
3235808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
3245808f684SSatish Balay
3255808f684SSatish Balay    def testZeroEntries(self):
3265808f684SSatish Balay        f = lambda : self.A.zeroEntries()
3275808f684SSatish Balay        self.assertRaises(Exception, f)
3285808f684SSatish Balay
3295808f684SSatish Balay    def testMult(self):
3305808f684SSatish Balay        x, y = self.A.createVecs()
3315808f684SSatish Balay        f = lambda : self.A.mult(x, y)
3325808f684SSatish Balay        self.assertRaises(Exception, f)
3335808f684SSatish Balay
3345808f684SSatish Balay    def testMultTranspose(self):
3355808f684SSatish Balay        x, y = self.A.createVecs()
3365808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
3375808f684SSatish Balay        self.assertRaises(Exception, f)
3385808f684SSatish Balay
3395808f684SSatish Balay    def testGetDiagonal(self):
3405808f684SSatish Balay        d = self.A.createVecLeft()
3415808f684SSatish Balay        f = lambda : self.A.getDiagonal(d)
3425808f684SSatish Balay        self.assertRaises(Exception, f)
3435808f684SSatish Balay
3445808f684SSatish Balay    def testSetDiagonal(self):
3455808f684SSatish Balay        d = self.A.createVecLeft()
3465808f684SSatish Balay        f = lambda : self.A.setDiagonal(d)
3475808f684SSatish Balay        self.assertRaises(Exception, f)
3485808f684SSatish Balay
3495808f684SSatish Balay    def testDiagonalScale(self):
3505808f684SSatish Balay        x, y = self.A.createVecs()
3515808f684SSatish Balay        f = lambda : self.A.diagonalScale(x, y)
3525808f684SSatish Balay        self.assertRaises(Exception, f)
3535808f684SSatish Balay
354e124b1b1SStefano Zampini    def testDuplicate(self):
355e124b1b1SStefano Zampini        f1 = lambda : self.A.duplicate(x, True)
356e124b1b1SStefano Zampini        f2 = lambda : self.A.duplicate(x, False)
357e124b1b1SStefano Zampini        self.assertRaises(Exception, f1)
358e124b1b1SStefano Zampini        self.assertRaises(Exception, f2)
359e124b1b1SStefano Zampini
3601cebabd4SStefano Zampini    def testSetVecType(self):
3611cebabd4SStefano Zampini        self.A.setVecType('mpi')
3621cebabd4SStefano Zampini        self.assertTrue('mpi' == self.A.getVecType())
3631cebabd4SStefano Zampini
364300d917bSStefano Zampini    def testH2Opus(self):
365300d917bSStefano Zampini        if not PETSc.Sys.hasExternalPackage("h2opus"):
366300d917bSStefano Zampini            return
367300d917bSStefano Zampini        if self.A.getComm().Get_size() > 1:
368300d917bSStefano Zampini            return
369300d917bSStefano Zampini        h = PETSc.Mat()
370300d917bSStefano Zampini
371300d917bSStefano Zampini        # need matrix vector and its transpose for norm estimation
372300d917bSStefano Zampini        AA = self.A.getPythonContext()
373300d917bSStefano Zampini        if not hasattr(AA,'mult'):
374300d917bSStefano Zampini            return
375300d917bSStefano Zampini        AA.multTranspose = AA.mult
376300d917bSStefano Zampini
377300d917bSStefano Zampini        # without coordinates
378300d917bSStefano Zampini        h.createH2OpusFromMat(self.A,leafsize=2)
379300d917bSStefano Zampini        h.assemble()
380300d917bSStefano Zampini        h.destroy()
381300d917bSStefano Zampini
382300d917bSStefano Zampini        # with coordinates
383300d917bSStefano Zampini        coords = numpy.linspace((1,2,3),(10,20,30),self.A.getSize()[0],dtype=PETSc.RealType)
384300d917bSStefano Zampini        h.createH2OpusFromMat(self.A,coords,leafsize=2)
385300d917bSStefano Zampini        h.assemble()
386300d917bSStefano Zampini
387300d917bSStefano Zampini        # test API
388300d917bSStefano Zampini        h.H2OpusOrthogonalize()
389300d917bSStefano Zampini        h.H2OpusCompress(1.e-1)
390300d917bSStefano Zampini
391300d917bSStefano Zampini        # Low-rank update
392300d917bSStefano Zampini        U = PETSc.Mat()
393300d917bSStefano Zampini        U.createDense([h.getSizes()[0],3],comm=h.getComm())
394300d917bSStefano Zampini        U.setUp()
395300d917bSStefano Zampini        U.setRandom()
396300d917bSStefano Zampini
397300d917bSStefano Zampini        he = PETSc.Mat()
398300d917bSStefano Zampini        h.convert('dense',he)
399300d917bSStefano Zampini        he.axpy(1.0, U.matTransposeMult(U))
400300d917bSStefano Zampini
401300d917bSStefano Zampini        h.H2OpusLowRankUpdate(U)
402300d917bSStefano Zampini        self.assertTrue(he.equal(h))
403300d917bSStefano Zampini
404300d917bSStefano Zampini
405300d917bSStefano Zampini        h.destroy()
406300d917bSStefano Zampini
407300d917bSStefano Zampini        del AA.multTranspose
408300d917bSStefano Zampini
409ebead697SStefano Zampini    def testGetType(self):
410ebead697SStefano Zampini        ctx = self.A.getPythonContext()
411ebead697SStefano Zampini        pytype = "{0}.{1}".format(ctx.__module__, type(ctx).__name__)
412ebead697SStefano Zampini        self.assertTrue(self.A.getPythonType() == pytype)
413300d917bSStefano Zampini
41422fceea1SStefano Zampiniclass TestScaledIdentity(TestMatrix):
4155808f684SSatish Balay
41622fceea1SStefano Zampini    PYCLS = 'ScaledIdentity'
4175808f684SSatish Balay
4185808f684SSatish Balay    def testMult(self):
41922fceea1SStefano Zampini        s = self._getCtx().s
4205808f684SSatish Balay        x, y = self.A.createVecs()
4215808f684SSatish Balay        x.setRandom()
4225808f684SSatish Balay        self.A.mult(x,y)
42322fceea1SStefano Zampini        self.assertTrue(y.equal(s*x))
4245808f684SSatish Balay
4255808f684SSatish Balay    def testMultTransposeSymmKnown(self):
42622fceea1SStefano Zampini        s = self._getCtx().s
4275808f684SSatish Balay        x, y = self.A.createVecs()
4285808f684SSatish Balay        x.setRandom()
4295808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
4305808f684SSatish Balay        self.A.multTranspose(x,y)
43122fceea1SStefano Zampini        self.assertTrue(y.equal(s*x))
4325808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
4335808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
4345808f684SSatish Balay        self.assertRaises(Exception, f)
4355808f684SSatish Balay
4365808f684SSatish Balay    def testMultTransposeNewMeth(self):
43722fceea1SStefano Zampini        s = self._getCtx().s
4385808f684SSatish Balay        x, y = self.A.createVecs()
4395808f684SSatish Balay        x.setRandom()
4405808f684SSatish Balay        AA = self.A.getPythonContext()
4415808f684SSatish Balay        AA.multTranspose = AA.mult
4425808f684SSatish Balay        self.A.multTranspose(x,y)
4435808f684SSatish Balay        del AA.multTranspose
44422fceea1SStefano Zampini        self.assertTrue(y.equal(s*x))
4455808f684SSatish Balay
4465808f684SSatish Balay    def testGetDiagonal(self):
44722fceea1SStefano Zampini        s = self._getCtx().s
4485808f684SSatish Balay        d = self.A.createVecLeft()
4495808f684SSatish Balay        o = d.duplicate()
45022fceea1SStefano Zampini        o.set(s)
4515808f684SSatish Balay        self.A.getDiagonal(d)
4525808f684SSatish Balay        self.assertTrue(o.equal(d))
4535808f684SSatish Balay
454e124b1b1SStefano Zampini    def testDuplicate(self):
455e124b1b1SStefano Zampini        B = self.A.duplicate(False)
456e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().s == 2)
457e124b1b1SStefano Zampini        B = self.A.duplicate(True)
458e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().s == self.A.getPythonContext().s)
459e124b1b1SStefano Zampini
460ee6c7c31SStefano Zampini    def testMatMat(self):
46122fceea1SStefano Zampini        s = self._getCtx().s
462ee6c7c31SStefano Zampini        R = PETSc.Random().create(self.COMM)
463ee6c7c31SStefano Zampini        R.setFromOptions()
464ee6c7c31SStefano Zampini        A = PETSc.Mat().create(self.COMM)
465ee6c7c31SStefano Zampini        A.setSizes(self.A.getSizes())
466ee6c7c31SStefano Zampini        A.setType(PETSc.Mat.Type.AIJ)
467*26cec326SBarry Smith        A.setPreallocationNNZ(None)
468ee6c7c31SStefano Zampini        A.setRandom(R)
469ee6c7c31SStefano Zampini        B = PETSc.Mat().create(self.COMM)
470ee6c7c31SStefano Zampini        B.setSizes(self.A.getSizes())
471ee6c7c31SStefano Zampini        B.setType(PETSc.Mat.Type.AIJ)
472*26cec326SBarry Smith        B.setPreallocationNNZ(None)
473ee6c7c31SStefano Zampini        B.setRandom(R)
474ee6c7c31SStefano Zampini        I = PETSc.Mat().create(self.COMM)
475ee6c7c31SStefano Zampini        I.setSizes(self.A.getSizes())
476ee6c7c31SStefano Zampini        I.setType(PETSc.Mat.Type.AIJ)
477ee6c7c31SStefano Zampini        I.setUp()
478ee6c7c31SStefano Zampini        I.assemble()
47922fceea1SStefano Zampini        I.shift(s)
480ee6c7c31SStefano Zampini
481ee6c7c31SStefano Zampini        self.assertTrue(self.A.matMult(A).equal(I.matMult(A)))
482ee6c7c31SStefano Zampini        self.assertTrue(A.matMult(self.A).equal(A.matMult(I)))
483ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
484ee6c7c31SStefano Zampini            self.assertTrue(self.A.matTransposeMult(A).equal(I.matTransposeMult(A)))
485ee6c7c31SStefano Zampini            self.assertTrue(A.matTransposeMult(self.A).equal(A.matTransposeMult(I)))
486ee6c7c31SStefano Zampini        self.assertTrue(self.A.transposeMatMult(A).equal(I.transposeMatMult(A)))
487ee6c7c31SStefano Zampini        self.assertTrue(A.transposeMatMult(self.A).equal(A.transposeMatMult(I)))
488ee6c7c31SStefano Zampini        self.assertAlmostEqual((self.A.ptap(A) - I.ptap(A)).norm(), 0.0, places=5)
489ee6c7c31SStefano Zampini        self.assertAlmostEqual((A.ptap(self.A) - A.ptap(I)).norm(), 0.0, places=5)
490ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
491ee6c7c31SStefano Zampini            self.assertAlmostEqual((self.A.rart(A) - I.rart(A)).norm(), 0.0, places=5)
492ee6c7c31SStefano Zampini            self.assertAlmostEqual((A.rart(self.A) - A.rart(I)).norm(), 0.0, places=5)
493ee6c7c31SStefano Zampini        self.assertAlmostEqual((self.A.matMatMult(A,B)-I.matMatMult(A,B)).norm(), 0.0, places=5)
494ee6c7c31SStefano Zampini        self.assertAlmostEqual((A.matMatMult(self.A,B)-A.matMatMult(I,B)).norm(), 0.0, places=5)
495ee6c7c31SStefano Zampini        self.assertAlmostEqual((A.matMatMult(B,self.A)-A.matMatMult(B,I)).norm(), 0.0, places=5)
496ee6c7c31SStefano Zampini
49722fceea1SStefano Zampini    def testShift(self):
49822fceea1SStefano Zampini        sold = self._getCtx().s
49922fceea1SStefano Zampini        self.A.shift(-0.5)
50022fceea1SStefano Zampini        s = self._getCtx().s
50122fceea1SStefano Zampini        self.assertTrue(s == sold - 0.5)
50222fceea1SStefano Zampini
50322fceea1SStefano Zampini    def testScale(self):
50422fceea1SStefano Zampini        sold = self._getCtx().s
50522fceea1SStefano Zampini        self.A.scale(-0.5)
50622fceea1SStefano Zampini        s = self._getCtx().s
50722fceea1SStefano Zampini        self.assertTrue(s == sold * -0.5)
50822fceea1SStefano Zampini
5095808f684SSatish Balayclass TestDiagonal(TestMatrix):
5105808f684SSatish Balay
5115808f684SSatish Balay    PYCLS = 'Diagonal'
5125808f684SSatish Balay
5135808f684SSatish Balay    def setUp(self):
5145808f684SSatish Balay        super(TestDiagonal, self).setUp()
5155808f684SSatish Balay        D = self.A.createVecLeft()
5165808f684SSatish Balay        s, e = D.getOwnershipRange()
5175808f684SSatish Balay        for i in range(s, e):
5185808f684SSatish Balay            D[i] = i+1
5195808f684SSatish Balay        D.assemble()
5205808f684SSatish Balay        self.A.setDiagonal(D)
5215808f684SSatish Balay
5225808f684SSatish Balay    def testZeroEntries(self):
5235808f684SSatish Balay        self.A.zeroEntries()
5245808f684SSatish Balay        D = self._getCtx().D
5255808f684SSatish Balay        self.assertEqual(D.norm(), 0)
5265808f684SSatish Balay
5275808f684SSatish Balay    def testMult(self):
5285808f684SSatish Balay        x, y = self.A.createVecs()
5295808f684SSatish Balay        x.set(1)
5305808f684SSatish Balay        self.A.mult(x,y)
5315808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5325808f684SSatish Balay
5335808f684SSatish Balay    def testMultTransposeSymmKnown(self):
5345808f684SSatish Balay        x, y = self.A.createVecs()
5355808f684SSatish Balay        x.set(1)
5365808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
5375808f684SSatish Balay        self.A.multTranspose(x,y)
5385808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5395808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
5405808f684SSatish Balay        f = lambda : self.A.multTranspose(x, y)
5415808f684SSatish Balay        self.assertRaises(Exception, f)
5425808f684SSatish Balay
5435808f684SSatish Balay    def testMultTransposeNewMeth(self):
5445808f684SSatish Balay        x, y = self.A.createVecs()
5455808f684SSatish Balay        x.set(1)
5465808f684SSatish Balay        AA = self.A.getPythonContext()
5475808f684SSatish Balay        AA.multTranspose = AA.mult
5485808f684SSatish Balay        self.A.multTranspose(x,y)
5495808f684SSatish Balay        del AA.multTranspose
5505808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5515808f684SSatish Balay
552e124b1b1SStefano Zampini    def testDuplicate(self):
553e124b1b1SStefano Zampini        B = self.A.duplicate(False)
554e124b1b1SStefano Zampini        B = self.A.duplicate(True)
555e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().D.equal(self.A.getPythonContext().D))
556e124b1b1SStefano Zampini
5575808f684SSatish Balay    def testGetDiagonal(self):
5585808f684SSatish Balay        d = self.A.createVecLeft()
5595808f684SSatish Balay        self.A.getDiagonal(d)
5605808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
5615808f684SSatish Balay
5625808f684SSatish Balay    def testSetDiagonal(self):
5635808f684SSatish Balay        d = self.A.createVecLeft()
5645808f684SSatish Balay        d.setRandom()
5655808f684SSatish Balay        self.A.setDiagonal(d)
5665808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
5675808f684SSatish Balay
5685808f684SSatish Balay    def testDiagonalScale(self):
5695808f684SSatish Balay        x, y = self.A.createVecs()
5705808f684SSatish Balay        x.set(2)
5715808f684SSatish Balay        y.set(3)
5725808f684SSatish Balay        old = self._getCtx().D.copy()
5735808f684SSatish Balay        self.A.diagonalScale(x, y)
5745808f684SSatish Balay        D = self._getCtx().D
5755808f684SSatish Balay        self.assertTrue(D.equal(old*6))
5765808f684SSatish Balay
5775808f684SSatish Balay    def testCreateTranspose(self):
5785808f684SSatish Balay        A = self.A
5795808f684SSatish Balay        A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
5805808f684SSatish Balay        AT = PETSc.Mat().createTranspose(A)
5815808f684SSatish Balay        x, y = A.createVecs()
5825808f684SSatish Balay        xt, yt = AT.createVecs()
5835808f684SSatish Balay        #
5845808f684SSatish Balay        y.setRandom()
5855808f684SSatish Balay        A.multTranspose(y, x)
5865808f684SSatish Balay        y.copy(xt)
5875808f684SSatish Balay        AT.mult(xt, yt)
5885808f684SSatish Balay        self.assertTrue(yt.equal(x))
5895808f684SSatish Balay        #
5905808f684SSatish Balay        x.setRandom()
5915808f684SSatish Balay        A.mult(x, y)
5925808f684SSatish Balay        x.copy(yt)
5935808f684SSatish Balay        AT.multTranspose(yt, xt)
5945808f684SSatish Balay        self.assertTrue(xt.equal(y))
5955808f684SSatish Balay        del A
5965808f684SSatish Balay
5978af18dd8SStefano Zampini    def testConvert(self):
5988af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.AIJ,PETSc.Mat()).equal(self.A))
5998af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.BAIJ,PETSc.Mat()).equal(self.A))
6008af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.SBAIJ,PETSc.Mat()).equal(self.A))
6018af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.DENSE,PETSc.Mat()).equal(self.A))
6028c2316a8SJeremy Tillay
60322fceea1SStefano Zampini    def testShift(self):
60422fceea1SStefano Zampini        old = self._getCtx().D.copy()
60522fceea1SStefano Zampini        self.A.shift(-0.5)
60622fceea1SStefano Zampini        D = self._getCtx().D
60722fceea1SStefano Zampini        self.assertTrue(D.equal(old-0.5))
60822fceea1SStefano Zampini
60922fceea1SStefano Zampini    def testScale(self):
61022fceea1SStefano Zampini        old = self._getCtx().D.copy()
61122fceea1SStefano Zampini        self.A.scale(-0.5)
61222fceea1SStefano Zampini        D = self._getCtx().D
61322fceea1SStefano Zampini        self.assertTrue(D.equal(-0.5*old))
61422fceea1SStefano Zampini
61522fceea1SStefano Zampini
6165808f684SSatish Balay# --------------------------------------------------------------------
6175808f684SSatish Balay
6185808f684SSatish Balayif __name__ == '__main__':
6195808f684SSatish Balay    unittest.main()
620