1 #include <../src/mat/impls/baij/seq/baij.h> 2 3 PetscErrorCode MatSolveTranspose_SeqBAIJ_2_NaturalOrdering_inplace(Mat A,Vec bb,Vec xx) 4 { 5 Mat_SeqBAIJ *a=(Mat_SeqBAIJ*)A->data; 6 PetscInt i,nz,idx,idt,oidx; 7 const PetscInt *diag = a->diag,*vi,n=a->mbs,*ai=a->i,*aj=a->j; 8 const MatScalar *aa =a->a,*v; 9 PetscScalar s1,s2,x1,x2,*x; 10 11 PetscFunctionBegin; 12 PetscCall(VecCopy(bb,xx)); 13 PetscCall(VecGetArray(xx,&x)); 14 15 /* forward solve the U^T */ 16 idx = 0; 17 for (i=0; i<n; i++) { 18 19 v = aa + 4*diag[i]; 20 /* multiply by the inverse of the block diagonal */ 21 x1 = x[idx]; x2 = x[1+idx]; 22 s1 = v[0]*x1 + v[1]*x2; 23 s2 = v[2]*x1 + v[3]*x2; 24 v += 4; 25 26 vi = aj + diag[i] + 1; 27 nz = ai[i+1] - diag[i] - 1; 28 while (nz--) { 29 oidx = 2*(*vi++); 30 x[oidx] -= v[0]*s1 + v[1]*s2; 31 x[oidx+1] -= v[2]*s1 + v[3]*s2; 32 v += 4; 33 } 34 x[idx] = s1;x[1+idx] = s2; 35 idx += 2; 36 } 37 /* backward solve the L^T */ 38 for (i=n-1; i>=0; i--) { 39 v = aa + 4*diag[i] - 4; 40 vi = aj + diag[i] - 1; 41 nz = diag[i] - ai[i]; 42 idt = 2*i; 43 s1 = x[idt]; s2 = x[1+idt]; 44 while (nz--) { 45 idx = 2*(*vi--); 46 x[idx] -= v[0]*s1 + v[1]*s2; 47 x[idx+1] -= v[2]*s1 + v[3]*s2; 48 v -= 4; 49 } 50 } 51 PetscCall(VecRestoreArray(xx,&x)); 52 PetscCall(PetscLogFlops(2.0*4.0*(a->nz) - 2.0*A->cmap->n)); 53 PetscFunctionReturn(0); 54 } 55 56 PetscErrorCode MatSolveTranspose_SeqBAIJ_2_NaturalOrdering(Mat A,Vec bb,Vec xx) 57 { 58 Mat_SeqBAIJ *a=(Mat_SeqBAIJ*)A->data; 59 const PetscInt n=a->mbs,*vi,*ai=a->i,*aj=a->j,*diag=a->diag; 60 PetscInt nz,idx,idt,j,i,oidx; 61 const PetscInt bs =A->rmap->bs,bs2=a->bs2; 62 const MatScalar *aa=a->a,*v; 63 PetscScalar s1,s2,x1,x2,*x; 64 65 PetscFunctionBegin; 66 PetscCall(VecCopy(bb,xx)); 67 PetscCall(VecGetArray(xx,&x)); 68 69 /* forward solve the U^T */ 70 idx = 0; 71 for (i=0; i<n; i++) { 72 v = aa + bs2*diag[i]; 73 /* multiply by the inverse of the block diagonal */ 74 x1 = x[idx]; x2 = x[1+idx]; 75 s1 = v[0]*x1 + v[1]*x2; 76 s2 = v[2]*x1 + v[3]*x2; 77 v -= bs2; 78 79 vi = aj + diag[i] - 1; 80 nz = diag[i] - diag[i+1] - 1; 81 for (j=0; j>-nz; j--) { 82 oidx = bs*vi[j]; 83 x[oidx] -= v[0]*s1 + v[1]*s2; 84 x[oidx+1] -= v[2]*s1 + v[3]*s2; 85 v -= bs2; 86 } 87 x[idx] = s1;x[1+idx] = s2; 88 idx += bs; 89 } 90 /* backward solve the L^T */ 91 for (i=n-1; i>=0; i--) { 92 v = aa + bs2*ai[i]; 93 vi = aj + ai[i]; 94 nz = ai[i+1] - ai[i]; 95 idt = bs*i; 96 s1 = x[idt]; s2 = x[1+idt]; 97 for (j=0; j<nz; j++) { 98 idx = bs*vi[j]; 99 x[idx] -= v[0]*s1 + v[1]*s2; 100 x[idx+1] -= v[2]*s1 + v[3]*s2; 101 v += bs2; 102 } 103 } 104 PetscCall(VecRestoreArray(xx,&x)); 105 PetscCall(PetscLogFlops(2.0*bs2*(a->nz) - bs*A->cmap->n)); 106 PetscFunctionReturn(0); 107 } 108