xref: /petsc/src/binding/petsc4py/test/test_mat_py.py (revision f575958e429d4b56056f077d08c15ad47397da1a)
15808f684SSatish Balayfrom petsc4py import PETSc
26f336411SStefano Zampiniimport unittest
36f336411SStefano Zampiniimport numpy
45808f684SSatish Balayfrom sys import getrefcount
55808f684SSatish Balay# --------------------------------------------------------------------
65808f684SSatish Balay
75808f684SSatish Balay
86f336411SStefano Zampiniclass Matrix:
9b2584804SStefano Zampini    setupcalled = 0
10b2584804SStefano Zampini
115808f684SSatish Balay    def __init__(self):
125808f684SSatish Balay        pass
135808f684SSatish Balay
145808f684SSatish Balay    def create(self, mat):
155808f684SSatish Balay        pass
165808f684SSatish Balay
175808f684SSatish Balay    def destroy(self, mat):
185808f684SSatish Balay        pass
195808f684SSatish Balay
20b2584804SStefano Zampini    def setUp(self, mat):
21b2584804SStefano Zampini        self.setupcalled += 1
2222fceea1SStefano Zampini
23*f575958eSStefano Zampini
246f336411SStefano Zampiniclass ScaledIdentity(Matrix):
2522fceea1SStefano Zampini    s = 2.0
2622fceea1SStefano Zampini
2722fceea1SStefano Zampini    def scale(self, mat, s):
2822fceea1SStefano Zampini        self.s *= s
2922fceea1SStefano Zampini
3022fceea1SStefano Zampini    def shift(self, mat, s):
3122fceea1SStefano Zampini        self.s += s
325808f684SSatish Balay
335808f684SSatish Balay    def mult(self, mat, x, y):
345808f684SSatish Balay        x.copy(y)
3522fceea1SStefano Zampini        y.scale(self.s)
365808f684SSatish Balay
37e124b1b1SStefano Zampini    def duplicate(self, mat, op):
38e124b1b1SStefano Zampini        dmat = PETSc.Mat()
39e124b1b1SStefano Zampini        dctx = ScaledIdentity()
40e124b1b1SStefano Zampini        dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm())
41e124b1b1SStefano Zampini        if op == PETSc.Mat.DuplicateOption.COPY_VALUES:
42e124b1b1SStefano Zampini            dctx.s = self.s
43e124b1b1SStefano Zampini            dmat.setUp()
44e124b1b1SStefano Zampini        return dmat
45e124b1b1SStefano Zampini
465808f684SSatish Balay    def getDiagonal(self, mat, vd):
4722fceea1SStefano Zampini        vd.set(self.s)
485808f684SSatish Balay
49ee6c7c31SStefano Zampini    def productSetFromOptions(self, mat, producttype, A, B, C):
50ee6c7c31SStefano Zampini        return True
51ee6c7c31SStefano Zampini
52ee6c7c31SStefano Zampini    def productSymbolic(self, mat, product, producttype, A, B, C):
53ee6c7c31SStefano Zampini        if producttype == 'AB':
54ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B
55ee6c7c31SStefano Zampini                product.setType(B.getType())
56ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
57ee6c7c31SStefano Zampini                product.setUp()
58ee6c7c31SStefano Zampini                product.assemble()
59ee6c7c31SStefano Zampini                B.copy(product)
60ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity
61ee6c7c31SStefano Zampini                product.setType(A.getType())
62ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
63ee6c7c31SStefano Zampini                product.setUp()
64ee6c7c31SStefano Zampini                product.assemble()
65ee6c7c31SStefano Zampini                A.copy(product)
66ee6c7c31SStefano Zampini            else:
67ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
68ee6c7c31SStefano Zampini        elif producttype == 'AtB':
69ee6c7c31SStefano Zampini            if mat is A:  # product = identity^T * B
70ee6c7c31SStefano Zampini                product.setType(B.getType())
71ee6c7c31SStefano Zampini                product.setSizes(B.getSizes())
72ee6c7c31SStefano Zampini                product.setUp()
73ee6c7c31SStefano Zampini                product.assemble()
74ee6c7c31SStefano Zampini                B.copy(product)
75ee6c7c31SStefano Zampini            elif mat is B:  # product = A^T * identity
76ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
77ee6c7c31SStefano Zampini                A.transpose(tmp)
78ee6c7c31SStefano Zampini                product.setType(tmp.getType())
79ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
80ee6c7c31SStefano Zampini                product.setUp()
81ee6c7c31SStefano Zampini                product.assemble()
82ee6c7c31SStefano Zampini                tmp.copy(product)
83ee6c7c31SStefano Zampini            else:
84ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
85ee6c7c31SStefano Zampini        elif producttype == 'ABt':
86ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B^T
87ee6c7c31SStefano Zampini                tmp = PETSc.Mat()
88ee6c7c31SStefano Zampini                B.transpose(tmp)
89ee6c7c31SStefano Zampini                product.setType(tmp.getType())
90ee6c7c31SStefano Zampini                product.setSizes(tmp.getSizes())
91ee6c7c31SStefano Zampini                product.setUp()
92ee6c7c31SStefano Zampini                product.assemble()
93ee6c7c31SStefano Zampini                tmp.copy(product)
94ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity^T
95ee6c7c31SStefano Zampini                product.setType(A.getType())
96ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
97ee6c7c31SStefano Zampini                product.setUp()
98ee6c7c31SStefano Zampini                product.assemble()
99ee6c7c31SStefano Zampini                A.copy(product)
100ee6c7c31SStefano Zampini            else:
101ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
102ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
103ee6c7c31SStefano Zampini            if mat is A:  # product = P^T * identity * P
104ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
105ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
106ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
107ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
108ee6c7c31SStefano Zampini                product.setUp()
109ee6c7c31SStefano Zampini                product.assemble()
110ee6c7c31SStefano Zampini                self.tmp.copy(product)
111ee6c7c31SStefano Zampini            elif mat is B:  # product = identity^T * A * identity
112ee6c7c31SStefano Zampini                product.setType(A.getType())
113ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
114ee6c7c31SStefano Zampini                product.setUp()
115ee6c7c31SStefano Zampini                product.assemble()
116ee6c7c31SStefano Zampini                A.copy(product)
117ee6c7c31SStefano Zampini            else:
118ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
119ee6c7c31SStefano Zampini        elif producttype == 'RARt':
120ee6c7c31SStefano Zampini            if mat is A:  # product = R * identity * R^t
121ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
122ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
123ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
124ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
125ee6c7c31SStefano Zampini                product.setUp()
126ee6c7c31SStefano Zampini                product.assemble()
127ee6c7c31SStefano Zampini                self.tmp.copy(product)
128ee6c7c31SStefano Zampini            elif mat is B:  # product = identity * A * identity^T
129ee6c7c31SStefano Zampini                product.setType(A.getType())
130ee6c7c31SStefano Zampini                product.setSizes(A.getSizes())
131ee6c7c31SStefano Zampini                product.setUp()
132ee6c7c31SStefano Zampini                product.assemble()
133ee6c7c31SStefano Zampini                A.copy(product)
134ee6c7c31SStefano Zampini            else:
135ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
136ee6c7c31SStefano Zampini        elif producttype == 'ABC':
137ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B * C
138ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
139ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
140ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
141ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
142ee6c7c31SStefano Zampini                product.setUp()
143ee6c7c31SStefano Zampini                product.assemble()
144ee6c7c31SStefano Zampini                self.tmp.copy(product)
145ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity * C
146ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
147ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
148ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
149ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
150ee6c7c31SStefano Zampini                product.setUp()
151ee6c7c31SStefano Zampini                product.assemble()
152ee6c7c31SStefano Zampini                self.tmp.copy(product)
153ee6c7c31SStefano Zampini            elif mat is C:  # product = A * B * identity
154ee6c7c31SStefano Zampini                self.tmp = PETSc.Mat()
155ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
156ee6c7c31SStefano Zampini                product.setType(self.tmp.getType())
157ee6c7c31SStefano Zampini                product.setSizes(self.tmp.getSizes())
158ee6c7c31SStefano Zampini                product.setUp()
159ee6c7c31SStefano Zampini                product.assemble()
160ee6c7c31SStefano Zampini                self.tmp.copy(product)
161ee6c7c31SStefano Zampini            else:
162ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
163ee6c7c31SStefano Zampini        else:
1646f336411SStefano Zampini            raise RuntimeError(f'Product {producttype} not implemented')
165ee6c7c31SStefano Zampini        product.zeroEntries()
166ee6c7c31SStefano Zampini
167ee6c7c31SStefano Zampini    def productNumeric(self, mat, product, producttype, A, B, C):
168ee6c7c31SStefano Zampini        if producttype == 'AB':
169ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B
170ee6c7c31SStefano Zampini                B.copy(product, structure=True)
171ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity
172ee6c7c31SStefano Zampini                A.copy(product, structure=True)
173ee6c7c31SStefano Zampini            else:
174ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
17522fceea1SStefano Zampini            product.scale(self.s)
176ee6c7c31SStefano Zampini        elif producttype == 'AtB':
177ee6c7c31SStefano Zampini            if mat is A:  # product = identity^T * B
178ee6c7c31SStefano Zampini                B.copy(product, structure=True)
179ee6c7c31SStefano Zampini            elif mat is B:  # product = A^T * identity
1807fb60732SBarry Smith                A.setTransposePrecursor(product)
181ee6c7c31SStefano Zampini                A.transpose(product)
182ee6c7c31SStefano Zampini            else:
183ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
18422fceea1SStefano Zampini            product.scale(self.s)
185ee6c7c31SStefano Zampini        elif producttype == 'ABt':
186ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B^T
1877fb60732SBarry Smith                B.setTransposePrecursor(product)
188ee6c7c31SStefano Zampini                B.transpose(product)
189ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity^T
190ee6c7c31SStefano Zampini                A.copy(product, structure=True)
191ee6c7c31SStefano Zampini            else:
192ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
19322fceea1SStefano Zampini            product.scale(self.s)
194ee6c7c31SStefano Zampini        elif producttype == 'PtAP':
195ee6c7c31SStefano Zampini            if mat is A:  # product = P^T * identity * P
196ee6c7c31SStefano Zampini                B.transposeMatMult(B, self.tmp)
197ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
19822fceea1SStefano Zampini                product.scale(self.s)
199ee6c7c31SStefano Zampini            elif mat is B:  # product = identity^T * A * identity
200ee6c7c31SStefano Zampini                A.copy(product, structure=True)
20122fceea1SStefano Zampini                product.scale(self.s**2)
202ee6c7c31SStefano Zampini            else:
203ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
204ee6c7c31SStefano Zampini        elif producttype == 'RARt':
205ee6c7c31SStefano Zampini            if mat is A:  # product = R * identity * R^t
206ee6c7c31SStefano Zampini                B.matTransposeMult(B, self.tmp)
207ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
20822fceea1SStefano Zampini                product.scale(self.s)
209ee6c7c31SStefano Zampini            elif mat is B:  # product = identity * A * identity^T
210ee6c7c31SStefano Zampini                A.copy(product, structure=True)
21122fceea1SStefano Zampini                product.scale(self.s**2)
212ee6c7c31SStefano Zampini            else:
213ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
214ee6c7c31SStefano Zampini        elif producttype == 'ABC':
215ee6c7c31SStefano Zampini            if mat is A:  # product = identity * B * C
216ee6c7c31SStefano Zampini                B.matMult(C, self.tmp)
217ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
218ee6c7c31SStefano Zampini            elif mat is B:  # product = A * identity * C
219ee6c7c31SStefano Zampini                A.matMult(C, self.tmp)
220ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
221ee6c7c31SStefano Zampini            elif mat is C:  # product = A * B * identity
222ee6c7c31SStefano Zampini                A.matMult(B, self.tmp)
223ee6c7c31SStefano Zampini                self.tmp.copy(product, structure=True)
224ee6c7c31SStefano Zampini            else:
225ee6c7c31SStefano Zampini                raise RuntimeError('wrong configuration')
22622fceea1SStefano Zampini            product.scale(self.s)
227ee6c7c31SStefano Zampini        else:
2286f336411SStefano Zampini            raise RuntimeError(f'Product {producttype} not implemented')
2296f336411SStefano Zampini
230ee6c7c31SStefano Zampini
2315808f684SSatish Balayclass Diagonal(Matrix):
2325808f684SSatish Balay    def create(self, mat):
2336f336411SStefano Zampini        super().create(mat)
2345808f684SSatish Balay        mat.setUp()
2355808f684SSatish Balay        self.D = mat.createVecLeft()
2365808f684SSatish Balay
2375808f684SSatish Balay    def destroy(self, mat):
2385808f684SSatish Balay        self.D.destroy()
2396f336411SStefano Zampini        super().destroy(mat)
2405808f684SSatish Balay
2415808f684SSatish Balay    def scale(self, mat, a):
2425808f684SSatish Balay        self.D.scale(a)
2435808f684SSatish Balay
2445808f684SSatish Balay    def shift(self, mat, a):
2455808f684SSatish Balay        self.D.shift(a)
2465808f684SSatish Balay
2475808f684SSatish Balay    def zeroEntries(self, mat):
2485808f684SSatish Balay        self.D.zeroEntries()
2495808f684SSatish Balay
2505808f684SSatish Balay    def mult(self, mat, x, y):
2515808f684SSatish Balay        y.pointwiseMult(x, self.D)
2525808f684SSatish Balay
253e124b1b1SStefano Zampini    def duplicate(self, mat, op):
254e124b1b1SStefano Zampini        dmat = PETSc.Mat()
255e124b1b1SStefano Zampini        dctx = Diagonal()
256e124b1b1SStefano Zampini        dmat.createPython(mat.getSizes(), dctx, comm=mat.getComm())
257e124b1b1SStefano Zampini        dctx.D = self.D.duplicate()
258e124b1b1SStefano Zampini        if op == PETSc.Mat.DuplicateOption.COPY_VALUES:
259e124b1b1SStefano Zampini            self.D.copy(dctx.D)
260e124b1b1SStefano Zampini            dmat.setUp()
261e124b1b1SStefano Zampini        return dmat
262e124b1b1SStefano Zampini
2635808f684SSatish Balay    def getDiagonal(self, mat, vd):
2645808f684SSatish Balay        self.D.copy(vd)
2655808f684SSatish Balay
2665808f684SSatish Balay    def setDiagonal(self, mat, vd, im):
2675808f684SSatish Balay        if isinstance(im, bool):
2685808f684SSatish Balay            addv = im
2695808f684SSatish Balay            if addv:
2705808f684SSatish Balay                self.D.axpy(1, vd)
2715808f684SSatish Balay            else:
2725808f684SSatish Balay                vd.copy(self.D)
2735808f684SSatish Balay        elif im == PETSc.InsertMode.INSERT_VALUES:
2745808f684SSatish Balay            vd.copy(self.D)
2755808f684SSatish Balay        elif im == PETSc.InsertMode.ADD_VALUES:
2765808f684SSatish Balay            self.D.axpy(1, vd)
2775808f684SSatish Balay        else:
2785808f684SSatish Balay            raise ValueError('wrong InsertMode %d' % im)
2795808f684SSatish Balay
2805808f684SSatish Balay    def diagonalScale(self, mat, vl, vr):
2816f336411SStefano Zampini        if vl:
2826f336411SStefano Zampini            self.D.pointwiseMult(self.D, vl)
2836f336411SStefano Zampini        if vr:
2846f336411SStefano Zampini            self.D.pointwiseMult(self.D, vr)
2856f336411SStefano Zampini
2865808f684SSatish Balay
2875808f684SSatish Balay# --------------------------------------------------------------------
2885808f684SSatish Balay
2895808f684SSatish Balay
2906f336411SStefano Zampiniclass TestMatrix(unittest.TestCase):
2915808f684SSatish Balay    COMM = PETSc.COMM_WORLD
2925808f684SSatish Balay    PYMOD = __name__
2935808f684SSatish Balay    PYCLS = 'Matrix'
294b2584804SStefano Zampini    CREATE_WITH_NONE = False
2955808f684SSatish Balay
2965808f684SSatish Balay    def _getCtx(self):
2975808f684SSatish Balay        return self.A.getPythonContext()
2985808f684SSatish Balay
2995808f684SSatish Balay    def setUp(self):
300300d917bSStefano Zampini        N = self.N = 13
3015808f684SSatish Balay        self.A = PETSc.Mat()
3025808f684SSatish Balay        if 0:  # command line way
3035808f684SSatish Balay            self.A.create(self.COMM)
3045808f684SSatish Balay            self.A.setSizes([N, N])
3055808f684SSatish Balay            self.A.setType('python')
3065808f684SSatish Balay            OptDB = PETSc.Options(self.A)
3076f336411SStefano Zampini            OptDB['mat_python_type'] = f'{self.PYMOD}.{self.PYCLS}'
3085808f684SSatish Balay            self.A.setFromOptions()
3095808f684SSatish Balay            del OptDB['mat_python_type']
3105808f684SSatish Balay            self.assertTrue(self._getCtx() is not None)
3115808f684SSatish Balay        else:  # python way
3125808f684SSatish Balay            context = globals()[self.PYCLS]()
313b2584804SStefano Zampini            if self.CREATE_WITH_NONE:  # test passing None as context
314b2584804SStefano Zampini                self.A.createPython([N, N], None, comm=self.COMM)
315b2584804SStefano Zampini                self.A.setPythonContext(context)
3165808f684SSatish Balay                self.A.setUp()
317b2584804SStefano Zampini            else:
318b2584804SStefano Zampini                self.A.createPython([N, N], context, comm=self.COMM)
3195808f684SSatish Balay            self.assertTrue(self._getCtx() is context)
3205808f684SSatish Balay            self.assertEqual(getrefcount(context), 3)
3215808f684SSatish Balay            del context
3225808f684SSatish Balay            self.assertEqual(getrefcount(self._getCtx()), 2)
3235808f684SSatish Balay
3245808f684SSatish Balay    def tearDown(self):
3255808f684SSatish Balay        ctx = self.A.getPythonContext()
3265808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
3275808f684SSatish Balay        self.A.destroy()  # XXX
3285808f684SSatish Balay        self.A = None
32962e5d2d2SJDBetteridge        PETSc.garbage_cleanup()
3305808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 2)
3315808f684SSatish Balay
3325808f684SSatish Balay    def testBasic(self):
3335808f684SSatish Balay        ctx = self.A.getPythonContext()
3345808f684SSatish Balay        self.assertTrue(self._getCtx() is ctx)
3355808f684SSatish Balay        self.assertEqual(getrefcount(ctx), 3)
3365808f684SSatish Balay
337b2584804SStefano Zampini    def testSetUp(self):
338b2584804SStefano Zampini        ctx = self.A.getPythonContext()
339b2584804SStefano Zampini        setupcalled = ctx.setupcalled
340b2584804SStefano Zampini        self.A.setUp()
341b2584804SStefano Zampini        self.assertEqual(setupcalled, ctx.setupcalled)
342b2584804SStefano Zampini        self.A.setPythonContext(ctx)
343b2584804SStefano Zampini        self.A.setUp()
344b2584804SStefano Zampini        self.assertEqual(setupcalled + 1, ctx.setupcalled)
345b2584804SStefano Zampini
3465808f684SSatish Balay    def testZeroEntries(self):
3475808f684SSatish Balay        f = lambda: self.A.zeroEntries()
3485808f684SSatish Balay        self.assertRaises(Exception, f)
3495808f684SSatish Balay
3505808f684SSatish Balay    def testMult(self):
3515808f684SSatish Balay        x, y = self.A.createVecs()
3525808f684SSatish Balay        f = lambda: self.A.mult(x, y)
3535808f684SSatish Balay        self.assertRaises(Exception, f)
3545808f684SSatish Balay
3555808f684SSatish Balay    def testMultTranspose(self):
3565808f684SSatish Balay        x, y = self.A.createVecs()
3575808f684SSatish Balay        f = lambda: self.A.multTranspose(x, y)
3585808f684SSatish Balay        self.assertRaises(Exception, f)
3595808f684SSatish Balay
3605808f684SSatish Balay    def testGetDiagonal(self):
3615808f684SSatish Balay        d = self.A.createVecLeft()
3625808f684SSatish Balay        f = lambda: self.A.getDiagonal(d)
3635808f684SSatish Balay        self.assertRaises(Exception, f)
3645808f684SSatish Balay
3655808f684SSatish Balay    def testSetDiagonal(self):
3665808f684SSatish Balay        d = self.A.createVecLeft()
3675808f684SSatish Balay        f = lambda: self.A.setDiagonal(d)
3685808f684SSatish Balay        self.assertRaises(Exception, f)
3695808f684SSatish Balay
3705808f684SSatish Balay    def testDiagonalScale(self):
3715808f684SSatish Balay        x, y = self.A.createVecs()
3725808f684SSatish Balay        f = lambda: self.A.diagonalScale(x, y)
3735808f684SSatish Balay        self.assertRaises(Exception, f)
3745808f684SSatish Balay
375e124b1b1SStefano Zampini    def testDuplicate(self):
3766f336411SStefano Zampini        f1 = lambda: self.A.duplicate(True)
3776f336411SStefano Zampini        f2 = lambda: self.A.duplicate(False)
378e124b1b1SStefano Zampini        self.assertRaises(Exception, f1)
379e124b1b1SStefano Zampini        self.assertRaises(Exception, f2)
380e124b1b1SStefano Zampini
3811cebabd4SStefano Zampini    def testSetVecType(self):
3821cebabd4SStefano Zampini        self.A.setVecType('mpi')
3831cebabd4SStefano Zampini        self.assertTrue('mpi' == self.A.getVecType())
3841cebabd4SStefano Zampini
385300d917bSStefano Zampini    def testH2Opus(self):
3866f336411SStefano Zampini        if not PETSc.Sys.hasExternalPackage('h2opus'):
387300d917bSStefano Zampini            return
388300d917bSStefano Zampini        if self.A.getComm().Get_size() > 1:
389300d917bSStefano Zampini            return
390300d917bSStefano Zampini        h = PETSc.Mat()
391300d917bSStefano Zampini
392300d917bSStefano Zampini        # need matrix vector and its transpose for norm estimation
393300d917bSStefano Zampini        AA = self.A.getPythonContext()
394300d917bSStefano Zampini        if not hasattr(AA, 'mult'):
395300d917bSStefano Zampini            return
396300d917bSStefano Zampini        AA.multTranspose = AA.mult
397300d917bSStefano Zampini
398300d917bSStefano Zampini        # without coordinates
399300d917bSStefano Zampini        h.createH2OpusFromMat(self.A, leafsize=2)
400300d917bSStefano Zampini        h.assemble()
401300d917bSStefano Zampini        h.destroy()
402300d917bSStefano Zampini
403300d917bSStefano Zampini        # with coordinates
4046f336411SStefano Zampini        coords = numpy.linspace(
4056f336411SStefano Zampini            (1, 2, 3), (10, 20, 30), self.A.getSize()[0], dtype=PETSc.RealType
4066f336411SStefano Zampini        )
407300d917bSStefano Zampini        h.createH2OpusFromMat(self.A, coords, leafsize=2)
408300d917bSStefano Zampini        h.assemble()
409300d917bSStefano Zampini
410300d917bSStefano Zampini        # test API
411300d917bSStefano Zampini        h.H2OpusOrthogonalize()
4126f336411SStefano Zampini        h.H2OpusCompress(1.0e-1)
413300d917bSStefano Zampini
414300d917bSStefano Zampini        # Low-rank update
415300d917bSStefano Zampini        U = PETSc.Mat()
416300d917bSStefano Zampini        U.createDense([h.getSizes()[0], 3], comm=h.getComm())
417300d917bSStefano Zampini        U.setUp()
418300d917bSStefano Zampini        U.setRandom()
419300d917bSStefano Zampini
420300d917bSStefano Zampini        he = PETSc.Mat()
421300d917bSStefano Zampini        h.convert('dense', he)
422300d917bSStefano Zampini        he.axpy(1.0, U.matTransposeMult(U))
423300d917bSStefano Zampini
424300d917bSStefano Zampini        h.H2OpusLowRankUpdate(U)
425300d917bSStefano Zampini        self.assertTrue(he.equal(h))
426300d917bSStefano Zampini
427300d917bSStefano Zampini        h.destroy()
428300d917bSStefano Zampini
429300d917bSStefano Zampini        del AA.multTranspose
430300d917bSStefano Zampini
431ebead697SStefano Zampini    def testGetType(self):
432ebead697SStefano Zampini        ctx = self.A.getPythonContext()
4336f336411SStefano Zampini        pytype = f'{ctx.__module__}.{type(ctx).__name__}'
434ebead697SStefano Zampini        self.assertTrue(self.A.getPythonType() == pytype)
435300d917bSStefano Zampini
4365808f684SSatish Balay
4376f336411SStefano Zampiniclass TestScaledIdentity(TestMatrix):
43822fceea1SStefano Zampini    PYCLS = 'ScaledIdentity'
4395808f684SSatish Balay
4405808f684SSatish Balay    def testMult(self):
44122fceea1SStefano Zampini        s = self._getCtx().s
4425808f684SSatish Balay        x, y = self.A.createVecs()
4435808f684SSatish Balay        x.setRandom()
4445808f684SSatish Balay        self.A.mult(x, y)
44522fceea1SStefano Zampini        self.assertTrue(y.equal(s * x))
4465808f684SSatish Balay
4475808f684SSatish Balay    def testMultTransposeSymmKnown(self):
44822fceea1SStefano Zampini        s = self._getCtx().s
4495808f684SSatish Balay        x, y = self.A.createVecs()
4505808f684SSatish Balay        x.setRandom()
4515808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
4525808f684SSatish Balay        self.A.multTranspose(x, y)
45322fceea1SStefano Zampini        self.assertTrue(y.equal(s * x))
4545808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
4555808f684SSatish Balay        f = lambda: self.A.multTranspose(x, y)
4565808f684SSatish Balay        self.assertRaises(Exception, f)
4575808f684SSatish Balay
4585808f684SSatish Balay    def testMultTransposeNewMeth(self):
45922fceea1SStefano Zampini        s = self._getCtx().s
4605808f684SSatish Balay        x, y = self.A.createVecs()
4615808f684SSatish Balay        x.setRandom()
4625808f684SSatish Balay        AA = self.A.getPythonContext()
4635808f684SSatish Balay        AA.multTranspose = AA.mult
4645808f684SSatish Balay        self.A.multTranspose(x, y)
4655808f684SSatish Balay        del AA.multTranspose
46622fceea1SStefano Zampini        self.assertTrue(y.equal(s * x))
4675808f684SSatish Balay
4685808f684SSatish Balay    def testGetDiagonal(self):
46922fceea1SStefano Zampini        s = self._getCtx().s
4705808f684SSatish Balay        d = self.A.createVecLeft()
4715808f684SSatish Balay        o = d.duplicate()
47222fceea1SStefano Zampini        o.set(s)
4735808f684SSatish Balay        self.A.getDiagonal(d)
4745808f684SSatish Balay        self.assertTrue(o.equal(d))
4755808f684SSatish Balay
476e124b1b1SStefano Zampini    def testDuplicate(self):
477e124b1b1SStefano Zampini        B = self.A.duplicate(False)
478e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().s == 2)
479e124b1b1SStefano Zampini        B = self.A.duplicate(True)
480e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().s == self.A.getPythonContext().s)
481e124b1b1SStefano Zampini
482ee6c7c31SStefano Zampini    def testMatMat(self):
48322fceea1SStefano Zampini        s = self._getCtx().s
484ee6c7c31SStefano Zampini        R = PETSc.Random().create(self.COMM)
485ee6c7c31SStefano Zampini        R.setFromOptions()
486ee6c7c31SStefano Zampini        A = PETSc.Mat().create(self.COMM)
487ee6c7c31SStefano Zampini        A.setSizes(self.A.getSizes())
488ee6c7c31SStefano Zampini        A.setType(PETSc.Mat.Type.AIJ)
48926cec326SBarry Smith        A.setPreallocationNNZ(None)
490ee6c7c31SStefano Zampini        A.setRandom(R)
491ee6c7c31SStefano Zampini        B = PETSc.Mat().create(self.COMM)
492ee6c7c31SStefano Zampini        B.setSizes(self.A.getSizes())
493ee6c7c31SStefano Zampini        B.setType(PETSc.Mat.Type.AIJ)
49426cec326SBarry Smith        B.setPreallocationNNZ(None)
495ee6c7c31SStefano Zampini        B.setRandom(R)
4966f336411SStefano Zampini        Id = PETSc.Mat().create(self.COMM)
4976f336411SStefano Zampini        Id.setSizes(self.A.getSizes())
4986f336411SStefano Zampini        Id.setType(PETSc.Mat.Type.AIJ)
4996f336411SStefano Zampini        Id.setUp()
5006f336411SStefano Zampini        Id.assemble()
5016f336411SStefano Zampini        Id.shift(s)
502ee6c7c31SStefano Zampini
5036f336411SStefano Zampini        self.assertTrue(self.A.matMult(A).equal(Id.matMult(A)))
5046f336411SStefano Zampini        self.assertTrue(A.matMult(self.A).equal(A.matMult(Id)))
505ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
5066f336411SStefano Zampini            self.assertTrue(self.A.matTransposeMult(A).equal(Id.matTransposeMult(A)))
5076f336411SStefano Zampini            self.assertTrue(A.matTransposeMult(self.A).equal(A.matTransposeMult(Id)))
5086f336411SStefano Zampini        self.assertTrue(self.A.transposeMatMult(A).equal(Id.transposeMatMult(A)))
5096f336411SStefano Zampini        self.assertTrue(A.transposeMatMult(self.A).equal(A.transposeMatMult(Id)))
5106f336411SStefano Zampini        self.assertAlmostEqual((self.A.ptap(A) - Id.ptap(A)).norm(), 0.0, places=5)
5116f336411SStefano Zampini        self.assertAlmostEqual((A.ptap(self.A) - A.ptap(Id)).norm(), 0.0, places=5)
512ee6c7c31SStefano Zampini        if self.A.getComm().Get_size() == 1:
5136f336411SStefano Zampini            self.assertAlmostEqual((self.A.rart(A) - Id.rart(A)).norm(), 0.0, places=5)
5146f336411SStefano Zampini            self.assertAlmostEqual((A.rart(self.A) - A.rart(Id)).norm(), 0.0, places=5)
5156f336411SStefano Zampini        self.assertAlmostEqual(
5166f336411SStefano Zampini            (self.A.matMatMult(A, B) - Id.matMatMult(A, B)).norm(), 0.0, places=5
5176f336411SStefano Zampini        )
5186f336411SStefano Zampini        self.assertAlmostEqual(
5196f336411SStefano Zampini            (A.matMatMult(self.A, B) - A.matMatMult(Id, B)).norm(), 0.0, places=5
5206f336411SStefano Zampini        )
5216f336411SStefano Zampini        self.assertAlmostEqual(
5226f336411SStefano Zampini            (A.matMatMult(B, self.A) - A.matMatMult(B, Id)).norm(), 0.0, places=5
5236f336411SStefano Zampini        )
524ee6c7c31SStefano Zampini
52522fceea1SStefano Zampini    def testShift(self):
52622fceea1SStefano Zampini        sold = self._getCtx().s
52722fceea1SStefano Zampini        self.A.shift(-0.5)
52822fceea1SStefano Zampini        s = self._getCtx().s
52922fceea1SStefano Zampini        self.assertTrue(s == sold - 0.5)
53022fceea1SStefano Zampini
53122fceea1SStefano Zampini    def testScale(self):
53222fceea1SStefano Zampini        sold = self._getCtx().s
53322fceea1SStefano Zampini        self.A.scale(-0.5)
53422fceea1SStefano Zampini        s = self._getCtx().s
53522fceea1SStefano Zampini        self.assertTrue(s == sold * -0.5)
53622fceea1SStefano Zampini
5379e7eb791SStefano Zampini    def testDiagonalMat(self):
5389e7eb791SStefano Zampini        s = self._getCtx().s
5396f336411SStefano Zampini        B = PETSc.Mat().createConstantDiagonal(
5406f336411SStefano Zampini            self.A.getSizes(), s, comm=self.A.getComm()
5416f336411SStefano Zampini        )
5429e7eb791SStefano Zampini        self.assertTrue(self.A.equal(B))
5439e7eb791SStefano Zampini
5445808f684SSatish Balay
5456f336411SStefano Zampiniclass TestDiagonal(TestMatrix):
5465808f684SSatish Balay    PYCLS = 'Diagonal'
547b2584804SStefano Zampini    CREATE_WITH_NONE = True
5485808f684SSatish Balay
5495808f684SSatish Balay    def setUp(self):
5506f336411SStefano Zampini        super().setUp()
5515808f684SSatish Balay        D = self.A.createVecLeft()
5525808f684SSatish Balay        s, e = D.getOwnershipRange()
5535808f684SSatish Balay        for i in range(s, e):
5545808f684SSatish Balay            D[i] = i + 1
5555808f684SSatish Balay        D.assemble()
5565808f684SSatish Balay        self.A.setDiagonal(D)
5575808f684SSatish Balay
5585808f684SSatish Balay    def testZeroEntries(self):
5595808f684SSatish Balay        self.A.zeroEntries()
5605808f684SSatish Balay        D = self._getCtx().D
5615808f684SSatish Balay        self.assertEqual(D.norm(), 0)
5625808f684SSatish Balay
5635808f684SSatish Balay    def testMult(self):
5645808f684SSatish Balay        x, y = self.A.createVecs()
5655808f684SSatish Balay        x.set(1)
5665808f684SSatish Balay        self.A.mult(x, y)
5675808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5685808f684SSatish Balay
5695808f684SSatish Balay    def testMultTransposeSymmKnown(self):
5705808f684SSatish Balay        x, y = self.A.createVecs()
5715808f684SSatish Balay        x.set(1)
5725808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
5735808f684SSatish Balay        self.A.multTranspose(x, y)
5745808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5755808f684SSatish Balay        self.A.setOption(PETSc.Mat.Option.SYMMETRIC, False)
5765808f684SSatish Balay        f = lambda: self.A.multTranspose(x, y)
5775808f684SSatish Balay        self.assertRaises(Exception, f)
5785808f684SSatish Balay
5795808f684SSatish Balay    def testMultTransposeNewMeth(self):
5805808f684SSatish Balay        x, y = self.A.createVecs()
5815808f684SSatish Balay        x.set(1)
5825808f684SSatish Balay        AA = self.A.getPythonContext()
5835808f684SSatish Balay        AA.multTranspose = AA.mult
5845808f684SSatish Balay        self.A.multTranspose(x, y)
5855808f684SSatish Balay        del AA.multTranspose
5865808f684SSatish Balay        self.assertTrue(y.equal(self._getCtx().D))
5875808f684SSatish Balay
588e124b1b1SStefano Zampini    def testDuplicate(self):
589e124b1b1SStefano Zampini        B = self.A.duplicate(False)
590e124b1b1SStefano Zampini        B = self.A.duplicate(True)
591e124b1b1SStefano Zampini        self.assertTrue(B.getPythonContext().D.equal(self.A.getPythonContext().D))
592e124b1b1SStefano Zampini
5935808f684SSatish Balay    def testGetDiagonal(self):
5945808f684SSatish Balay        d = self.A.createVecLeft()
5955808f684SSatish Balay        self.A.getDiagonal(d)
5965808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
5975808f684SSatish Balay
5985808f684SSatish Balay    def testSetDiagonal(self):
5995808f684SSatish Balay        d = self.A.createVecLeft()
6005808f684SSatish Balay        d.setRandom()
6015808f684SSatish Balay        self.A.setDiagonal(d)
6025808f684SSatish Balay        self.assertTrue(d.equal(self._getCtx().D))
6035808f684SSatish Balay
6045808f684SSatish Balay    def testDiagonalScale(self):
6055808f684SSatish Balay        x, y = self.A.createVecs()
6065808f684SSatish Balay        x.set(2)
6075808f684SSatish Balay        y.set(3)
6085808f684SSatish Balay        old = self._getCtx().D.copy()
6095808f684SSatish Balay        self.A.diagonalScale(x, y)
6105808f684SSatish Balay        D = self._getCtx().D
6115808f684SSatish Balay        self.assertTrue(D.equal(old * 6))
6125808f684SSatish Balay
6135808f684SSatish Balay    def testCreateTranspose(self):
6145808f684SSatish Balay        A = self.A
6155808f684SSatish Balay        A.setOption(PETSc.Mat.Option.SYMMETRIC, True)
6165808f684SSatish Balay        AT = PETSc.Mat().createTranspose(A)
6175808f684SSatish Balay        x, y = A.createVecs()
6185808f684SSatish Balay        xt, yt = AT.createVecs()
6195808f684SSatish Balay        #
6205808f684SSatish Balay        y.setRandom()
6215808f684SSatish Balay        A.multTranspose(y, x)
6225808f684SSatish Balay        y.copy(xt)
6235808f684SSatish Balay        AT.mult(xt, yt)
6245808f684SSatish Balay        self.assertTrue(yt.equal(x))
6255808f684SSatish Balay        #
6265808f684SSatish Balay        x.setRandom()
6275808f684SSatish Balay        A.mult(x, y)
6285808f684SSatish Balay        x.copy(yt)
6295808f684SSatish Balay        AT.multTranspose(yt, xt)
6305808f684SSatish Balay        self.assertTrue(xt.equal(y))
6315808f684SSatish Balay        del A
6325808f684SSatish Balay
6338af18dd8SStefano Zampini    def testConvert(self):
6348af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.AIJ, PETSc.Mat()).equal(self.A))
6358af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.BAIJ, PETSc.Mat()).equal(self.A))
6368af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.SBAIJ, PETSc.Mat()).equal(self.A))
6378af18dd8SStefano Zampini        self.assertTrue(self.A.convert(PETSc.Mat.Type.DENSE, PETSc.Mat()).equal(self.A))
6388c2316a8SJeremy Tillay
63922fceea1SStefano Zampini    def testShift(self):
64022fceea1SStefano Zampini        old = self._getCtx().D.copy()
64122fceea1SStefano Zampini        self.A.shift(-0.5)
64222fceea1SStefano Zampini        D = self._getCtx().D
64322fceea1SStefano Zampini        self.assertTrue(D.equal(old - 0.5))
64422fceea1SStefano Zampini
64522fceea1SStefano Zampini    def testScale(self):
64622fceea1SStefano Zampini        old = self._getCtx().D.copy()
64722fceea1SStefano Zampini        self.A.scale(-0.5)
64822fceea1SStefano Zampini        D = self._getCtx().D
64922fceea1SStefano Zampini        self.assertTrue(D.equal(-0.5 * old))
65022fceea1SStefano Zampini
6519e7eb791SStefano Zampini    def testDiagonalMat(self):
6529e7eb791SStefano Zampini        D = self._getCtx().D.copy()
6539e7eb791SStefano Zampini        B = PETSc.Mat().createDiagonal(D)
6549e7eb791SStefano Zampini        self.assertTrue(self.A.equal(B))
6559e7eb791SStefano Zampini
65622fceea1SStefano Zampini
6575808f684SSatish Balay# --------------------------------------------------------------------
6585808f684SSatish Balay
6595808f684SSatish Balayif __name__ == '__main__':
6605808f684SSatish Balay    unittest.main()
661