xref: /petsc/src/mat/impls/baij/seq/baijfact13.c (revision fbf9dbe564678ed6eff1806adbc4c4f01b9743f4)
1 
2 /*
3     Factorization code for BAIJ format.
4 */
5 #include <../src/mat/impls/baij/seq/baij.h>
6 #include <petsc/private/kernels/blockinvert.h>
7 
8 /*
9       Version for when blocks are 3 by 3
10 */
11 PetscErrorCode MatLUFactorNumeric_SeqBAIJ_3_inplace(Mat C, Mat A, const MatFactorInfo *info)
12 {
13   Mat_SeqBAIJ    *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
14   IS              isrow = b->row, isicol = b->icol;
15   const PetscInt *r, *ic;
16   PetscInt        i, j, n = a->mbs, *bi = b->i, *bj = b->j;
17   PetscInt       *ajtmpold, *ajtmp, nz, row, *ai = a->i, *aj = a->j;
18   PetscInt       *diag_offset = b->diag, idx, *pj;
19   MatScalar      *pv, *v, *rtmp, *pc, *w, *x;
20   MatScalar       p1, p2, p3, p4, m1, m2, m3, m4, m5, m6, m7, m8, m9, x1, x2, x3, x4;
21   MatScalar       p5, p6, p7, p8, p9, x5, x6, x7, x8, x9;
22   MatScalar      *ba = b->a, *aa = a->a;
23   PetscReal       shift = info->shiftamount;
24   PetscBool       allowzeropivot, zeropivotdetected;
25 
26   PetscFunctionBegin;
27   PetscCall(ISGetIndices(isrow, &r));
28   PetscCall(ISGetIndices(isicol, &ic));
29   PetscCall(PetscMalloc1(9 * (n + 1), &rtmp));
30   allowzeropivot = PetscNot(A->erroriffailure);
31 
32   for (i = 0; i < n; i++) {
33     nz    = bi[i + 1] - bi[i];
34     ajtmp = bj + bi[i];
35     for (j = 0; j < nz; j++) {
36       x    = rtmp + 9 * ajtmp[j];
37       x[0] = x[1] = x[2] = x[3] = x[4] = x[5] = x[6] = x[7] = x[8] = 0.0;
38     }
39     /* load in initial (unfactored row) */
40     idx      = r[i];
41     nz       = ai[idx + 1] - ai[idx];
42     ajtmpold = aj + ai[idx];
43     v        = aa + 9 * ai[idx];
44     for (j = 0; j < nz; j++) {
45       x    = rtmp + 9 * ic[ajtmpold[j]];
46       x[0] = v[0];
47       x[1] = v[1];
48       x[2] = v[2];
49       x[3] = v[3];
50       x[4] = v[4];
51       x[5] = v[5];
52       x[6] = v[6];
53       x[7] = v[7];
54       x[8] = v[8];
55       v += 9;
56     }
57     row = *ajtmp++;
58     while (row < i) {
59       pc = rtmp + 9 * row;
60       p1 = pc[0];
61       p2 = pc[1];
62       p3 = pc[2];
63       p4 = pc[3];
64       p5 = pc[4];
65       p6 = pc[5];
66       p7 = pc[6];
67       p8 = pc[7];
68       p9 = pc[8];
69       if (p1 != 0.0 || p2 != 0.0 || p3 != 0.0 || p4 != 0.0 || p5 != 0.0 || p6 != 0.0 || p7 != 0.0 || p8 != 0.0 || p9 != 0.0) {
70         pv    = ba + 9 * diag_offset[row];
71         pj    = bj + diag_offset[row] + 1;
72         x1    = pv[0];
73         x2    = pv[1];
74         x3    = pv[2];
75         x4    = pv[3];
76         x5    = pv[4];
77         x6    = pv[5];
78         x7    = pv[6];
79         x8    = pv[7];
80         x9    = pv[8];
81         pc[0] = m1 = p1 * x1 + p4 * x2 + p7 * x3;
82         pc[1] = m2 = p2 * x1 + p5 * x2 + p8 * x3;
83         pc[2] = m3 = p3 * x1 + p6 * x2 + p9 * x3;
84 
85         pc[3] = m4 = p1 * x4 + p4 * x5 + p7 * x6;
86         pc[4] = m5 = p2 * x4 + p5 * x5 + p8 * x6;
87         pc[5] = m6 = p3 * x4 + p6 * x5 + p9 * x6;
88 
89         pc[6] = m7 = p1 * x7 + p4 * x8 + p7 * x9;
90         pc[7] = m8 = p2 * x7 + p5 * x8 + p8 * x9;
91         pc[8] = m9 = p3 * x7 + p6 * x8 + p9 * x9;
92         nz         = bi[row + 1] - diag_offset[row] - 1;
93         pv += 9;
94         for (j = 0; j < nz; j++) {
95           x1 = pv[0];
96           x2 = pv[1];
97           x3 = pv[2];
98           x4 = pv[3];
99           x5 = pv[4];
100           x6 = pv[5];
101           x7 = pv[6];
102           x8 = pv[7];
103           x9 = pv[8];
104           x  = rtmp + 9 * pj[j];
105           x[0] -= m1 * x1 + m4 * x2 + m7 * x3;
106           x[1] -= m2 * x1 + m5 * x2 + m8 * x3;
107           x[2] -= m3 * x1 + m6 * x2 + m9 * x3;
108 
109           x[3] -= m1 * x4 + m4 * x5 + m7 * x6;
110           x[4] -= m2 * x4 + m5 * x5 + m8 * x6;
111           x[5] -= m3 * x4 + m6 * x5 + m9 * x6;
112 
113           x[6] -= m1 * x7 + m4 * x8 + m7 * x9;
114           x[7] -= m2 * x7 + m5 * x8 + m8 * x9;
115           x[8] -= m3 * x7 + m6 * x8 + m9 * x9;
116           pv += 9;
117         }
118         PetscCall(PetscLogFlops(54.0 * nz + 36.0));
119       }
120       row = *ajtmp++;
121     }
122     /* finished row so stick it into b->a */
123     pv = ba + 9 * bi[i];
124     pj = bj + bi[i];
125     nz = bi[i + 1] - bi[i];
126     for (j = 0; j < nz; j++) {
127       x     = rtmp + 9 * pj[j];
128       pv[0] = x[0];
129       pv[1] = x[1];
130       pv[2] = x[2];
131       pv[3] = x[3];
132       pv[4] = x[4];
133       pv[5] = x[5];
134       pv[6] = x[6];
135       pv[7] = x[7];
136       pv[8] = x[8];
137       pv += 9;
138     }
139     /* invert diagonal block */
140     w = ba + 9 * diag_offset[i];
141     PetscCall(PetscKernel_A_gets_inverse_A_3(w, shift, allowzeropivot, &zeropivotdetected));
142     if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
143   }
144 
145   PetscCall(PetscFree(rtmp));
146   PetscCall(ISRestoreIndices(isicol, &ic));
147   PetscCall(ISRestoreIndices(isrow, &r));
148 
149   C->ops->solve          = MatSolve_SeqBAIJ_3_inplace;
150   C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_3_inplace;
151   C->assembled           = PETSC_TRUE;
152 
153   PetscCall(PetscLogFlops(1.333333333333 * 3 * 3 * 3 * b->mbs)); /* from inverting diagonal blocks */
154   PetscFunctionReturn(PETSC_SUCCESS);
155 }
156 
157 /* MatLUFactorNumeric_SeqBAIJ_3 -
158      copied from MatLUFactorNumeric_SeqBAIJ_N_inplace() and manually re-implemented
159        PetscKernel_A_gets_A_times_B()
160        PetscKernel_A_gets_A_minus_B_times_C()
161        PetscKernel_A_gets_inverse_A()
162 */
163 PetscErrorCode MatLUFactorNumeric_SeqBAIJ_3(Mat B, Mat A, const MatFactorInfo *info)
164 {
165   Mat             C = B;
166   Mat_SeqBAIJ    *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
167   IS              isrow = b->row, isicol = b->icol;
168   const PetscInt *r, *ic;
169   PetscInt        i, j, k, nz, nzL, row;
170   const PetscInt  n = a->mbs, *ai = a->i, *aj = a->j, *bi = b->i, *bj = b->j;
171   const PetscInt *ajtmp, *bjtmp, *bdiag = b->diag, *pj, bs2 = a->bs2;
172   MatScalar      *rtmp, *pc, *mwork, *v, *pv, *aa = a->a;
173   PetscInt        flg;
174   PetscReal       shift = info->shiftamount;
175   PetscBool       allowzeropivot, zeropivotdetected;
176 
177   PetscFunctionBegin;
178   PetscCall(ISGetIndices(isrow, &r));
179   PetscCall(ISGetIndices(isicol, &ic));
180   allowzeropivot = PetscNot(A->erroriffailure);
181 
182   /* generate work space needed by the factorization */
183   PetscCall(PetscMalloc2(bs2 * n, &rtmp, bs2, &mwork));
184   PetscCall(PetscArrayzero(rtmp, bs2 * n));
185 
186   for (i = 0; i < n; i++) {
187     /* zero rtmp */
188     /* L part */
189     nz    = bi[i + 1] - bi[i];
190     bjtmp = bj + bi[i];
191     for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
192 
193     /* U part */
194     nz    = bdiag[i] - bdiag[i + 1];
195     bjtmp = bj + bdiag[i + 1] + 1;
196     for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
197 
198     /* load in initial (unfactored row) */
199     nz    = ai[r[i] + 1] - ai[r[i]];
200     ajtmp = aj + ai[r[i]];
201     v     = aa + bs2 * ai[r[i]];
202     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(rtmp + bs2 * ic[ajtmp[j]], v + bs2 * j, bs2));
203 
204     /* elimination */
205     bjtmp = bj + bi[i];
206     nzL   = bi[i + 1] - bi[i];
207     for (k = 0; k < nzL; k++) {
208       row = bjtmp[k];
209       pc  = rtmp + bs2 * row;
210       for (flg = 0, j = 0; j < bs2; j++) {
211         if (pc[j] != 0.0) {
212           flg = 1;
213           break;
214         }
215       }
216       if (flg) {
217         pv = b->a + bs2 * bdiag[row];
218         /* PetscKernel_A_gets_A_times_B(bs,pc,pv,mwork); *pc = *pc * (*pv); */
219         PetscCall(PetscKernel_A_gets_A_times_B_3(pc, pv, mwork));
220 
221         pj = b->j + bdiag[row + 1] + 1; /* beginning of U(row,:) */
222         pv = b->a + bs2 * (bdiag[row + 1] + 1);
223         nz = bdiag[row] - bdiag[row + 1] - 1; /* num of entries in U(row,:) excluding diag */
224         for (j = 0; j < nz; j++) {
225           /* PetscKernel_A_gets_A_minus_B_times_C(bs,rtmp+bs2*pj[j],pc,pv+bs2*j); */
226           /* rtmp+bs2*pj[j] = rtmp+bs2*pj[j] - (*pc)*(pv+bs2*j) */
227           v = rtmp + bs2 * pj[j];
228           PetscCall(PetscKernel_A_gets_A_minus_B_times_C_3(v, pc, pv));
229           pv += bs2;
230         }
231         PetscCall(PetscLogFlops(54.0 * nz + 45)); /* flops = 2*bs^3*nz + 2*bs^3 - bs2) */
232       }
233     }
234 
235     /* finished row so stick it into b->a */
236     /* L part */
237     pv = b->a + bs2 * bi[i];
238     pj = b->j + bi[i];
239     nz = bi[i + 1] - bi[i];
240     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
241 
242     /* Mark diagonal and invert diagonal for simpler triangular solves */
243     pv = b->a + bs2 * bdiag[i];
244     pj = b->j + bdiag[i];
245     PetscCall(PetscArraycpy(pv, rtmp + bs2 * pj[0], bs2));
246     PetscCall(PetscKernel_A_gets_inverse_A_3(pv, shift, allowzeropivot, &zeropivotdetected));
247     if (zeropivotdetected) B->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
248 
249     /* U part */
250     pj = b->j + bdiag[i + 1] + 1;
251     pv = b->a + bs2 * (bdiag[i + 1] + 1);
252     nz = bdiag[i] - bdiag[i + 1] - 1;
253     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
254   }
255 
256   PetscCall(PetscFree2(rtmp, mwork));
257   PetscCall(ISRestoreIndices(isicol, &ic));
258   PetscCall(ISRestoreIndices(isrow, &r));
259 
260   C->ops->solve          = MatSolve_SeqBAIJ_3;
261   C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_3;
262   C->assembled           = PETSC_TRUE;
263 
264   PetscCall(PetscLogFlops(1.333333333333 * 3 * 3 * 3 * n)); /* from inverting diagonal blocks */
265   PetscFunctionReturn(PETSC_SUCCESS);
266 }
267 
268 PetscErrorCode MatLUFactorNumeric_SeqBAIJ_3_NaturalOrdering_inplace(Mat C, Mat A, const MatFactorInfo *info)
269 {
270   Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
271   PetscInt     i, j, n = a->mbs, *bi = b->i, *bj = b->j;
272   PetscInt    *ajtmpold, *ajtmp, nz, row;
273   PetscInt    *diag_offset = b->diag, *ai = a->i, *aj = a->j, *pj;
274   MatScalar   *pv, *v, *rtmp, *pc, *w, *x;
275   MatScalar    p1, p2, p3, p4, m1, m2, m3, m4, m5, m6, m7, m8, m9, x1, x2, x3, x4;
276   MatScalar    p5, p6, p7, p8, p9, x5, x6, x7, x8, x9;
277   MatScalar   *ba = b->a, *aa = a->a;
278   PetscReal    shift = info->shiftamount;
279   PetscBool    allowzeropivot, zeropivotdetected;
280 
281   PetscFunctionBegin;
282   PetscCall(PetscMalloc1(9 * (n + 1), &rtmp));
283   allowzeropivot = PetscNot(A->erroriffailure);
284 
285   for (i = 0; i < n; i++) {
286     nz    = bi[i + 1] - bi[i];
287     ajtmp = bj + bi[i];
288     for (j = 0; j < nz; j++) {
289       x    = rtmp + 9 * ajtmp[j];
290       x[0] = x[1] = x[2] = x[3] = x[4] = x[5] = x[6] = x[7] = x[8] = 0.0;
291     }
292     /* load in initial (unfactored row) */
293     nz       = ai[i + 1] - ai[i];
294     ajtmpold = aj + ai[i];
295     v        = aa + 9 * ai[i];
296     for (j = 0; j < nz; j++) {
297       x    = rtmp + 9 * ajtmpold[j];
298       x[0] = v[0];
299       x[1] = v[1];
300       x[2] = v[2];
301       x[3] = v[3];
302       x[4] = v[4];
303       x[5] = v[5];
304       x[6] = v[6];
305       x[7] = v[7];
306       x[8] = v[8];
307       v += 9;
308     }
309     row = *ajtmp++;
310     while (row < i) {
311       pc = rtmp + 9 * row;
312       p1 = pc[0];
313       p2 = pc[1];
314       p3 = pc[2];
315       p4 = pc[3];
316       p5 = pc[4];
317       p6 = pc[5];
318       p7 = pc[6];
319       p8 = pc[7];
320       p9 = pc[8];
321       if (p1 != 0.0 || p2 != 0.0 || p3 != 0.0 || p4 != 0.0 || p5 != 0.0 || p6 != 0.0 || p7 != 0.0 || p8 != 0.0 || p9 != 0.0) {
322         pv    = ba + 9 * diag_offset[row];
323         pj    = bj + diag_offset[row] + 1;
324         x1    = pv[0];
325         x2    = pv[1];
326         x3    = pv[2];
327         x4    = pv[3];
328         x5    = pv[4];
329         x6    = pv[5];
330         x7    = pv[6];
331         x8    = pv[7];
332         x9    = pv[8];
333         pc[0] = m1 = p1 * x1 + p4 * x2 + p7 * x3;
334         pc[1] = m2 = p2 * x1 + p5 * x2 + p8 * x3;
335         pc[2] = m3 = p3 * x1 + p6 * x2 + p9 * x3;
336 
337         pc[3] = m4 = p1 * x4 + p4 * x5 + p7 * x6;
338         pc[4] = m5 = p2 * x4 + p5 * x5 + p8 * x6;
339         pc[5] = m6 = p3 * x4 + p6 * x5 + p9 * x6;
340 
341         pc[6] = m7 = p1 * x7 + p4 * x8 + p7 * x9;
342         pc[7] = m8 = p2 * x7 + p5 * x8 + p8 * x9;
343         pc[8] = m9 = p3 * x7 + p6 * x8 + p9 * x9;
344 
345         nz = bi[row + 1] - diag_offset[row] - 1;
346         pv += 9;
347         for (j = 0; j < nz; j++) {
348           x1 = pv[0];
349           x2 = pv[1];
350           x3 = pv[2];
351           x4 = pv[3];
352           x5 = pv[4];
353           x6 = pv[5];
354           x7 = pv[6];
355           x8 = pv[7];
356           x9 = pv[8];
357           x  = rtmp + 9 * pj[j];
358           x[0] -= m1 * x1 + m4 * x2 + m7 * x3;
359           x[1] -= m2 * x1 + m5 * x2 + m8 * x3;
360           x[2] -= m3 * x1 + m6 * x2 + m9 * x3;
361 
362           x[3] -= m1 * x4 + m4 * x5 + m7 * x6;
363           x[4] -= m2 * x4 + m5 * x5 + m8 * x6;
364           x[5] -= m3 * x4 + m6 * x5 + m9 * x6;
365 
366           x[6] -= m1 * x7 + m4 * x8 + m7 * x9;
367           x[7] -= m2 * x7 + m5 * x8 + m8 * x9;
368           x[8] -= m3 * x7 + m6 * x8 + m9 * x9;
369           pv += 9;
370         }
371         PetscCall(PetscLogFlops(54.0 * nz + 36.0));
372       }
373       row = *ajtmp++;
374     }
375     /* finished row so stick it into b->a */
376     pv = ba + 9 * bi[i];
377     pj = bj + bi[i];
378     nz = bi[i + 1] - bi[i];
379     for (j = 0; j < nz; j++) {
380       x     = rtmp + 9 * pj[j];
381       pv[0] = x[0];
382       pv[1] = x[1];
383       pv[2] = x[2];
384       pv[3] = x[3];
385       pv[4] = x[4];
386       pv[5] = x[5];
387       pv[6] = x[6];
388       pv[7] = x[7];
389       pv[8] = x[8];
390       pv += 9;
391     }
392     /* invert diagonal block */
393     w = ba + 9 * diag_offset[i];
394     PetscCall(PetscKernel_A_gets_inverse_A_3(w, shift, allowzeropivot, &zeropivotdetected));
395     if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
396   }
397 
398   PetscCall(PetscFree(rtmp));
399 
400   C->ops->solve          = MatSolve_SeqBAIJ_3_NaturalOrdering_inplace;
401   C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_3_NaturalOrdering_inplace;
402   C->assembled           = PETSC_TRUE;
403 
404   PetscCall(PetscLogFlops(1.333333333333 * 3 * 3 * 3 * b->mbs)); /* from inverting diagonal blocks */
405   PetscFunctionReturn(PETSC_SUCCESS);
406 }
407 
408 /*
409   MatLUFactorNumeric_SeqBAIJ_3_NaturalOrdering -
410     copied from MatLUFactorNumeric_SeqBAIJ_2_NaturalOrdering_inplace()
411 */
412 PetscErrorCode MatLUFactorNumeric_SeqBAIJ_3_NaturalOrdering(Mat B, Mat A, const MatFactorInfo *info)
413 {
414   Mat             C = B;
415   Mat_SeqBAIJ    *a = (Mat_SeqBAIJ *)A->data, *b = (Mat_SeqBAIJ *)C->data;
416   PetscInt        i, j, k, nz, nzL, row;
417   const PetscInt  n = a->mbs, *ai = a->i, *aj = a->j, *bi = b->i, *bj = b->j;
418   const PetscInt *ajtmp, *bjtmp, *bdiag = b->diag, *pj, bs2 = a->bs2;
419   MatScalar      *rtmp, *pc, *mwork, *v, *pv, *aa = a->a;
420   PetscInt        flg;
421   PetscReal       shift = info->shiftamount;
422   PetscBool       allowzeropivot, zeropivotdetected;
423 
424   PetscFunctionBegin;
425   allowzeropivot = PetscNot(A->erroriffailure);
426 
427   /* generate work space needed by the factorization */
428   PetscCall(PetscMalloc2(bs2 * n, &rtmp, bs2, &mwork));
429   PetscCall(PetscArrayzero(rtmp, bs2 * n));
430 
431   for (i = 0; i < n; i++) {
432     /* zero rtmp */
433     /* L part */
434     nz    = bi[i + 1] - bi[i];
435     bjtmp = bj + bi[i];
436     for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
437 
438     /* U part */
439     nz    = bdiag[i] - bdiag[i + 1];
440     bjtmp = bj + bdiag[i + 1] + 1;
441     for (j = 0; j < nz; j++) PetscCall(PetscArrayzero(rtmp + bs2 * bjtmp[j], bs2));
442 
443     /* load in initial (unfactored row) */
444     nz    = ai[i + 1] - ai[i];
445     ajtmp = aj + ai[i];
446     v     = aa + bs2 * ai[i];
447     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(rtmp + bs2 * ajtmp[j], v + bs2 * j, bs2));
448 
449     /* elimination */
450     bjtmp = bj + bi[i];
451     nzL   = bi[i + 1] - bi[i];
452     for (k = 0; k < nzL; k++) {
453       row = bjtmp[k];
454       pc  = rtmp + bs2 * row;
455       for (flg = 0, j = 0; j < bs2; j++) {
456         if (pc[j] != 0.0) {
457           flg = 1;
458           break;
459         }
460       }
461       if (flg) {
462         pv = b->a + bs2 * bdiag[row];
463         /* PetscKernel_A_gets_A_times_B(bs,pc,pv,mwork); *pc = *pc * (*pv); */
464         PetscCall(PetscKernel_A_gets_A_times_B_3(pc, pv, mwork));
465 
466         pj = b->j + bdiag[row + 1] + 1; /* beginning of U(row,:) */
467         pv = b->a + bs2 * (bdiag[row + 1] + 1);
468         nz = bdiag[row] - bdiag[row + 1] - 1; /* num of entries in U(row,:) excluding diag */
469         for (j = 0; j < nz; j++) {
470           /* PetscKernel_A_gets_A_minus_B_times_C(bs,rtmp+bs2*pj[j],pc,pv+bs2*j); */
471           /* rtmp+bs2*pj[j] = rtmp+bs2*pj[j] - (*pc)*(pv+bs2*j) */
472           v = rtmp + bs2 * pj[j];
473           PetscCall(PetscKernel_A_gets_A_minus_B_times_C_3(v, pc, pv));
474           pv += bs2;
475         }
476         PetscCall(PetscLogFlops(54.0 * nz + 45)); /* flops = 2*bs^3*nz + 2*bs^3 - bs2) */
477       }
478     }
479 
480     /* finished row so stick it into b->a */
481     /* L part */
482     pv = b->a + bs2 * bi[i];
483     pj = b->j + bi[i];
484     nz = bi[i + 1] - bi[i];
485     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
486 
487     /* Mark diagonal and invert diagonal for simpler triangular solves */
488     pv = b->a + bs2 * bdiag[i];
489     pj = b->j + bdiag[i];
490     PetscCall(PetscArraycpy(pv, rtmp + bs2 * pj[0], bs2));
491     PetscCall(PetscKernel_A_gets_inverse_A_3(pv, shift, allowzeropivot, &zeropivotdetected));
492     if (zeropivotdetected) B->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
493 
494     /* U part */
495     pv = b->a + bs2 * (bdiag[i + 1] + 1);
496     pj = b->j + bdiag[i + 1] + 1;
497     nz = bdiag[i] - bdiag[i + 1] - 1;
498     for (j = 0; j < nz; j++) PetscCall(PetscArraycpy(pv + bs2 * j, rtmp + bs2 * pj[j], bs2));
499   }
500   PetscCall(PetscFree2(rtmp, mwork));
501 
502   C->ops->solve          = MatSolve_SeqBAIJ_3_NaturalOrdering;
503   C->ops->forwardsolve   = MatForwardSolve_SeqBAIJ_3_NaturalOrdering;
504   C->ops->backwardsolve  = MatBackwardSolve_SeqBAIJ_3_NaturalOrdering;
505   C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_3_NaturalOrdering;
506   C->assembled           = PETSC_TRUE;
507 
508   PetscCall(PetscLogFlops(1.333333333333 * 3 * 3 * 3 * n)); /* from inverting diagonal blocks */
509   PetscFunctionReturn(PETSC_SUCCESS);
510 }
511