1 #include <../src/mat/impls/baij/seq/baij.h> 2 #include <petsc/private/kernels/blockinvert.h> 3 4 PetscErrorCode MatSolveTranspose_SeqBAIJ_2_inplace(Mat A,Vec bb,Vec xx) 5 { 6 Mat_SeqBAIJ *a =(Mat_SeqBAIJ*)A->data; 7 IS iscol=a->col,isrow=a->row; 8 const PetscInt *r,*c,*rout,*cout; 9 const PetscInt *diag=a->diag,n=a->mbs,*vi,*ai=a->i,*aj=a->j; 10 PetscInt i,nz,idx,idt,ii,ic,ir,oidx; 11 const MatScalar *aa=a->a,*v; 12 PetscScalar s1,s2,x1,x2,*x,*t; 13 const PetscScalar *b; 14 15 PetscFunctionBegin; 16 PetscCall(VecGetArrayRead(bb,&b)); 17 PetscCall(VecGetArray(xx,&x)); 18 t = a->solve_work; 19 20 PetscCall(ISGetIndices(isrow,&rout)); r = rout; 21 PetscCall(ISGetIndices(iscol,&cout)); c = cout; 22 23 /* copy the b into temp work space according to permutation */ 24 ii = 0; 25 for (i=0; i<n; i++) { 26 ic = 2*c[i]; 27 t[ii] = b[ic]; 28 t[ii+1] = b[ic+1]; 29 ii += 2; 30 } 31 32 /* forward solve the U^T */ 33 idx = 0; 34 for (i=0; i<n; i++) { 35 36 v = aa + 4*diag[i]; 37 /* multiply by the inverse of the block diagonal */ 38 x1 = t[idx]; x2 = t[1+idx]; 39 s1 = v[0]*x1 + v[1]*x2; 40 s2 = v[2]*x1 + v[3]*x2; 41 v += 4; 42 43 vi = aj + diag[i] + 1; 44 nz = ai[i+1] - diag[i] - 1; 45 while (nz--) { 46 oidx = 2*(*vi++); 47 t[oidx] -= v[0]*s1 + v[1]*s2; 48 t[oidx+1] -= v[2]*s1 + v[3]*s2; 49 v += 4; 50 } 51 t[idx] = s1;t[1+idx] = s2; 52 idx += 2; 53 } 54 /* backward solve the L^T */ 55 for (i=n-1; i>=0; i--) { 56 v = aa + 4*diag[i] - 4; 57 vi = aj + diag[i] - 1; 58 nz = diag[i] - ai[i]; 59 idt = 2*i; 60 s1 = t[idt]; s2 = t[1+idt]; 61 while (nz--) { 62 idx = 2*(*vi--); 63 t[idx] -= v[0]*s1 + v[1]*s2; 64 t[idx+1] -= v[2]*s1 + v[3]*s2; 65 v -= 4; 66 } 67 } 68 69 /* copy t into x according to permutation */ 70 ii = 0; 71 for (i=0; i<n; i++) { 72 ir = 2*r[i]; 73 x[ir] = t[ii]; 74 x[ir+1] = t[ii+1]; 75 ii += 2; 76 } 77 78 PetscCall(ISRestoreIndices(isrow,&rout)); 79 PetscCall(ISRestoreIndices(iscol,&cout)); 80 PetscCall(VecRestoreArrayRead(bb,&b)); 81 PetscCall(VecRestoreArray(xx,&x)); 82 PetscCall(PetscLogFlops(2.0*4*(a->nz) - 2.0*A->cmap->n)); 83 PetscFunctionReturn(0); 84 } 85 86 PetscErrorCode MatSolveTranspose_SeqBAIJ_2(Mat A,Vec bb,Vec xx) 87 { 88 Mat_SeqBAIJ *a=(Mat_SeqBAIJ*)A->data; 89 IS iscol=a->col,isrow=a->row; 90 const PetscInt n =a->mbs,*vi,*ai=a->i,*aj=a->j,*diag=a->diag; 91 const PetscInt *r,*c,*rout,*cout; 92 PetscInt nz,idx,idt,j,i,oidx,ii,ic,ir; 93 const PetscInt bs =A->rmap->bs,bs2=a->bs2; 94 const MatScalar *aa=a->a,*v; 95 PetscScalar s1,s2,x1,x2,*x,*t; 96 const PetscScalar *b; 97 98 PetscFunctionBegin; 99 PetscCall(VecGetArrayRead(bb,&b)); 100 PetscCall(VecGetArray(xx,&x)); 101 t = a->solve_work; 102 103 PetscCall(ISGetIndices(isrow,&rout)); r = rout; 104 PetscCall(ISGetIndices(iscol,&cout)); c = cout; 105 106 /* copy b into temp work space according to permutation */ 107 for (i=0; i<n; i++) { 108 ii = bs*i; ic = bs*c[i]; 109 t[ii] = b[ic]; t[ii+1] = b[ic+1]; 110 } 111 112 /* forward solve the U^T */ 113 idx = 0; 114 for (i=0; i<n; i++) { 115 v = aa + bs2*diag[i]; 116 /* multiply by the inverse of the block diagonal */ 117 x1 = t[idx]; x2 = t[1+idx]; 118 s1 = v[0]*x1 + v[1]*x2; 119 s2 = v[2]*x1 + v[3]*x2; 120 v -= bs2; 121 122 vi = aj + diag[i] - 1; 123 nz = diag[i] - diag[i+1] - 1; 124 for (j=0; j>-nz; j--) { 125 oidx = bs*vi[j]; 126 t[oidx] -= v[0]*s1 + v[1]*s2; 127 t[oidx+1] -= v[2]*s1 + v[3]*s2; 128 v -= bs2; 129 } 130 t[idx] = s1;t[1+idx] = s2; 131 idx += bs; 132 } 133 /* backward solve the L^T */ 134 for (i=n-1; i>=0; i--) { 135 v = aa + bs2*ai[i]; 136 vi = aj + ai[i]; 137 nz = ai[i+1] - ai[i]; 138 idt = bs*i; 139 s1 = t[idt]; s2 = t[1+idt]; 140 for (j=0; j<nz; j++) { 141 idx = bs*vi[j]; 142 t[idx] -= v[0]*s1 + v[1]*s2; 143 t[idx+1] -= v[2]*s1 + v[3]*s2; 144 v += bs2; 145 } 146 } 147 148 /* copy t into x according to permutation */ 149 for (i=0; i<n; i++) { 150 ii = bs*i; ir = bs*r[i]; 151 x[ir] = t[ii]; x[ir+1] = t[ii+1]; 152 } 153 154 PetscCall(ISRestoreIndices(isrow,&rout)); 155 PetscCall(ISRestoreIndices(iscol,&cout)); 156 PetscCall(VecRestoreArrayRead(bb,&b)); 157 PetscCall(VecRestoreArray(xx,&x)); 158 PetscCall(PetscLogFlops(2.0*bs2*(a->nz) - bs*A->cmap->n)); 159 PetscFunctionReturn(0); 160 } 161