xref: /petsc/src/ksp/pc/impls/vpbjacobi/vpbjacobi.c (revision c87f018dd60dd7cbbffff5f1ec4e5d075ae0fc6f)
1 #include <../src/ksp/pc/impls/vpbjacobi/vpbjacobi.h>
2 
3 static PetscErrorCode PCApply_VPBJacobi(PC pc, Vec x, Vec y)
4 {
5   PC_VPBJacobi      *jac = (PC_VPBJacobi *)pc->data;
6   PetscInt           i, ncnt = 0;
7   const MatScalar   *diag = jac->diag;
8   PetscInt           ib, jb, bs;
9   const PetscScalar *xx;
10   PetscScalar       *yy, x0, x1, x2, x3, x4, x5, x6;
11   PetscInt           nblocks;
12   const PetscInt    *bsizes;
13 
14   PetscFunctionBegin;
15   PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes));
16   PetscCall(VecGetArrayRead(x, &xx));
17   PetscCall(VecGetArray(y, &yy));
18   for (i = 0; i < nblocks; i++) {
19     bs = bsizes[i];
20     switch (bs) {
21     case 1:
22       yy[ncnt] = *diag * xx[ncnt];
23       break;
24     case 2:
25       x0           = xx[ncnt];
26       x1           = xx[ncnt + 1];
27       yy[ncnt]     = diag[0] * x0 + diag[2] * x1;
28       yy[ncnt + 1] = diag[1] * x0 + diag[3] * x1;
29       break;
30     case 3:
31       x0           = xx[ncnt];
32       x1           = xx[ncnt + 1];
33       x2           = xx[ncnt + 2];
34       yy[ncnt]     = diag[0] * x0 + diag[3] * x1 + diag[6] * x2;
35       yy[ncnt + 1] = diag[1] * x0 + diag[4] * x1 + diag[7] * x2;
36       yy[ncnt + 2] = diag[2] * x0 + diag[5] * x1 + diag[8] * x2;
37       break;
38     case 4:
39       x0           = xx[ncnt];
40       x1           = xx[ncnt + 1];
41       x2           = xx[ncnt + 2];
42       x3           = xx[ncnt + 3];
43       yy[ncnt]     = diag[0] * x0 + diag[4] * x1 + diag[8] * x2 + diag[12] * x3;
44       yy[ncnt + 1] = diag[1] * x0 + diag[5] * x1 + diag[9] * x2 + diag[13] * x3;
45       yy[ncnt + 2] = diag[2] * x0 + diag[6] * x1 + diag[10] * x2 + diag[14] * x3;
46       yy[ncnt + 3] = diag[3] * x0 + diag[7] * x1 + diag[11] * x2 + diag[15] * x3;
47       break;
48     case 5:
49       x0           = xx[ncnt];
50       x1           = xx[ncnt + 1];
51       x2           = xx[ncnt + 2];
52       x3           = xx[ncnt + 3];
53       x4           = xx[ncnt + 4];
54       yy[ncnt]     = diag[0] * x0 + diag[5] * x1 + diag[10] * x2 + diag[15] * x3 + diag[20] * x4;
55       yy[ncnt + 1] = diag[1] * x0 + diag[6] * x1 + diag[11] * x2 + diag[16] * x3 + diag[21] * x4;
56       yy[ncnt + 2] = diag[2] * x0 + diag[7] * x1 + diag[12] * x2 + diag[17] * x3 + diag[22] * x4;
57       yy[ncnt + 3] = diag[3] * x0 + diag[8] * x1 + diag[13] * x2 + diag[18] * x3 + diag[23] * x4;
58       yy[ncnt + 4] = diag[4] * x0 + diag[9] * x1 + diag[14] * x2 + diag[19] * x3 + diag[24] * x4;
59       break;
60     case 6:
61       x0           = xx[ncnt];
62       x1           = xx[ncnt + 1];
63       x2           = xx[ncnt + 2];
64       x3           = xx[ncnt + 3];
65       x4           = xx[ncnt + 4];
66       x5           = xx[ncnt + 5];
67       yy[ncnt]     = diag[0] * x0 + diag[6] * x1 + diag[12] * x2 + diag[18] * x3 + diag[24] * x4 + diag[30] * x5;
68       yy[ncnt + 1] = diag[1] * x0 + diag[7] * x1 + diag[13] * x2 + diag[19] * x3 + diag[25] * x4 + diag[31] * x5;
69       yy[ncnt + 2] = diag[2] * x0 + diag[8] * x1 + diag[14] * x2 + diag[20] * x3 + diag[26] * x4 + diag[32] * x5;
70       yy[ncnt + 3] = diag[3] * x0 + diag[9] * x1 + diag[15] * x2 + diag[21] * x3 + diag[27] * x4 + diag[33] * x5;
71       yy[ncnt + 4] = diag[4] * x0 + diag[10] * x1 + diag[16] * x2 + diag[22] * x3 + diag[28] * x4 + diag[34] * x5;
72       yy[ncnt + 5] = diag[5] * x0 + diag[11] * x1 + diag[17] * x2 + diag[23] * x3 + diag[29] * x4 + diag[35] * x5;
73       break;
74     case 7:
75       x0           = xx[ncnt];
76       x1           = xx[ncnt + 1];
77       x2           = xx[ncnt + 2];
78       x3           = xx[ncnt + 3];
79       x4           = xx[ncnt + 4];
80       x5           = xx[ncnt + 5];
81       x6           = xx[ncnt + 6];
82       yy[ncnt]     = diag[0] * x0 + diag[7] * x1 + diag[14] * x2 + diag[21] * x3 + diag[28] * x4 + diag[35] * x5 + diag[42] * x6;
83       yy[ncnt + 1] = diag[1] * x0 + diag[8] * x1 + diag[15] * x2 + diag[22] * x3 + diag[29] * x4 + diag[36] * x5 + diag[43] * x6;
84       yy[ncnt + 2] = diag[2] * x0 + diag[9] * x1 + diag[16] * x2 + diag[23] * x3 + diag[30] * x4 + diag[37] * x5 + diag[44] * x6;
85       yy[ncnt + 3] = diag[3] * x0 + diag[10] * x1 + diag[17] * x2 + diag[24] * x3 + diag[31] * x4 + diag[38] * x5 + diag[45] * x6;
86       yy[ncnt + 4] = diag[4] * x0 + diag[11] * x1 + diag[18] * x2 + diag[25] * x3 + diag[32] * x4 + diag[39] * x5 + diag[46] * x6;
87       yy[ncnt + 5] = diag[5] * x0 + diag[12] * x1 + diag[19] * x2 + diag[26] * x3 + diag[33] * x4 + diag[40] * x5 + diag[47] * x6;
88       yy[ncnt + 6] = diag[6] * x0 + diag[13] * x1 + diag[20] * x2 + diag[27] * x3 + diag[34] * x4 + diag[41] * x5 + diag[48] * x6;
89       break;
90     default:
91       for (ib = 0; ib < bs; ib++) {
92         PetscScalar rowsum = 0;
93         for (jb = 0; jb < bs; jb++) rowsum += diag[ib + jb * bs] * xx[ncnt + jb];
94         yy[ncnt + ib] = rowsum;
95       }
96     }
97     ncnt += bsizes[i];
98     diag += bsizes[i] * bsizes[i];
99   }
100   PetscCall(VecRestoreArrayRead(x, &xx));
101   PetscCall(VecRestoreArray(y, &yy));
102   PetscFunctionReturn(PETSC_SUCCESS);
103 }
104 
105 static PetscErrorCode PCApplyTranspose_VPBJacobi(PC pc, Vec x, Vec y)
106 {
107   PC_VPBJacobi      *jac = (PC_VPBJacobi *)pc->data;
108   PetscInt           i, ncnt = 0;
109   const MatScalar   *diag = jac->diag;
110   PetscInt           ib, jb, bs;
111   const PetscScalar *xx;
112   PetscScalar       *yy, x0, x1, x2, x3, x4, x5, x6;
113   PetscInt           nblocks;
114   const PetscInt    *bsizes;
115 
116   PetscFunctionBegin;
117   PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes));
118   PetscCall(VecGetArrayRead(x, &xx));
119   PetscCall(VecGetArray(y, &yy));
120   for (i = 0; i < nblocks; i++) {
121     bs = bsizes[i];
122     switch (bs) {
123     case 1:
124       yy[ncnt] = *diag * xx[ncnt];
125       break;
126     case 2:
127       x0           = xx[ncnt];
128       x1           = xx[ncnt + 1];
129       yy[ncnt]     = diag[0] * x0 + diag[1] * x1;
130       yy[ncnt + 1] = diag[2] * x0 + diag[3] * x1;
131       break;
132     case 3:
133       x0           = xx[ncnt];
134       x1           = xx[ncnt + 1];
135       x2           = xx[ncnt + 2];
136       yy[ncnt]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2;
137       yy[ncnt + 1] = diag[3] * x0 + diag[4] * x1 + diag[5] * x2;
138       yy[ncnt + 2] = diag[6] * x0 + diag[7] * x1 + diag[8] * x2;
139       break;
140     case 4:
141       x0           = xx[ncnt];
142       x1           = xx[ncnt + 1];
143       x2           = xx[ncnt + 2];
144       x3           = xx[ncnt + 3];
145       yy[ncnt]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3;
146       yy[ncnt + 1] = diag[4] * x0 + diag[5] * x1 + diag[6] * x2 + diag[7] * x3;
147       yy[ncnt + 2] = diag[8] * x0 + diag[9] * x1 + diag[10] * x2 + diag[11] * x3;
148       yy[ncnt + 3] = diag[12] * x0 + diag[13] * x1 + diag[14] * x2 + diag[15] * x3;
149       break;
150     case 5:
151       x0           = xx[ncnt];
152       x1           = xx[ncnt + 1];
153       x2           = xx[ncnt + 2];
154       x3           = xx[ncnt + 3];
155       x4           = xx[ncnt + 4];
156       yy[ncnt]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3 + diag[4] * x4;
157       yy[ncnt + 1] = diag[5] * x0 + diag[6] * x1 + diag[7] * x2 + diag[8] * x3 + diag[9] * x4;
158       yy[ncnt + 2] = diag[10] * x0 + diag[11] * x1 + diag[12] * x2 + diag[13] * x3 + diag[14] * x4;
159       yy[ncnt + 3] = diag[15] * x0 + diag[16] * x1 + diag[17] * x2 + diag[18] * x3 + diag[19] * x4;
160       yy[ncnt + 4] = diag[20] * x0 + diag[21] * x1 + diag[22] * x2 + diag[23] * x3 + diag[24] * x4;
161       break;
162     case 6:
163       x0           = xx[ncnt];
164       x1           = xx[ncnt + 1];
165       x2           = xx[ncnt + 2];
166       x3           = xx[ncnt + 3];
167       x4           = xx[ncnt + 4];
168       x5           = xx[ncnt + 5];
169       yy[ncnt]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3 + diag[4] * x4 + diag[5] * x5;
170       yy[ncnt + 1] = diag[6] * x0 + diag[7] * x1 + diag[8] * x2 + diag[9] * x3 + diag[10] * x4 + diag[11] * x5;
171       yy[ncnt + 2] = diag[12] * x0 + diag[13] * x1 + diag[14] * x2 + diag[15] * x3 + diag[16] * x4 + diag[17] * x5;
172       yy[ncnt + 3] = diag[18] * x0 + diag[19] * x1 + diag[20] * x2 + diag[21] * x3 + diag[22] * x4 + diag[23] * x5;
173       yy[ncnt + 4] = diag[24] * x0 + diag[25] * x1 + diag[26] * x2 + diag[27] * x3 + diag[28] * x4 + diag[29] * x5;
174       yy[ncnt + 5] = diag[30] * x0 + diag[31] * x1 + diag[32] * x2 + diag[33] * x3 + diag[34] * x4 + diag[35] * x5;
175       break;
176     case 7:
177       x0           = xx[ncnt];
178       x1           = xx[ncnt + 1];
179       x2           = xx[ncnt + 2];
180       x3           = xx[ncnt + 3];
181       x4           = xx[ncnt + 4];
182       x5           = xx[ncnt + 5];
183       x6           = xx[ncnt + 6];
184       yy[ncnt]     = diag[0] * x0 + diag[1] * x1 + diag[2] * x2 + diag[3] * x3 + diag[4] * x4 + diag[5] * x5 + diag[6] * x6;
185       yy[ncnt + 1] = diag[7] * x0 + diag[8] * x1 + diag[9] * x2 + diag[10] * x3 + diag[11] * x4 + diag[12] * x5 + diag[13] * x6;
186       yy[ncnt + 2] = diag[14] * x0 + diag[15] * x1 + diag[16] * x2 + diag[17] * x3 + diag[18] * x4 + diag[19] * x5 + diag[20] * x6;
187       yy[ncnt + 3] = diag[21] * x0 + diag[22] * x1 + diag[23] * x2 + diag[24] * x3 + diag[25] * x4 + diag[26] * x5 + diag[27] * x6;
188       yy[ncnt + 4] = diag[28] * x0 + diag[29] * x1 + diag[30] * x2 + diag[31] * x3 + diag[32] * x4 + diag[33] * x5 + diag[34] * x6;
189       yy[ncnt + 5] = diag[35] * x0 + diag[36] * x1 + diag[37] * x2 + diag[38] * x3 + diag[39] * x4 + diag[40] * x5 + diag[41] * x6;
190       yy[ncnt + 6] = diag[42] * x0 + diag[43] * x1 + diag[44] * x2 + diag[45] * x3 + diag[46] * x4 + diag[47] * x5 + diag[48] * x6;
191       break;
192     default:
193       for (ib = 0; ib < bs; ib++) {
194         PetscScalar rowsum = 0;
195         for (jb = 0; jb < bs; jb++) rowsum += diag[ib * bs + jb] * xx[ncnt + jb];
196         yy[ncnt + ib] = rowsum;
197       }
198     }
199     ncnt += bsizes[i];
200     diag += bsizes[i] * bsizes[i];
201   }
202   PetscCall(VecRestoreArrayRead(x, &xx));
203   PetscCall(VecRestoreArray(y, &yy));
204   PetscFunctionReturn(PETSC_SUCCESS);
205 }
206 
207 PETSC_INTERN PetscErrorCode PCSetUp_VPBJacobi_Host(PC pc)
208 {
209   PC_VPBJacobi   *jac = (PC_VPBJacobi *)pc->data;
210   Mat             A   = pc->pmat;
211   MatFactorError  err;
212   PetscInt        i, nsize = 0, nlocal;
213   PetscInt        nblocks;
214   const PetscInt *bsizes;
215 
216   PetscFunctionBegin;
217   PetscCall(MatGetVariableBlockSizes(pc->pmat, &nblocks, &bsizes));
218   PetscCall(MatGetLocalSize(pc->pmat, &nlocal, NULL));
219   PetscCheck(!nlocal || nblocks, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must call MatSetVariableBlockSizes() before using PCVPBJACOBI");
220   if (!jac->diag) {
221     PetscInt max_bs = -1, min_bs = PETSC_MAX_INT;
222     for (i = 0; i < nblocks; i++) {
223       min_bs = PetscMin(min_bs, bsizes[i]);
224       max_bs = PetscMax(max_bs, bsizes[i]);
225       nsize += bsizes[i] * bsizes[i];
226     }
227     PetscCall(PetscMalloc1(nsize, &jac->diag));
228     jac->nblocks = nblocks;
229     jac->min_bs  = min_bs;
230     jac->max_bs  = max_bs;
231   }
232   PetscCall(MatInvertVariableBlockDiagonal(A, nblocks, bsizes, jac->diag));
233   PetscCall(MatFactorGetError(A, &err));
234   if (err) pc->failedreason = (PCFailedReason)err;
235   pc->ops->apply          = PCApply_VPBJacobi;
236   pc->ops->applytranspose = PCApplyTranspose_VPBJacobi;
237   PetscFunctionReturn(PETSC_SUCCESS);
238 }
239 
240 static PetscErrorCode PCSetUp_VPBJacobi(PC pc)
241 {
242   PetscFunctionBegin;
243   /* In PCCreate_VPBJacobi() pmat might have not been set, so we wait to the last minute to do the dispatch */
244 #if defined(PETSC_HAVE_CUDA)
245   PetscBool isCuda;
246   PetscCall(PetscObjectTypeCompareAny((PetscObject)pc->pmat, &isCuda, MATSEQAIJCUSPARSE, MATMPIAIJCUSPARSE, ""));
247 #endif
248 #if defined(PETSC_HAVE_KOKKOS_KERNELS)
249   PetscBool isKok;
250   PetscCall(PetscObjectTypeCompareAny((PetscObject)pc->pmat, &isKok, MATSEQAIJKOKKOS, MATMPIAIJKOKKOS, ""));
251 #endif
252 
253 #if defined(PETSC_HAVE_CUDA)
254   if (isCuda) PetscCall(PCSetUp_VPBJacobi_CUDA(pc));
255   else
256 #endif
257 #if defined(PETSC_HAVE_KOKKOS_KERNELS)
258     if (isKok)
259     PetscCall(PCSetUp_VPBJacobi_Kokkos(pc));
260   else
261 #endif
262   {
263     PetscCall(PCSetUp_VPBJacobi_Host(pc));
264   }
265   PetscFunctionReturn(PETSC_SUCCESS);
266 }
267 
268 static PetscErrorCode PCView_VPBJacobi(PC pc, PetscViewer viewer)
269 {
270   PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;
271   PetscBool     iascii;
272 
273   PetscFunctionBegin;
274   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &iascii));
275   if (iascii) {
276     PetscCall(PetscViewerASCIIPrintf(viewer, "  number of blocks: %" PetscInt_FMT "\n", jac->nblocks));
277     PetscCall(PetscViewerASCIIPrintf(viewer, "  block sizes: min=%" PetscInt_FMT " max=%" PetscInt_FMT "\n", jac->min_bs, jac->max_bs));
278   }
279   PetscFunctionReturn(PETSC_SUCCESS);
280 }
281 
282 PETSC_INTERN PetscErrorCode PCDestroy_VPBJacobi(PC pc)
283 {
284   PC_VPBJacobi *jac = (PC_VPBJacobi *)pc->data;
285 
286   PetscFunctionBegin;
287   /*
288       Free the private data structure that was hanging off the PC
289   */
290   PetscCall(PetscFree(jac->diag));
291   PetscCall(PetscFree(pc->data));
292   PetscFunctionReturn(PETSC_SUCCESS);
293 }
294 
295 /*MC
296      PCVPBJACOBI - Variable size point block Jacobi preconditioner
297 
298    Level: beginner
299 
300    Notes:
301      See `PCJACOBI` for point Jacobi preconditioning, `PCPBJACOBI` for fixed point block size, and `PCBJACOBI` for large size blocks
302 
303      This works for `MATAIJ` matrices
304 
305      Uses dense LU factorization with partial pivoting to invert the blocks; if a zero pivot
306      is detected a PETSc error is generated.
307 
308      One must call `MatSetVariableBlockSizes()` to use this preconditioner
309 
310    Developer Notes:
311      This should support the `PCSetErrorIfFailure()` flag set to `PETSC_TRUE` to allow
312      the factorization to continue even after a zero pivot is found resulting in a Nan and hence
313      terminating `KSP` with a `KSP_DIVERGED_NANORINF` allowing
314      a nonlinear solver/ODE integrator to recover without stopping the program as currently happens.
315 
316      Perhaps should provide an option that allows generation of a valid preconditioner
317      even if a block is singular as the `PCJACOBI` does.
318 
319 .seealso: [](ch_ksp), `MatSetVariableBlockSizes()`, `PCCreate()`, `PCSetType()`, `PCType`, `PC`, `PCJACOBI`, `PCPBJACOBI`, `PCBJACOBI`
320 M*/
321 
322 PETSC_EXTERN PetscErrorCode PCCreate_VPBJacobi(PC pc)
323 {
324   PC_VPBJacobi *jac;
325 
326   PetscFunctionBegin;
327   /*
328      Creates the private data structure for this preconditioner and
329      attach it to the PC object.
330   */
331   PetscCall(PetscNew(&jac));
332   pc->data = (void *)jac;
333 
334   /*
335      Initialize the pointers to vectors to ZERO; these will be used to store
336      diagonal entries of the matrix for fast preconditioner application.
337   */
338   jac->diag = NULL;
339 
340   /*
341       Set the pointers for the functions that are provided above.
342       Now when the user-level routines (such as PCApply(), PCDestroy(), etc.)
343       are called, they will automatically call these functions.  Note we
344       choose not to provide a couple of these functions since they are
345       not needed.
346   */
347   pc->ops->apply               = PCApply_VPBJacobi;
348   pc->ops->applytranspose      = NULL;
349   pc->ops->setup               = PCSetUp_VPBJacobi;
350   pc->ops->destroy             = PCDestroy_VPBJacobi;
351   pc->ops->setfromoptions      = NULL;
352   pc->ops->view                = PCView_VPBJacobi;
353   pc->ops->applyrichardson     = NULL;
354   pc->ops->applysymmetricleft  = NULL;
355   pc->ops->applysymmetricright = NULL;
356   PetscFunctionReturn(PETSC_SUCCESS);
357 }
358