xref: /petsc/src/binding/petsc4py/test/test_mat_py.py (revision b2584804908b6ae8fffb813f76258847e9469937)
15808f684SSatish Balayfrom petsc4py import PETSc
26f336411SStefano Zampiniimport unittest
36f336411SStefano Zampiniimport numpy
45808f684SSatish Balayfrom sys import getrefcount
55808f684SSatish Balay# --------------------------------------------------------------------
65808f684SSatish Balay
75808f684SSatish Balay
86f336411SStefano Zampiniclass Matrix:
9*b2584804SStefano Zampini    setupcalled = 0
10*b2584804SStefano Zampini
115808f684SSatish Balay    def __init__(self):
125808f684SSatish Balay        pass
135808f684SSatish Balay
145808f684SSatish Balay    def create(self, mat):
155808f684SSatish Balay        pass
165808f684SSatish Balay
175808f684SSatish Balay    def destroy(self, mat):
185808f684SSatish Balay        pass
195808f684SSatish Balay
20*b2584804SStefano Zampini    def setUp(self, mat):
21*b2584804SStefano Zampini        self.setupcalled += 1
2222fceea1SStefano Zampini
236f336411SStefano Zampiniclass ScaledIdentity(Matrix):
2422fceea1SStefano Zampini    s = 2.0
2522fceea1SStefano Zampini
2622fceea1SStefano Zampini    def scale(self, mat, s):
2722fceea1SStefano Zampini        self.s *= s
2822fceea1SStefano Zampini
2922fceea1SStefano Zampini    def shift(self, mat, s):
3022fceea1SStefano Zampini        self.s += s
315808f684SSatish Balay
325808f684SSatish Balay    def mult(self, mat, x, y):
335808f684SSatish Balay        x.copy(y)
3422fceea1SStefano Zampini        y.scale(self.s)
355808f684SSatish Balay
36e124b1b1SStefano Zampini    def duplicate(self, mat, op):
37e124b1b1SStefano Zampini        dmat = PETSc.Mat()
38e124b1b1SStefano Zampini        dctx = ScaledIdentity()
39e124b1b1SStefano Zampini        dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm())
40e124b1b1SStefano Zampini        if op == PETSc.Mat.DuplicateOption.COPY_VALUES:
41e124b1b1SStefano Zampini            dctx.s = self.s
42e124b1b1SStefano Zampini            dmat.setUp()
43e124b1b1SStefano Zampini        return dmat
44e124b1b1SStefano Zampini
455808f684SSatish Balay    def getDiagonal(self, mat, vd):
4622fceea1SStefano Zampini        vd.set(self.s)
475808f684SSatish Balay
48ee6c7c31SStefano Zampini    def productSetFromOptions(self, mat, producttype, A, B, C):
49ee6c7c31SStefano Zampini        return True
50ee6c7c31SStefano Zampini
51ee6c7c31SStefano Zampini    def productSymbolic(self, mat, product, producttype, A, B, C):
52ee6c7c31SStefano Zampini        if producttype == 'AB':
53ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B
54ee6c7c31SStefano Zampini                product.setType(B.getType())
55ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
56ee6c7c31SStefano Zampini                product.setUp()
57ee6c7c31SStefano Zampini                product.assemble()
58ee6c7c31SStefano Zampini                B.copy(product)
59ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity
60ee6c7c31SStefano Zampini                product.setType(A.getType())
61ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
62ee6c7c31SStefano Zampini                product.setUp()
63ee6c7c31SStefano Zampini                product.assemble()
64ee6c7c31SStefano Zampini                A.copy(product)
65ee6c7c31SStefano Zampini            else:
66ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
67ee6c7c31SStefano Zampini        elif producttype == 'AtB':
68ee6c7c31SStefano Zampini            if mat is A:  # product = identity^T * B
69ee6c7c31SStefano Zampini                product.setType(B.getType())
70ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
71ee6c7c31SStefano Zampini                product.setUp()
72ee6c7c31SStefano Zampini                product.assemble()
73ee6c7c31SStefano Zampini                B.copy(product)
74ee6c7c31SStefano Zampini            elif mat is B:  # product = A^T * identity
75ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
76ee6c7c31SStefano Zampini                A.transpose(tmp)
77ee6c7c31SStefano Zampini                product.setType(tmp.getType())
78ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
79ee6c7c31SStefano Zampini                product.setUp()
80ee6c7c31SStefano Zampini                product.assemble()
81ee6c7c31SStefano Zampini                tmp.copy(product)
82ee6c7c31SStefano Zampini            else:
83ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
84ee6c7c31SStefano Zampini        elif producttype == 'ABt':
85ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B^T
86ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
87ee6c7c31SStefano Zampini                B.transpose(tmp)
88ee6c7c31SStefano Zampini                product.setType(tmp.getType())
89ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
90ee6c7c31SStefano Zampini                product.setUp()
91ee6c7c31SStefano Zampini                product.assemble()
92ee6c7c31SStefano Zampini                tmp.copy(product)
93ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity^T
94ee6c7c31SStefano Zampini                product.setType(A.getType())
95ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
96ee6c7c31SStefano Zampini                product.setUp()
97ee6c7c31SStefano Zampini                product.assemble()
98ee6c7c31SStefano Zampini                A.copy(product)
99ee6c7c31SStefano Zampini            else:
100ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
101ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
102ee6c7c31SStefano Zampini            if mat is A:  # product = P^T * identity * P
103ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
104ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
105ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
106ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
107ee6c7c31SStefano Zampini                product.setUp()
108ee6c7c31SStefano Zampini                product.assemble()
109ee6c7c31SStefano Zampini                self.tmp.copy(product)
110ee6c7c31SStefano Zampini            elif mat is B:  # product = identity^T * A * identity
111ee6c7c31SStefano Zampini                product.setType(A.getType())
112ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
113ee6c7c31SStefano Zampini                product.setUp()
114ee6c7c31SStefano Zampini                product.assemble()
115ee6c7c31SStefano Zampini                A.copy(product)
116ee6c7c31SStefano Zampini            else:
117ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
118ee6c7c31SStefano Zampini        elif producttype == 'RARt':
119ee6c7c31SStefano Zampini            if mat is A:  # product = R * identity * R^t
120ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
121ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
122ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
123ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
124ee6c7c31SStefano Zampini                product.setUp()
125ee6c7c31SStefano Zampini                product.assemble()
126ee6c7c31SStefano Zampini                self.tmp.copy(product)
127ee6c7c31SStefano Zampini            elif mat is B:  # product = identity * A * identity^T
128ee6c7c31SStefano Zampini                product.setType(A.getType())
129ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
130ee6c7c31SStefano Zampini                product.setUp()
131ee6c7c31SStefano Zampini                product.assemble()
132ee6c7c31SStefano Zampini                A.copy(product)
133ee6c7c31SStefano Zampini            else:
134ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
135ee6c7c31SStefano Zampini        elif producttype == 'ABC':
136ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B * C
137ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
138ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
139ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
140ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
141ee6c7c31SStefano Zampini                product.setUp()
142ee6c7c31SStefano Zampini                product.assemble()
143ee6c7c31SStefano Zampini                self.tmp.copy(product)
144ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity * C
145ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
146ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
147ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
148ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
149ee6c7c31SStefano Zampini                product.setUp()
150ee6c7c31SStefano Zampini                product.assemble()
151ee6c7c31SStefano Zampini                self.tmp.copy(product)
152ee6c7c31SStefano Zampini            elif mat is C:  # product = A * B * identity
153ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
154ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
155ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
156ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
157ee6c7c31SStefano Zampini                product.setUp()
158ee6c7c31SStefano Zampini                product.assemble()
159ee6c7c31SStefano Zampini                self.tmp.copy(product)
160ee6c7c31SStefano Zampini            else:
161ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
162ee6c7c31SStefano Zampini        else:
1636f336411SStefano Zampini            raise RuntimeError(f'Product {producttype} not implemented')
164ee6c7c31SStefano Zampini        product.zeroEntries()
165ee6c7c31SStefano Zampini
166ee6c7c31SStefano Zampini    def productNumeric(self, mat, product, producttype, A, B, C):
167ee6c7c31SStefano Zampini        if producttype == 'AB':
168ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B
169ee6c7c31SStefano Zampini                B.copy(product, structure=True)
170ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity
171ee6c7c31SStefano Zampini                A.copy(product, structure=True)
172ee6c7c31SStefano Zampini            else:
173ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
17422fceea1SStefano Zampini            product.scale(self.s)
175ee6c7c31SStefano Zampini        elif producttype == 'AtB':
176ee6c7c31SStefano Zampini            if mat is A:  # product = identity^T * B
177ee6c7c31SStefano Zampini                B.copy(product, structure=True)
178ee6c7c31SStefano Zampini            elif mat is B:  # product = A^T * identity
1797fb60732SBarry Smith                A.setTransposePrecursor(product)
180ee6c7c31SStefano Zampini                A.transpose(product)
181ee6c7c31SStefano Zampini            else:
182ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
18322fceea1SStefano Zampini            product.scale(self.s)
184ee6c7c31SStefano Zampini        elif producttype == 'ABt':
185ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B^T
1867fb60732SBarry Smith                B.setTransposePrecursor(product)
187ee6c7c31SStefano Zampini                B.transpose(product)
188ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity^T
189ee6c7c31SStefano Zampini                A.copy(product, structure=True)
190ee6c7c31SStefano Zampini            else:
191ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
19222fceea1SStefano Zampini            product.scale(self.s)
193ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
194ee6c7c31SStefano Zampini            if mat is A:  # product = P^T * identity * P
195ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
196ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
19722fceea1SStefano Zampini                product.scale(self.s)
198ee6c7c31SStefano Zampini            elif mat is B:  # product = identity^T * A * identity
199ee6c7c31SStefano Zampini                A.copy(product, structure=True)
20022fceea1SStefano Zampini                product.scale(self.s**2)
201ee6c7c31SStefano Zampini            else:
202ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
203ee6c7c31SStefano Zampini        elif producttype == 'RARt':
204ee6c7c31SStefano Zampini            if mat is A:  # product = R * identity * R^t
205ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
206ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
20722fceea1SStefano Zampini                product.scale(self.s)
208ee6c7c31SStefano Zampini            elif mat is B:  # product = identity * A * identity^T
209ee6c7c31SStefano Zampini                A.copy(product, structure=True)
21022fceea1SStefano Zampini                product.scale(self.s**2)
211ee6c7c31SStefano Zampini            else:
212ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
213ee6c7c31SStefano Zampini        elif producttype == 'ABC':
214ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B * C
215ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
216ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
217ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity * C
218ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
219ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
220ee6c7c31SStefano Zampini            elif mat is C:  # product = A * B * identity
221ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
222ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
223ee6c7c31SStefano Zampini            else:
224ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
22522fceea1SStefano Zampini            product.scale(self.s)
226ee6c7c31SStefano Zampini        else:
2276f336411SStefano Zampini            raise RuntimeError(f'Product {producttype} not implemented')
2286f336411SStefano Zampini
229ee6c7c31SStefano Zampini
2305808f684SSatish Balayclass Diagonal(Matrix):
2315808f684SSatish Balay    def create(self, mat):
2326f336411SStefano Zampini        super().create(mat)
2335808f684SSatish Balay        mat.setUp()
2345808f684SSatish Balay        self.D = mat.createVecLeft()
2355808f684SSatish Balay
2365808f684SSatish Balay    def destroy(self, mat):
2375808f684SSatish Balay        self.D.destroy()
2386f336411SStefano Zampini        super().destroy(mat)
2395808f684SSatish Balay
2405808f684SSatish Balay    def scale(self, mat, a):
2415808f684SSatish Balay        self.D.scale(a)
2425808f684SSatish Balay
2435808f684SSatish Balay    def shift(self, mat, a):
2445808f684SSatish Balay        self.D.shift(a)
2455808f684SSatish Balay
2465808f684SSatish Balay    def zeroEntries(self, mat):
2475808f684SSatish Balay        self.D.zeroEntries()
2485808f684SSatish Balay
2495808f684SSatish Balay    def mult(self, mat, x, y):
2505808f684SSatish Balay        y.pointwiseMult(x, self.D)
2515808f684SSatish Balay
252e124b1b1SStefano Zampini    def duplicate(self, mat, op):
253e124b1b1SStefano Zampini        dmat = PETSc.Mat()
254e124b1b1SStefano Zampini        dctx = Diagonal()
255e124b1b1SStefano Zampini        dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm())
256e124b1b1SStefano Zampini        dctx.D = self.D.duplicate()
257e124b1b1SStefano Zampini        if op == PETSc.Mat.DuplicateOption.COPY_VALUES:
258e124b1b1SStefano Zampini            self.D.copy(dctx.D)
259e124b1b1SStefano Zampini            dmat.setUp()
260e124b1b1SStefano Zampini        return dmat
261e124b1b1SStefano Zampini
2625808f684SSatish Balay    def getDiagonal(self, mat, vd):
2635808f684SSatish Balay        self.D.copy(vd)
2645808f684SSatish Balay
2655808f684SSatish Balay    def setDiagonal(self, mat, vd, im):
2665808f684SSatish Balay        if isinstance(im, bool):
2675808f684SSatish Balay            addv = im
2685808f684SSatish Balay            if addv:
2695808f684SSatish Balay                self.D.axpy(1, vd)
2705808f684SSatish Balay            else:
2715808f684SSatish Balay                vd.copy(self.D)
2725808f684SSatish Balay        elif im == PETSc.InsertMode.INSERT_VALUES:
2735808f684SSatish Balay            vd.copy(self.D)
2745808f684SSatish Balay        elif im == PETSc.InsertMode.ADD_VALUES:
2755808f684SSatish Balay            self.D.axpy(1, vd)
2765808f684SSatish Balay        else:
2775808f684SSatish Balay            raise ValueError('wrong InsertMode %d' % im)
2785808f684SSatish Balay
2795808f684SSatish Balay    def diagonalScale(self, mat, vl, vr):
2806f336411SStefano Zampini        if vl:
2816f336411SStefano Zampini            self.D.pointwiseMult(self.D, vl)
2826f336411SStefano Zampini        if vr:
2836f336411SStefano Zampini            self.D.pointwiseMult(self.D, vr)
2846f336411SStefano Zampini
2855808f684SSatish Balay
2865808f684SSatish Balay# --------------------------------------------------------------------
2875808f684SSatish Balay
2885808f684SSatish Balay
2896f336411SStefano Zampiniclass TestMatrix(unittest.TestCase):
2905808f684SSatish Balay    COMM = PETSc.COMM_WORLD
2915808f684SSatish Balay    PYMOD = __name__
2925808f684SSatish Balay    PYCLS = 'Matrix'
293*b2584804SStefano Zampini    CREATE_WITH_NONE = False
2945808f684SSatish Balay
2955808f684SSatish Balay    def _getCtx(self):
2965808f684SSatish Balay        return self.A.getPythonContext()
2975808f684SSatish Balay
2985808f684SSatish Balay    def setUp(self):
299300d917bSStefano Zampini        N = self.N = 13
3005808f684SSatish Balay        self.A = PETSc.Mat()
3015808f684SSatish Balay        if 0:  # command line way
3025808f684SSatish Balay            self.A.create(self.COMM)
3035808f684SSatish Balay            self.A.setSizes([N, N])
3045808f684SSatish Balay            self.A.setType('python')
3055808f684SSatish Balay            OptDB = PETSc.Options(self.A)
3066f336411SStefano Zampini            OptDB['mat_python_type'] = f'{self.PYMOD}.{self.PYCLS}'
3075808f684SSatish Balay            self.A.setFromOptions()
3085808f684SSatish Balay            del OptDB['mat_python_type']
3095808f684SSatish Balay            self.assertTrue(self._getCtx() is not None)
3105808f684SSatish Balay        else:  # python way
3115808f684SSatish Balay            context = globals()[self.PYCLS]()
312*b2584804SStefano Zampini            if self.CREATE_WITH_NONE:  # test passing None as context
313*b2584804SStefano Zampini                self.A.createPython([N, N], None, comm=self.COMM)
314*b2584804SStefano Zampini                self.A.setPythonContext(context)
3155808f684SSatish Balay                self.A.setUp()
316*b2584804SStefano Zampini            else:
317*b2584804SStefano Zampini                self.A.createPython([N, N], context, comm=self.COMM)
3185808f684SSatish Balay            self.assertTrue(self._getCtx() is context)
3195808f684SSatish Balay            self.assertEqual(getrefcount(context), 3)
3205808f684SSatish Balay            del context
3215808f684SSatish Balay            self.assertEqual(getrefcount(self._getCtx()), 2)
3225808f684SSatish Balay
3235808f684SSatish Balay    def tearDown(self):
3245808f684SSatish Balay        ctx = self.A.getPythonContext()
3255808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
3265808f684SSatish Balay        self.A.destroy()  # XXX
3275808f684SSatish Balay        self.A = None
32862e5d2d2SJDBetteridge        PETSc.garbage_cleanup()
3295808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 2)
3305808f684SSatish Balay
3315808f684SSatish Balay    def testBasic(self):
3325808f684SSatish Balay        ctx = self.A.getPythonContext()
3335808f684SSatish Balay        self.assertTrue(self._getCtx() is ctx)
3345808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
3355808f684SSatish Balay
336*b2584804SStefano Zampini    def testSetUp(self):
337*b2584804SStefano Zampini        ctx = self.A.getPythonContext()
338*b2584804SStefano Zampini        setupcalled = ctx.setupcalled
339*b2584804SStefano Zampini        self.A.setUp()
340*b2584804SStefano Zampini        self.assertEqual(setupcalled, ctx.setupcalled)
341*b2584804SStefano Zampini        self.A.setPythonContext(ctx)
342*b2584804SStefano Zampini        self.A.setUp()
343*b2584804SStefano Zampini        self.assertEqual(setupcalled + 1, ctx.setupcalled)
344*b2584804SStefano Zampini
3455808f684SSatish Balay    def testZeroEntries(self):
3465808f684SSatish Balay        f = lambda: self.A.zeroEntries()
3475808f684SSatish Balay        self.assertRaises(Exception, f)
3485808f684SSatish Balay
3495808f684SSatish Balay    def testMult(self):
3505808f684SSatish Balay        x, y = self.A.createVecs()
3515808f684SSatish Balay        f = lambda: self.A.mult(x, y)
3525808f684SSatish Balay        self.assertRaises(Exception, f)
3535808f684SSatish Balay
3545808f684SSatish Balay    def testMultTranspose(self):
3555808f684SSatish Balay        x, y = self.A.createVecs()
3565808f684SSatish Balay        f = lambda: self.A.multTranspose(x, y)
3575808f684SSatish Balay        self.assertRaises(Exception, f)
3585808f684SSatish Balay
3595808f684SSatish Balay    def testGetDiagonal(self):
3605808f684SSatish Balay        d = self.A.createVecLeft()
3615808f684SSatish Balay        f = lambda: self.A.getDiagonal(d)
3625808f684SSatish Balay        self.assertRaises(Exception, f)
3635808f684SSatish Balay
3645808f684SSatish Balay    def testSetDiagonal(self):
3655808f684SSatish Balay        d = self.A.createVecLeft()
3665808f684SSatish Balay        f = lambda: self.A.setDiagonal(d)
3675808f684SSatish Balay        self.assertRaises(Exception, f)
3685808f684SSatish Balay
3695808f684SSatish Balay    def testDiagonalScale(self):
3705808f684SSatish Balay        x, y = self.A.createVecs()
3715808f684SSatish Balay        f = lambda: self.A.diagonalScale(x, y)
3725808f684SSatish Balay        self.assertRaises(Exception, f)
3735808f684SSatish Balay
374e124b1b1SStefano Zampini    def testDuplicate(self):
3756f336411SStefano Zampini        f1 = lambda: self.A.duplicate(True)
3766f336411SStefano Zampini        f2 = lambda: self.A.duplicate(False)
377e124b1b1SStefano Zampini        self.assertRaises(Exception, f1)
378e124b1b1SStefano Zampini        self.assertRaises(Exception, f2)
379e124b1b1SStefano Zampini
3801cebabd4SStefano Zampini    def testSetVecType(self):
3811cebabd4SStefano Zampini        self.A.setVecType('mpi')
3821cebabd4SStefano Zampini        self.assertTrue('mpi' == self.A.getVecType())
3831cebabd4SStefano Zampini
384300d917bSStefano Zampini    def testH2Opus(self):
3856f336411SStefano Zampini        if not PETSc.Sys.hasExternalPackage('h2opus'):
386300d917bSStefano Zampini            return
387300d917bSStefano Zampini        if self.A.getComm().Get_size() > 1:
388300d917bSStefano Zampini            return
389300d917bSStefano Zampini        h = PETSc.Mat()
390300d917bSStefano Zampini
391300d917bSStefano Zampini        # need matrix vector and its transpose for norm estimation
392300d917bSStefano Zampini        AA = self.A.getPythonContext()
393300d917bSStefano Zampini        if not hasattr(AA, 'mult'):
394300d917bSStefano Zampini            return
395300d917bSStefano Zampini        AA.multTranspose = AA.mult
396300d917bSStefano Zampini
397300d917bSStefano Zampini        # without coordinates
398300d917bSStefano Zampini        h.createH2OpusFromMat(self.A, leafsize=2)
399300d917bSStefano Zampini        h.assemble()
400300d917bSStefano Zampini        h.destroy()
401300d917bSStefano Zampini
402300d917bSStefano Zampini        # with coordinates
4036f336411SStefano Zampini        coords = numpy.linspace(
4046f336411SStefano Zampini            (1, 2, 3), (10, 20, 30), self.A.getSize()[0], dtype=PETSc.RealType
4056f336411SStefano Zampini        )
406300d917bSStefano Zampini        h.createH2OpusFromMat(self.A, coords, leafsize=2)
407300d917bSStefano Zampini        h.assemble()
408300d917bSStefano Zampini
409300d917bSStefano Zampini        # test API
410300d917bSStefano Zampini        h.H2OpusOrthogonalize()
4116f336411SStefano Zampini        h.H2OpusCompress(1.0e-1)
412300d917bSStefano Zampini
413300d917bSStefano Zampini        # Low-rank update
414300d917bSStefano Zampini        U = PETSc.Mat()
415300d917bSStefano Zampini        U.createDense([h.getSizes()[0], 3], comm=h.getComm())
416300d917bSStefano Zampini        U.setUp()
417300d917bSStefano Zampini        U.setRandom()
418300d917bSStefano Zampini
419300d917bSStefano Zampini        he = PETSc.Mat()
420300d917bSStefano Zampini        h.convert('dense', he)
421300d917bSStefano Zampini        he.axpy(1.0, U.matTransposeMult(U))
422300d917bSStefano Zampini
423300d917bSStefano Zampini        h.H2OpusLowRankUpdate(U)
424300d917bSStefano Zampini        self.assertTrue(he.equal(h))
425300d917bSStefano Zampini
426300d917bSStefano Zampini        h.destroy()
427300d917bSStefano Zampini
428300d917bSStefano Zampini        del AA.multTranspose
429300d917bSStefano Zampini
430ebead697SStefano Zampini    def testGetType(self):
431ebead697SStefano Zampini        ctx = self.A.getPythonContext()
4326f336411SStefano Zampini        pytype = f'{ctx.__module__}.{type(ctx).__name__}'
433ebead697SStefano Zampini        self.assertTrue(self.A.getPythonType() == pytype)
434300d917bSStefano Zampini
4355808f684SSatish Balay
4366f336411SStefano Zampiniclass TestScaledIdentity(TestMatrix):
43722fceea1SStefano Zampini    PYCLS = 'ScaledIdentity'
4385808f684SSatish Balay
4395808f684SSatish Balay    def testMult(self):
44022fceea1SStefano Zampini        s = self._getCtx().s
4415808f684SSatish Balay        x, y = self.A.createVecs()
4425808f684SSatish Balay        x.setRandom()
4435808f684SSatish Balay        self.A.mult(x, y)
44422fceea1SStefano Zampini        self.assertTrue(y.equal(s * x))
4455808f684SSatish Balay
4465808f684SSatish Balay    def testMultTransposeSymmKnown(self):
44722fceea1SStefano Zampini        s = self._getCtx().s
4485808f684SSatish Balay        x, y = self.A.createVecs()
4495808f684SSatish Balay        x.setRandom()
4505808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
4515808f684SSatish Balay        self.A.multTranspose(x, y)
45222fceea1SStefano Zampini        self.assertTrue(y.equal(s * x))
4535808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
4545808f684SSatish Balay        f = lambda: self.A.multTranspose(x, y)
4555808f684SSatish Balay        self.assertRaises(Exception, f)
4565808f684SSatish Balay
4575808f684SSatish Balay    def testMultTransposeNewMeth(self):
45822fceea1SStefano Zampini        s = self._getCtx().s
4595808f684SSatish Balay        x, y = self.A.createVecs()
4605808f684SSatish Balay        x.setRandom()
4615808f684SSatish Balay        AA = self.A.getPythonContext()
4625808f684SSatish Balay        AA.multTranspose = AA.mult
4635808f684SSatish Balay        self.A.multTranspose(x, y)
4645808f684SSatish Balay        del AA.multTranspose
46522fceea1SStefano Zampini        self.assertTrue(y.equal(s * x))
4665808f684SSatish Balay
4675808f684SSatish Balay    def testGetDiagonal(self):
46822fceea1SStefano Zampini        s = self._getCtx().s
4695808f684SSatish Balay        d = self.A.createVecLeft()
4705808f684SSatish Balay        o = d.duplicate()
47122fceea1SStefano Zampini        o.set(s)
4725808f684SSatish Balay        self.A.getDiagonal(d)
4735808f684SSatish Balay        self.assertTrue(o.equal(d))
4745808f684SSatish Balay
475e124b1b1SStefano Zampini    def testDuplicate(self):
476e124b1b1SStefano Zampini        B = self.A.duplicate(False)
477e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().s == 2)
478e124b1b1SStefano Zampini        B = self.A.duplicate(True)
479e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().s == self.A.getPythonContext().s)
480e124b1b1SStefano Zampini
481ee6c7c31SStefano Zampini    def testMatMat(self):
48222fceea1SStefano Zampini        s = self._getCtx().s
483ee6c7c31SStefano Zampini        R = PETSc.Random().create(self.COMM)
484ee6c7c31SStefano Zampini        R.setFromOptions()
485ee6c7c31SStefano Zampini        A = PETSc.Mat().create(self.COMM)
486ee6c7c31SStefano Zampini        A.setSizes(self.A.getSizes())
487ee6c7c31SStefano Zampini        A.setType(PETSc.Mat.Type.AIJ)
48826cec326SBarry Smith        A.setPreallocationNNZ(None)
489ee6c7c31SStefano Zampini        A.setRandom(R)
490ee6c7c31SStefano Zampini        B = PETSc.Mat().create(self.COMM)
491ee6c7c31SStefano Zampini        B.setSizes(self.A.getSizes())
492ee6c7c31SStefano Zampini        B.setType(PETSc.Mat.Type.AIJ)
49326cec326SBarry Smith        B.setPreallocationNNZ(None)
494ee6c7c31SStefano Zampini        B.setRandom(R)
4956f336411SStefano Zampini        Id = PETSc.Mat().create(self.COMM)
4966f336411SStefano Zampini        Id.setSizes(self.A.getSizes())
4976f336411SStefano Zampini        Id.setType(PETSc.Mat.Type.AIJ)
4986f336411SStefano Zampini        Id.setUp()
4996f336411SStefano Zampini        Id.assemble()
5006f336411SStefano Zampini        Id.shift(s)
501ee6c7c31SStefano Zampini
5026f336411SStefano Zampini        self.assertTrue(self.A.matMult(A).equal(Id.matMult(A)))
5036f336411SStefano Zampini        self.assertTrue(A.matMult(self.A).equal(A.matMult(Id)))
504ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
5056f336411SStefano Zampini            self.assertTrue(self.A.matTransposeMult(A).equal(Id.matTransposeMult(A)))
5066f336411SStefano Zampini            self.assertTrue(A.matTransposeMult(self.A).equal(A.matTransposeMult(Id)))
5076f336411SStefano Zampini        self.assertTrue(self.A.transposeMatMult(A).equal(Id.transposeMatMult(A)))
5086f336411SStefano Zampini        self.assertTrue(A.transposeMatMult(self.A).equal(A.transposeMatMult(Id)))
5096f336411SStefano Zampini        self.assertAlmostEqual((self.A.ptap(A) - Id.ptap(A)).norm(), 0.0, places=5)
5106f336411SStefano Zampini        self.assertAlmostEqual((A.ptap(self.A) - A.ptap(Id)).norm(), 0.0, places=5)
511ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
5126f336411SStefano Zampini            self.assertAlmostEqual((self.A.rart(A) - Id.rart(A)).norm(), 0.0, places=5)
5136f336411SStefano Zampini            self.assertAlmostEqual((A.rart(self.A) - A.rart(Id)).norm(), 0.0, places=5)
5146f336411SStefano Zampini        self.assertAlmostEqual(
5156f336411SStefano Zampini            (self.A.matMatMult(A, B) - Id.matMatMult(A, B)).norm(), 0.0, places=5
5166f336411SStefano Zampini        )
5176f336411SStefano Zampini        self.assertAlmostEqual(
5186f336411SStefano Zampini            (A.matMatMult(self.A, B) - A.matMatMult(Id, B)).norm(), 0.0, places=5
5196f336411SStefano Zampini        )
5206f336411SStefano Zampini        self.assertAlmostEqual(
5216f336411SStefano Zampini            (A.matMatMult(B, self.A) - A.matMatMult(B, Id)).norm(), 0.0, places=5
5226f336411SStefano Zampini        )
523ee6c7c31SStefano Zampini
52422fceea1SStefano Zampini    def testShift(self):
52522fceea1SStefano Zampini        sold = self._getCtx().s
52622fceea1SStefano Zampini        self.A.shift(-0.5)
52722fceea1SStefano Zampini        s = self._getCtx().s
52822fceea1SStefano Zampini        self.assertTrue(s == sold - 0.5)
52922fceea1SStefano Zampini
53022fceea1SStefano Zampini    def testScale(self):
53122fceea1SStefano Zampini        sold = self._getCtx().s
53222fceea1SStefano Zampini        self.A.scale(-0.5)
53322fceea1SStefano Zampini        s = self._getCtx().s
53422fceea1SStefano Zampini        self.assertTrue(s == sold * -0.5)
53522fceea1SStefano Zampini
5369e7eb791SStefano Zampini    def testDiagonalMat(self):
5379e7eb791SStefano Zampini        s = self._getCtx().s
5386f336411SStefano Zampini        B = PETSc.Mat().createConstantDiagonal(
5396f336411SStefano Zampini            self.A.getSizes(), s, comm=self.A.getComm()
5406f336411SStefano Zampini        )
5419e7eb791SStefano Zampini        self.assertTrue(self.A.equal(B))
5429e7eb791SStefano Zampini
5435808f684SSatish Balay
5446f336411SStefano Zampiniclass TestDiagonal(TestMatrix):
5455808f684SSatish Balay    PYCLS = 'Diagonal'
546*b2584804SStefano Zampini    CREATE_WITH_NONE = True
5475808f684SSatish Balay
5485808f684SSatish Balay    def setUp(self):
5496f336411SStefano Zampini        super().setUp()
5505808f684SSatish Balay        D = self.A.createVecLeft()
5515808f684SSatish Balay        s, e = D.getOwnershipRange()
5525808f684SSatish Balay        for i in range(s, e):
5535808f684SSatish Balay            D[i] = i + 1
5545808f684SSatish Balay        D.assemble()
5555808f684SSatish Balay        self.A.setDiagonal(D)
5565808f684SSatish Balay
5575808f684SSatish Balay    def testZeroEntries(self):
5585808f684SSatish Balay        self.A.zeroEntries()
5595808f684SSatish Balay        D = self._getCtx().D
5605808f684SSatish Balay        self.assertEqual(D.norm(), 0)
5615808f684SSatish Balay
5625808f684SSatish Balay    def testMult(self):
5635808f684SSatish Balay        x, y = self.A.createVecs()
5645808f684SSatish Balay        x.set(1)
5655808f684SSatish Balay        self.A.mult(x, y)
5665808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5675808f684SSatish Balay
5685808f684SSatish Balay    def testMultTransposeSymmKnown(self):
5695808f684SSatish Balay        x, y = self.A.createVecs()
5705808f684SSatish Balay        x.set(1)
5715808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
5725808f684SSatish Balay        self.A.multTranspose(x, y)
5735808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5745808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
5755808f684SSatish Balay        f = lambda: self.A.multTranspose(x, y)
5765808f684SSatish Balay        self.assertRaises(Exception, f)
5775808f684SSatish Balay
5785808f684SSatish Balay    def testMultTransposeNewMeth(self):
5795808f684SSatish Balay        x, y = self.A.createVecs()
5805808f684SSatish Balay        x.set(1)
5815808f684SSatish Balay        AA = self.A.getPythonContext()
5825808f684SSatish Balay        AA.multTranspose = AA.mult
5835808f684SSatish Balay        self.A.multTranspose(x, y)
5845808f684SSatish Balay        del AA.multTranspose
5855808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5865808f684SSatish Balay
587e124b1b1SStefano Zampini    def testDuplicate(self):
588e124b1b1SStefano Zampini        B = self.A.duplicate(False)
589e124b1b1SStefano Zampini        B = self.A.duplicate(True)
590e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().D.equal(self.A.getPythonContext().D))
591e124b1b1SStefano Zampini
5925808f684SSatish Balay    def testGetDiagonal(self):
5935808f684SSatish Balay        d = self.A.createVecLeft()
5945808f684SSatish Balay        self.A.getDiagonal(d)
5955808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
5965808f684SSatish Balay
5975808f684SSatish Balay    def testSetDiagonal(self):
5985808f684SSatish Balay        d = self.A.createVecLeft()
5995808f684SSatish Balay        d.setRandom()
6005808f684SSatish Balay        self.A.setDiagonal(d)
6015808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
6025808f684SSatish Balay
6035808f684SSatish Balay    def testDiagonalScale(self):
6045808f684SSatish Balay        x, y = self.A.createVecs()
6055808f684SSatish Balay        x.set(2)
6065808f684SSatish Balay        y.set(3)
6075808f684SSatish Balay        old = self._getCtx().D.copy()
6085808f684SSatish Balay        self.A.diagonalScale(x, y)
6095808f684SSatish Balay        D = self._getCtx().D
6105808f684SSatish Balay        self.assertTrue(D.equal(old * 6))
6115808f684SSatish Balay
6125808f684SSatish Balay    def testCreateTranspose(self):
6135808f684SSatish Balay        A = self.A
6145808f684SSatish Balay        A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
6155808f684SSatish Balay        AT = PETSc.Mat().createTranspose(A)
6165808f684SSatish Balay        x, y = A.createVecs()
6175808f684SSatish Balay        xt, yt = AT.createVecs()
6185808f684SSatish Balay        #
6195808f684SSatish Balay        y.setRandom()
6205808f684SSatish Balay        A.multTranspose(y, x)
6215808f684SSatish Balay        y.copy(xt)
6225808f684SSatish Balay        AT.mult(xt, yt)
6235808f684SSatish Balay        self.assertTrue(yt.equal(x))
6245808f684SSatish Balay        #
6255808f684SSatish Balay        x.setRandom()
6265808f684SSatish Balay        A.mult(x, y)
6275808f684SSatish Balay        x.copy(yt)
6285808f684SSatish Balay        AT.multTranspose(yt, xt)
6295808f684SSatish Balay        self.assertTrue(xt.equal(y))
6305808f684SSatish Balay        del A
6315808f684SSatish Balay
6328af18dd8SStefano Zampini    def testConvert(self):
6338af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.AIJ, PETSc.Mat()).equal(self.A))
6348af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.BAIJ, PETSc.Mat()).equal(self.A))
6358af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.SBAIJ, PETSc.Mat()).equal(self.A))
6368af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.DENSE, PETSc.Mat()).equal(self.A))
6378c2316a8SJeremy Tillay
63822fceea1SStefano Zampini    def testShift(self):
63922fceea1SStefano Zampini        old = self._getCtx().D.copy()
64022fceea1SStefano Zampini        self.A.shift(-0.5)
64122fceea1SStefano Zampini        D = self._getCtx().D
64222fceea1SStefano Zampini        self.assertTrue(D.equal(old - 0.5))
64322fceea1SStefano Zampini
64422fceea1SStefano Zampini    def testScale(self):
64522fceea1SStefano Zampini        old = self._getCtx().D.copy()
64622fceea1SStefano Zampini        self.A.scale(-0.5)
64722fceea1SStefano Zampini        D = self._getCtx().D
64822fceea1SStefano Zampini        self.assertTrue(D.equal(-0.5 * old))
64922fceea1SStefano Zampini
6509e7eb791SStefano Zampini    def testDiagonalMat(self):
6519e7eb791SStefano Zampini        D = self._getCtx().D.copy()
6529e7eb791SStefano Zampini        B = PETSc.Mat().createDiagonal(D)
6539e7eb791SStefano Zampini        self.assertTrue(self.A.equal(B))
6549e7eb791SStefano Zampini
65522fceea1SStefano Zampini
6565808f684SSatish Balay# --------------------------------------------------------------------
6575808f684SSatish Balay
6585808f684SSatish Balayif __name__ == '__main__':
6595808f684SSatish Balay    unittest.main()
660