xref: /petsc/src/mat/impls/baij/seq/baijsolvtran3.c (revision 503c0ea9b45bcfbcebbb1ea5341243bbc69f0bea)
1 #include <../src/mat/impls/baij/seq/baij.h>
2 #include <petsc/private/kernels/blockinvert.h>
3 
4 PetscErrorCode MatSolveTranspose_SeqBAIJ_3_inplace(Mat A,Vec bb,Vec xx)
5 {
6   Mat_SeqBAIJ       *a   =(Mat_SeqBAIJ*)A->data;
7   IS                iscol=a->col,isrow=a->row;
8   const PetscInt    *r,*c,*rout,*cout;
9   const PetscInt    *diag=a->diag,n=a->mbs,*vi,*ai=a->i,*aj=a->j;
10   PetscInt          i,nz,idx,idt,ii,ic,ir,oidx;
11   const MatScalar   *aa=a->a,*v;
12   PetscScalar       s1,s2,s3,x1,x2,x3,*x,*t;
13   const PetscScalar *b;
14 
15   PetscFunctionBegin;
16   PetscCall(VecGetArrayRead(bb,&b));
17   PetscCall(VecGetArray(xx,&x));
18   t    = a->solve_work;
19 
20   PetscCall(ISGetIndices(isrow,&rout)); r = rout;
21   PetscCall(ISGetIndices(iscol,&cout)); c = cout;
22 
23   /* copy the b into temp work space according to permutation */
24   ii = 0;
25   for (i=0; i<n; i++) {
26     ic      = 3*c[i];
27     t[ii]   = b[ic];
28     t[ii+1] = b[ic+1];
29     t[ii+2] = b[ic+2];
30     ii     += 3;
31   }
32 
33   /* forward solve the U^T */
34   idx = 0;
35   for (i=0; i<n; i++) {
36 
37     v = aa + 9*diag[i];
38     /* multiply by the inverse of the block diagonal */
39     x1 = t[idx];   x2 = t[1+idx]; x3    = t[2+idx];
40     s1 = v[0]*x1  +  v[1]*x2 +  v[2]*x3;
41     s2 = v[3]*x1  +  v[4]*x2 +  v[5]*x3;
42     s3 = v[6]*x1  +  v[7]*x2 + v[8]*x3;
43     v += 9;
44 
45     vi = aj + diag[i] + 1;
46     nz = ai[i+1] - diag[i] - 1;
47     while (nz--) {
48       oidx       = 3*(*vi++);
49       t[oidx]   -= v[0]*s1  +  v[1]*s2 +  v[2]*s3;
50       t[oidx+1] -= v[3]*s1  +  v[4]*s2 +  v[5]*s3;
51       t[oidx+2] -= v[6]*s1 + v[7]*s2 + v[8]*s3;
52       v         += 9;
53     }
54     t[idx] = s1;t[1+idx] = s2; t[2+idx] = s3;
55     idx   += 3;
56   }
57   /* backward solve the L^T */
58   for (i=n-1; i>=0; i--) {
59     v   = aa + 9*diag[i] - 9;
60     vi  = aj + diag[i] - 1;
61     nz  = diag[i] - ai[i];
62     idt = 3*i;
63     s1  = t[idt];  s2 = t[1+idt]; s3 = t[2+idt];
64     while (nz--) {
65       idx       = 3*(*vi--);
66       t[idx]   -=  v[0]*s1 +  v[1]*s2 +  v[2]*s3;
67       t[idx+1] -=  v[3]*s1 +  v[4]*s2 +  v[5]*s3;
68       t[idx+2] -= v[6]*s1 + v[7]*s2 + v[8]*s3;
69       v        -= 9;
70     }
71   }
72 
73   /* copy t into x according to permutation */
74   ii = 0;
75   for (i=0; i<n; i++) {
76     ir      = 3*r[i];
77     x[ir]   = t[ii];
78     x[ir+1] = t[ii+1];
79     x[ir+2] = t[ii+2];
80     ii     += 3;
81   }
82 
83   PetscCall(ISRestoreIndices(isrow,&rout));
84   PetscCall(ISRestoreIndices(iscol,&cout));
85   PetscCall(VecRestoreArrayRead(bb,&b));
86   PetscCall(VecRestoreArray(xx,&x));
87   PetscCall(PetscLogFlops(2.0*9*(a->nz) - 3.0*A->cmap->n));
88   PetscFunctionReturn(0);
89 }
90 
91 PetscErrorCode MatSolveTranspose_SeqBAIJ_3(Mat A,Vec bb,Vec xx)
92 {
93   Mat_SeqBAIJ       *a=(Mat_SeqBAIJ*)A->data;
94   IS                iscol=a->col,isrow=a->row;
95   const PetscInt    n    =a->mbs,*vi,*ai=a->i,*aj=a->j,*diag=a->diag;
96   const PetscInt    *r,*c,*rout,*cout;
97   PetscInt          nz,idx,idt,j,i,oidx,ii,ic,ir;
98   const PetscInt    bs =A->rmap->bs,bs2=a->bs2;
99   const MatScalar   *aa=a->a,*v;
100   PetscScalar       s1,s2,s3,x1,x2,x3,*x,*t;
101   const PetscScalar *b;
102 
103   PetscFunctionBegin;
104   PetscCall(VecGetArrayRead(bb,&b));
105   PetscCall(VecGetArray(xx,&x));
106   t    = a->solve_work;
107 
108   PetscCall(ISGetIndices(isrow,&rout)); r = rout;
109   PetscCall(ISGetIndices(iscol,&cout)); c = cout;
110 
111   /* copy b into temp work space according to permutation */
112   for (i=0; i<n; i++) {
113     ii    = bs*i; ic = bs*c[i];
114     t[ii] = b[ic]; t[ii+1] = b[ic+1]; t[ii+2] = b[ic+2];
115   }
116 
117   /* forward solve the U^T */
118   idx = 0;
119   for (i=0; i<n; i++) {
120     v = aa + bs2*diag[i];
121     /* multiply by the inverse of the block diagonal */
122     x1 = t[idx];   x2 = t[1+idx]; x3    = t[2+idx];
123     s1 = v[0]*x1  +  v[1]*x2 +  v[2]*x3;
124     s2 = v[3]*x1  +  v[4]*x2 +  v[5]*x3;
125     s3 = v[6]*x1  +  v[7]*x2 + v[8]*x3;
126     v -= bs2;
127 
128     vi = aj + diag[i] - 1;
129     nz = diag[i] - diag[i+1] - 1;
130     for (j=0; j>-nz; j--) {
131       oidx       = bs*vi[j];
132       t[oidx]   -= v[0]*s1  +  v[1]*s2 +  v[2]*s3;
133       t[oidx+1] -= v[3]*s1  +  v[4]*s2 +  v[5]*s3;
134       t[oidx+2] -= v[6]*s1 + v[7]*s2 + v[8]*s3;
135       v         -= bs2;
136     }
137     t[idx] = s1;t[1+idx] = s2;  t[2+idx] = s3;
138     idx   += bs;
139   }
140   /* backward solve the L^T */
141   for (i=n-1; i>=0; i--) {
142     v   = aa + bs2*ai[i];
143     vi  = aj + ai[i];
144     nz  = ai[i+1] - ai[i];
145     idt = bs*i;
146     s1  = t[idt];  s2 = t[1+idt];  s3 = t[2+idt];
147     for (j=0; j<nz; j++) {
148       idx       = bs*vi[j];
149       t[idx]   -= v[0]*s1  +  v[1]*s2 +  v[2]*s3;
150       t[idx+1] -= v[3]*s1  +  v[4]*s2 +  v[5]*s3;
151       t[idx+2] -= v[6]*s1 + v[7]*s2 + v[8]*s3;
152       v        += bs2;
153     }
154   }
155 
156   /* copy t into x according to permutation */
157   for (i=0; i<n; i++) {
158     ii    = bs*i;  ir = bs*r[i];
159     x[ir] = t[ii];  x[ir+1] = t[ii+1]; x[ir+2] = t[ii+2];
160   }
161 
162   PetscCall(ISRestoreIndices(isrow,&rout));
163   PetscCall(ISRestoreIndices(iscol,&cout));
164   PetscCall(VecRestoreArrayRead(bb,&b));
165   PetscCall(VecRestoreArray(xx,&x));
166   PetscCall(PetscLogFlops(2.0*bs2*(a->nz) - bs*A->cmap->n));
167   PetscFunctionReturn(0);
168 }
169