xref: /petsc/src/mat/impls/transpose/htransm.c (revision 2b938fe496202fa1a08b5a98b06ece7455dbcdc4)
1 #include <../src/mat/impls/shell/shell.h> /*I "petscmat.h" I*/
2 
3 PETSC_INTERN PetscErrorCode MatProductSetFromOptions_HT(Mat D)
4 {
5   Mat            A, B, C, Ain, Bin, Cin;
6   PetscBool      Aistrans, Bistrans, Cistrans;
7   PetscInt       Atrans, Btrans, Ctrans;
8   MatProductType ptype;
9 
10   PetscFunctionBegin;
11   MatCheckProduct(D, 1);
12   A = D->product->A;
13   B = D->product->B;
14   C = D->product->C;
15   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATHERMITIANTRANSPOSEVIRTUAL, &Aistrans));
16   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATHERMITIANTRANSPOSEVIRTUAL, &Bistrans));
17   PetscCall(PetscObjectTypeCompare((PetscObject)C, MATHERMITIANTRANSPOSEVIRTUAL, &Cistrans));
18   PetscCheck(Aistrans || Bistrans || Cistrans, PetscObjectComm((PetscObject)D), PETSC_ERR_PLIB, "This should not happen");
19   Atrans = 0;
20   Ain    = A;
21   while (Aistrans) {
22     Atrans++;
23     PetscCall(MatHermitianTransposeGetMat(Ain, &Ain));
24     PetscCall(PetscObjectTypeCompare((PetscObject)Ain, MATHERMITIANTRANSPOSEVIRTUAL, &Aistrans));
25   }
26   Btrans = 0;
27   Bin    = B;
28   while (Bistrans) {
29     Btrans++;
30     PetscCall(MatHermitianTransposeGetMat(Bin, &Bin));
31     PetscCall(PetscObjectTypeCompare((PetscObject)Bin, MATHERMITIANTRANSPOSEVIRTUAL, &Bistrans));
32   }
33   Ctrans = 0;
34   Cin    = C;
35   while (Cistrans) {
36     Ctrans++;
37     PetscCall(MatHermitianTransposeGetMat(Cin, &Cin));
38     PetscCall(PetscObjectTypeCompare((PetscObject)Cin, MATHERMITIANTRANSPOSEVIRTUAL, &Cistrans));
39   }
40   Atrans = Atrans % 2;
41   Btrans = Btrans % 2;
42   Ctrans = Ctrans % 2;
43   ptype  = D->product->type; /* same product type by default */
44   if (Ain->symmetric == PETSC_BOOL3_TRUE) Atrans = 0;
45   if (Bin->symmetric == PETSC_BOOL3_TRUE) Btrans = 0;
46   if (Cin && Cin->symmetric == PETSC_BOOL3_TRUE) Ctrans = 0;
47 
48   if (Atrans || Btrans || Ctrans) {
49     PetscCheck(!PetscDefined(USE_COMPLEX), PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "No support for complex Hermitian transpose matrices");
50     ptype = MATPRODUCT_UNSPECIFIED;
51     switch (D->product->type) {
52     case MATPRODUCT_AB:
53       if (Atrans && Btrans) { /* At * Bt we do not have support for this */
54         /* TODO custom implementation ? */
55       } else if (Atrans) { /* At * B */
56         ptype = MATPRODUCT_AtB;
57       } else { /* A * Bt */
58         ptype = MATPRODUCT_ABt;
59       }
60       break;
61     case MATPRODUCT_AtB:
62       if (Atrans && Btrans) { /* A * Bt */
63         ptype = MATPRODUCT_ABt;
64       } else if (Atrans) { /* A * B */
65         ptype = MATPRODUCT_AB;
66       } else { /* At * Bt we do not have support for this */
67         /* TODO custom implementation ? */
68       }
69       break;
70     case MATPRODUCT_ABt:
71       if (Atrans && Btrans) { /* At * B */
72         ptype = MATPRODUCT_AtB;
73       } else if (Atrans) { /* At * Bt we do not have support for this */
74         /* TODO custom implementation ? */
75       } else { /* A * B */
76         ptype = MATPRODUCT_AB;
77       }
78       break;
79     case MATPRODUCT_PtAP:
80       if (Atrans) { /* PtAtP */
81         /* TODO custom implementation ? */
82       } else { /* RARt */
83         ptype = MATPRODUCT_RARt;
84       }
85       break;
86     case MATPRODUCT_RARt:
87       if (Atrans) { /* RAtRt */
88         /* TODO custom implementation ? */
89       } else { /* PtAP */
90         ptype = MATPRODUCT_PtAP;
91       }
92       break;
93     case MATPRODUCT_ABC:
94       /* TODO custom implementation ? */
95       break;
96     default:
97       SETERRQ(PetscObjectComm((PetscObject)D), PETSC_ERR_SUP, "ProductType %s is not supported", MatProductTypes[D->product->type]);
98     }
99   }
100   PetscCall(MatProductReplaceMats(Ain, Bin, Cin, D));
101   PetscCall(MatProductSetType(D, ptype));
102   PetscCall(MatProductSetFromOptions(D));
103   PetscFunctionReturn(PETSC_SUCCESS);
104 }
105 
106 static PetscErrorCode MatMult_HT(Mat N, Vec x, Vec y)
107 {
108   Mat A;
109 
110   PetscFunctionBegin;
111   PetscCall(MatShellGetContext(N, &A));
112   PetscCall(MatMultHermitianTranspose(A, x, y));
113   PetscFunctionReturn(PETSC_SUCCESS);
114 }
115 
116 static PetscErrorCode MatMultHermitianTranspose_HT(Mat N, Vec x, Vec y)
117 {
118   Mat A;
119 
120   PetscFunctionBegin;
121   PetscCall(MatShellGetContext(N, &A));
122   PetscCall(MatMult(A, x, y));
123   PetscFunctionReturn(PETSC_SUCCESS);
124 }
125 
126 static PetscErrorCode MatSolve_HT_LU(Mat N, Vec b, Vec x)
127 {
128   Mat A;
129   Vec w;
130 
131   PetscFunctionBegin;
132   PetscCall(MatShellGetContext(N, &A));
133   PetscCall(VecDuplicate(b, &w));
134   PetscCall(VecCopy(b, w));
135   PetscCall(VecConjugate(w));
136   PetscCall(MatSolveTranspose(A, w, x));
137   PetscCall(VecConjugate(x));
138   PetscCall(VecDestroy(&w));
139   PetscFunctionReturn(PETSC_SUCCESS);
140 }
141 
142 static PetscErrorCode MatSolveAdd_HT_LU(Mat N, Vec b, Vec y, Vec x)
143 {
144   Mat A;
145   Vec v, w;
146 
147   PetscFunctionBegin;
148   PetscCall(MatShellGetContext(N, &A));
149   PetscCall(VecDuplicate(b, &v));
150   PetscCall(VecDuplicate(b, &w));
151   PetscCall(VecCopy(y, v));
152   PetscCall(VecCopy(b, w));
153   PetscCall(VecConjugate(v));
154   PetscCall(VecConjugate(w));
155   PetscCall(MatSolveTransposeAdd(A, w, v, x));
156   PetscCall(VecConjugate(x));
157   PetscCall(VecDestroy(&v));
158   PetscCall(VecDestroy(&w));
159   PetscFunctionReturn(PETSC_SUCCESS);
160 }
161 
162 static PetscErrorCode MatMatSolve_HT_LU(Mat N, Mat B, Mat X)
163 {
164   Mat A, W;
165 
166   PetscFunctionBegin;
167   PetscCall(MatShellGetContext(N, &A));
168   PetscCall(MatDuplicate(B, MAT_COPY_VALUES, &W));
169   PetscCall(MatConjugate(W));
170   PetscCall(MatMatSolveTranspose(A, W, X));
171   PetscCall(MatConjugate(X));
172   PetscCall(MatDestroy(&W));
173   PetscFunctionReturn(PETSC_SUCCESS);
174 }
175 
176 static PetscErrorCode MatLUFactor_HT(Mat N, IS row, IS col, const MatFactorInfo *minfo)
177 {
178   Mat A;
179 
180   PetscFunctionBegin;
181   PetscCall(MatShellGetContext(N, &A));
182   PetscCall(MatLUFactor(A, col, row, minfo));
183   PetscCall(MatShellSetOperation(N, MATOP_SOLVE, (void (*)(void))MatSolve_HT_LU));
184   PetscCall(MatShellSetOperation(N, MATOP_SOLVE_ADD, (void (*)(void))MatSolveAdd_HT_LU));
185   PetscCall(MatShellSetOperation(N, MATOP_MAT_SOLVE, (void (*)(void))MatMatSolve_HT_LU));
186   PetscFunctionReturn(PETSC_SUCCESS);
187 }
188 
189 static PetscErrorCode MatSolve_HT_Cholesky(Mat N, Vec b, Vec x)
190 {
191   Mat A;
192 
193   PetscFunctionBegin;
194   PetscCall(MatShellGetContext(N, &A));
195   PetscCall(MatSolve(A, b, x));
196   PetscFunctionReturn(PETSC_SUCCESS);
197 }
198 
199 static PetscErrorCode MatSolveAdd_HT_Cholesky(Mat N, Vec b, Vec y, Vec x)
200 {
201   Mat A;
202   Vec v, w;
203 
204   PetscFunctionBegin;
205   PetscCall(MatShellGetContext(N, &A));
206   PetscCall(VecDuplicate(b, &v));
207   PetscCall(VecDuplicate(b, &w));
208   PetscCall(VecCopy(y, v));
209   PetscCall(VecCopy(b, w));
210   PetscCall(VecConjugate(v));
211   PetscCall(VecConjugate(w));
212   PetscCall(MatSolveTransposeAdd(A, w, v, x));
213   PetscCall(VecConjugate(x));
214   PetscCall(VecDestroy(&v));
215   PetscCall(VecDestroy(&w));
216   PetscFunctionReturn(PETSC_SUCCESS);
217 }
218 
219 static PetscErrorCode MatMatSolve_HT_Cholesky(Mat N, Mat B, Mat X)
220 {
221   Mat A, W;
222 
223   PetscFunctionBegin;
224   PetscCall(MatShellGetContext(N, &A));
225   PetscCall(MatDuplicate(B, MAT_COPY_VALUES, &W));
226   PetscCall(MatConjugate(W));
227   PetscCall(MatMatSolveTranspose(A, W, X));
228   PetscCall(MatConjugate(X));
229   PetscCall(MatDestroy(&W));
230   PetscFunctionReturn(PETSC_SUCCESS);
231 }
232 
233 static PetscErrorCode MatCholeskyFactor_HT(Mat N, IS perm, const MatFactorInfo *minfo)
234 {
235   Mat A;
236 
237   PetscFunctionBegin;
238   PetscCall(MatShellGetContext(N, &A));
239   PetscCheck(!PetscDefined(USE_COMPLEX) || A->hermitian == PETSC_BOOL3_TRUE, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Cholesky supported only if original matrix is Hermitian");
240   PetscCall(MatCholeskyFactor(A, perm, minfo));
241   PetscCall(MatShellSetOperation(N, MATOP_SOLVE, (void (*)(void))MatSolve_HT_Cholesky));
242   PetscCall(MatShellSetOperation(N, MATOP_SOLVE_ADD, (void (*)(void))MatSolveAdd_HT_Cholesky));
243   PetscCall(MatShellSetOperation(N, MATOP_MAT_SOLVE, (void (*)(void))MatMatSolve_HT_Cholesky));
244   PetscFunctionReturn(PETSC_SUCCESS);
245 }
246 
247 static PetscErrorCode MatDestroy_HT(Mat N)
248 {
249   Mat A;
250 
251   PetscFunctionBegin;
252   PetscCall(MatShellGetContext(N, &A));
253   PetscCall(MatDestroy(&A));
254   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatHermitianTransposeGetMat_C", NULL));
255 #if !defined(PETSC_USE_COMPLEX)
256   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatTransposeGetMat_C", NULL));
257 #endif
258   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatProductSetFromOptions_anytype_C", NULL));
259   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatShellSetContext_C", NULL));
260   PetscFunctionReturn(PETSC_SUCCESS);
261 }
262 
263 static PetscErrorCode MatDuplicate_HT(Mat N, MatDuplicateOption op, Mat *m)
264 {
265   Mat A, C;
266 
267   PetscFunctionBegin;
268   PetscCall(MatShellGetContext(N, &A));
269   PetscCall(MatDuplicate(A, op, &C));
270   PetscCall(MatCreateHermitianTranspose(C, m));
271   if (op == MAT_COPY_VALUES) {
272     PetscCall(MatCopy(N, *m, SAME_NONZERO_PATTERN));
273     PetscCall(MatPropagateSymmetryOptions(A, C));
274   }
275   PetscCall(MatDestroy(&C));
276   PetscFunctionReturn(PETSC_SUCCESS);
277 }
278 
279 static PetscErrorCode MatHasOperation_HT(Mat mat, MatOperation op, PetscBool *has)
280 {
281   Mat A;
282 
283   PetscFunctionBegin;
284   PetscCall(MatShellGetContext(mat, &A));
285   *has = PETSC_FALSE;
286   if (op == MATOP_MULT || op == MATOP_MULT_ADD) {
287     PetscCall(MatHasOperation(A, MATOP_MULT_HERMITIAN_TRANSPOSE, has));
288     if (!*has) PetscCall(MatHasOperation(A, MATOP_MULT_TRANSPOSE, has));
289   } else if (op == MATOP_MULT_HERMITIAN_TRANSPOSE || op == MATOP_MULT_HERMITIAN_TRANS_ADD || op == MATOP_MULT_TRANSPOSE || op == MATOP_MULT_TRANSPOSE_ADD) {
290     PetscCall(MatHasOperation(A, MATOP_MULT, has));
291   } else if (((void **)mat->ops)[op]) *has = PETSC_TRUE;
292   PetscFunctionReturn(PETSC_SUCCESS);
293 }
294 
295 static PetscErrorCode MatHermitianTransposeGetMat_HT(Mat N, Mat *M)
296 {
297   PetscFunctionBegin;
298   PetscCall(MatShellGetContext(N, M));
299   PetscFunctionReturn(PETSC_SUCCESS);
300 }
301 
302 /*@
303   MatHermitianTransposeGetMat - Gets the `Mat` object stored inside a `MATHERMITIANTRANSPOSEVIRTUAL`
304 
305   Logically Collective
306 
307   Input Parameter:
308 . A - the `MATHERMITIANTRANSPOSEVIRTUAL` matrix
309 
310   Output Parameter:
311 . M - the matrix object stored inside A
312 
313   Level: intermediate
314 
315 .seealso: [](ch_matrices), `Mat`, `MATHERMITIANTRANSPOSEVIRTUAL`, `MatCreateHermitianTranspose()`
316 @*/
317 PetscErrorCode MatHermitianTransposeGetMat(Mat A, Mat *M)
318 {
319   PetscFunctionBegin;
320   PetscValidHeaderSpecific(A, MAT_CLASSID, 1);
321   PetscValidType(A, 1);
322   PetscAssertPointer(M, 2);
323   PetscUseMethod(A, "MatHermitianTransposeGetMat_C", (Mat, Mat *), (A, M));
324   PetscFunctionReturn(PETSC_SUCCESS);
325 }
326 
327 static PetscErrorCode MatGetDiagonal_HT(Mat N, Vec v)
328 {
329   Mat A;
330 
331   PetscFunctionBegin;
332   PetscCall(MatShellGetContext(N, &A));
333   PetscCall(MatGetDiagonal(A, v));
334   PetscCall(VecConjugate(v));
335   PetscFunctionReturn(PETSC_SUCCESS);
336 }
337 
338 static PetscErrorCode MatCopy_HT(Mat A, Mat B, MatStructure str)
339 {
340   Mat a, b;
341 
342   PetscFunctionBegin;
343   PetscCall(MatShellGetContext(A, &a));
344   PetscCall(MatShellGetContext(B, &b));
345   PetscCall(MatCopy(a, b, str));
346   PetscFunctionReturn(PETSC_SUCCESS);
347 }
348 
349 static PetscErrorCode MatConvert_HT(Mat N, MatType newtype, MatReuse reuse, Mat *newmat)
350 {
351   Mat         A;
352   PetscScalar vscale = 1.0, vshift = 0.0;
353   PetscBool   flg;
354 
355   PetscFunctionBegin;
356   PetscCall(MatShellGetContext(N, &A));
357   PetscCall(MatHasOperation(A, MATOP_HERMITIAN_TRANSPOSE, &flg));
358   if (flg || N->ops->getrow) { /* if this condition is false, MatConvert_Shell() will be called in MatConvert_Basic(), so the following checks are not needed */
359     PetscCheck(!((Mat_Shell *)N->data)->zrows && !((Mat_Shell *)N->data)->zcols, PetscObjectComm((PetscObject)N), PETSC_ERR_SUP, "Cannot call MatConvert() if MatZeroRows() or MatZeroRowsColumns() has been called on the input Mat");
360     PetscCheck(!((Mat_Shell *)N->data)->axpy, PetscObjectComm((PetscObject)N), PETSC_ERR_SUP, "Cannot call MatConvert() if MatAXPY() has been called on the input Mat");
361     PetscCheck(!((Mat_Shell *)N->data)->left && !((Mat_Shell *)N->data)->right, PetscObjectComm((PetscObject)N), PETSC_ERR_SUP, "Cannot call MatConvert() if MatDiagonalScale() has been called on the input Mat");
362     PetscCheck(!((Mat_Shell *)N->data)->dshift, PetscObjectComm((PetscObject)N), PETSC_ERR_SUP, "Cannot call MatConvert() if MatDiagonalSet() has been called on the input Mat");
363     vscale = ((Mat_Shell *)N->data)->vscale;
364     vshift = ((Mat_Shell *)N->data)->vshift;
365   }
366   if (flg) {
367     Mat B;
368 
369     PetscCall(MatHermitianTranspose(A, MAT_INITIAL_MATRIX, &B));
370     if (reuse != MAT_INPLACE_MATRIX) {
371       PetscCall(MatConvert(B, newtype, reuse, newmat));
372       PetscCall(MatDestroy(&B));
373     } else {
374       PetscCall(MatConvert(B, newtype, MAT_INPLACE_MATRIX, &B));
375       PetscCall(MatHeaderReplace(N, &B));
376     }
377   } else { /* use basic converter as fallback */
378     flg = (PetscBool)(N->ops->getrow != NULL);
379     PetscCall(MatConvert_Basic(N, newtype, reuse, newmat));
380   }
381   if (flg) {
382     PetscCall(MatScale(*newmat, vscale));
383     PetscCall(MatShift(*newmat, vshift));
384   }
385   PetscFunctionReturn(PETSC_SUCCESS);
386 }
387 
388 /*MC
389    MATHERMITIANTRANSPOSEVIRTUAL - "hermitiantranspose" - A matrix type that represents a virtual transpose of a matrix
390 
391   Level: advanced
392 
393   Developer Notes:
394   This is implemented on top of `MATSHELL` to get support for scaling and shifting without requiring duplicate code
395 
396   Users can not call `MatShellSetOperation()` operations on this class, there is some error checking for that incorrect usage
397 
398 .seealso: [](ch_matrices), `Mat`, `MATTRANSPOSEVIRTUAL`, `Mat`, `MatCreateHermitianTranspose()`, `MatCreateTranspose()`
399 M*/
400 
401 /*@
402   MatCreateHermitianTranspose - Creates a new matrix object of `MatType` `MATHERMITIANTRANSPOSEVIRTUAL` that behaves like A'*
403 
404   Collective
405 
406   Input Parameter:
407 . A - the (possibly rectangular) matrix
408 
409   Output Parameter:
410 . N - the matrix that represents A'*
411 
412   Level: intermediate
413 
414   Note:
415   The Hermitian transpose A' is NOT actually formed! Rather the new matrix
416   object performs the matrix-vector product, `MatMult()`, by using the `MatMultHermitianTranspose()` on
417   the original matrix
418 
419 .seealso: [](ch_matrices), `Mat`, `MatCreateNormal()`, `MatMult()`, `MatMultHermitianTranspose()`, `MatCreate()`,
420           `MATTRANSPOSEVIRTUAL`, `MatCreateTranspose()`, `MatHermitianTransposeGetMat()`, `MATNORMAL`, `MATNORMALHERMITIAN`
421 @*/
422 PetscErrorCode MatCreateHermitianTranspose(Mat A, Mat *N)
423 {
424   VecType vtype;
425 
426   PetscFunctionBegin;
427   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), N));
428   PetscCall(PetscLayoutReference(A->rmap, &((*N)->cmap)));
429   PetscCall(PetscLayoutReference(A->cmap, &((*N)->rmap)));
430   PetscCall(MatSetType(*N, MATSHELL));
431   PetscCall(MatShellSetContext(*N, A));
432   PetscCall(PetscObjectReference((PetscObject)A));
433 
434   PetscCall(MatSetBlockSizes(*N, PetscAbs(A->cmap->bs), PetscAbs(A->rmap->bs)));
435   PetscCall(MatGetVecType(A, &vtype));
436   PetscCall(MatSetVecType(*N, vtype));
437 #if defined(PETSC_HAVE_DEVICE)
438   PetscCall(MatBindToCPU(*N, A->boundtocpu));
439 #endif
440   PetscCall(MatSetUp(*N));
441 
442   PetscCall(MatShellSetOperation(*N, MATOP_DESTROY, (void (*)(void))MatDestroy_HT));
443   PetscCall(MatShellSetOperation(*N, MATOP_MULT, (void (*)(void))MatMult_HT));
444   PetscCall(MatShellSetOperation(*N, MATOP_MULT_HERMITIAN_TRANSPOSE, (void (*)(void))MatMultHermitianTranspose_HT));
445 #if !defined(PETSC_USE_COMPLEX)
446   PetscCall(MatShellSetOperation(*N, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultHermitianTranspose_HT));
447 #endif
448   PetscCall(MatShellSetOperation(*N, MATOP_LUFACTOR, (void (*)(void))MatLUFactor_HT));
449   PetscCall(MatShellSetOperation(*N, MATOP_CHOLESKYFACTOR, (void (*)(void))MatCholeskyFactor_HT));
450   PetscCall(MatShellSetOperation(*N, MATOP_DUPLICATE, (void (*)(void))MatDuplicate_HT));
451   PetscCall(MatShellSetOperation(*N, MATOP_HAS_OPERATION, (void (*)(void))MatHasOperation_HT));
452   PetscCall(MatShellSetOperation(*N, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_HT));
453   PetscCall(MatShellSetOperation(*N, MATOP_COPY, (void (*)(void))MatCopy_HT));
454   PetscCall(MatShellSetOperation(*N, MATOP_CONVERT, (void (*)(void))MatConvert_HT));
455 
456   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatHermitianTransposeGetMat_C", MatHermitianTransposeGetMat_HT));
457 #if !defined(PETSC_USE_COMPLEX)
458   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatTransposeGetMat_C", MatHermitianTransposeGetMat_HT));
459 #endif
460   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_HT));
461   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatShellSetContext_C", MatShellSetContext_Immutable));
462   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatShellSetContextDestroy_C", MatShellSetContextDestroy_Immutable));
463   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatShellSetManageScalingShifts_C", MatShellSetManageScalingShifts_Immutable));
464   PetscCall(PetscObjectChangeTypeName((PetscObject)*N, MATHERMITIANTRANSPOSEVIRTUAL));
465   PetscFunctionReturn(PETSC_SUCCESS);
466 }
467