xref: /petsc/src/mat/impls/baij/seq/baijsolvtran5.c (revision 503c0ea9b45bcfbcebbb1ea5341243bbc69f0bea)
1 #include <../src/mat/impls/baij/seq/baij.h>
2 #include <petsc/private/kernels/blockinvert.h>
3 
4 PetscErrorCode MatSolveTranspose_SeqBAIJ_5_inplace(Mat A,Vec bb,Vec xx)
5 {
6   Mat_SeqBAIJ       *a   =(Mat_SeqBAIJ*)A->data;
7   IS                iscol=a->col,isrow=a->row;
8   const PetscInt    *r,*c,*rout,*cout;
9   const PetscInt    *diag=a->diag,n=a->mbs,*vi,*ai=a->i,*aj=a->j;
10   PetscInt          i,nz,idx,idt,ii,ic,ir,oidx;
11   const MatScalar   *aa=a->a,*v;
12   PetscScalar       s1,s2,s3,s4,s5,x1,x2,x3,x4,x5,*x,*t;
13   const PetscScalar *b;
14 
15   PetscFunctionBegin;
16   PetscCall(VecGetArrayRead(bb,&b));
17   PetscCall(VecGetArray(xx,&x));
18   t    = a->solve_work;
19 
20   PetscCall(ISGetIndices(isrow,&rout)); r = rout;
21   PetscCall(ISGetIndices(iscol,&cout)); c = cout;
22 
23   /* copy the b into temp work space according to permutation */
24   ii = 0;
25   for (i=0; i<n; i++) {
26     ic      = 5*c[i];
27     t[ii]   = b[ic];
28     t[ii+1] = b[ic+1];
29     t[ii+2] = b[ic+2];
30     t[ii+3] = b[ic+3];
31     t[ii+4] = b[ic+4];
32     ii     += 5;
33   }
34 
35   /* forward solve the U^T */
36   idx = 0;
37   for (i=0; i<n; i++) {
38 
39     v = aa + 25*diag[i];
40     /* multiply by the inverse of the block diagonal */
41     x1 = t[idx];   x2 = t[1+idx]; x3    = t[2+idx]; x4 = t[3+idx]; x5 = t[4+idx];
42     s1 = v[0]*x1  +  v[1]*x2 +  v[2]*x3 +  v[3]*x4 +  v[4]*x5;
43     s2 = v[5]*x1  +  v[6]*x2 +  v[7]*x3 +  v[8]*x4 +  v[9]*x5;
44     s3 = v[10]*x1 + v[11]*x2 + v[12]*x3 + v[13]*x4 + v[14]*x5;
45     s4 = v[15]*x1 + v[16]*x2 + v[17]*x3 + v[18]*x4 + v[19]*x5;
46     s5 = v[20]*x1 + v[21]*x2 + v[22]*x3 + v[23]*x4 + v[24]*x5;
47     v += 25;
48 
49     vi = aj + diag[i] + 1;
50     nz = ai[i+1] - diag[i] - 1;
51     while (nz--) {
52       oidx       = 5*(*vi++);
53       t[oidx]   -= v[0]*s1  +  v[1]*s2 +  v[2]*s3 +  v[3]*s4 +  v[4]*s5;
54       t[oidx+1] -= v[5]*s1  +  v[6]*s2 +  v[7]*s3 +  v[8]*s4 +  v[9]*s5;
55       t[oidx+2] -= v[10]*s1 + v[11]*s2 + v[12]*s3 + v[13]*s4 + v[14]*s5;
56       t[oidx+3] -= v[15]*s1 + v[16]*s2 + v[17]*s3 + v[18]*s4 + v[19]*s5;
57       t[oidx+4] -= v[20]*s1 + v[21]*s2 + v[22]*s3 + v[23]*s4 + v[24]*s5;
58       v         += 25;
59     }
60     t[idx] = s1;t[1+idx] = s2; t[2+idx] = s3;t[3+idx] = s4; t[4+idx] = s5;
61     idx   += 5;
62   }
63   /* backward solve the L^T */
64   for (i=n-1; i>=0; i--) {
65     v   = aa + 25*diag[i] - 25;
66     vi  = aj + diag[i] - 1;
67     nz  = diag[i] - ai[i];
68     idt = 5*i;
69     s1  = t[idt];  s2 = t[1+idt]; s3 = t[2+idt];s4 = t[3+idt]; s5 = t[4+idt];
70     while (nz--) {
71       idx       = 5*(*vi--);
72       t[idx]   -=  v[0]*s1 +  v[1]*s2 +  v[2]*s3 +  v[3]*s4 +  v[4]*s5;
73       t[idx+1] -=  v[5]*s1 +  v[6]*s2 +  v[7]*s3 +  v[8]*s4 +  v[9]*s5;
74       t[idx+2] -= v[10]*s1 + v[11]*s2 + v[12]*s3 + v[13]*s4 + v[14]*s5;
75       t[idx+3] -= v[15]*s1 + v[16]*s2 + v[17]*s3 + v[18]*s4 + v[19]*s5;
76       t[idx+4] -= v[20]*s1 + v[21]*s2 + v[22]*s3 + v[23]*s4 + v[24]*s5;
77       v        -= 25;
78     }
79   }
80 
81   /* copy t into x according to permutation */
82   ii = 0;
83   for (i=0; i<n; i++) {
84     ir      = 5*r[i];
85     x[ir]   = t[ii];
86     x[ir+1] = t[ii+1];
87     x[ir+2] = t[ii+2];
88     x[ir+3] = t[ii+3];
89     x[ir+4] = t[ii+4];
90     ii     += 5;
91   }
92 
93   PetscCall(ISRestoreIndices(isrow,&rout));
94   PetscCall(ISRestoreIndices(iscol,&cout));
95   PetscCall(VecRestoreArrayRead(bb,&b));
96   PetscCall(VecRestoreArray(xx,&x));
97   PetscCall(PetscLogFlops(2.0*25*(a->nz) - 5.0*A->cmap->n));
98   PetscFunctionReturn(0);
99 }
100 
101 PetscErrorCode MatSolveTranspose_SeqBAIJ_5(Mat A,Vec bb,Vec xx)
102 {
103   Mat_SeqBAIJ       *a=(Mat_SeqBAIJ*)A->data;
104   IS                iscol=a->col,isrow=a->row;
105   const PetscInt    n    =a->mbs,*vi,*ai=a->i,*aj=a->j,*diag=a->diag;
106   const PetscInt    *r,*c,*rout,*cout;
107   PetscInt          nz,idx,idt,j,i,oidx,ii,ic,ir;
108   const PetscInt    bs =A->rmap->bs,bs2=a->bs2;
109   const MatScalar   *aa=a->a,*v;
110   PetscScalar       s1,s2,s3,s4,s5,x1,x2,x3,x4,x5,*x,*t;
111   const PetscScalar *b;
112 
113   PetscFunctionBegin;
114   PetscCall(VecGetArrayRead(bb,&b));
115   PetscCall(VecGetArray(xx,&x));
116   t    = a->solve_work;
117 
118   PetscCall(ISGetIndices(isrow,&rout)); r = rout;
119   PetscCall(ISGetIndices(iscol,&cout)); c = cout;
120 
121   /* copy b into temp work space according to permutation */
122   for (i=0; i<n; i++) {
123     ii      = bs*i; ic = bs*c[i];
124     t[ii]   = b[ic]; t[ii+1] = b[ic+1]; t[ii+2] = b[ic+2]; t[ii+3] = b[ic+3];
125     t[ii+4] = b[ic+4];
126   }
127 
128   /* forward solve the U^T */
129   idx = 0;
130   for (i=0; i<n; i++) {
131     v = aa + bs2*diag[i];
132     /* multiply by the inverse of the block diagonal */
133     x1 = t[idx];   x2 = t[1+idx]; x3    = t[2+idx]; x4 = t[3+idx]; x5 = t[4+idx];
134     s1 = v[0]*x1  +  v[1]*x2 +  v[2]*x3 +  v[3]*x4 +  v[4]*x5;
135     s2 = v[5]*x1  +  v[6]*x2 +  v[7]*x3 +  v[8]*x4 +  v[9]*x5;
136     s3 = v[10]*x1 + v[11]*x2 + v[12]*x3 + v[13]*x4 + v[14]*x5;
137     s4 = v[15]*x1 + v[16]*x2 + v[17]*x3 + v[18]*x4 + v[19]*x5;
138     s5 = v[20]*x1 + v[21]*x2 + v[22]*x3 + v[23]*x4 + v[24]*x5;
139     v -= bs2;
140 
141     vi = aj + diag[i] - 1;
142     nz = diag[i] - diag[i+1] - 1;
143     for (j=0; j>-nz; j--) {
144       oidx       = bs*vi[j];
145       t[oidx]   -= v[0]*s1  +  v[1]*s2 +  v[2]*s3 +  v[3]*s4 +  v[4]*s5;
146       t[oidx+1] -= v[5]*s1  +  v[6]*s2 +  v[7]*s3 +  v[8]*s4 +  v[9]*s5;
147       t[oidx+2] -= v[10]*s1 + v[11]*s2 + v[12]*s3 + v[13]*s4 + v[14]*s5;
148       t[oidx+3] -= v[15]*s1 + v[16]*s2 + v[17]*s3 + v[18]*s4 + v[19]*s5;
149       t[oidx+4] -= v[20]*s1 + v[21]*s2 + v[22]*s3 + v[23]*s4 + v[24]*s5;
150       v         -= bs2;
151     }
152     t[idx] = s1;t[1+idx] = s2;  t[2+idx] = s3;  t[3+idx] = s4; t[4+idx] =s5;
153     idx   += bs;
154   }
155   /* backward solve the L^T */
156   for (i=n-1; i>=0; i--) {
157     v   = aa + bs2*ai[i];
158     vi  = aj + ai[i];
159     nz  = ai[i+1] - ai[i];
160     idt = bs*i;
161     s1  = t[idt];  s2 = t[1+idt];  s3 = t[2+idt];  s4 = t[3+idt]; s5 = t[4+idt];
162     for (j=0; j<nz; j++) {
163       idx       = bs*vi[j];
164       t[idx]   -= v[0]*s1  +  v[1]*s2 +  v[2]*s3 +  v[3]*s4 +  v[4]*s5;
165       t[idx+1] -= v[5]*s1  +  v[6]*s2 +  v[7]*s3 +  v[8]*s4 +  v[9]*s5;
166       t[idx+2] -= v[10]*s1 + v[11]*s2 + v[12]*s3 + v[13]*s4 + v[14]*s5;
167       t[idx+3] -= v[15]*s1 + v[16]*s2 + v[17]*s3 + v[18]*s4 + v[19]*s5;
168       t[idx+4] -= v[20]*s1 + v[21]*s2 + v[22]*s3 + v[23]*s4 + v[24]*s5;
169       v        += bs2;
170     }
171   }
172 
173   /* copy t into x according to permutation */
174   for (i=0; i<n; i++) {
175     ii      = bs*i;  ir = bs*r[i];
176     x[ir]   = t[ii];  x[ir+1] = t[ii+1]; x[ir+2] = t[ii+2];  x[ir+3] = t[ii+3];
177     x[ir+4] = t[ii+4];
178   }
179 
180   PetscCall(ISRestoreIndices(isrow,&rout));
181   PetscCall(ISRestoreIndices(iscol,&cout));
182   PetscCall(VecRestoreArrayRead(bb,&b));
183   PetscCall(VecRestoreArray(xx,&x));
184   PetscCall(PetscLogFlops(2.0*bs2*(a->nz) - bs*A->cmap->n));
185   PetscFunctionReturn(0);
186 }
187