xref: /petsc/src/mat/impls/transpose/transm.c (revision 2b938fe496202fa1a08b5a98b06ece7455dbcdc4)
1 #include <../src/mat/impls/shell/shell.h> /*I "petscmat.h" I*/
2 
3 static PetscErrorCode MatMult_Transpose(Mat N, Vec x, Vec y)
4 {
5   Mat A;
6 
7   PetscFunctionBegin;
8   PetscCall(MatShellGetContext(N, &A));
9   PetscCall(MatMultTranspose(A, x, y));
10   PetscFunctionReturn(PETSC_SUCCESS);
11 }
12 
13 static PetscErrorCode MatMultTranspose_Transpose(Mat N, Vec x, Vec y)
14 {
15   Mat A;
16 
17   PetscFunctionBegin;
18   PetscCall(MatShellGetContext(N, &A));
19   PetscCall(MatMult(A, x, y));
20   PetscFunctionReturn(PETSC_SUCCESS);
21 }
22 
23 static PetscErrorCode MatSolve_Transpose_LU(Mat N, Vec b, Vec x)
24 {
25   Mat A;
26 
27   PetscFunctionBegin;
28   PetscCall(MatShellGetContext(N, &A));
29   PetscCall(MatSolveTranspose(A, b, x));
30   PetscFunctionReturn(PETSC_SUCCESS);
31 }
32 
33 static PetscErrorCode MatSolveAdd_Transpose_LU(Mat N, Vec b, Vec y, Vec x)
34 {
35   Mat A;
36 
37   PetscFunctionBegin;
38   PetscCall(MatShellGetContext(N, &A));
39   PetscCall(MatSolveTransposeAdd(A, b, y, x));
40   PetscFunctionReturn(PETSC_SUCCESS);
41 }
42 
43 static PetscErrorCode MatSolveTranspose_Transpose_LU(Mat N, Vec b, Vec x)
44 {
45   Mat A;
46 
47   PetscFunctionBegin;
48   PetscCall(MatShellGetContext(N, &A));
49   PetscCall(MatSolve(A, b, x));
50   PetscFunctionReturn(PETSC_SUCCESS);
51 }
52 
53 static PetscErrorCode MatSolveTransposeAdd_Transpose_LU(Mat N, Vec b, Vec y, Vec x)
54 {
55   Mat A;
56 
57   PetscFunctionBegin;
58   PetscCall(MatShellGetContext(N, &A));
59   PetscCall(MatSolveAdd(A, b, y, x));
60   PetscFunctionReturn(PETSC_SUCCESS);
61 }
62 
63 static PetscErrorCode MatMatSolve_Transpose_LU(Mat N, Mat B, Mat X)
64 {
65   Mat A;
66 
67   PetscFunctionBegin;
68   PetscCall(MatShellGetContext(N, &A));
69   PetscCall(MatMatSolveTranspose(A, B, X));
70   PetscFunctionReturn(PETSC_SUCCESS);
71 }
72 
73 static PetscErrorCode MatMatSolveTranspose_Transpose_LU(Mat N, Mat B, Mat X)
74 {
75   Mat A;
76 
77   PetscFunctionBegin;
78   PetscCall(MatShellGetContext(N, &A));
79   PetscCall(MatMatSolve(A, B, X));
80   PetscFunctionReturn(PETSC_SUCCESS);
81 }
82 
83 static PetscErrorCode MatLUFactor_Transpose(Mat N, IS row, IS col, const MatFactorInfo *minfo)
84 {
85   Mat A;
86 
87   PetscFunctionBegin;
88   PetscCall(MatShellGetContext(N, &A));
89   PetscCall(MatLUFactor(A, col, row, minfo));
90   PetscCall(MatShellSetOperation(N, MATOP_SOLVE, (void (*)(void))MatSolve_Transpose_LU));
91   PetscCall(MatShellSetOperation(N, MATOP_SOLVE_ADD, (void (*)(void))MatSolveAdd_Transpose_LU));
92   PetscCall(MatShellSetOperation(N, MATOP_SOLVE_TRANSPOSE, (void (*)(void))MatSolveTranspose_Transpose_LU));
93   PetscCall(MatShellSetOperation(N, MATOP_SOLVE_TRANSPOSE_ADD, (void (*)(void))MatSolveTransposeAdd_Transpose_LU));
94   PetscCall(MatShellSetOperation(N, MATOP_MAT_SOLVE, (void (*)(void))MatMatSolve_Transpose_LU));
95   PetscCall(MatShellSetOperation(N, MATOP_MAT_SOLVE_TRANSPOSE, (void (*)(void))MatMatSolveTranspose_Transpose_LU));
96   PetscFunctionReturn(PETSC_SUCCESS);
97 }
98 
99 static PetscErrorCode MatSolve_Transpose_Cholesky(Mat N, Vec b, Vec x)
100 {
101   Mat A;
102 
103   PetscFunctionBegin;
104   PetscCall(MatShellGetContext(N, &A));
105   PetscCall(MatSolveTranspose(A, b, x));
106   PetscFunctionReturn(PETSC_SUCCESS);
107 }
108 
109 static PetscErrorCode MatSolveAdd_Transpose_Cholesky(Mat N, Vec b, Vec y, Vec x)
110 {
111   Mat A;
112 
113   PetscFunctionBegin;
114   PetscCall(MatShellGetContext(N, &A));
115   PetscCall(MatSolveTransposeAdd(A, b, y, x));
116   PetscFunctionReturn(PETSC_SUCCESS);
117 }
118 
119 static PetscErrorCode MatSolveTranspose_Transpose_Cholesky(Mat N, Vec b, Vec x)
120 {
121   Mat A;
122 
123   PetscFunctionBegin;
124   PetscCall(MatShellGetContext(N, &A));
125   PetscCall(MatSolve(A, b, x));
126   PetscFunctionReturn(PETSC_SUCCESS);
127 }
128 
129 static PetscErrorCode MatSolveTransposeAdd_Transpose_Cholesky(Mat N, Vec b, Vec y, Vec x)
130 {
131   Mat A;
132 
133   PetscFunctionBegin;
134   PetscCall(MatShellGetContext(N, &A));
135   PetscCall(MatSolveAdd(A, b, y, x));
136   PetscFunctionReturn(PETSC_SUCCESS);
137 }
138 
139 static PetscErrorCode MatMatSolve_Transpose_Cholesky(Mat N, Mat B, Mat X)
140 {
141   Mat A;
142 
143   PetscFunctionBegin;
144   PetscCall(MatShellGetContext(N, &A));
145   PetscCall(MatMatSolveTranspose(A, B, X));
146   PetscFunctionReturn(PETSC_SUCCESS);
147 }
148 
149 static PetscErrorCode MatMatSolveTranspose_Transpose_Cholesky(Mat N, Mat B, Mat X)
150 {
151   Mat A;
152 
153   PetscFunctionBegin;
154   PetscCall(MatShellGetContext(N, &A));
155   PetscCall(MatMatSolve(A, B, X));
156   PetscFunctionReturn(PETSC_SUCCESS);
157 }
158 
159 static PetscErrorCode MatCholeskyFactor_Transpose(Mat N, IS perm, const MatFactorInfo *minfo)
160 {
161   Mat A;
162 
163   PetscFunctionBegin;
164   PetscCall(MatShellGetContext(N, &A));
165   PetscCall(MatCholeskyFactor(A, perm, minfo));
166   PetscCall(MatShellSetOperation(N, MATOP_SOLVE, (void (*)(void))MatSolve_Transpose_Cholesky));
167   PetscCall(MatShellSetOperation(N, MATOP_SOLVE_ADD, (void (*)(void))MatSolveAdd_Transpose_Cholesky));
168   PetscCall(MatShellSetOperation(N, MATOP_SOLVE_TRANSPOSE, (void (*)(void))MatSolveTranspose_Transpose_Cholesky));
169   PetscCall(MatShellSetOperation(N, MATOP_SOLVE_TRANSPOSE_ADD, (void (*)(void))MatSolveTransposeAdd_Transpose_Cholesky));
170   PetscCall(MatShellSetOperation(N, MATOP_MAT_SOLVE, (void (*)(void))MatMatSolve_Transpose_Cholesky));
171   PetscCall(MatShellSetOperation(N, MATOP_MAT_SOLVE_TRANSPOSE, (void (*)(void))MatMatSolveTranspose_Transpose_Cholesky));
172   PetscFunctionReturn(PETSC_SUCCESS);
173 }
174 
175 static PetscErrorCode MatDestroy_Transpose(Mat N)
176 {
177   Mat A;
178 
179   PetscFunctionBegin;
180   PetscCall(MatShellGetContext(N, &A));
181   PetscCall(MatDestroy(&A));
182   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatTransposeGetMat_C", NULL));
183   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatProductSetFromOptions_anytype_C", NULL));
184   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatShellSetContext_C", NULL));
185   PetscFunctionReturn(PETSC_SUCCESS);
186 }
187 
188 static PetscErrorCode MatDuplicate_Transpose(Mat N, MatDuplicateOption op, Mat *m)
189 {
190   Mat A, C;
191 
192   PetscFunctionBegin;
193   PetscCall(MatShellGetContext(N, &A));
194   PetscCall(MatDuplicate(A, op, &C));
195   PetscCall(MatCreateTranspose(C, m));
196   if (op == MAT_COPY_VALUES) {
197     PetscCall(MatCopy(N, *m, SAME_NONZERO_PATTERN));
198     PetscCall(MatPropagateSymmetryOptions(A, C));
199   }
200   PetscCall(MatDestroy(&C));
201   PetscFunctionReturn(PETSC_SUCCESS);
202 }
203 
204 static PetscErrorCode MatHasOperation_Transpose(Mat mat, MatOperation op, PetscBool *has)
205 {
206   Mat A;
207 
208   PetscFunctionBegin;
209   PetscCall(MatShellGetContext(mat, &A));
210   *has = PETSC_FALSE;
211   if (op == MATOP_MULT || op == MATOP_MULT_ADD) {
212     PetscCall(MatHasOperation(A, MATOP_MULT_TRANSPOSE, has));
213   } else if (op == MATOP_MULT_TRANSPOSE || op == MATOP_MULT_TRANSPOSE_ADD) {
214     PetscCall(MatHasOperation(A, MATOP_MULT, has));
215   } else if (((void **)mat->ops)[op]) *has = PETSC_TRUE;
216   PetscFunctionReturn(PETSC_SUCCESS);
217 }
218 
219 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_Transpose(Mat D)
220 {
221   Mat            A, B, C, Ain, Bin, Cin;
222   PetscBool      Aistrans, Bistrans, Cistrans;
223   PetscInt       Atrans, Btrans, Ctrans;
224   MatProductType ptype;
225 
226   PetscFunctionBegin;
227   MatCheckProduct(D, 1);
228   A = D->product->A;
229   B = D->product->B;
230   C = D->product->C;
231   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATTRANSPOSEVIRTUAL, &Aistrans));
232   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATTRANSPOSEVIRTUAL, &Bistrans));
233   PetscCall(PetscObjectTypeCompare((PetscObject)C, MATTRANSPOSEVIRTUAL, &Cistrans));
234   PetscCheck(Aistrans || Bistrans || Cistrans, PetscObjectComm((PetscObject)D), PETSC_ERR_PLIB, "This should not happen");
235   Atrans = 0;
236   Ain    = A;
237   while (Aistrans) {
238     Atrans++;
239     PetscCall(MatTransposeGetMat(Ain, &Ain));
240     PetscCall(PetscObjectTypeCompare((PetscObject)Ain, MATTRANSPOSEVIRTUAL, &Aistrans));
241   }
242   Btrans = 0;
243   Bin    = B;
244   while (Bistrans) {
245     Btrans++;
246     PetscCall(MatTransposeGetMat(Bin, &Bin));
247     PetscCall(PetscObjectTypeCompare((PetscObject)Bin, MATTRANSPOSEVIRTUAL, &Bistrans));
248   }
249   Ctrans = 0;
250   Cin    = C;
251   while (Cistrans) {
252     Ctrans++;
253     PetscCall(MatTransposeGetMat(Cin, &Cin));
254     PetscCall(PetscObjectTypeCompare((PetscObject)Cin, MATTRANSPOSEVIRTUAL, &Cistrans));
255   }
256   Atrans = Atrans % 2;
257   Btrans = Btrans % 2;
258   Ctrans = Ctrans % 2;
259   ptype  = D->product->type; /* same product type by default */
260   if (Ain->symmetric == PETSC_BOOL3_TRUE) Atrans = 0;
261   if (Bin->symmetric == PETSC_BOOL3_TRUE) Btrans = 0;
262   if (Cin && Cin->symmetric == PETSC_BOOL3_TRUE) Ctrans = 0;
263 
264   if (Atrans || Btrans || Ctrans) {
265     ptype = MATPRODUCT_UNSPECIFIED;
266     switch (D->product->type) {
267     case MATPRODUCT_AB:
268       if (Atrans && Btrans) { /* At * Bt we do not have support for this */
269         /* TODO custom implementation ? */
270       } else if (Atrans) { /* At * B */
271         ptype = MATPRODUCT_AtB;
272       } else { /* A * Bt */
273         ptype = MATPRODUCT_ABt;
274       }
275       break;
276     case MATPRODUCT_AtB:
277       if (Atrans && Btrans) { /* A * Bt */
278         ptype = MATPRODUCT_ABt;
279       } else if (Atrans) { /* A * B */
280         ptype = MATPRODUCT_AB;
281       } else { /* At * Bt we do not have support for this */
282         /* TODO custom implementation ? */
283       }
284       break;
285     case MATPRODUCT_ABt:
286       if (Atrans && Btrans) { /* At * B */
287         ptype = MATPRODUCT_AtB;
288       } else if (Atrans) { /* At * Bt we do not have support for this */
289         /* TODO custom implementation ? */
290       } else { /* A * B */
291         ptype = MATPRODUCT_AB;
292       }
293       break;
294     case MATPRODUCT_PtAP:
295       if (Atrans) { /* PtAtP */
296         /* TODO custom implementation ? */
297       } else { /* RARt */
298         ptype = MATPRODUCT_RARt;
299       }
300       break;
301     case MATPRODUCT_RARt:
302       if (Atrans) { /* RAtRt */
303         /* TODO custom implementation ? */
304       } else { /* PtAP */
305         ptype = MATPRODUCT_PtAP;
306       }
307       break;
308     case MATPRODUCT_ABC:
309       /* TODO custom implementation ? */
310       break;
311     default:
312       SETERRQ(PetscObjectComm((PetscObject)D), PETSC_ERR_SUP, "ProductType %s is not supported", MatProductTypes[D->product->type]);
313     }
314   }
315   PetscCall(MatProductReplaceMats(Ain, Bin, Cin, D));
316   PetscCall(MatProductSetType(D, ptype));
317   PetscCall(MatProductSetFromOptions(D));
318   PetscFunctionReturn(PETSC_SUCCESS);
319 }
320 
321 static PetscErrorCode MatGetDiagonal_Transpose(Mat N, Vec v)
322 {
323   Mat A;
324 
325   PetscFunctionBegin;
326   PetscCall(MatShellGetContext(N, &A));
327   PetscCall(MatGetDiagonal(A, v));
328   PetscFunctionReturn(PETSC_SUCCESS);
329 }
330 
331 static PetscErrorCode MatCopy_Transpose(Mat A, Mat B, MatStructure str)
332 {
333   Mat a, b;
334 
335   PetscFunctionBegin;
336   PetscCall(MatShellGetContext(A, &a));
337   PetscCall(MatShellGetContext(B, &b));
338   PetscCall(MatCopy(a, b, str));
339   PetscFunctionReturn(PETSC_SUCCESS);
340 }
341 
342 static PetscErrorCode MatConvert_Transpose(Mat N, MatType newtype, MatReuse reuse, Mat *newmat)
343 {
344   Mat         A;
345   PetscScalar vscale = 1.0, vshift = 0.0;
346   PetscBool   flg;
347 
348   PetscFunctionBegin;
349   PetscCall(MatShellGetContext(N, &A));
350   PetscCall(MatHasOperation(A, MATOP_TRANSPOSE, &flg));
351   if (flg || N->ops->getrow) { /* if this condition is false, MatConvert_Shell() will be called in MatConvert_Basic(), so the following checks are not needed */
352     PetscCheck(!((Mat_Shell *)N->data)->zrows && !((Mat_Shell *)N->data)->zcols, PetscObjectComm((PetscObject)N), PETSC_ERR_SUP, "Cannot call MatConvert() if MatZeroRows() or MatZeroRowsColumns() has been called on the input Mat");
353     PetscCheck(!((Mat_Shell *)N->data)->axpy, PetscObjectComm((PetscObject)N), PETSC_ERR_SUP, "Cannot call MatConvert() if MatAXPY() has been called on the input Mat");
354     PetscCheck(!((Mat_Shell *)N->data)->left && !((Mat_Shell *)N->data)->right, PetscObjectComm((PetscObject)N), PETSC_ERR_SUP, "Cannot call MatConvert() if MatDiagonalScale() has been called on the input Mat");
355     PetscCheck(!((Mat_Shell *)N->data)->dshift, PetscObjectComm((PetscObject)N), PETSC_ERR_SUP, "Cannot call MatConvert() if MatDiagonalSet() has been called on the input Mat");
356     vscale = ((Mat_Shell *)N->data)->vscale;
357     vshift = ((Mat_Shell *)N->data)->vshift;
358   }
359   if (flg) {
360     Mat B;
361 
362     PetscCall(MatTranspose(A, MAT_INITIAL_MATRIX, &B));
363     if (reuse != MAT_INPLACE_MATRIX) {
364       PetscCall(MatConvert(B, newtype, reuse, newmat));
365       PetscCall(MatDestroy(&B));
366     } else {
367       PetscCall(MatConvert(B, newtype, MAT_INPLACE_MATRIX, &B));
368       PetscCall(MatHeaderReplace(N, &B));
369     }
370   } else { /* use basic converter as fallback */
371     flg = (PetscBool)(N->ops->getrow != NULL);
372     PetscCall(MatConvert_Basic(N, newtype, reuse, newmat));
373   }
374   if (flg) {
375     PetscCall(MatScale(*newmat, vscale));
376     PetscCall(MatShift(*newmat, vshift));
377   }
378   PetscFunctionReturn(PETSC_SUCCESS);
379 }
380 
381 static PetscErrorCode MatTransposeGetMat_Transpose(Mat N, Mat *M)
382 {
383   PetscFunctionBegin;
384   PetscCall(MatShellGetContext(N, M));
385   PetscFunctionReturn(PETSC_SUCCESS);
386 }
387 
388 /*@
389   MatTransposeGetMat - Gets the `Mat` object stored inside a `MATTRANSPOSEVIRTUAL`
390 
391   Logically Collective
392 
393   Input Parameter:
394 . A - the `MATTRANSPOSEVIRTUAL` matrix
395 
396   Output Parameter:
397 . M - the matrix object stored inside `A`
398 
399   Level: intermediate
400 
401 .seealso: [](ch_matrices), `Mat`, `MATTRANSPOSEVIRTUAL`, `MatCreateTranspose()`
402 @*/
403 PetscErrorCode MatTransposeGetMat(Mat A, Mat *M)
404 {
405   PetscFunctionBegin;
406   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
407   PetscValidType(A, 1);
408   PetscAssertPointer(M, 2);
409   PetscUseMethod(A, "MatTransposeGetMat_C", (Mat, Mat *), (A, M));
410   PetscFunctionReturn(PETSC_SUCCESS);
411 }
412 
413 /*MC
414    MATTRANSPOSEVIRTUAL - "transpose" - A matrix type that represents a virtual transpose of a matrix
415 
416   Level: advanced
417 
418   Developer Notes:
419   This is implemented on top of `MATSHELL` to get support for scaling and shifting without requiring duplicate code
420 
421   Users can not call `MatShellSetOperation()` operations on this class, there is some error checking for that incorrect usage
422 
423 .seealso: [](ch_matrices), `Mat`, `MATHERMITIANTRANSPOSEVIRTUAL`, `Mat`, `MatCreateHermitianTranspose()`, `MatCreateTranspose()`,
424           `MATNORMALHERMITIAN`, `MATNORMAL`
425 M*/
426 
427 /*@
428   MatCreateTranspose - Creates a new matrix `MATTRANSPOSEVIRTUAL` object that behaves like A'
429 
430   Collective
431 
432   Input Parameter:
433 . A - the (possibly rectangular) matrix
434 
435   Output Parameter:
436 . N - the matrix that represents A'
437 
438   Level: intermediate
439 
440   Note:
441   The transpose A' is NOT actually formed! Rather the new matrix
442   object performs the matrix-vector product by using the `MatMultTranspose()` on
443   the original matrix
444 
445 .seealso: [](ch_matrices), `Mat`, `MATTRANSPOSEVIRTUAL`, `MatCreateNormal()`, `MatMult()`, `MatMultTranspose()`, `MatCreate()`,
446           `MATNORMALHERMITIAN`
447 @*/
448 PetscErrorCode MatCreateTranspose(Mat A, Mat *N)
449 {
450   VecType vtype;
451 
452   PetscFunctionBegin;
453   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), N));
454   PetscCall(PetscLayoutReference(A->rmap, &((*N)->cmap)));
455   PetscCall(PetscLayoutReference(A->cmap, &((*N)->rmap)));
456   PetscCall(MatSetType(*N, MATSHELL));
457   PetscCall(MatShellSetContext(*N, A));
458   PetscCall(PetscObjectReference((PetscObject)A));
459 
460   PetscCall(MatSetBlockSizes(*N, PetscAbs(A->cmap->bs), PetscAbs(A->rmap->bs)));
461   PetscCall(MatGetVecType(A, &vtype));
462   PetscCall(MatSetVecType(*N, vtype));
463 #if defined(PETSC_HAVE_DEVICE)
464   PetscCall(MatBindToCPU(*N, A->boundtocpu));
465 #endif
466   PetscCall(MatSetUp(*N));
467 
468   PetscCall(MatShellSetOperation(*N, MATOP_DESTROY, (void (*)(void))MatDestroy_Transpose));
469   PetscCall(MatShellSetOperation(*N, MATOP_MULT, (void (*)(void))MatMult_Transpose));
470   PetscCall(MatShellSetOperation(*N, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Transpose));
471   PetscCall(MatShellSetOperation(*N, MATOP_LUFACTOR, (void (*)(void))MatLUFactor_Transpose));
472   PetscCall(MatShellSetOperation(*N, MATOP_CHOLESKYFACTOR, (void (*)(void))MatCholeskyFactor_Transpose));
473   PetscCall(MatShellSetOperation(*N, MATOP_DUPLICATE, (void (*)(void))MatDuplicate_Transpose));
474   PetscCall(MatShellSetOperation(*N, MATOP_HAS_OPERATION, (void (*)(void))MatHasOperation_Transpose));
475   PetscCall(MatShellSetOperation(*N, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Transpose));
476   PetscCall(MatShellSetOperation(*N, MATOP_COPY, (void (*)(void))MatCopy_Transpose));
477   PetscCall(MatShellSetOperation(*N, MATOP_CONVERT, (void (*)(void))MatConvert_Transpose));
478 
479   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatTransposeGetMat_C", MatTransposeGetMat_Transpose));
480   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_Transpose));
481   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatShellSetContext_C", MatShellSetContext_Immutable));
482   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatShellSetContextDestroy_C", MatShellSetContextDestroy_Immutable));
483   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatShellSetManageScalingShifts_C", MatShellSetManageScalingShifts_Immutable));
484   PetscCall(PetscObjectChangeTypeName((PetscObject)*N, MATTRANSPOSEVIRTUAL));
485   PetscFunctionReturn(PETSC_SUCCESS);
486 }
487