xref: /petsc/src/mat/impls/baij/seq/baijsolvtrann.c (revision 503c0ea9b45bcfbcebbb1ea5341243bbc69f0bea)
1 #include <../src/mat/impls/baij/seq/baij.h>
2 #include <petsc/private/kernels/blockinvert.h>
3 
4 /* ----------------------------------------------------------- */
5 PetscErrorCode MatSolveTranspose_SeqBAIJ_N_inplace(Mat A,Vec bb,Vec xx)
6 {
7   Mat_SeqBAIJ       *a   =(Mat_SeqBAIJ*)A->data;
8   IS                iscol=a->col,isrow=a->row;
9   const PetscInt    *r,*c,*rout,*cout,*ai=a->i,*aj=a->j,*vi;
10   PetscInt          i,nz,j;
11   const PetscInt    n  =a->mbs,bs=A->rmap->bs,bs2=a->bs2;
12   const MatScalar   *aa=a->a,*v;
13   PetscScalar       *x,*t,*ls;
14   const PetscScalar *b;
15 
16   PetscFunctionBegin;
17   PetscCall(VecGetArrayRead(bb,&b));
18   PetscCall(VecGetArray(xx,&x));
19   t    = a->solve_work;
20 
21   PetscCall(ISGetIndices(isrow,&rout)); r = rout;
22   PetscCall(ISGetIndices(iscol,&cout)); c = cout;
23 
24   /* copy the b into temp work space according to permutation */
25   for (i=0; i<n; i++) {
26     for (j=0; j<bs; j++) {
27       t[i*bs+j] = b[c[i]*bs+j];
28     }
29   }
30 
31   /* forward solve the upper triangular transpose */
32   ls = a->solve_work + A->cmap->n;
33   for (i=0; i<n; i++) {
34     PetscCall(PetscArraycpy(ls,t+i*bs,bs));
35     PetscKernel_w_gets_transA_times_v(bs,ls,aa+bs2*a->diag[i],t+i*bs);
36     v  = aa + bs2*(a->diag[i] + 1);
37     vi = aj + a->diag[i] + 1;
38     nz = ai[i+1] - a->diag[i] - 1;
39     while (nz--) {
40       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(*vi++),v,t+i*bs);
41       v += bs2;
42     }
43   }
44 
45   /* backward solve the lower triangular transpose */
46   for (i=n-1; i>=0; i--) {
47     v  = aa + bs2*ai[i];
48     vi = aj + ai[i];
49     nz = a->diag[i] - ai[i];
50     while (nz--) {
51       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(*vi++),v,t+i*bs);
52       v += bs2;
53     }
54   }
55 
56   /* copy t into x according to permutation */
57   for (i=0; i<n; i++) {
58     for (j=0; j<bs; j++) {
59       x[bs*r[i]+j]   = t[bs*i+j];
60     }
61   }
62 
63   PetscCall(ISRestoreIndices(isrow,&rout));
64   PetscCall(ISRestoreIndices(iscol,&cout));
65   PetscCall(VecRestoreArrayRead(bb,&b));
66   PetscCall(VecRestoreArray(xx,&x));
67   PetscCall(PetscLogFlops(2.0*(a->bs2)*(a->nz) - A->rmap->bs*A->cmap->n));
68   PetscFunctionReturn(0);
69 }
70 
71 PetscErrorCode MatSolveTranspose_SeqBAIJ_N(Mat A,Vec bb,Vec xx)
72 {
73   Mat_SeqBAIJ       *a   =(Mat_SeqBAIJ*)A->data;
74   IS                iscol=a->col,isrow=a->row;
75   const PetscInt    *r,*c,*rout,*cout;
76   const PetscInt    n=a->mbs,*ai=a->i,*aj=a->j,*vi,*diag=a->diag;
77   PetscInt          i,j,nz;
78   const PetscInt    bs =A->rmap->bs,bs2=a->bs2;
79   const MatScalar   *aa=a->a,*v;
80   PetscScalar       *x,*t,*ls;
81   const PetscScalar *b;
82 
83   PetscFunctionBegin;
84   PetscCall(VecGetArrayRead(bb,&b));
85   PetscCall(VecGetArray(xx,&x));
86   t    = a->solve_work;
87 
88   PetscCall(ISGetIndices(isrow,&rout)); r = rout;
89   PetscCall(ISGetIndices(iscol,&cout)); c = cout;
90 
91   /* copy the b into temp work space according to permutation */
92   for (i=0; i<n; i++) {
93     for (j=0; j<bs; j++) {
94       t[i*bs+j] = b[c[i]*bs+j];
95     }
96   }
97 
98   /* forward solve the upper triangular transpose */
99   ls = a->solve_work + A->cmap->n;
100   for (i=0; i<n; i++) {
101     PetscCall(PetscArraycpy(ls,t+i*bs,bs));
102     PetscKernel_w_gets_transA_times_v(bs,ls,aa+bs2*diag[i],t+i*bs);
103     v  = aa + bs2*(diag[i] - 1);
104     vi = aj + diag[i] - 1;
105     nz = diag[i] - diag[i+1] - 1;
106     for (j=0; j>-nz; j--) {
107       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(vi[j]),v,t+i*bs);
108       v -= bs2;
109     }
110   }
111 
112   /* backward solve the lower triangular transpose */
113   for (i=n-1; i>=0; i--) {
114     v  = aa + bs2*ai[i];
115     vi = aj + ai[i];
116     nz = ai[i+1] - ai[i];
117     for (j=0; j<nz; j++) {
118       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(vi[j]),v,t+i*bs);
119       v += bs2;
120     }
121   }
122 
123   /* copy t into x according to permutation */
124   for (i=0; i<n; i++) {
125     for (j=0; j<bs; j++) {
126       x[bs*r[i]+j]   = t[bs*i+j];
127     }
128   }
129 
130   PetscCall(ISRestoreIndices(isrow,&rout));
131   PetscCall(ISRestoreIndices(iscol,&cout));
132   PetscCall(VecRestoreArrayRead(bb,&b));
133   PetscCall(VecRestoreArray(xx,&x));
134   PetscCall(PetscLogFlops(2.0*(a->bs2)*(a->nz) - A->rmap->bs*A->cmap->n));
135   PetscFunctionReturn(0);
136 }
137