xref: /petsc/src/mat/impls/submat/submat.c (revision cd929ea3f739fd9f7b6394f772cb40b9d4e6d97c)
1 
2 #include <petsc/private/matimpl.h>          /*I "petscmat.h" I*/
3 
4 typedef struct {
5   IS          isrow,iscol;      /* rows and columns in submatrix, only used to check consistency */
6   Vec         left,right;       /* optional scaling */
7   Vec         olwork,orwork;    /* work vectors outside the scatters, only touched by PreScale and only created if needed*/
8   Vec         lwork,rwork;      /* work vectors inside the scatters */
9   VecScatter  lrestrict,rprolong;
10   Mat         A;
11   PetscScalar scale;
12 } Mat_SubVirtual;
13 
14 static PetscErrorCode PreScaleLeft(Mat N,Vec x,Vec *xx)
15 {
16   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
17   PetscErrorCode ierr;
18 
19   PetscFunctionBegin;
20   if (!Na->left) {
21     *xx = x;
22   } else {
23     if (!Na->olwork) {
24       ierr = VecDuplicate(Na->left,&Na->olwork);CHKERRQ(ierr);
25     }
26     ierr = VecPointwiseMult(Na->olwork,x,Na->left);CHKERRQ(ierr);
27     *xx  = Na->olwork;
28   }
29   PetscFunctionReturn(0);
30 }
31 
32 static PetscErrorCode PreScaleRight(Mat N,Vec x,Vec *xx)
33 {
34   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
35   PetscErrorCode ierr;
36 
37   PetscFunctionBegin;
38   if (!Na->right) {
39     *xx = x;
40   } else {
41     if (!Na->orwork) {
42       ierr = VecDuplicate(Na->right,&Na->orwork);CHKERRQ(ierr);
43     }
44     ierr = VecPointwiseMult(Na->orwork,x,Na->right);CHKERRQ(ierr);
45     *xx  = Na->orwork;
46   }
47   PetscFunctionReturn(0);
48 }
49 
50 static PetscErrorCode PostScaleLeft(Mat N,Vec x)
51 {
52   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
53   PetscErrorCode ierr;
54 
55   PetscFunctionBegin;
56   if (Na->left) {
57     ierr = VecPointwiseMult(x,x,Na->left);CHKERRQ(ierr);
58   }
59   PetscFunctionReturn(0);
60 }
61 
62 static PetscErrorCode PostScaleRight(Mat N,Vec x)
63 {
64   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
65   PetscErrorCode ierr;
66 
67   PetscFunctionBegin;
68   if (Na->right) {
69     ierr = VecPointwiseMult(x,x,Na->right);CHKERRQ(ierr);
70   }
71   PetscFunctionReturn(0);
72 }
73 
74 static PetscErrorCode MatScale_SubMatrix(Mat N,PetscScalar scale)
75 {
76   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
77 
78   PetscFunctionBegin;
79   Na->scale *= scale;
80   PetscFunctionReturn(0);
81 }
82 
83 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N,Vec left,Vec right)
84 {
85   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
86   PetscErrorCode ierr;
87 
88   PetscFunctionBegin;
89   if (left) {
90     if (!Na->left) {
91       ierr = VecDuplicate(left,&Na->left);CHKERRQ(ierr);
92       ierr = VecCopy(left,Na->left);CHKERRQ(ierr);
93     } else {
94       ierr = VecPointwiseMult(Na->left,left,Na->left);CHKERRQ(ierr);
95     }
96   }
97   if (right) {
98     if (!Na->right) {
99       ierr = VecDuplicate(right,&Na->right);CHKERRQ(ierr);
100       ierr = VecCopy(right,Na->right);CHKERRQ(ierr);
101     } else {
102       ierr = VecPointwiseMult(Na->right,right,Na->right);CHKERRQ(ierr);
103     }
104   }
105   PetscFunctionReturn(0);
106 }
107 
108 static PetscErrorCode MatSolve_SubMatrix(Mat N,Vec x,Vec y)
109 {
110   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
111   Vec            xx  = 0;
112   PetscErrorCode ierr;
113 
114   PetscFunctionBegin;
115   ierr = PreScaleLeft(N,x,&xx);CHKERRQ(ierr);
116   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
117   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
118   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
119   ierr = MatSolve(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
120   ierr = VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
121   ierr = VecScatterEnd  (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
122   ierr = PostScaleRight(N,y);CHKERRQ(ierr);
123   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
124   PetscFunctionReturn(0);
125 }
126 
127 static PetscErrorCode MatMult_SubMatrix(Mat N,Vec x,Vec y)
128 {
129   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
130   Vec            xx  = 0;
131   PetscErrorCode ierr;
132 
133   PetscFunctionBegin;
134   ierr = PreScaleRight(N,x,&xx);CHKERRQ(ierr);
135   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
136   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
137   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
138   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
139   ierr = VecScatterBegin(Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
140   ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
141   ierr = PostScaleLeft(N,y);CHKERRQ(ierr);
142   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
143   PetscFunctionReturn(0);
144 }
145 
146 static PetscErrorCode MatMultAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
147 {
148   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
149   Vec            xx  = 0;
150   PetscErrorCode ierr;
151 
152   PetscFunctionBegin;
153   ierr = PreScaleRight(N,v1,&xx);CHKERRQ(ierr);
154   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
155   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
156   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
157   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
158   if (v2 == v3) {
159     if (Na->scale == (PetscScalar)1.0 && !Na->left) {
160       ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
161       ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
162     } else {
163       if (!Na->olwork) {ierr = VecDuplicate(v3,&Na->olwork);CHKERRQ(ierr);}
164       ierr = VecScatterBegin(Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
165       ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
166       ierr = PostScaleLeft(N,Na->olwork);CHKERRQ(ierr);
167       ierr = VecAXPY(v3,Na->scale,Na->olwork);CHKERRQ(ierr);
168     }
169   } else {
170     ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
171     ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
172     ierr = PostScaleLeft(N,v3);CHKERRQ(ierr);
173     ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
174   }
175   PetscFunctionReturn(0);
176 }
177 
178 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N,Vec x,Vec y)
179 {
180   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
181   Vec            xx  = 0;
182   PetscErrorCode ierr;
183 
184   PetscFunctionBegin;
185   ierr = PreScaleLeft(N,x,&xx);CHKERRQ(ierr);
186   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
187   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
188   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
189   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
190   ierr = VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
191   ierr = VecScatterEnd  (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
192   ierr = PostScaleRight(N,y);CHKERRQ(ierr);
193   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
194   PetscFunctionReturn(0);
195 }
196 
197 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
198 {
199   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
200   Vec            xx  = 0;
201   PetscErrorCode ierr;
202 
203   PetscFunctionBegin;
204   ierr = PreScaleLeft(N,v1,&xx);CHKERRQ(ierr);
205   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
206   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
207   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
208   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
209   if (v2 == v3) {
210     if (Na->scale == (PetscScalar)1.0 && !Na->right) {
211       ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
212       ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
213     } else {
214       if (!Na->orwork) {ierr = VecDuplicate(v3,&Na->orwork);CHKERRQ(ierr);}
215       ierr = VecScatterBegin(Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
216       ierr = VecScatterEnd  (Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
217       ierr = PostScaleRight(N,Na->orwork);CHKERRQ(ierr);
218       ierr = VecAXPY(v3,Na->scale,Na->orwork);CHKERRQ(ierr);
219     }
220   } else {
221     ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
222     ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
223     ierr = PostScaleRight(N,v3);CHKERRQ(ierr);
224     ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
225   }
226   PetscFunctionReturn(0);
227 }
228 
229 static PetscErrorCode MatDestroy_SubMatrix(Mat N)
230 {
231   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
232   PetscErrorCode ierr;
233 
234   PetscFunctionBegin;
235   ierr = ISDestroy(&Na->isrow);CHKERRQ(ierr);
236   ierr = ISDestroy(&Na->iscol);CHKERRQ(ierr);
237   ierr = VecDestroy(&Na->left);CHKERRQ(ierr);
238   ierr = VecDestroy(&Na->right);CHKERRQ(ierr);
239   ierr = VecDestroy(&Na->olwork);CHKERRQ(ierr);
240   ierr = VecDestroy(&Na->orwork);CHKERRQ(ierr);
241   ierr = VecDestroy(&Na->lwork);CHKERRQ(ierr);
242   ierr = VecDestroy(&Na->rwork);CHKERRQ(ierr);
243   ierr = VecScatterDestroy(&Na->lrestrict);CHKERRQ(ierr);
244   ierr = VecScatterDestroy(&Na->rprolong);CHKERRQ(ierr);
245   ierr = MatDestroy(&Na->A);CHKERRQ(ierr);
246   ierr = PetscFree(N->data);CHKERRQ(ierr);
247   PetscFunctionReturn(0);
248 }
249 
250 /*@
251    MatCreateSubMatrixVirtual - Creates a virtual matrix that acts as a submatrix
252 
253    Collective on Mat
254 
255    Input Parameters:
256 +  A - matrix that we will extract a submatrix of
257 .  isrow - rows to be present in the submatrix
258 -  iscol - columns to be present in the submatrix
259 
260    Output Parameters:
261 .  newmat - new matrix
262 
263    Level: developer
264 
265    Notes:
266    Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.
267 
268 .seealso: MatCreateSubMatrix(), MatSubMatrixVirtualUpdate()
269 @*/
270 PetscErrorCode MatCreateSubMatrixVirtual(Mat A,IS isrow,IS iscol,Mat *newmat)
271 {
272   Vec            left,right;
273   PetscInt       m,n;
274   Mat            N;
275   Mat_SubVirtual *Na;
276   PetscErrorCode ierr;
277 
278   PetscFunctionBegin;
279   PetscValidHeaderSpecific(A,MAT_CLASSID,1);
280   PetscValidHeaderSpecific(isrow,IS_CLASSID,2);
281   PetscValidHeaderSpecific(iscol,IS_CLASSID,3);
282   PetscValidPointer(newmat,4);
283   *newmat = 0;
284 
285   ierr = MatCreate(PetscObjectComm((PetscObject)A),&N);CHKERRQ(ierr);
286   ierr = ISGetLocalSize(isrow,&m);CHKERRQ(ierr);
287   ierr = ISGetLocalSize(iscol,&n);CHKERRQ(ierr);
288   ierr = MatSetSizes(N,m,n,PETSC_DETERMINE,PETSC_DETERMINE);CHKERRQ(ierr);
289   ierr = PetscObjectChangeTypeName((PetscObject)N,MATSUBMATRIX);CHKERRQ(ierr);
290 
291   ierr      = PetscNewLog(N,&Na);CHKERRQ(ierr);
292   N->data   = (void*)Na;
293   ierr      = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
294   ierr      = PetscObjectReference((PetscObject)isrow);CHKERRQ(ierr);
295   ierr      = PetscObjectReference((PetscObject)iscol);CHKERRQ(ierr);
296   Na->A     = A;
297   Na->isrow = isrow;
298   Na->iscol = iscol;
299   Na->scale = 1.0;
300 
301   N->ops->destroy          = MatDestroy_SubMatrix;
302   N->ops->solve            = MatSolve_SubMatrix;
303   N->ops->mult             = MatMult_SubMatrix;
304   N->ops->multadd          = MatMultAdd_SubMatrix;
305   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
306   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
307   N->ops->scale            = MatScale_SubMatrix;
308   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
309 
310   ierr = MatSetBlockSizesFromMats(N,A,A);CHKERRQ(ierr);
311   ierr = PetscLayoutSetUp(N->rmap);CHKERRQ(ierr);
312   ierr = PetscLayoutSetUp(N->cmap);CHKERRQ(ierr);
313 
314   ierr = MatCreateVecs(A,&Na->rwork,&Na->lwork);CHKERRQ(ierr);
315   ierr = VecCreate(PetscObjectComm((PetscObject)isrow),&left);CHKERRQ(ierr);
316   ierr = VecCreate(PetscObjectComm((PetscObject)iscol),&right);CHKERRQ(ierr);
317   ierr = VecSetSizes(left,m,PETSC_DETERMINE);CHKERRQ(ierr);
318   ierr = VecSetSizes(right,n,PETSC_DETERMINE);CHKERRQ(ierr);
319   ierr = VecSetUp(left);CHKERRQ(ierr);
320   ierr = VecSetUp(right);CHKERRQ(ierr);
321   ierr = VecScatterCreate(Na->lwork,isrow,left,NULL,&Na->lrestrict);CHKERRQ(ierr);
322   ierr = VecScatterCreate(right,NULL,Na->rwork,iscol,&Na->rprolong);CHKERRQ(ierr);
323   ierr = VecDestroy(&left);CHKERRQ(ierr);
324   ierr = VecDestroy(&right);CHKERRQ(ierr);
325 
326   N->assembled = PETSC_TRUE;
327 
328   ierr = MatSetUp(N);CHKERRQ(ierr);
329 
330   *newmat      = N;
331   PetscFunctionReturn(0);
332 }
333 
334 
335 /*@
336    MatSubMatrixVirtualUpdate - Updates a submatrix
337 
338    Collective on Mat
339 
340    Input Parameters:
341 +  N - submatrix to update
342 .  A - full matrix in the submatrix
343 .  isrow - rows in the update (same as the first time the submatrix was created)
344 -  iscol - columns in the update (same as the first time the submatrix was created)
345 
346    Level: developer
347 
348    Notes:
349    Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.
350 
351 .seealso: MatCreateSubMatrixVirtual()
352 @*/
353 PetscErrorCode  MatSubMatrixVirtualUpdate(Mat N,Mat A,IS isrow,IS iscol)
354 {
355   PetscErrorCode ierr;
356   PetscBool      flg;
357   Mat_SubVirtual *Na;
358 
359   PetscFunctionBegin;
360   PetscValidHeaderSpecific(N,MAT_CLASSID,1);
361   PetscValidHeaderSpecific(A,MAT_CLASSID,2);
362   PetscValidHeaderSpecific(isrow,IS_CLASSID,3);
363   PetscValidHeaderSpecific(iscol,IS_CLASSID,4);
364   ierr = PetscObjectTypeCompare((PetscObject)N,MATSUBMATRIX,&flg);CHKERRQ(ierr);
365   if (!flg) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONG,"Matrix has wrong type");
366 
367   Na   = (Mat_SubVirtual*)N->data;
368   ierr = ISEqual(isrow,Na->isrow,&flg);CHKERRQ(ierr);
369   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different row indices");
370   ierr = ISEqual(iscol,Na->iscol,&flg);CHKERRQ(ierr);
371   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different column indices");
372 
373   ierr  = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
374   ierr  = MatDestroy(&Na->A);CHKERRQ(ierr);
375   Na->A = A;
376 
377   Na->scale = 1.0;
378   ierr      = VecDestroy(&Na->left);CHKERRQ(ierr);
379   ierr      = VecDestroy(&Na->right);CHKERRQ(ierr);
380   PetscFunctionReturn(0);
381 }
382