xref: /petsc/src/mat/impls/baij/seq/baijsolvtran2.c (revision 503c0ea9b45bcfbcebbb1ea5341243bbc69f0bea)
1 #include <../src/mat/impls/baij/seq/baij.h>
2 #include <petsc/private/kernels/blockinvert.h>
3 
4 PetscErrorCode MatSolveTranspose_SeqBAIJ_2_inplace(Mat A,Vec bb,Vec xx)
5 {
6   Mat_SeqBAIJ       *a   =(Mat_SeqBAIJ*)A->data;
7   IS                iscol=a->col,isrow=a->row;
8   const PetscInt    *r,*c,*rout,*cout;
9   const PetscInt    *diag=a->diag,n=a->mbs,*vi,*ai=a->i,*aj=a->j;
10   PetscInt          i,nz,idx,idt,ii,ic,ir,oidx;
11   const MatScalar   *aa=a->a,*v;
12   PetscScalar       s1,s2,x1,x2,*x,*t;
13   const PetscScalar *b;
14 
15   PetscFunctionBegin;
16   PetscCall(VecGetArrayRead(bb,&b));
17   PetscCall(VecGetArray(xx,&x));
18   t    = a->solve_work;
19 
20   PetscCall(ISGetIndices(isrow,&rout)); r = rout;
21   PetscCall(ISGetIndices(iscol,&cout)); c = cout;
22 
23   /* copy the b into temp work space according to permutation */
24   ii = 0;
25   for (i=0; i<n; i++) {
26     ic      = 2*c[i];
27     t[ii]   = b[ic];
28     t[ii+1] = b[ic+1];
29     ii     += 2;
30   }
31 
32   /* forward solve the U^T */
33   idx = 0;
34   for (i=0; i<n; i++) {
35 
36     v = aa + 4*diag[i];
37     /* multiply by the inverse of the block diagonal */
38     x1 = t[idx];   x2 = t[1+idx];
39     s1 = v[0]*x1  +  v[1]*x2;
40     s2 = v[2]*x1  +  v[3]*x2;
41     v += 4;
42 
43     vi = aj + diag[i] + 1;
44     nz = ai[i+1] - diag[i] - 1;
45     while (nz--) {
46       oidx       = 2*(*vi++);
47       t[oidx]   -= v[0]*s1  +  v[1]*s2;
48       t[oidx+1] -= v[2]*s1  +  v[3]*s2;
49       v         += 4;
50     }
51     t[idx] = s1;t[1+idx] = s2;
52     idx   += 2;
53   }
54   /* backward solve the L^T */
55   for (i=n-1; i>=0; i--) {
56     v   = aa + 4*diag[i] - 4;
57     vi  = aj + diag[i] - 1;
58     nz  = diag[i] - ai[i];
59     idt = 2*i;
60     s1  = t[idt];  s2 = t[1+idt];
61     while (nz--) {
62       idx       = 2*(*vi--);
63       t[idx]   -=  v[0]*s1 +  v[1]*s2;
64       t[idx+1] -=  v[2]*s1 +  v[3]*s2;
65       v        -= 4;
66     }
67   }
68 
69   /* copy t into x according to permutation */
70   ii = 0;
71   for (i=0; i<n; i++) {
72     ir      = 2*r[i];
73     x[ir]   = t[ii];
74     x[ir+1] = t[ii+1];
75     ii     += 2;
76   }
77 
78   PetscCall(ISRestoreIndices(isrow,&rout));
79   PetscCall(ISRestoreIndices(iscol,&cout));
80   PetscCall(VecRestoreArrayRead(bb,&b));
81   PetscCall(VecRestoreArray(xx,&x));
82   PetscCall(PetscLogFlops(2.0*4*(a->nz) - 2.0*A->cmap->n));
83   PetscFunctionReturn(0);
84 }
85 
86 PetscErrorCode MatSolveTranspose_SeqBAIJ_2(Mat A,Vec bb,Vec xx)
87 {
88   Mat_SeqBAIJ       *a=(Mat_SeqBAIJ*)A->data;
89   IS                iscol=a->col,isrow=a->row;
90   const PetscInt    n    =a->mbs,*vi,*ai=a->i,*aj=a->j,*diag=a->diag;
91   const PetscInt    *r,*c,*rout,*cout;
92   PetscInt          nz,idx,idt,j,i,oidx,ii,ic,ir;
93   const PetscInt    bs =A->rmap->bs,bs2=a->bs2;
94   const MatScalar   *aa=a->a,*v;
95   PetscScalar       s1,s2,x1,x2,*x,*t;
96   const PetscScalar *b;
97 
98   PetscFunctionBegin;
99   PetscCall(VecGetArrayRead(bb,&b));
100   PetscCall(VecGetArray(xx,&x));
101   t    = a->solve_work;
102 
103   PetscCall(ISGetIndices(isrow,&rout)); r = rout;
104   PetscCall(ISGetIndices(iscol,&cout)); c = cout;
105 
106   /* copy b into temp work space according to permutation */
107   for (i=0; i<n; i++) {
108     ii    = bs*i; ic = bs*c[i];
109     t[ii] = b[ic]; t[ii+1] = b[ic+1];
110   }
111 
112   /* forward solve the U^T */
113   idx = 0;
114   for (i=0; i<n; i++) {
115     v = aa + bs2*diag[i];
116     /* multiply by the inverse of the block diagonal */
117     x1 = t[idx];   x2 = t[1+idx];
118     s1 = v[0]*x1  +  v[1]*x2;
119     s2 = v[2]*x1  +  v[3]*x2;
120     v -= bs2;
121 
122     vi = aj + diag[i] - 1;
123     nz = diag[i] - diag[i+1] - 1;
124     for (j=0; j>-nz; j--) {
125       oidx       = bs*vi[j];
126       t[oidx]   -= v[0]*s1  +  v[1]*s2;
127       t[oidx+1] -= v[2]*s1  +  v[3]*s2;
128       v         -= bs2;
129     }
130     t[idx] = s1;t[1+idx] = s2;
131     idx   += bs;
132   }
133   /* backward solve the L^T */
134   for (i=n-1; i>=0; i--) {
135     v   = aa + bs2*ai[i];
136     vi  = aj + ai[i];
137     nz  = ai[i+1] - ai[i];
138     idt = bs*i;
139     s1  = t[idt];  s2 = t[1+idt];
140     for (j=0; j<nz; j++) {
141       idx       = bs*vi[j];
142       t[idx]   -=  v[0]*s1 +  v[1]*s2;
143       t[idx+1] -=  v[2]*s1 +  v[3]*s2;
144       v        += bs2;
145     }
146   }
147 
148   /* copy t into x according to permutation */
149   for (i=0; i<n; i++) {
150     ii    = bs*i;  ir = bs*r[i];
151     x[ir] = t[ii];  x[ir+1] = t[ii+1];
152   }
153 
154   PetscCall(ISRestoreIndices(isrow,&rout));
155   PetscCall(ISRestoreIndices(iscol,&cout));
156   PetscCall(VecRestoreArrayRead(bb,&b));
157   PetscCall(VecRestoreArray(xx,&x));
158   PetscCall(PetscLogFlops(2.0*bs2*(a->nz) - bs*A->cmap->n));
159   PetscFunctionReturn(0);
160 }
161