xref: /petsc/src/binding/petsc4py/test/test_mat_py.py (revision 6f33641175f69f1db294cc9ba81c3f4ad4f81d49)
15808f684SSatish Balayfrom petsc4py import PETSc
2*6f336411SStefano Zampiniimport unittest
3*6f336411SStefano Zampiniimport numpy
45808f684SSatish Balayfrom sys import getrefcount
55808f684SSatish Balay# --------------------------------------------------------------------
65808f684SSatish Balay
75808f684SSatish Balay
8*6f336411SStefano Zampiniclass Matrix:
95808f684SSatish Balay    def __init__(self):
105808f684SSatish Balay        pass
115808f684SSatish Balay
125808f684SSatish Balay    def create(self, mat):
135808f684SSatish Balay        pass
145808f684SSatish Balay
155808f684SSatish Balay    def destroy(self, mat):
165808f684SSatish Balay        pass
175808f684SSatish Balay
1822fceea1SStefano Zampini
19*6f336411SStefano Zampiniclass ScaledIdentity(Matrix):
2022fceea1SStefano Zampini    s = 2.0
2122fceea1SStefano Zampini
2222fceea1SStefano Zampini    def scale(self, mat, s):
2322fceea1SStefano Zampini        self.s *= s
2422fceea1SStefano Zampini
2522fceea1SStefano Zampini    def shift(self, mat, s):
2622fceea1SStefano Zampini        self.s += s
275808f684SSatish Balay
285808f684SSatish Balay    def mult(self, mat, x, y):
295808f684SSatish Balay        x.copy(y)
3022fceea1SStefano Zampini        y.scale(self.s)
315808f684SSatish Balay
32e124b1b1SStefano Zampini    def duplicate(self, mat, op):
33e124b1b1SStefano Zampini        dmat = PETSc.Mat()
34e124b1b1SStefano Zampini        dctx = ScaledIdentity()
35e124b1b1SStefano Zampini        dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm())
36e124b1b1SStefano Zampini        if op == PETSc.Mat.DuplicateOption.COPY_VALUES:
37e124b1b1SStefano Zampini            dctx.s = self.s
38e124b1b1SStefano Zampini            dmat.setUp()
39e124b1b1SStefano Zampini        return dmat
40e124b1b1SStefano Zampini
415808f684SSatish Balay    def getDiagonal(self, mat, vd):
4222fceea1SStefano Zampini        vd.set(self.s)
435808f684SSatish Balay
44ee6c7c31SStefano Zampini    def productSetFromOptions(self, mat, producttype, A, B, C):
45ee6c7c31SStefano Zampini        return True
46ee6c7c31SStefano Zampini
47ee6c7c31SStefano Zampini    def productSymbolic(self, mat, product, producttype, A, B, C):
48ee6c7c31SStefano Zampini        if producttype == 'AB':
49ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B
50ee6c7c31SStefano Zampini                product.setType(B.getType())
51ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
52ee6c7c31SStefano Zampini                product.setUp()
53ee6c7c31SStefano Zampini                product.assemble()
54ee6c7c31SStefano Zampini                B.copy(product)
55ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity
56ee6c7c31SStefano Zampini                product.setType(A.getType())
57ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
58ee6c7c31SStefano Zampini                product.setUp()
59ee6c7c31SStefano Zampini                product.assemble()
60ee6c7c31SStefano Zampini                A.copy(product)
61ee6c7c31SStefano Zampini            else:
62ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
63ee6c7c31SStefano Zampini        elif producttype == 'AtB':
64ee6c7c31SStefano Zampini            if mat is A:  # product = identity^T * B
65ee6c7c31SStefano Zampini                product.setType(B.getType())
66ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
67ee6c7c31SStefano Zampini                product.setUp()
68ee6c7c31SStefano Zampini                product.assemble()
69ee6c7c31SStefano Zampini                B.copy(product)
70ee6c7c31SStefano Zampini            elif mat is B:  # product = A^T * identity
71ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
72ee6c7c31SStefano Zampini                A.transpose(tmp)
73ee6c7c31SStefano Zampini                product.setType(tmp.getType())
74ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
75ee6c7c31SStefano Zampini                product.setUp()
76ee6c7c31SStefano Zampini                product.assemble()
77ee6c7c31SStefano Zampini                tmp.copy(product)
78ee6c7c31SStefano Zampini            else:
79ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
80ee6c7c31SStefano Zampini        elif producttype == 'ABt':
81ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B^T
82ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
83ee6c7c31SStefano Zampini                B.transpose(tmp)
84ee6c7c31SStefano Zampini                product.setType(tmp.getType())
85ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
86ee6c7c31SStefano Zampini                product.setUp()
87ee6c7c31SStefano Zampini                product.assemble()
88ee6c7c31SStefano Zampini                tmp.copy(product)
89ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity^T
90ee6c7c31SStefano Zampini                product.setType(A.getType())
91ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
92ee6c7c31SStefano Zampini                product.setUp()
93ee6c7c31SStefano Zampini                product.assemble()
94ee6c7c31SStefano Zampini                A.copy(product)
95ee6c7c31SStefano Zampini            else:
96ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
97ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
98ee6c7c31SStefano Zampini            if mat is A:  # product = P^T * identity * P
99ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
100ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
101ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
102ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
103ee6c7c31SStefano Zampini                product.setUp()
104ee6c7c31SStefano Zampini                product.assemble()
105ee6c7c31SStefano Zampini                self.tmp.copy(product)
106ee6c7c31SStefano Zampini            elif mat is B:  # product = identity^T * A * identity
107ee6c7c31SStefano Zampini                product.setType(A.getType())
108ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
109ee6c7c31SStefano Zampini                product.setUp()
110ee6c7c31SStefano Zampini                product.assemble()
111ee6c7c31SStefano Zampini                A.copy(product)
112ee6c7c31SStefano Zampini            else:
113ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
114ee6c7c31SStefano Zampini        elif producttype == 'RARt':
115ee6c7c31SStefano Zampini            if mat is A:  # product = R * identity * R^t
116ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
117ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
118ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
119ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
120ee6c7c31SStefano Zampini                product.setUp()
121ee6c7c31SStefano Zampini                product.assemble()
122ee6c7c31SStefano Zampini                self.tmp.copy(product)
123ee6c7c31SStefano Zampini            elif mat is B:  # product = identity * A * identity^T
124ee6c7c31SStefano Zampini                product.setType(A.getType())
125ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
126ee6c7c31SStefano Zampini                product.setUp()
127ee6c7c31SStefano Zampini                product.assemble()
128ee6c7c31SStefano Zampini                A.copy(product)
129ee6c7c31SStefano Zampini            else:
130ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
131ee6c7c31SStefano Zampini        elif producttype == 'ABC':
132ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B * C
133ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
134ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
135ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
136ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
137ee6c7c31SStefano Zampini                product.setUp()
138ee6c7c31SStefano Zampini                product.assemble()
139ee6c7c31SStefano Zampini                self.tmp.copy(product)
140ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity * C
141ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
142ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
143ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
144ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
145ee6c7c31SStefano Zampini                product.setUp()
146ee6c7c31SStefano Zampini                product.assemble()
147ee6c7c31SStefano Zampini                self.tmp.copy(product)
148ee6c7c31SStefano Zampini            elif mat is C:  # product = A * B * identity
149ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
150ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
151ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
152ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
153ee6c7c31SStefano Zampini                product.setUp()
154ee6c7c31SStefano Zampini                product.assemble()
155ee6c7c31SStefano Zampini                self.tmp.copy(product)
156ee6c7c31SStefano Zampini            else:
157ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
158ee6c7c31SStefano Zampini        else:
159*6f336411SStefano Zampini            raise RuntimeError(f'Product {producttype} not implemented')
160ee6c7c31SStefano Zampini        product.zeroEntries()
161ee6c7c31SStefano Zampini
162ee6c7c31SStefano Zampini    def productNumeric(self, mat, product, producttype, A, B, C):
163ee6c7c31SStefano Zampini        if producttype == 'AB':
164ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B
165ee6c7c31SStefano Zampini                B.copy(product, structure=True)
166ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity
167ee6c7c31SStefano Zampini                A.copy(product, structure=True)
168ee6c7c31SStefano Zampini            else:
169ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
17022fceea1SStefano Zampini            product.scale(self.s)
171ee6c7c31SStefano Zampini        elif producttype == 'AtB':
172ee6c7c31SStefano Zampini            if mat is A:  # product = identity^T * B
173ee6c7c31SStefano Zampini                B.copy(product, structure=True)
174ee6c7c31SStefano Zampini            elif mat is B:  # product = A^T * identity
1757fb60732SBarry Smith                A.setTransposePrecursor(product)
176ee6c7c31SStefano Zampini                A.transpose(product)
177ee6c7c31SStefano Zampini            else:
178ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
17922fceea1SStefano Zampini            product.scale(self.s)
180ee6c7c31SStefano Zampini        elif producttype == 'ABt':
181ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B^T
1827fb60732SBarry Smith                B.setTransposePrecursor(product)
183ee6c7c31SStefano Zampini                B.transpose(product)
184ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity^T
185ee6c7c31SStefano Zampini                A.copy(product, structure=True)
186ee6c7c31SStefano Zampini            else:
187ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
18822fceea1SStefano Zampini            product.scale(self.s)
189ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
190ee6c7c31SStefano Zampini            if mat is A:  # product = P^T * identity * P
191ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
192ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
19322fceea1SStefano Zampini                product.scale(self.s)
194ee6c7c31SStefano Zampini            elif mat is B:  # product = identity^T * A * identity
195ee6c7c31SStefano Zampini                A.copy(product, structure=True)
19622fceea1SStefano Zampini                product.scale(self.s**2)
197ee6c7c31SStefano Zampini            else:
198ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
199ee6c7c31SStefano Zampini        elif producttype == 'RARt':
200ee6c7c31SStefano Zampini            if mat is A:  # product = R * identity * R^t
201ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
202ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
20322fceea1SStefano Zampini                product.scale(self.s)
204ee6c7c31SStefano Zampini            elif mat is B:  # product = identity * A * identity^T
205ee6c7c31SStefano Zampini                A.copy(product, structure=True)
20622fceea1SStefano Zampini                product.scale(self.s**2)
207ee6c7c31SStefano Zampini            else:
208ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
209ee6c7c31SStefano Zampini        elif producttype == 'ABC':
210ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B * C
211ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
212ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
213ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity * C
214ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
215ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
216ee6c7c31SStefano Zampini            elif mat is C:  # product = A * B * identity
217ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
218ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
219ee6c7c31SStefano Zampini            else:
220ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
22122fceea1SStefano Zampini            product.scale(self.s)
222ee6c7c31SStefano Zampini        else:
223*6f336411SStefano Zampini            raise RuntimeError(f'Product {producttype} not implemented')
224*6f336411SStefano Zampini
225ee6c7c31SStefano Zampini
2265808f684SSatish Balayclass Diagonal(Matrix):
2275808f684SSatish Balay    def create(self, mat):
228*6f336411SStefano Zampini        super().create(mat)
2295808f684SSatish Balay        mat.setUp()
2305808f684SSatish Balay        self.D = mat.createVecLeft()
2315808f684SSatish Balay
2325808f684SSatish Balay    def destroy(self, mat):
2335808f684SSatish Balay        self.D.destroy()
234*6f336411SStefano Zampini        super().destroy(mat)
2355808f684SSatish Balay
2365808f684SSatish Balay    def scale(self, mat, a):
2375808f684SSatish Balay        self.D.scale(a)
2385808f684SSatish Balay
2395808f684SSatish Balay    def shift(self, mat, a):
2405808f684SSatish Balay        self.D.shift(a)
2415808f684SSatish Balay
2425808f684SSatish Balay    def zeroEntries(self, mat):
2435808f684SSatish Balay        self.D.zeroEntries()
2445808f684SSatish Balay
2455808f684SSatish Balay    def mult(self, mat, x, y):
2465808f684SSatish Balay        y.pointwiseMult(x, self.D)
2475808f684SSatish Balay
248e124b1b1SStefano Zampini    def duplicate(self, mat, op):
249e124b1b1SStefano Zampini        dmat = PETSc.Mat()
250e124b1b1SStefano Zampini        dctx = Diagonal()
251e124b1b1SStefano Zampini        dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm())
252e124b1b1SStefano Zampini        dctx.D = self.D.duplicate()
253e124b1b1SStefano Zampini        if op == PETSc.Mat.DuplicateOption.COPY_VALUES:
254e124b1b1SStefano Zampini            self.D.copy(dctx.D)
255e124b1b1SStefano Zampini            dmat.setUp()
256e124b1b1SStefano Zampini        return dmat
257e124b1b1SStefano Zampini
2585808f684SSatish Balay    def getDiagonal(self, mat, vd):
2595808f684SSatish Balay        self.D.copy(vd)
2605808f684SSatish Balay
2615808f684SSatish Balay    def setDiagonal(self, mat, vd, im):
2625808f684SSatish Balay        if isinstance(im, bool):
2635808f684SSatish Balay            addv = im
2645808f684SSatish Balay            if addv:
2655808f684SSatish Balay                self.D.axpy(1, vd)
2665808f684SSatish Balay            else:
2675808f684SSatish Balay                vd.copy(self.D)
2685808f684SSatish Balay        elif im == PETSc.InsertMode.INSERT_VALUES:
2695808f684SSatish Balay            vd.copy(self.D)
2705808f684SSatish Balay        elif im == PETSc.InsertMode.ADD_VALUES:
2715808f684SSatish Balay            self.D.axpy(1, vd)
2725808f684SSatish Balay        else:
2735808f684SSatish Balay            raise ValueError('wrong InsertMode %d' % im)
2745808f684SSatish Balay
2755808f684SSatish Balay    def diagonalScale(self, mat, vl, vr):
276*6f336411SStefano Zampini        if vl:
277*6f336411SStefano Zampini            self.D.pointwiseMult(self.D, vl)
278*6f336411SStefano Zampini        if vr:
279*6f336411SStefano Zampini            self.D.pointwiseMult(self.D, vr)
280*6f336411SStefano Zampini
2815808f684SSatish Balay
2825808f684SSatish Balay# --------------------------------------------------------------------
2835808f684SSatish Balay
2845808f684SSatish Balay
285*6f336411SStefano Zampiniclass TestMatrix(unittest.TestCase):
2865808f684SSatish Balay    COMM = PETSc.COMM_WORLD
2875808f684SSatish Balay    PYMOD = __name__
2885808f684SSatish Balay    PYCLS = 'Matrix'
2895808f684SSatish Balay
2905808f684SSatish Balay    def _getCtx(self):
2915808f684SSatish Balay        return self.A.getPythonContext()
2925808f684SSatish Balay
2935808f684SSatish Balay    def setUp(self):
294300d917bSStefano Zampini        N = self.N = 13
2955808f684SSatish Balay        self.A = PETSc.Mat()
2965808f684SSatish Balay        if 0:  # command line way
2975808f684SSatish Balay            self.A.create(self.COMM)
2985808f684SSatish Balay            self.A.setSizes([N, N])
2995808f684SSatish Balay            self.A.setType('python')
3005808f684SSatish Balay            OptDB = PETSc.Options(self.A)
301*6f336411SStefano Zampini            OptDB['mat_python_type'] = f'{self.PYMOD}.{self.PYCLS}'
3025808f684SSatish Balay            self.A.setFromOptions()
3035808f684SSatish Balay            self.A.setUp()
3045808f684SSatish Balay            del OptDB['mat_python_type']
3055808f684SSatish Balay            self.assertTrue(self._getCtx() is not None)
3065808f684SSatish Balay        else:  # python way
3075808f684SSatish Balay            context = globals()[self.PYCLS]()
3085808f684SSatish Balay            self.A.createPython([N, N], context, comm=self.COMM)
3095808f684SSatish Balay            self.A.setUp()
3105808f684SSatish Balay            self.assertTrue(self._getCtx() is context)
3115808f684SSatish Balay            self.assertEqual(getrefcount(context), 3)
3125808f684SSatish Balay            del context
3135808f684SSatish Balay            self.assertEqual(getrefcount(self._getCtx()), 2)
3145808f684SSatish Balay
3155808f684SSatish Balay    def tearDown(self):
3165808f684SSatish Balay        ctx = self.A.getPythonContext()
3175808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
3185808f684SSatish Balay        self.A.destroy()  # XXX
3195808f684SSatish Balay        self.A = None
32062e5d2d2SJDBetteridge        PETSc.garbage_cleanup()
3215808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 2)
3225808f684SSatish Balay        # import gc,pprint; pprint.pprint(gc.get_referrers(ctx))
3235808f684SSatish Balay
3245808f684SSatish Balay    def testBasic(self):
3255808f684SSatish Balay        ctx = self.A.getPythonContext()
3265808f684SSatish Balay        self.assertTrue(self._getCtx() is ctx)
3275808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
3285808f684SSatish Balay
3295808f684SSatish Balay    def testZeroEntries(self):
3305808f684SSatish Balay        f = lambda: self.A.zeroEntries()
3315808f684SSatish Balay        self.assertRaises(Exception, f)
3325808f684SSatish Balay
3335808f684SSatish Balay    def testMult(self):
3345808f684SSatish Balay        x, y = self.A.createVecs()
3355808f684SSatish Balay        f = lambda: self.A.mult(x, y)
3365808f684SSatish Balay        self.assertRaises(Exception, f)
3375808f684SSatish Balay
3385808f684SSatish Balay    def testMultTranspose(self):
3395808f684SSatish Balay        x, y = self.A.createVecs()
3405808f684SSatish Balay        f = lambda: self.A.multTranspose(x, y)
3415808f684SSatish Balay        self.assertRaises(Exception, f)
3425808f684SSatish Balay
3435808f684SSatish Balay    def testGetDiagonal(self):
3445808f684SSatish Balay        d = self.A.createVecLeft()
3455808f684SSatish Balay        f = lambda: self.A.getDiagonal(d)
3465808f684SSatish Balay        self.assertRaises(Exception, f)
3475808f684SSatish Balay
3485808f684SSatish Balay    def testSetDiagonal(self):
3495808f684SSatish Balay        d = self.A.createVecLeft()
3505808f684SSatish Balay        f = lambda: self.A.setDiagonal(d)
3515808f684SSatish Balay        self.assertRaises(Exception, f)
3525808f684SSatish Balay
3535808f684SSatish Balay    def testDiagonalScale(self):
3545808f684SSatish Balay        x, y = self.A.createVecs()
3555808f684SSatish Balay        f = lambda: self.A.diagonalScale(x, y)
3565808f684SSatish Balay        self.assertRaises(Exception, f)
3575808f684SSatish Balay
358e124b1b1SStefano Zampini    def testDuplicate(self):
359*6f336411SStefano Zampini        f1 = lambda: self.A.duplicate(True)
360*6f336411SStefano Zampini        f2 = lambda: self.A.duplicate(False)
361e124b1b1SStefano Zampini        self.assertRaises(Exception, f1)
362e124b1b1SStefano Zampini        self.assertRaises(Exception, f2)
363e124b1b1SStefano Zampini
3641cebabd4SStefano Zampini    def testSetVecType(self):
3651cebabd4SStefano Zampini        self.A.setVecType('mpi')
3661cebabd4SStefano Zampini        self.assertTrue('mpi' == self.A.getVecType())
3671cebabd4SStefano Zampini
368300d917bSStefano Zampini    def testH2Opus(self):
369*6f336411SStefano Zampini        if not PETSc.Sys.hasExternalPackage('h2opus'):
370300d917bSStefano Zampini            return
371300d917bSStefano Zampini        if self.A.getComm().Get_size() > 1:
372300d917bSStefano Zampini            return
373300d917bSStefano Zampini        h = PETSc.Mat()
374300d917bSStefano Zampini
375300d917bSStefano Zampini        # need matrix vector and its transpose for norm estimation
376300d917bSStefano Zampini        AA = self.A.getPythonContext()
377300d917bSStefano Zampini        if not hasattr(AA, 'mult'):
378300d917bSStefano Zampini            return
379300d917bSStefano Zampini        AA.multTranspose = AA.mult
380300d917bSStefano Zampini
381300d917bSStefano Zampini        # without coordinates
382300d917bSStefano Zampini        h.createH2OpusFromMat(self.A, leafsize=2)
383300d917bSStefano Zampini        h.assemble()
384300d917bSStefano Zampini        h.destroy()
385300d917bSStefano Zampini
386300d917bSStefano Zampini        # with coordinates
387*6f336411SStefano Zampini        coords = numpy.linspace(
388*6f336411SStefano Zampini            (1, 2, 3), (10, 20, 30), self.A.getSize()[0], dtype=PETSc.RealType
389*6f336411SStefano Zampini        )
390300d917bSStefano Zampini        h.createH2OpusFromMat(self.A, coords, leafsize=2)
391300d917bSStefano Zampini        h.assemble()
392300d917bSStefano Zampini
393300d917bSStefano Zampini        # test API
394300d917bSStefano Zampini        h.H2OpusOrthogonalize()
395*6f336411SStefano Zampini        h.H2OpusCompress(1.0e-1)
396300d917bSStefano Zampini
397300d917bSStefano Zampini        # Low-rank update
398300d917bSStefano Zampini        U = PETSc.Mat()
399300d917bSStefano Zampini        U.createDense([h.getSizes()[0], 3], comm=h.getComm())
400300d917bSStefano Zampini        U.setUp()
401300d917bSStefano Zampini        U.setRandom()
402300d917bSStefano Zampini
403300d917bSStefano Zampini        he = PETSc.Mat()
404300d917bSStefano Zampini        h.convert('dense', he)
405300d917bSStefano Zampini        he.axpy(1.0, U.matTransposeMult(U))
406300d917bSStefano Zampini
407300d917bSStefano Zampini        h.H2OpusLowRankUpdate(U)
408300d917bSStefano Zampini        self.assertTrue(he.equal(h))
409300d917bSStefano Zampini
410300d917bSStefano Zampini        h.destroy()
411300d917bSStefano Zampini
412300d917bSStefano Zampini        del AA.multTranspose
413300d917bSStefano Zampini
414ebead697SStefano Zampini    def testGetType(self):
415ebead697SStefano Zampini        ctx = self.A.getPythonContext()
416*6f336411SStefano Zampini        pytype = f'{ctx.__module__}.{type(ctx).__name__}'
417ebead697SStefano Zampini        self.assertTrue(self.A.getPythonType() == pytype)
418300d917bSStefano Zampini
4195808f684SSatish Balay
420*6f336411SStefano Zampiniclass TestScaledIdentity(TestMatrix):
42122fceea1SStefano Zampini    PYCLS = 'ScaledIdentity'
4225808f684SSatish Balay
4235808f684SSatish Balay    def testMult(self):
42422fceea1SStefano Zampini        s = self._getCtx().s
4255808f684SSatish Balay        x, y = self.A.createVecs()
4265808f684SSatish Balay        x.setRandom()
4275808f684SSatish Balay        self.A.mult(x, y)
42822fceea1SStefano Zampini        self.assertTrue(y.equal(s * x))
4295808f684SSatish Balay
4305808f684SSatish Balay    def testMultTransposeSymmKnown(self):
43122fceea1SStefano Zampini        s = self._getCtx().s
4325808f684SSatish Balay        x, y = self.A.createVecs()
4335808f684SSatish Balay        x.setRandom()
4345808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
4355808f684SSatish Balay        self.A.multTranspose(x, y)
43622fceea1SStefano Zampini        self.assertTrue(y.equal(s * x))
4375808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
4385808f684SSatish Balay        f = lambda: self.A.multTranspose(x, y)
4395808f684SSatish Balay        self.assertRaises(Exception, f)
4405808f684SSatish Balay
4415808f684SSatish Balay    def testMultTransposeNewMeth(self):
44222fceea1SStefano Zampini        s = self._getCtx().s
4435808f684SSatish Balay        x, y = self.A.createVecs()
4445808f684SSatish Balay        x.setRandom()
4455808f684SSatish Balay        AA = self.A.getPythonContext()
4465808f684SSatish Balay        AA.multTranspose = AA.mult
4475808f684SSatish Balay        self.A.multTranspose(x, y)
4485808f684SSatish Balay        del AA.multTranspose
44922fceea1SStefano Zampini        self.assertTrue(y.equal(s * x))
4505808f684SSatish Balay
4515808f684SSatish Balay    def testGetDiagonal(self):
45222fceea1SStefano Zampini        s = self._getCtx().s
4535808f684SSatish Balay        d = self.A.createVecLeft()
4545808f684SSatish Balay        o = d.duplicate()
45522fceea1SStefano Zampini        o.set(s)
4565808f684SSatish Balay        self.A.getDiagonal(d)
4575808f684SSatish Balay        self.assertTrue(o.equal(d))
4585808f684SSatish Balay
459e124b1b1SStefano Zampini    def testDuplicate(self):
460e124b1b1SStefano Zampini        B = self.A.duplicate(False)
461e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().s == 2)
462e124b1b1SStefano Zampini        B = self.A.duplicate(True)
463e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().s == self.A.getPythonContext().s)
464e124b1b1SStefano Zampini
465ee6c7c31SStefano Zampini    def testMatMat(self):
46622fceea1SStefano Zampini        s = self._getCtx().s
467ee6c7c31SStefano Zampini        R = PETSc.Random().create(self.COMM)
468ee6c7c31SStefano Zampini        R.setFromOptions()
469ee6c7c31SStefano Zampini        A = PETSc.Mat().create(self.COMM)
470ee6c7c31SStefano Zampini        A.setSizes(self.A.getSizes())
471ee6c7c31SStefano Zampini        A.setType(PETSc.Mat.Type.AIJ)
47226cec326SBarry Smith        A.setPreallocationNNZ(None)
473ee6c7c31SStefano Zampini        A.setRandom(R)
474ee6c7c31SStefano Zampini        B = PETSc.Mat().create(self.COMM)
475ee6c7c31SStefano Zampini        B.setSizes(self.A.getSizes())
476ee6c7c31SStefano Zampini        B.setType(PETSc.Mat.Type.AIJ)
47726cec326SBarry Smith        B.setPreallocationNNZ(None)
478ee6c7c31SStefano Zampini        B.setRandom(R)
479*6f336411SStefano Zampini        Id = PETSc.Mat().create(self.COMM)
480*6f336411SStefano Zampini        Id.setSizes(self.A.getSizes())
481*6f336411SStefano Zampini        Id.setType(PETSc.Mat.Type.AIJ)
482*6f336411SStefano Zampini        Id.setUp()
483*6f336411SStefano Zampini        Id.assemble()
484*6f336411SStefano Zampini        Id.shift(s)
485ee6c7c31SStefano Zampini
486*6f336411SStefano Zampini        self.assertTrue(self.A.matMult(A).equal(Id.matMult(A)))
487*6f336411SStefano Zampini        self.assertTrue(A.matMult(self.A).equal(A.matMult(Id)))
488ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
489*6f336411SStefano Zampini            self.assertTrue(self.A.matTransposeMult(A).equal(Id.matTransposeMult(A)))
490*6f336411SStefano Zampini            self.assertTrue(A.matTransposeMult(self.A).equal(A.matTransposeMult(Id)))
491*6f336411SStefano Zampini        self.assertTrue(self.A.transposeMatMult(A).equal(Id.transposeMatMult(A)))
492*6f336411SStefano Zampini        self.assertTrue(A.transposeMatMult(self.A).equal(A.transposeMatMult(Id)))
493*6f336411SStefano Zampini        self.assertAlmostEqual((self.A.ptap(A) - Id.ptap(A)).norm(), 0.0, places=5)
494*6f336411SStefano Zampini        self.assertAlmostEqual((A.ptap(self.A) - A.ptap(Id)).norm(), 0.0, places=5)
495ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
496*6f336411SStefano Zampini            self.assertAlmostEqual((self.A.rart(A) - Id.rart(A)).norm(), 0.0, places=5)
497*6f336411SStefano Zampini            self.assertAlmostEqual((A.rart(self.A) - A.rart(Id)).norm(), 0.0, places=5)
498*6f336411SStefano Zampini        self.assertAlmostEqual(
499*6f336411SStefano Zampini            (self.A.matMatMult(A, B) - Id.matMatMult(A, B)).norm(), 0.0, places=5
500*6f336411SStefano Zampini        )
501*6f336411SStefano Zampini        self.assertAlmostEqual(
502*6f336411SStefano Zampini            (A.matMatMult(self.A, B) - A.matMatMult(Id, B)).norm(), 0.0, places=5
503*6f336411SStefano Zampini        )
504*6f336411SStefano Zampini        self.assertAlmostEqual(
505*6f336411SStefano Zampini            (A.matMatMult(B, self.A) - A.matMatMult(B, Id)).norm(), 0.0, places=5
506*6f336411SStefano Zampini        )
507ee6c7c31SStefano Zampini
50822fceea1SStefano Zampini    def testShift(self):
50922fceea1SStefano Zampini        sold = self._getCtx().s
51022fceea1SStefano Zampini        self.A.shift(-0.5)
51122fceea1SStefano Zampini        s = self._getCtx().s
51222fceea1SStefano Zampini        self.assertTrue(s == sold - 0.5)
51322fceea1SStefano Zampini
51422fceea1SStefano Zampini    def testScale(self):
51522fceea1SStefano Zampini        sold = self._getCtx().s
51622fceea1SStefano Zampini        self.A.scale(-0.5)
51722fceea1SStefano Zampini        s = self._getCtx().s
51822fceea1SStefano Zampini        self.assertTrue(s == sold * -0.5)
51922fceea1SStefano Zampini
5209e7eb791SStefano Zampini    def testDiagonalMat(self):
5219e7eb791SStefano Zampini        s = self._getCtx().s
522*6f336411SStefano Zampini        B = PETSc.Mat().createConstantDiagonal(
523*6f336411SStefano Zampini            self.A.getSizes(), s, comm=self.A.getComm()
524*6f336411SStefano Zampini        )
5259e7eb791SStefano Zampini        self.assertTrue(self.A.equal(B))
5269e7eb791SStefano Zampini
5275808f684SSatish Balay
528*6f336411SStefano Zampiniclass TestDiagonal(TestMatrix):
5295808f684SSatish Balay    PYCLS = 'Diagonal'
5305808f684SSatish Balay
5315808f684SSatish Balay    def setUp(self):
532*6f336411SStefano Zampini        super().setUp()
5335808f684SSatish Balay        D = self.A.createVecLeft()
5345808f684SSatish Balay        s, e = D.getOwnershipRange()
5355808f684SSatish Balay        for i in range(s, e):
5365808f684SSatish Balay            D[i] = i + 1
5375808f684SSatish Balay        D.assemble()
5385808f684SSatish Balay        self.A.setDiagonal(D)
5395808f684SSatish Balay
5405808f684SSatish Balay    def testZeroEntries(self):
5415808f684SSatish Balay        self.A.zeroEntries()
5425808f684SSatish Balay        D = self._getCtx().D
5435808f684SSatish Balay        self.assertEqual(D.norm(), 0)
5445808f684SSatish Balay
5455808f684SSatish Balay    def testMult(self):
5465808f684SSatish Balay        x, y = self.A.createVecs()
5475808f684SSatish Balay        x.set(1)
5485808f684SSatish Balay        self.A.mult(x, y)
5495808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5505808f684SSatish Balay
5515808f684SSatish Balay    def testMultTransposeSymmKnown(self):
5525808f684SSatish Balay        x, y = self.A.createVecs()
5535808f684SSatish Balay        x.set(1)
5545808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
5555808f684SSatish Balay        self.A.multTranspose(x, y)
5565808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5575808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
5585808f684SSatish Balay        f = lambda: self.A.multTranspose(x, y)
5595808f684SSatish Balay        self.assertRaises(Exception, f)
5605808f684SSatish Balay
5615808f684SSatish Balay    def testMultTransposeNewMeth(self):
5625808f684SSatish Balay        x, y = self.A.createVecs()
5635808f684SSatish Balay        x.set(1)
5645808f684SSatish Balay        AA = self.A.getPythonContext()
5655808f684SSatish Balay        AA.multTranspose = AA.mult
5665808f684SSatish Balay        self.A.multTranspose(x, y)
5675808f684SSatish Balay        del AA.multTranspose
5685808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5695808f684SSatish Balay
570e124b1b1SStefano Zampini    def testDuplicate(self):
571e124b1b1SStefano Zampini        B = self.A.duplicate(False)
572e124b1b1SStefano Zampini        B = self.A.duplicate(True)
573e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().D.equal(self.A.getPythonContext().D))
574e124b1b1SStefano Zampini
5755808f684SSatish Balay    def testGetDiagonal(self):
5765808f684SSatish Balay        d = self.A.createVecLeft()
5775808f684SSatish Balay        self.A.getDiagonal(d)
5785808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
5795808f684SSatish Balay
5805808f684SSatish Balay    def testSetDiagonal(self):
5815808f684SSatish Balay        d = self.A.createVecLeft()
5825808f684SSatish Balay        d.setRandom()
5835808f684SSatish Balay        self.A.setDiagonal(d)
5845808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
5855808f684SSatish Balay
5865808f684SSatish Balay    def testDiagonalScale(self):
5875808f684SSatish Balay        x, y = self.A.createVecs()
5885808f684SSatish Balay        x.set(2)
5895808f684SSatish Balay        y.set(3)
5905808f684SSatish Balay        old = self._getCtx().D.copy()
5915808f684SSatish Balay        self.A.diagonalScale(x, y)
5925808f684SSatish Balay        D = self._getCtx().D
5935808f684SSatish Balay        self.assertTrue(D.equal(old * 6))
5945808f684SSatish Balay
5955808f684SSatish Balay    def testCreateTranspose(self):
5965808f684SSatish Balay        A = self.A
5975808f684SSatish Balay        A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
5985808f684SSatish Balay        AT = PETSc.Mat().createTranspose(A)
5995808f684SSatish Balay        x, y = A.createVecs()
6005808f684SSatish Balay        xt, yt = AT.createVecs()
6015808f684SSatish Balay        #
6025808f684SSatish Balay        y.setRandom()
6035808f684SSatish Balay        A.multTranspose(y, x)
6045808f684SSatish Balay        y.copy(xt)
6055808f684SSatish Balay        AT.mult(xt, yt)
6065808f684SSatish Balay        self.assertTrue(yt.equal(x))
6075808f684SSatish Balay        #
6085808f684SSatish Balay        x.setRandom()
6095808f684SSatish Balay        A.mult(x, y)
6105808f684SSatish Balay        x.copy(yt)
6115808f684SSatish Balay        AT.multTranspose(yt, xt)
6125808f684SSatish Balay        self.assertTrue(xt.equal(y))
6135808f684SSatish Balay        del A
6145808f684SSatish Balay
6158af18dd8SStefano Zampini    def testConvert(self):
6168af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.AIJ, PETSc.Mat()).equal(self.A))
6178af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.BAIJ, PETSc.Mat()).equal(self.A))
6188af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.SBAIJ, PETSc.Mat()).equal(self.A))
6198af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.DENSE, PETSc.Mat()).equal(self.A))
6208c2316a8SJeremy Tillay
62122fceea1SStefano Zampini    def testShift(self):
62222fceea1SStefano Zampini        old = self._getCtx().D.copy()
62322fceea1SStefano Zampini        self.A.shift(-0.5)
62422fceea1SStefano Zampini        D = self._getCtx().D
62522fceea1SStefano Zampini        self.assertTrue(D.equal(old - 0.5))
62622fceea1SStefano Zampini
62722fceea1SStefano Zampini    def testScale(self):
62822fceea1SStefano Zampini        old = self._getCtx().D.copy()
62922fceea1SStefano Zampini        self.A.scale(-0.5)
63022fceea1SStefano Zampini        D = self._getCtx().D
63122fceea1SStefano Zampini        self.assertTrue(D.equal(-0.5 * old))
63222fceea1SStefano Zampini
6339e7eb791SStefano Zampini    def testDiagonalMat(self):
6349e7eb791SStefano Zampini        D = self._getCtx().D.copy()
6359e7eb791SStefano Zampini        B = PETSc.Mat().createDiagonal(D)
6369e7eb791SStefano Zampini        self.assertTrue(self.A.equal(B))
6379e7eb791SStefano Zampini
63822fceea1SStefano Zampini
6395808f684SSatish Balay# --------------------------------------------------------------------
6405808f684SSatish Balay
6415808f684SSatish Balayif __name__ == '__main__':
6425808f684SSatish Balay    unittest.main()
643