xref: /petsc/src/mat/impls/submat/submat.c (revision 7b1c77167952b0557943a232cdd64e09e7c985ed)
1 
2 #include <petsc/private/matimpl.h>          /*I "petscmat.h" I*/
3 
4 typedef struct {
5   IS          isrow,iscol;      /* rows and columns in submatrix, only used to check consistency */
6   Vec         left,right;       /* optional scaling */
7   Vec         olwork,orwork;    /* work vectors outside the scatters, only touched by PreScale and only created if needed*/
8   Vec         lwork,rwork;      /* work vectors inside the scatters */
9   VecScatter  lrestrict,rprolong;
10   Mat         A;
11   PetscScalar scale;
12 } Mat_SubVirtual;
13 
14 static PetscErrorCode PreScaleLeft(Mat N,Vec x,Vec *xx)
15 {
16   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
17   PetscErrorCode ierr;
18 
19   PetscFunctionBegin;
20   if (!Na->left) {
21     *xx = x;
22   } else {
23     if (!Na->olwork) {
24       ierr = VecDuplicate(Na->left,&Na->olwork);CHKERRQ(ierr);
25     }
26     ierr = VecPointwiseMult(Na->olwork,x,Na->left);CHKERRQ(ierr);
27     *xx  = Na->olwork;
28   }
29   PetscFunctionReturn(0);
30 }
31 
32 static PetscErrorCode PreScaleRight(Mat N,Vec x,Vec *xx)
33 {
34   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
35   PetscErrorCode ierr;
36 
37   PetscFunctionBegin;
38   if (!Na->right) {
39     *xx = x;
40   } else {
41     if (!Na->orwork) {
42       ierr = VecDuplicate(Na->right,&Na->orwork);CHKERRQ(ierr);
43     }
44     ierr = VecPointwiseMult(Na->orwork,x,Na->right);CHKERRQ(ierr);
45     *xx  = Na->orwork;
46   }
47   PetscFunctionReturn(0);
48 }
49 
50 static PetscErrorCode PostScaleLeft(Mat N,Vec x)
51 {
52   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
53   PetscErrorCode ierr;
54 
55   PetscFunctionBegin;
56   if (Na->left) {
57     ierr = VecPointwiseMult(x,x,Na->left);CHKERRQ(ierr);
58   }
59   PetscFunctionReturn(0);
60 }
61 
62 static PetscErrorCode PostScaleRight(Mat N,Vec x)
63 {
64   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
65   PetscErrorCode ierr;
66 
67   PetscFunctionBegin;
68   if (Na->right) {
69     ierr = VecPointwiseMult(x,x,Na->right);CHKERRQ(ierr);
70   }
71   PetscFunctionReturn(0);
72 }
73 
74 static PetscErrorCode MatScale_SubMatrix(Mat N,PetscScalar scale)
75 {
76   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
77 
78   PetscFunctionBegin;
79   Na->scale *= scale;
80   PetscFunctionReturn(0);
81 }
82 
83 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N,Vec left,Vec right)
84 {
85   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
86   PetscErrorCode ierr;
87 
88   PetscFunctionBegin;
89   if (left) {
90     if (!Na->left) {
91       ierr = VecDuplicate(left,&Na->left);CHKERRQ(ierr);
92       ierr = VecCopy(left,Na->left);CHKERRQ(ierr);
93     } else {
94       ierr = VecPointwiseMult(Na->left,left,Na->left);CHKERRQ(ierr);
95     }
96   }
97   if (right) {
98     if (!Na->right) {
99       ierr = VecDuplicate(right,&Na->right);CHKERRQ(ierr);
100       ierr = VecCopy(right,Na->right);CHKERRQ(ierr);
101     } else {
102       ierr = VecPointwiseMult(Na->right,right,Na->right);CHKERRQ(ierr);
103     }
104   }
105   PetscFunctionReturn(0);
106 }
107 
108 static PetscErrorCode MatSolve_SubMatrix(Mat N,Vec x,Vec y)
109 {
110   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
111   Vec            lwork, rwork, xx  = 0;
112   PetscErrorCode ierr;
113 
114   PetscFunctionBegin;
115   ierr = PreScaleLeft(N,x,&xx);CHKERRQ(ierr);
116   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
117   ierr = VecGetSubVector(Na->lwork, Na->iscol, &lwork);CHKERRQ(ierr);
118   ierr = VecCopy(x, lwork);CHKERRQ(ierr);
119   ierr = VecRestoreSubVector(Na->lwork, Na->iscol, &lwork);CHKERRQ(ierr);
120   ierr = MatSolve(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
121   ierr = VecGetSubVector(Na->rwork, Na->isrow, &rwork);CHKERRQ(ierr);
122   ierr = VecCopy(rwork, y);CHKERRQ(ierr);
123   ierr = VecRestoreSubVector(Na->rwork, Na->isrow, &rwork);CHKERRQ(ierr);
124   ierr = PostScaleRight(N,y);CHKERRQ(ierr);
125   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
126   PetscFunctionReturn(0);
127 }
128 
129 static PetscErrorCode MatMult_SubMatrix(Mat N,Vec x,Vec y)
130 {
131   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
132   Vec            xx  = 0;
133   PetscErrorCode ierr;
134 
135   PetscFunctionBegin;
136   ierr = PreScaleRight(N,x,&xx);CHKERRQ(ierr);
137   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
138   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
139   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
140   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
141   ierr = VecScatterBegin(Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
142   ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
143   ierr = PostScaleLeft(N,y);CHKERRQ(ierr);
144   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
145   PetscFunctionReturn(0);
146 }
147 
148 static PetscErrorCode MatMultAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
149 {
150   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
151   Vec            xx  = 0;
152   PetscErrorCode ierr;
153 
154   PetscFunctionBegin;
155   ierr = PreScaleRight(N,v1,&xx);CHKERRQ(ierr);
156   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
157   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
158   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
159   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
160   if (v2 == v3) {
161     if (Na->scale == (PetscScalar)1.0 && !Na->left) {
162       ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
163       ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
164     } else {
165       if (!Na->olwork) {ierr = VecDuplicate(v3,&Na->olwork);CHKERRQ(ierr);}
166       ierr = VecScatterBegin(Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
167       ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
168       ierr = PostScaleLeft(N,Na->olwork);CHKERRQ(ierr);
169       ierr = VecAXPY(v3,Na->scale,Na->olwork);CHKERRQ(ierr);
170     }
171   } else {
172     ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
173     ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
174     ierr = PostScaleLeft(N,v3);CHKERRQ(ierr);
175     ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
176   }
177   PetscFunctionReturn(0);
178 }
179 
180 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N,Vec x,Vec y)
181 {
182   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
183   Vec            xx  = 0;
184   PetscErrorCode ierr;
185 
186   PetscFunctionBegin;
187   ierr = PreScaleLeft(N,x,&xx);CHKERRQ(ierr);
188   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
189   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
190   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
191   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
192   ierr = VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
193   ierr = VecScatterEnd  (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
194   ierr = PostScaleRight(N,y);CHKERRQ(ierr);
195   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
196   PetscFunctionReturn(0);
197 }
198 
199 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
200 {
201   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
202   Vec            xx  = 0;
203   PetscErrorCode ierr;
204 
205   PetscFunctionBegin;
206   ierr = PreScaleLeft(N,v1,&xx);CHKERRQ(ierr);
207   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
208   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
209   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
210   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
211   if (v2 == v3) {
212     if (Na->scale == (PetscScalar)1.0 && !Na->right) {
213       ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
214       ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
215     } else {
216       if (!Na->orwork) {ierr = VecDuplicate(v3,&Na->orwork);CHKERRQ(ierr);}
217       ierr = VecScatterBegin(Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
218       ierr = VecScatterEnd  (Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
219       ierr = PostScaleRight(N,Na->orwork);CHKERRQ(ierr);
220       ierr = VecAXPY(v3,Na->scale,Na->orwork);CHKERRQ(ierr);
221     }
222   } else {
223     ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
224     ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
225     ierr = PostScaleRight(N,v3);CHKERRQ(ierr);
226     ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
227   }
228   PetscFunctionReturn(0);
229 }
230 
231 static PetscErrorCode MatDestroy_SubMatrix(Mat N)
232 {
233   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
234   PetscErrorCode ierr;
235 
236   PetscFunctionBegin;
237   ierr = ISDestroy(&Na->isrow);CHKERRQ(ierr);
238   ierr = ISDestroy(&Na->iscol);CHKERRQ(ierr);
239   ierr = VecDestroy(&Na->left);CHKERRQ(ierr);
240   ierr = VecDestroy(&Na->right);CHKERRQ(ierr);
241   ierr = VecDestroy(&Na->olwork);CHKERRQ(ierr);
242   ierr = VecDestroy(&Na->orwork);CHKERRQ(ierr);
243   ierr = VecDestroy(&Na->lwork);CHKERRQ(ierr);
244   ierr = VecDestroy(&Na->rwork);CHKERRQ(ierr);
245   ierr = VecScatterDestroy(&Na->lrestrict);CHKERRQ(ierr);
246   ierr = VecScatterDestroy(&Na->rprolong);CHKERRQ(ierr);
247   ierr = MatDestroy(&Na->A);CHKERRQ(ierr);
248   ierr = PetscFree(N->data);CHKERRQ(ierr);
249   PetscFunctionReturn(0);
250 }
251 
252 /*@
253    MatCreateSubMatrixVirtual - Creates a virtual matrix that acts as a submatrix
254 
255    Collective on Mat
256 
257    Input Parameters:
258 +  A - matrix that we will extract a submatrix of
259 .  isrow - rows to be present in the submatrix
260 -  iscol - columns to be present in the submatrix
261 
262    Output Parameters:
263 .  newmat - new matrix
264 
265    Level: developer
266 
267    Notes:
268    Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.
269 
270 .seealso: MatCreateSubMatrix(), MatSubMatrixVirtualUpdate()
271 @*/
272 PetscErrorCode MatCreateSubMatrixVirtual(Mat A,IS isrow,IS iscol,Mat *newmat)
273 {
274   Vec            left,right;
275   PetscInt       m,n;
276   Mat            N;
277   Mat_SubVirtual *Na;
278   PetscErrorCode ierr;
279 
280   PetscFunctionBegin;
281   PetscValidHeaderSpecific(A,MAT_CLASSID,1);
282   PetscValidHeaderSpecific(isrow,IS_CLASSID,2);
283   PetscValidHeaderSpecific(iscol,IS_CLASSID,3);
284   PetscValidPointer(newmat,4);
285   *newmat = 0;
286 
287   ierr = MatCreate(PetscObjectComm((PetscObject)A),&N);CHKERRQ(ierr);
288   ierr = ISGetLocalSize(isrow,&m);CHKERRQ(ierr);
289   ierr = ISGetLocalSize(iscol,&n);CHKERRQ(ierr);
290   ierr = MatSetSizes(N,m,n,PETSC_DETERMINE,PETSC_DETERMINE);CHKERRQ(ierr);
291   ierr = PetscObjectChangeTypeName((PetscObject)N,MATSUBMATRIX);CHKERRQ(ierr);
292 
293   ierr      = PetscNewLog(N,&Na);CHKERRQ(ierr);
294   N->data   = (void*)Na;
295   ierr      = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
296   ierr      = PetscObjectReference((PetscObject)isrow);CHKERRQ(ierr);
297   ierr      = PetscObjectReference((PetscObject)iscol);CHKERRQ(ierr);
298   Na->A     = A;
299   Na->isrow = isrow;
300   Na->iscol = iscol;
301   Na->scale = 1.0;
302 
303   N->ops->destroy          = MatDestroy_SubMatrix;
304   N->ops->solve            = MatSolve_SubMatrix;
305   N->ops->mult             = MatMult_SubMatrix;
306   N->ops->multadd          = MatMultAdd_SubMatrix;
307   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
308   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
309   N->ops->scale            = MatScale_SubMatrix;
310   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
311 
312   ierr = MatSetBlockSizesFromMats(N,A,A);CHKERRQ(ierr);
313   ierr = PetscLayoutSetUp(N->rmap);CHKERRQ(ierr);
314   ierr = PetscLayoutSetUp(N->cmap);CHKERRQ(ierr);
315 
316   ierr = MatCreateVecs(A,&Na->rwork,&Na->lwork);CHKERRQ(ierr);
317   ierr = VecCreate(PetscObjectComm((PetscObject)isrow),&left);CHKERRQ(ierr);
318   ierr = VecCreate(PetscObjectComm((PetscObject)iscol),&right);CHKERRQ(ierr);
319   ierr = VecSetSizes(left,m,PETSC_DETERMINE);CHKERRQ(ierr);
320   ierr = VecSetSizes(right,n,PETSC_DETERMINE);CHKERRQ(ierr);
321   ierr = VecSetUp(left);CHKERRQ(ierr);
322   ierr = VecSetUp(right);CHKERRQ(ierr);
323   ierr = VecScatterCreate(Na->lwork,isrow,left,NULL,&Na->lrestrict);CHKERRQ(ierr);
324   ierr = VecScatterCreate(right,NULL,Na->rwork,iscol,&Na->rprolong);CHKERRQ(ierr);
325   ierr = VecDestroy(&left);CHKERRQ(ierr);
326   ierr = VecDestroy(&right);CHKERRQ(ierr);
327 
328   N->assembled = PETSC_TRUE;
329 
330   ierr = MatSetUp(N);CHKERRQ(ierr);
331 
332   *newmat      = N;
333   PetscFunctionReturn(0);
334 }
335 
336 
337 /*@
338    MatSubMatrixVirtualUpdate - Updates a submatrix
339 
340    Collective on Mat
341 
342    Input Parameters:
343 +  N - submatrix to update
344 .  A - full matrix in the submatrix
345 .  isrow - rows in the update (same as the first time the submatrix was created)
346 -  iscol - columns in the update (same as the first time the submatrix was created)
347 
348    Level: developer
349 
350    Notes:
351    Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.
352 
353 .seealso: MatCreateSubMatrixVirtual()
354 @*/
355 PetscErrorCode  MatSubMatrixVirtualUpdate(Mat N,Mat A,IS isrow,IS iscol)
356 {
357   PetscErrorCode ierr;
358   PetscBool      flg;
359   Mat_SubVirtual *Na;
360 
361   PetscFunctionBegin;
362   PetscValidHeaderSpecific(N,MAT_CLASSID,1);
363   PetscValidHeaderSpecific(A,MAT_CLASSID,2);
364   PetscValidHeaderSpecific(isrow,IS_CLASSID,3);
365   PetscValidHeaderSpecific(iscol,IS_CLASSID,4);
366   ierr = PetscObjectTypeCompare((PetscObject)N,MATSUBMATRIX,&flg);CHKERRQ(ierr);
367   if (!flg) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONG,"Matrix has wrong type");
368 
369   Na   = (Mat_SubVirtual*)N->data;
370   ierr = ISEqual(isrow,Na->isrow,&flg);CHKERRQ(ierr);
371   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different row indices");
372   ierr = ISEqual(iscol,Na->iscol,&flg);CHKERRQ(ierr);
373   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different column indices");
374 
375   ierr  = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
376   ierr  = MatDestroy(&Na->A);CHKERRQ(ierr);
377   Na->A = A;
378 
379   Na->scale = 1.0;
380   ierr      = VecDestroy(&Na->left);CHKERRQ(ierr);
381   ierr      = VecDestroy(&Na->right);CHKERRQ(ierr);
382   PetscFunctionReturn(0);
383 }
384