1 #include <../src/mat/impls/baij/seq/baij.h> 2 #include <petsc/private/kernels/blockinvert.h> 3 4 PetscErrorCode MatSolveTranspose_SeqBAIJ_4_inplace(Mat A,Vec bb,Vec xx) 5 { 6 Mat_SeqBAIJ *a =(Mat_SeqBAIJ*)A->data; 7 IS iscol=a->col,isrow=a->row; 8 const PetscInt *r,*c,*rout,*cout; 9 const PetscInt *diag=a->diag,n=a->mbs,*vi,*ai=a->i,*aj=a->j; 10 PetscInt i,nz,idx,idt,ii,ic,ir,oidx; 11 const MatScalar *aa=a->a,*v; 12 PetscScalar s1,s2,s3,s4,x1,x2,x3,x4,*x,*t; 13 const PetscScalar *b; 14 15 PetscFunctionBegin; 16 PetscCall(VecGetArrayRead(bb,&b)); 17 PetscCall(VecGetArray(xx,&x)); 18 t = a->solve_work; 19 20 PetscCall(ISGetIndices(isrow,&rout)); r = rout; 21 PetscCall(ISGetIndices(iscol,&cout)); c = cout; 22 23 /* copy the b into temp work space according to permutation */ 24 ii = 0; 25 for (i=0; i<n; i++) { 26 ic = 4*c[i]; 27 t[ii] = b[ic]; 28 t[ii+1] = b[ic+1]; 29 t[ii+2] = b[ic+2]; 30 t[ii+3] = b[ic+3]; 31 ii += 4; 32 } 33 34 /* forward solve the U^T */ 35 idx = 0; 36 for (i=0; i<n; i++) { 37 38 v = aa + 16*diag[i]; 39 /* multiply by the inverse of the block diagonal */ 40 x1 = t[idx]; x2 = t[1+idx]; x3 = t[2+idx]; x4 = t[3+idx]; 41 s1 = v[0]*x1 + v[1]*x2 + v[2]*x3 + v[3]*x4; 42 s2 = v[4]*x1 + v[5]*x2 + v[6]*x3 + v[7]*x4; 43 s3 = v[8]*x1 + v[9]*x2 + v[10]*x3 + v[11]*x4; 44 s4 = v[12]*x1 + v[13]*x2 + v[14]*x3 + v[15]*x4; 45 v += 16; 46 47 vi = aj + diag[i] + 1; 48 nz = ai[i+1] - diag[i] - 1; 49 while (nz--) { 50 oidx = 4*(*vi++); 51 t[oidx] -= v[0]*s1 + v[1]*s2 + v[2]*s3 + v[3]*s4; 52 t[oidx+1] -= v[4]*s1 + v[5]*s2 + v[6]*s3 + v[7]*s4; 53 t[oidx+2] -= v[8]*s1 + v[9]*s2 + v[10]*s3 + v[11]*s4; 54 t[oidx+3] -= v[12]*s1 + v[13]*s2 + v[14]*s3 + v[15]*s4; 55 v += 16; 56 } 57 t[idx] = s1;t[1+idx] = s2; t[2+idx] = s3;t[3+idx] = s4; 58 idx += 4; 59 } 60 /* backward solve the L^T */ 61 for (i=n-1; i>=0; i--) { 62 v = aa + 16*diag[i] - 16; 63 vi = aj + diag[i] - 1; 64 nz = diag[i] - ai[i]; 65 idt = 4*i; 66 s1 = t[idt]; s2 = t[1+idt]; s3 = t[2+idt];s4 = t[3+idt]; 67 while (nz--) { 68 idx = 4*(*vi--); 69 t[idx] -= v[0]*s1 + v[1]*s2 + v[2]*s3 + v[3]*s4; 70 t[idx+1] -= v[4]*s1 + v[5]*s2 + v[6]*s3 + v[7]*s4; 71 t[idx+2] -= v[8]*s1 + v[9]*s2 + v[10]*s3 + v[11]*s4; 72 t[idx+3] -= v[12]*s1 + v[13]*s2 + v[14]*s3 + v[15]*s4; 73 v -= 16; 74 } 75 } 76 77 /* copy t into x according to permutation */ 78 ii = 0; 79 for (i=0; i<n; i++) { 80 ir = 4*r[i]; 81 x[ir] = t[ii]; 82 x[ir+1] = t[ii+1]; 83 x[ir+2] = t[ii+2]; 84 x[ir+3] = t[ii+3]; 85 ii += 4; 86 } 87 88 PetscCall(ISRestoreIndices(isrow,&rout)); 89 PetscCall(ISRestoreIndices(iscol,&cout)); 90 PetscCall(VecRestoreArrayRead(bb,&b)); 91 PetscCall(VecRestoreArray(xx,&x)); 92 PetscCall(PetscLogFlops(2.0*16*(a->nz) - 4.0*A->cmap->n)); 93 PetscFunctionReturn(0); 94 } 95 96 PetscErrorCode MatSolveTranspose_SeqBAIJ_4(Mat A,Vec bb,Vec xx) 97 { 98 Mat_SeqBAIJ *a=(Mat_SeqBAIJ*)A->data; 99 IS iscol=a->col,isrow=a->row; 100 const PetscInt n =a->mbs,*vi,*ai=a->i,*aj=a->j,*diag=a->diag; 101 const PetscInt *r,*c,*rout,*cout; 102 PetscInt nz,idx,idt,j,i,oidx,ii,ic,ir; 103 const PetscInt bs =A->rmap->bs,bs2=a->bs2; 104 const MatScalar *aa=a->a,*v; 105 PetscScalar s1,s2,s3,s4,x1,x2,x3,x4,*x,*t; 106 const PetscScalar *b; 107 108 PetscFunctionBegin; 109 PetscCall(VecGetArrayRead(bb,&b)); 110 PetscCall(VecGetArray(xx,&x)); 111 t = a->solve_work; 112 113 PetscCall(ISGetIndices(isrow,&rout)); r = rout; 114 PetscCall(ISGetIndices(iscol,&cout)); c = cout; 115 116 /* copy b into temp work space according to permutation */ 117 for (i=0; i<n; i++) { 118 ii = bs*i; ic = bs*c[i]; 119 t[ii] = b[ic]; t[ii+1] = b[ic+1]; t[ii+2] = b[ic+2]; t[ii+3] = b[ic+3]; 120 } 121 122 /* forward solve the U^T */ 123 idx = 0; 124 for (i=0; i<n; i++) { 125 v = aa + bs2*diag[i]; 126 /* multiply by the inverse of the block diagonal */ 127 x1 = t[idx]; x2 = t[1+idx]; x3 = t[2+idx]; x4 = t[3+idx]; 128 s1 = v[0]*x1 + v[1]*x2 + v[2]*x3 + v[3]*x4; 129 s2 = v[4]*x1 + v[5]*x2 + v[6]*x3 + v[7]*x4; 130 s3 = v[8]*x1 + v[9]*x2 + v[10]*x3 + v[11]*x4; 131 s4 = v[12]*x1 + v[13]*x2 + v[14]*x3 + v[15]*x4; 132 v -= bs2; 133 134 vi = aj + diag[i] - 1; 135 nz = diag[i] - diag[i+1] - 1; 136 for (j=0; j>-nz; j--) { 137 oidx = bs*vi[j]; 138 t[oidx] -= v[0]*s1 + v[1]*s2 + v[2]*s3 + v[3]*s4; 139 t[oidx+1] -= v[4]*s1 + v[5]*s2 + v[6]*s3 + v[7]*s4; 140 t[oidx+2] -= v[8]*s1 + v[9]*s2 + v[10]*s3 + v[11]*s4; 141 t[oidx+3] -= v[12]*s1 + v[13]*s2 + v[14]*s3 + v[15]*s4; 142 v -= bs2; 143 } 144 t[idx] = s1;t[1+idx] = s2; t[2+idx] = s3; t[3+idx] = s4; 145 idx += bs; 146 } 147 /* backward solve the L^T */ 148 for (i=n-1; i>=0; i--) { 149 v = aa + bs2*ai[i]; 150 vi = aj + ai[i]; 151 nz = ai[i+1] - ai[i]; 152 idt = bs*i; 153 s1 = t[idt]; s2 = t[1+idt]; s3 = t[2+idt]; s4 = t[3+idt]; 154 for (j=0; j<nz; j++) { 155 idx = bs*vi[j]; 156 t[idx] -= v[0]*s1 + v[1]*s2 + v[2]*s3 + v[3]*s4; 157 t[idx+1] -= v[4]*s1 + v[5]*s2 + v[6]*s3 + v[7]*s4; 158 t[idx+2] -= v[8]*s1 + v[9]*s2 + v[10]*s3 + v[11]*s4; 159 t[idx+3] -= v[12]*s1 + v[13]*s2 + v[14]*s3 + v[15]*s4; 160 v += bs2; 161 } 162 } 163 164 /* copy t into x according to permutation */ 165 for (i=0; i<n; i++) { 166 ii = bs*i; ir = bs*r[i]; 167 x[ir] = t[ii]; x[ir+1] = t[ii+1]; x[ir+2] = t[ii+2]; x[ir+3] = t[ii+3]; 168 } 169 170 PetscCall(ISRestoreIndices(isrow,&rout)); 171 PetscCall(ISRestoreIndices(iscol,&cout)); 172 PetscCall(VecRestoreArrayRead(bb,&b)); 173 PetscCall(VecRestoreArray(xx,&x)); 174 PetscCall(PetscLogFlops(2.0*bs2*(a->nz) - bs*A->cmap->n)); 175 PetscFunctionReturn(0); 176 } 177