xref: /petsc/src/mat/impls/baij/seq/baijsolvtran6.c (revision 58d68138c660dfb4e9f5b03334792cd4f2ffd7cc)
1 #include <../src/mat/impls/baij/seq/baij.h>
2 #include <petsc/private/kernels/blockinvert.h>
3 
4 PetscErrorCode MatSolveTranspose_SeqBAIJ_6_inplace(Mat A, Vec bb, Vec xx) {
5   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
6   IS                 iscol = a->col, isrow = a->row;
7   const PetscInt    *r, *c, *rout, *cout;
8   const PetscInt    *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j;
9   PetscInt           i, nz, idx, idt, ii, ic, ir, oidx;
10   const MatScalar   *aa = a->a, *v;
11   PetscScalar        s1, s2, s3, s4, s5, s6, x1, x2, x3, x4, x5, x6, *x, *t;
12   const PetscScalar *b;
13 
14   PetscFunctionBegin;
15   PetscCall(VecGetArrayRead(bb, &b));
16   PetscCall(VecGetArray(xx, &x));
17   t = a->solve_work;
18 
19   PetscCall(ISGetIndices(isrow, &rout));
20   r = rout;
21   PetscCall(ISGetIndices(iscol, &cout));
22   c = cout;
23 
24   /* copy the b into temp work space according to permutation */
25   ii = 0;
26   for (i = 0; i < n; i++) {
27     ic        = 6 * c[i];
28     t[ii]     = b[ic];
29     t[ii + 1] = b[ic + 1];
30     t[ii + 2] = b[ic + 2];
31     t[ii + 3] = b[ic + 3];
32     t[ii + 4] = b[ic + 4];
33     t[ii + 5] = b[ic + 5];
34     ii += 6;
35   }
36 
37   /* forward solve the U^T */
38   idx = 0;
39   for (i = 0; i < n; i++) {
40     v  = aa + 36 * diag[i];
41     /* multiply by the inverse of the block diagonal */
42     x1 = t[idx];
43     x2 = t[1 + idx];
44     x3 = t[2 + idx];
45     x4 = t[3 + idx];
46     x5 = t[4 + idx];
47     x6 = t[5 + idx];
48     s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4 + v[4] * x5 + v[5] * x6;
49     s2 = v[6] * x1 + v[7] * x2 + v[8] * x3 + v[9] * x4 + v[10] * x5 + v[11] * x6;
50     s3 = v[12] * x1 + v[13] * x2 + v[14] * x3 + v[15] * x4 + v[16] * x5 + v[17] * x6;
51     s4 = v[18] * x1 + v[19] * x2 + v[20] * x3 + v[21] * x4 + v[22] * x5 + v[23] * x6;
52     s5 = v[24] * x1 + v[25] * x2 + v[26] * x3 + v[27] * x4 + v[28] * x5 + v[29] * x6;
53     s6 = v[30] * x1 + v[31] * x2 + v[32] * x3 + v[33] * x4 + v[34] * x5 + v[35] * x6;
54     v += 36;
55 
56     vi = aj + diag[i] + 1;
57     nz = ai[i + 1] - diag[i] - 1;
58     while (nz--) {
59       oidx = 6 * (*vi++);
60       t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5 + v[5] * s6;
61       t[oidx + 1] -= v[6] * s1 + v[7] * s2 + v[8] * s3 + v[9] * s4 + v[10] * s5 + v[11] * s6;
62       t[oidx + 2] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4 + v[16] * s5 + v[17] * s6;
63       t[oidx + 3] -= v[18] * s1 + v[19] * s2 + v[20] * s3 + v[21] * s4 + v[22] * s5 + v[23] * s6;
64       t[oidx + 4] -= v[24] * s1 + v[25] * s2 + v[26] * s3 + v[27] * s4 + v[28] * s5 + v[29] * s6;
65       t[oidx + 5] -= v[30] * s1 + v[31] * s2 + v[32] * s3 + v[33] * s4 + v[34] * s5 + v[35] * s6;
66       v += 36;
67     }
68     t[idx]     = s1;
69     t[1 + idx] = s2;
70     t[2 + idx] = s3;
71     t[3 + idx] = s4;
72     t[4 + idx] = s5;
73     t[5 + idx] = s6;
74     idx += 6;
75   }
76   /* backward solve the L^T */
77   for (i = n - 1; i >= 0; i--) {
78     v   = aa + 36 * diag[i] - 36;
79     vi  = aj + diag[i] - 1;
80     nz  = diag[i] - ai[i];
81     idt = 6 * i;
82     s1  = t[idt];
83     s2  = t[1 + idt];
84     s3  = t[2 + idt];
85     s4  = t[3 + idt];
86     s5  = t[4 + idt];
87     s6  = t[5 + idt];
88     while (nz--) {
89       idx = 6 * (*vi--);
90       t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5 + v[5] * s6;
91       t[idx + 1] -= v[6] * s1 + v[7] * s2 + v[8] * s3 + v[9] * s4 + v[10] * s5 + v[11] * s6;
92       t[idx + 2] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4 + v[16] * s5 + v[17] * s6;
93       t[idx + 3] -= v[18] * s1 + v[19] * s2 + v[20] * s3 + v[21] * s4 + v[22] * s5 + v[23] * s6;
94       t[idx + 4] -= v[24] * s1 + v[25] * s2 + v[26] * s3 + v[27] * s4 + v[28] * s5 + v[29] * s6;
95       t[idx + 5] -= v[30] * s1 + v[31] * s2 + v[32] * s3 + v[33] * s4 + v[34] * s5 + v[35] * s6;
96       v -= 36;
97     }
98   }
99 
100   /* copy t into x according to permutation */
101   ii = 0;
102   for (i = 0; i < n; i++) {
103     ir        = 6 * r[i];
104     x[ir]     = t[ii];
105     x[ir + 1] = t[ii + 1];
106     x[ir + 2] = t[ii + 2];
107     x[ir + 3] = t[ii + 3];
108     x[ir + 4] = t[ii + 4];
109     x[ir + 5] = t[ii + 5];
110     ii += 6;
111   }
112 
113   PetscCall(ISRestoreIndices(isrow, &rout));
114   PetscCall(ISRestoreIndices(iscol, &cout));
115   PetscCall(VecRestoreArrayRead(bb, &b));
116   PetscCall(VecRestoreArray(xx, &x));
117   PetscCall(PetscLogFlops(2.0 * 36 * (a->nz) - 6.0 * A->cmap->n));
118   PetscFunctionReturn(0);
119 }
120 
121 PetscErrorCode MatSolveTranspose_SeqBAIJ_6(Mat A, Vec bb, Vec xx) {
122   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
123   IS                 iscol = a->col, isrow = a->row;
124   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j, *diag = a->diag;
125   const PetscInt    *r, *c, *rout, *cout;
126   PetscInt           nz, idx, idt, j, i, oidx, ii, ic, ir;
127   const PetscInt     bs = A->rmap->bs, bs2 = a->bs2;
128   const MatScalar   *aa = a->a, *v;
129   PetscScalar        s1, s2, s3, s4, s5, s6, x1, x2, x3, x4, x5, x6, *x, *t;
130   const PetscScalar *b;
131 
132   PetscFunctionBegin;
133   PetscCall(VecGetArrayRead(bb, &b));
134   PetscCall(VecGetArray(xx, &x));
135   t = a->solve_work;
136 
137   PetscCall(ISGetIndices(isrow, &rout));
138   r = rout;
139   PetscCall(ISGetIndices(iscol, &cout));
140   c = cout;
141 
142   /* copy b into temp work space according to permutation */
143   for (i = 0; i < n; i++) {
144     ii        = bs * i;
145     ic        = bs * c[i];
146     t[ii]     = b[ic];
147     t[ii + 1] = b[ic + 1];
148     t[ii + 2] = b[ic + 2];
149     t[ii + 3] = b[ic + 3];
150     t[ii + 4] = b[ic + 4];
151     t[ii + 5] = b[ic + 5];
152   }
153 
154   /* forward solve the U^T */
155   idx = 0;
156   for (i = 0; i < n; i++) {
157     v  = aa + bs2 * diag[i];
158     /* multiply by the inverse of the block diagonal */
159     x1 = t[idx];
160     x2 = t[1 + idx];
161     x3 = t[2 + idx];
162     x4 = t[3 + idx];
163     x5 = t[4 + idx];
164     x6 = t[5 + idx];
165     s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4 + v[4] * x5 + v[5] * x6;
166     s2 = v[6] * x1 + v[7] * x2 + v[8] * x3 + v[9] * x4 + v[10] * x5 + v[11] * x6;
167     s3 = v[12] * x1 + v[13] * x2 + v[14] * x3 + v[15] * x4 + v[16] * x5 + v[17] * x6;
168     s4 = v[18] * x1 + v[19] * x2 + v[20] * x3 + v[21] * x4 + v[22] * x5 + v[23] * x6;
169     s5 = v[24] * x1 + v[25] * x2 + v[26] * x3 + v[27] * x4 + v[28] * x5 + v[29] * x6;
170     s6 = v[30] * x1 + v[31] * x2 + v[32] * x3 + v[33] * x4 + v[34] * x5 + v[35] * x6;
171     v -= bs2;
172 
173     vi = aj + diag[i] - 1;
174     nz = diag[i] - diag[i + 1] - 1;
175     for (j = 0; j > -nz; j--) {
176       oidx = bs * vi[j];
177       t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5 + v[5] * s6;
178       t[oidx + 1] -= v[6] * s1 + v[7] * s2 + v[8] * s3 + v[9] * s4 + v[10] * s5 + v[11] * s6;
179       t[oidx + 2] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4 + v[16] * s5 + v[17] * s6;
180       t[oidx + 3] -= v[18] * s1 + v[19] * s2 + v[20] * s3 + v[21] * s4 + v[22] * s5 + v[23] * s6;
181       t[oidx + 4] -= v[24] * s1 + v[25] * s2 + v[26] * s3 + v[27] * s4 + v[28] * s5 + v[29] * s6;
182       t[oidx + 5] -= v[30] * s1 + v[31] * s2 + v[32] * s3 + v[33] * s4 + v[34] * s5 + v[35] * s6;
183       v -= bs2;
184     }
185     t[idx]     = s1;
186     t[1 + idx] = s2;
187     t[2 + idx] = s3;
188     t[3 + idx] = s4;
189     t[4 + idx] = s5;
190     t[5 + idx] = s6;
191     idx += bs;
192   }
193   /* backward solve the L^T */
194   for (i = n - 1; i >= 0; i--) {
195     v   = aa + bs2 * ai[i];
196     vi  = aj + ai[i];
197     nz  = ai[i + 1] - ai[i];
198     idt = bs * i;
199     s1  = t[idt];
200     s2  = t[1 + idt];
201     s3  = t[2 + idt];
202     s4  = t[3 + idt];
203     s5  = t[4 + idt];
204     s6  = t[5 + idt];
205     for (j = 0; j < nz; j++) {
206       idx = bs * vi[j];
207       t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5 + v[5] * s6;
208       t[idx + 1] -= v[6] * s1 + v[7] * s2 + v[8] * s3 + v[9] * s4 + v[10] * s5 + v[11] * s6;
209       t[idx + 2] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4 + v[16] * s5 + v[17] * s6;
210       t[idx + 3] -= v[18] * s1 + v[19] * s2 + v[20] * s3 + v[21] * s4 + v[22] * s5 + v[23] * s6;
211       t[idx + 4] -= v[24] * s1 + v[25] * s2 + v[26] * s3 + v[27] * s4 + v[28] * s5 + v[29] * s6;
212       t[idx + 5] -= v[30] * s1 + v[31] * s2 + v[32] * s3 + v[33] * s4 + v[34] * s5 + v[35] * s6;
213       v += bs2;
214     }
215   }
216 
217   /* copy t into x according to permutation */
218   for (i = 0; i < n; i++) {
219     ii        = bs * i;
220     ir        = bs * r[i];
221     x[ir]     = t[ii];
222     x[ir + 1] = t[ii + 1];
223     x[ir + 2] = t[ii + 2];
224     x[ir + 3] = t[ii + 3];
225     x[ir + 4] = t[ii + 4];
226     x[ir + 5] = t[ii + 5];
227   }
228 
229   PetscCall(ISRestoreIndices(isrow, &rout));
230   PetscCall(ISRestoreIndices(iscol, &cout));
231   PetscCall(VecRestoreArrayRead(bb, &b));
232   PetscCall(VecRestoreArray(xx, &x));
233   PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n));
234   PetscFunctionReturn(0);
235 }
236