xref: /petsc/src/mat/impls/submat/submat.c (revision e0ed867b6c64a799ff0332d58c8126f3e917b638)
1 
2 #include <petsc/private/matimpl.h>          /*I "petscmat.h" I*/
3 
4 typedef struct {
5   IS          isrow,iscol;      /* rows and columns in submatrix, only used to check consistency */
6   Vec         left,right;       /* optional scaling */
7   Vec         olwork,orwork;    /* work vectors outside the scatters, only touched by PreScale and only created if needed*/
8   Vec         lwork,rwork;      /* work vectors inside the scatters */
9   VecScatter  lrestrict,rprolong;
10   Mat         A;
11   PetscScalar scale;
12 } Mat_SubVirtual;
13 
14 static PetscErrorCode PreScaleLeft(Mat N,Vec x,Vec *xx)
15 {
16   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
17   PetscErrorCode ierr;
18 
19   PetscFunctionBegin;
20   if (!Na->left) {
21     *xx = x;
22   } else {
23     if (!Na->olwork) {
24       ierr = VecDuplicate(Na->left,&Na->olwork);CHKERRQ(ierr);
25     }
26     ierr = VecPointwiseMult(Na->olwork,x,Na->left);CHKERRQ(ierr);
27     *xx  = Na->olwork;
28   }
29   PetscFunctionReturn(0);
30 }
31 
32 static PetscErrorCode PreScaleRight(Mat N,Vec x,Vec *xx)
33 {
34   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
35   PetscErrorCode ierr;
36 
37   PetscFunctionBegin;
38   if (!Na->right) {
39     *xx = x;
40   } else {
41     if (!Na->orwork) {
42       ierr = VecDuplicate(Na->right,&Na->orwork);CHKERRQ(ierr);
43     }
44     ierr = VecPointwiseMult(Na->orwork,x,Na->right);CHKERRQ(ierr);
45     *xx  = Na->orwork;
46   }
47   PetscFunctionReturn(0);
48 }
49 
50 static PetscErrorCode PostScaleLeft(Mat N,Vec x)
51 {
52   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
53   PetscErrorCode ierr;
54 
55   PetscFunctionBegin;
56   if (Na->left) {
57     ierr = VecPointwiseMult(x,x,Na->left);CHKERRQ(ierr);
58   }
59   PetscFunctionReturn(0);
60 }
61 
62 static PetscErrorCode PostScaleRight(Mat N,Vec x)
63 {
64   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
65   PetscErrorCode ierr;
66 
67   PetscFunctionBegin;
68   if (Na->right) {
69     ierr = VecPointwiseMult(x,x,Na->right);CHKERRQ(ierr);
70   }
71   PetscFunctionReturn(0);
72 }
73 
74 static PetscErrorCode MatShift_SubMatrix(Mat N,PetscScalar shift)
75 {
76   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
77   PetscErrorCode ierr;
78 
79   PetscFunctionBegin;
80   ierr = MatShift(Na->A,shift);CHKERRQ(ierr);
81   PetscFunctionReturn(0);
82 }
83 
84 static PetscErrorCode MatScale_SubMatrix(Mat N,PetscScalar scale)
85 {
86   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
87 
88   PetscFunctionBegin;
89   Na->scale *= scale;
90   PetscFunctionReturn(0);
91 }
92 
93 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N,Vec left,Vec right)
94 {
95   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
96   PetscErrorCode ierr;
97 
98   PetscFunctionBegin;
99   if (left) {
100     if (!Na->left) {
101       ierr = VecDuplicate(left,&Na->left);CHKERRQ(ierr);
102       ierr = VecCopy(left,Na->left);CHKERRQ(ierr);
103     } else {
104       ierr = VecPointwiseMult(Na->left,left,Na->left);CHKERRQ(ierr);
105     }
106   }
107   if (right) {
108     if (!Na->right) {
109       ierr = VecDuplicate(right,&Na->right);CHKERRQ(ierr);
110       ierr = VecCopy(right,Na->right);CHKERRQ(ierr);
111     } else {
112       ierr = VecPointwiseMult(Na->right,right,Na->right);CHKERRQ(ierr);
113     }
114   }
115   PetscFunctionReturn(0);
116 }
117 
118 static PetscErrorCode MatSolve_SubMatrix(Mat N,Vec x,Vec y)
119 {
120   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
121   Vec            lwork, rwork, xx  = 0;
122   PetscErrorCode ierr;
123 
124   PetscFunctionBegin;
125   ierr = PreScaleLeft(N,x,&xx);CHKERRQ(ierr);
126   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
127   ierr = VecGetSubVector(Na->lwork, Na->iscol, &lwork);CHKERRQ(ierr);
128   ierr = VecCopy(x, lwork);CHKERRQ(ierr);
129   ierr = VecRestoreSubVector(Na->lwork, Na->iscol, &lwork);CHKERRQ(ierr);
130   ierr = MatSolve(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
131   ierr = VecGetSubVector(Na->rwork, Na->isrow, &rwork);CHKERRQ(ierr);
132   ierr = VecCopy(rwork, y);CHKERRQ(ierr);
133   ierr = VecRestoreSubVector(Na->rwork, Na->isrow, &rwork);CHKERRQ(ierr);
134   ierr = PostScaleRight(N,y);CHKERRQ(ierr);
135   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
136   PetscFunctionReturn(0);
137 }
138 
139 static PetscErrorCode MatMult_SubMatrix(Mat N,Vec x,Vec y)
140 {
141   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
142   Vec            xx  = 0;
143   PetscErrorCode ierr;
144 
145   PetscFunctionBegin;
146   ierr = PreScaleRight(N,x,&xx);CHKERRQ(ierr);
147   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
148   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
149   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
150   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
151   ierr = VecScatterBegin(Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
152   ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
153   ierr = PostScaleLeft(N,y);CHKERRQ(ierr);
154   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
155   PetscFunctionReturn(0);
156 }
157 
158 static PetscErrorCode MatMultAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
159 {
160   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
161   Vec            xx  = 0;
162   PetscErrorCode ierr;
163 
164   PetscFunctionBegin;
165   ierr = PreScaleRight(N,v1,&xx);CHKERRQ(ierr);
166   ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr);
167   ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
168   ierr = VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
169   ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr);
170   if (v2 == v3) {
171     if (Na->scale == (PetscScalar)1.0 && !Na->left) {
172       ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
173       ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
174     } else {
175       if (!Na->olwork) {ierr = VecDuplicate(v3,&Na->olwork);CHKERRQ(ierr);}
176       ierr = VecScatterBegin(Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
177       ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
178       ierr = PostScaleLeft(N,Na->olwork);CHKERRQ(ierr);
179       ierr = VecAXPY(v3,Na->scale,Na->olwork);CHKERRQ(ierr);
180     }
181   } else {
182     ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
183     ierr = VecScatterEnd  (Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
184     ierr = PostScaleLeft(N,v3);CHKERRQ(ierr);
185     ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
186   }
187   PetscFunctionReturn(0);
188 }
189 
190 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N,Vec x,Vec y)
191 {
192   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
193   Vec            xx  = 0;
194   PetscErrorCode ierr;
195 
196   PetscFunctionBegin;
197   ierr = PreScaleLeft(N,x,&xx);CHKERRQ(ierr);
198   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
199   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
200   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
201   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
202   ierr = VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
203   ierr = VecScatterEnd  (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
204   ierr = PostScaleRight(N,y);CHKERRQ(ierr);
205   ierr = VecScale(y,Na->scale);CHKERRQ(ierr);
206   PetscFunctionReturn(0);
207 }
208 
209 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
210 {
211   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
212   Vec            xx  = 0;
213   PetscErrorCode ierr;
214 
215   PetscFunctionBegin;
216   ierr = PreScaleLeft(N,v1,&xx);CHKERRQ(ierr);
217   ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr);
218   ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
219   ierr = VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
220   ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr);
221   if (v2 == v3) {
222     if (Na->scale == (PetscScalar)1.0 && !Na->right) {
223       ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
224       ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
225     } else {
226       if (!Na->orwork) {ierr = VecDuplicate(v3,&Na->orwork);CHKERRQ(ierr);}
227       ierr = VecScatterBegin(Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
228       ierr = VecScatterEnd  (Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
229       ierr = PostScaleRight(N,Na->orwork);CHKERRQ(ierr);
230       ierr = VecAXPY(v3,Na->scale,Na->orwork);CHKERRQ(ierr);
231     }
232   } else {
233     ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
234     ierr = VecScatterEnd  (Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
235     ierr = PostScaleRight(N,v3);CHKERRQ(ierr);
236     ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr);
237   }
238   PetscFunctionReturn(0);
239 }
240 
241 static PetscErrorCode MatDestroy_SubMatrix(Mat N)
242 {
243   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
244   PetscErrorCode ierr;
245 
246   PetscFunctionBegin;
247   ierr = ISDestroy(&Na->isrow);CHKERRQ(ierr);
248   ierr = ISDestroy(&Na->iscol);CHKERRQ(ierr);
249   ierr = VecDestroy(&Na->left);CHKERRQ(ierr);
250   ierr = VecDestroy(&Na->right);CHKERRQ(ierr);
251   ierr = VecDestroy(&Na->olwork);CHKERRQ(ierr);
252   ierr = VecDestroy(&Na->orwork);CHKERRQ(ierr);
253   ierr = VecDestroy(&Na->lwork);CHKERRQ(ierr);
254   ierr = VecDestroy(&Na->rwork);CHKERRQ(ierr);
255   ierr = VecScatterDestroy(&Na->lrestrict);CHKERRQ(ierr);
256   ierr = VecScatterDestroy(&Na->rprolong);CHKERRQ(ierr);
257   ierr = MatDestroy(&Na->A);CHKERRQ(ierr);
258   ierr = PetscFree(N->data);CHKERRQ(ierr);
259   PetscFunctionReturn(0);
260 }
261 
262 /*@
263    MatCreateSubMatrixVirtual - Creates a virtual matrix that acts as a submatrix
264 
265    Collective on Mat
266 
267    Input Parameters:
268 +  A - matrix that we will extract a submatrix of
269 .  isrow - rows to be present in the submatrix
270 -  iscol - columns to be present in the submatrix
271 
272    Output Parameters:
273 .  newmat - new matrix
274 
275    Level: developer
276 
277    Notes:
278    Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.
279 
280 .seealso: MatCreateSubMatrix(), MatSubMatrixVirtualUpdate()
281 @*/
282 PetscErrorCode MatCreateSubMatrixVirtual(Mat A,IS isrow,IS iscol,Mat *newmat)
283 {
284   Vec            left,right;
285   PetscInt       m,n;
286   Mat            N;
287   Mat_SubVirtual *Na;
288   PetscErrorCode ierr;
289 
290   PetscFunctionBegin;
291   PetscValidHeaderSpecific(A,MAT_CLASSID,1);
292   PetscValidHeaderSpecific(isrow,IS_CLASSID,2);
293   PetscValidHeaderSpecific(iscol,IS_CLASSID,3);
294   PetscValidPointer(newmat,4);
295   *newmat = 0;
296 
297   ierr = MatCreate(PetscObjectComm((PetscObject)A),&N);CHKERRQ(ierr);
298   ierr = ISGetLocalSize(isrow,&m);CHKERRQ(ierr);
299   ierr = ISGetLocalSize(iscol,&n);CHKERRQ(ierr);
300   ierr = MatSetSizes(N,m,n,PETSC_DETERMINE,PETSC_DETERMINE);CHKERRQ(ierr);
301   ierr = PetscObjectChangeTypeName((PetscObject)N,MATSUBMATRIX);CHKERRQ(ierr);
302 
303   ierr      = PetscNewLog(N,&Na);CHKERRQ(ierr);
304   N->data   = (void*)Na;
305   ierr      = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
306   ierr      = PetscObjectReference((PetscObject)isrow);CHKERRQ(ierr);
307   ierr      = PetscObjectReference((PetscObject)iscol);CHKERRQ(ierr);
308   Na->A     = A;
309   Na->isrow = isrow;
310   Na->iscol = iscol;
311   Na->scale = 1.0;
312 
313   N->ops->destroy          = MatDestroy_SubMatrix;
314   N->ops->solve            = MatSolve_SubMatrix;
315   N->ops->mult             = MatMult_SubMatrix;
316   N->ops->multadd          = MatMultAdd_SubMatrix;
317   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
318   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
319   N->ops->scale            = MatScale_SubMatrix;
320   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
321   if (A->ops->shift) N->ops->shift = MatShift_SubMatrix;
322 
323   ierr = MatSetBlockSizesFromMats(N,A,A);CHKERRQ(ierr);
324   ierr = PetscLayoutSetUp(N->rmap);CHKERRQ(ierr);
325   ierr = PetscLayoutSetUp(N->cmap);CHKERRQ(ierr);
326 
327   ierr = MatCreateVecs(A,&Na->rwork,&Na->lwork);CHKERRQ(ierr);
328   ierr = VecCreate(PetscObjectComm((PetscObject)isrow),&left);CHKERRQ(ierr);
329   ierr = VecCreate(PetscObjectComm((PetscObject)iscol),&right);CHKERRQ(ierr);
330   ierr = VecSetSizes(left,m,PETSC_DETERMINE);CHKERRQ(ierr);
331   ierr = VecSetSizes(right,n,PETSC_DETERMINE);CHKERRQ(ierr);
332   ierr = VecSetUp(left);CHKERRQ(ierr);
333   ierr = VecSetUp(right);CHKERRQ(ierr);
334   ierr = VecScatterCreate(Na->lwork,isrow,left,NULL,&Na->lrestrict);CHKERRQ(ierr);
335   ierr = VecScatterCreate(right,NULL,Na->rwork,iscol,&Na->rprolong);CHKERRQ(ierr);
336   ierr = VecDestroy(&left);CHKERRQ(ierr);
337   ierr = VecDestroy(&right);CHKERRQ(ierr);
338 
339   N->assembled = PETSC_TRUE;
340 
341   ierr = MatSetUp(N);CHKERRQ(ierr);
342 
343   *newmat      = N;
344   PetscFunctionReturn(0);
345 }
346 
347 
348 /*@
349    MatSubMatrixVirtualUpdate - Updates a submatrix
350 
351    Collective on Mat
352 
353    Input Parameters:
354 +  N - submatrix to update
355 .  A - full matrix in the submatrix
356 .  isrow - rows in the update (same as the first time the submatrix was created)
357 -  iscol - columns in the update (same as the first time the submatrix was created)
358 
359    Level: developer
360 
361    Notes:
362    Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.
363 
364 .seealso: MatCreateSubMatrixVirtual()
365 @*/
366 PetscErrorCode  MatSubMatrixVirtualUpdate(Mat N,Mat A,IS isrow,IS iscol)
367 {
368   PetscErrorCode ierr;
369   PetscBool      flg;
370   Mat_SubVirtual *Na;
371 
372   PetscFunctionBegin;
373   PetscValidHeaderSpecific(N,MAT_CLASSID,1);
374   PetscValidHeaderSpecific(A,MAT_CLASSID,2);
375   PetscValidHeaderSpecific(isrow,IS_CLASSID,3);
376   PetscValidHeaderSpecific(iscol,IS_CLASSID,4);
377   ierr = PetscObjectTypeCompare((PetscObject)N,MATSUBMATRIX,&flg);CHKERRQ(ierr);
378   if (!flg) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONG,"Matrix has wrong type");
379 
380   Na   = (Mat_SubVirtual*)N->data;
381   ierr = ISEqual(isrow,Na->isrow,&flg);CHKERRQ(ierr);
382   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different row indices");
383   ierr = ISEqual(iscol,Na->iscol,&flg);CHKERRQ(ierr);
384   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different column indices");
385 
386   ierr  = PetscObjectReference((PetscObject)A);CHKERRQ(ierr);
387   ierr  = MatDestroy(&Na->A);CHKERRQ(ierr);
388   Na->A = A;
389 
390   Na->scale = 1.0;
391   ierr      = VecDestroy(&Na->left);CHKERRQ(ierr);
392   ierr      = VecDestroy(&Na->right);CHKERRQ(ierr);
393   PetscFunctionReturn(0);
394 }
395