1 #include <../src/mat/impls/baij/seq/baij.h> 2 #include <petsc/private/kernels/blockinvert.h> 3 4 PetscErrorCode MatSolveTranspose_SeqBAIJ_5_inplace(Mat A,Vec bb,Vec xx) 5 { 6 Mat_SeqBAIJ *a =(Mat_SeqBAIJ*)A->data; 7 IS iscol=a->col,isrow=a->row; 8 const PetscInt *r,*c,*rout,*cout; 9 const PetscInt *diag=a->diag,n=a->mbs,*vi,*ai=a->i,*aj=a->j; 10 PetscInt i,nz,idx,idt,ii,ic,ir,oidx; 11 const MatScalar *aa=a->a,*v; 12 PetscScalar s1,s2,s3,s4,s5,x1,x2,x3,x4,x5,*x,*t; 13 const PetscScalar *b; 14 15 PetscFunctionBegin; 16 PetscCall(VecGetArrayRead(bb,&b)); 17 PetscCall(VecGetArray(xx,&x)); 18 t = a->solve_work; 19 20 PetscCall(ISGetIndices(isrow,&rout)); r = rout; 21 PetscCall(ISGetIndices(iscol,&cout)); c = cout; 22 23 /* copy the b into temp work space according to permutation */ 24 ii = 0; 25 for (i=0; i<n; i++) { 26 ic = 5*c[i]; 27 t[ii] = b[ic]; 28 t[ii+1] = b[ic+1]; 29 t[ii+2] = b[ic+2]; 30 t[ii+3] = b[ic+3]; 31 t[ii+4] = b[ic+4]; 32 ii += 5; 33 } 34 35 /* forward solve the U^T */ 36 idx = 0; 37 for (i=0; i<n; i++) { 38 39 v = aa + 25*diag[i]; 40 /* multiply by the inverse of the block diagonal */ 41 x1 = t[idx]; x2 = t[1+idx]; x3 = t[2+idx]; x4 = t[3+idx]; x5 = t[4+idx]; 42 s1 = v[0]*x1 + v[1]*x2 + v[2]*x3 + v[3]*x4 + v[4]*x5; 43 s2 = v[5]*x1 + v[6]*x2 + v[7]*x3 + v[8]*x4 + v[9]*x5; 44 s3 = v[10]*x1 + v[11]*x2 + v[12]*x3 + v[13]*x4 + v[14]*x5; 45 s4 = v[15]*x1 + v[16]*x2 + v[17]*x3 + v[18]*x4 + v[19]*x5; 46 s5 = v[20]*x1 + v[21]*x2 + v[22]*x3 + v[23]*x4 + v[24]*x5; 47 v += 25; 48 49 vi = aj + diag[i] + 1; 50 nz = ai[i+1] - diag[i] - 1; 51 while (nz--) { 52 oidx = 5*(*vi++); 53 t[oidx] -= v[0]*s1 + v[1]*s2 + v[2]*s3 + v[3]*s4 + v[4]*s5; 54 t[oidx+1] -= v[5]*s1 + v[6]*s2 + v[7]*s3 + v[8]*s4 + v[9]*s5; 55 t[oidx+2] -= v[10]*s1 + v[11]*s2 + v[12]*s3 + v[13]*s4 + v[14]*s5; 56 t[oidx+3] -= v[15]*s1 + v[16]*s2 + v[17]*s3 + v[18]*s4 + v[19]*s5; 57 t[oidx+4] -= v[20]*s1 + v[21]*s2 + v[22]*s3 + v[23]*s4 + v[24]*s5; 58 v += 25; 59 } 60 t[idx] = s1;t[1+idx] = s2; t[2+idx] = s3;t[3+idx] = s4; t[4+idx] = s5; 61 idx += 5; 62 } 63 /* backward solve the L^T */ 64 for (i=n-1; i>=0; i--) { 65 v = aa + 25*diag[i] - 25; 66 vi = aj + diag[i] - 1; 67 nz = diag[i] - ai[i]; 68 idt = 5*i; 69 s1 = t[idt]; s2 = t[1+idt]; s3 = t[2+idt];s4 = t[3+idt]; s5 = t[4+idt]; 70 while (nz--) { 71 idx = 5*(*vi--); 72 t[idx] -= v[0]*s1 + v[1]*s2 + v[2]*s3 + v[3]*s4 + v[4]*s5; 73 t[idx+1] -= v[5]*s1 + v[6]*s2 + v[7]*s3 + v[8]*s4 + v[9]*s5; 74 t[idx+2] -= v[10]*s1 + v[11]*s2 + v[12]*s3 + v[13]*s4 + v[14]*s5; 75 t[idx+3] -= v[15]*s1 + v[16]*s2 + v[17]*s3 + v[18]*s4 + v[19]*s5; 76 t[idx+4] -= v[20]*s1 + v[21]*s2 + v[22]*s3 + v[23]*s4 + v[24]*s5; 77 v -= 25; 78 } 79 } 80 81 /* copy t into x according to permutation */ 82 ii = 0; 83 for (i=0; i<n; i++) { 84 ir = 5*r[i]; 85 x[ir] = t[ii]; 86 x[ir+1] = t[ii+1]; 87 x[ir+2] = t[ii+2]; 88 x[ir+3] = t[ii+3]; 89 x[ir+4] = t[ii+4]; 90 ii += 5; 91 } 92 93 PetscCall(ISRestoreIndices(isrow,&rout)); 94 PetscCall(ISRestoreIndices(iscol,&cout)); 95 PetscCall(VecRestoreArrayRead(bb,&b)); 96 PetscCall(VecRestoreArray(xx,&x)); 97 PetscCall(PetscLogFlops(2.0*25*(a->nz) - 5.0*A->cmap->n)); 98 PetscFunctionReturn(0); 99 } 100 101 PetscErrorCode MatSolveTranspose_SeqBAIJ_5(Mat A,Vec bb,Vec xx) 102 { 103 Mat_SeqBAIJ *a=(Mat_SeqBAIJ*)A->data; 104 IS iscol=a->col,isrow=a->row; 105 const PetscInt n =a->mbs,*vi,*ai=a->i,*aj=a->j,*diag=a->diag; 106 const PetscInt *r,*c,*rout,*cout; 107 PetscInt nz,idx,idt,j,i,oidx,ii,ic,ir; 108 const PetscInt bs =A->rmap->bs,bs2=a->bs2; 109 const MatScalar *aa=a->a,*v; 110 PetscScalar s1,s2,s3,s4,s5,x1,x2,x3,x4,x5,*x,*t; 111 const PetscScalar *b; 112 113 PetscFunctionBegin; 114 PetscCall(VecGetArrayRead(bb,&b)); 115 PetscCall(VecGetArray(xx,&x)); 116 t = a->solve_work; 117 118 PetscCall(ISGetIndices(isrow,&rout)); r = rout; 119 PetscCall(ISGetIndices(iscol,&cout)); c = cout; 120 121 /* copy b into temp work space according to permutation */ 122 for (i=0; i<n; i++) { 123 ii = bs*i; ic = bs*c[i]; 124 t[ii] = b[ic]; t[ii+1] = b[ic+1]; t[ii+2] = b[ic+2]; t[ii+3] = b[ic+3]; 125 t[ii+4] = b[ic+4]; 126 } 127 128 /* forward solve the U^T */ 129 idx = 0; 130 for (i=0; i<n; i++) { 131 v = aa + bs2*diag[i]; 132 /* multiply by the inverse of the block diagonal */ 133 x1 = t[idx]; x2 = t[1+idx]; x3 = t[2+idx]; x4 = t[3+idx]; x5 = t[4+idx]; 134 s1 = v[0]*x1 + v[1]*x2 + v[2]*x3 + v[3]*x4 + v[4]*x5; 135 s2 = v[5]*x1 + v[6]*x2 + v[7]*x3 + v[8]*x4 + v[9]*x5; 136 s3 = v[10]*x1 + v[11]*x2 + v[12]*x3 + v[13]*x4 + v[14]*x5; 137 s4 = v[15]*x1 + v[16]*x2 + v[17]*x3 + v[18]*x4 + v[19]*x5; 138 s5 = v[20]*x1 + v[21]*x2 + v[22]*x3 + v[23]*x4 + v[24]*x5; 139 v -= bs2; 140 141 vi = aj + diag[i] - 1; 142 nz = diag[i] - diag[i+1] - 1; 143 for (j=0; j>-nz; j--) { 144 oidx = bs*vi[j]; 145 t[oidx] -= v[0]*s1 + v[1]*s2 + v[2]*s3 + v[3]*s4 + v[4]*s5; 146 t[oidx+1] -= v[5]*s1 + v[6]*s2 + v[7]*s3 + v[8]*s4 + v[9]*s5; 147 t[oidx+2] -= v[10]*s1 + v[11]*s2 + v[12]*s3 + v[13]*s4 + v[14]*s5; 148 t[oidx+3] -= v[15]*s1 + v[16]*s2 + v[17]*s3 + v[18]*s4 + v[19]*s5; 149 t[oidx+4] -= v[20]*s1 + v[21]*s2 + v[22]*s3 + v[23]*s4 + v[24]*s5; 150 v -= bs2; 151 } 152 t[idx] = s1;t[1+idx] = s2; t[2+idx] = s3; t[3+idx] = s4; t[4+idx] =s5; 153 idx += bs; 154 } 155 /* backward solve the L^T */ 156 for (i=n-1; i>=0; i--) { 157 v = aa + bs2*ai[i]; 158 vi = aj + ai[i]; 159 nz = ai[i+1] - ai[i]; 160 idt = bs*i; 161 s1 = t[idt]; s2 = t[1+idt]; s3 = t[2+idt]; s4 = t[3+idt]; s5 = t[4+idt]; 162 for (j=0; j<nz; j++) { 163 idx = bs*vi[j]; 164 t[idx] -= v[0]*s1 + v[1]*s2 + v[2]*s3 + v[3]*s4 + v[4]*s5; 165 t[idx+1] -= v[5]*s1 + v[6]*s2 + v[7]*s3 + v[8]*s4 + v[9]*s5; 166 t[idx+2] -= v[10]*s1 + v[11]*s2 + v[12]*s3 + v[13]*s4 + v[14]*s5; 167 t[idx+3] -= v[15]*s1 + v[16]*s2 + v[17]*s3 + v[18]*s4 + v[19]*s5; 168 t[idx+4] -= v[20]*s1 + v[21]*s2 + v[22]*s3 + v[23]*s4 + v[24]*s5; 169 v += bs2; 170 } 171 } 172 173 /* copy t into x according to permutation */ 174 for (i=0; i<n; i++) { 175 ii = bs*i; ir = bs*r[i]; 176 x[ir] = t[ii]; x[ir+1] = t[ii+1]; x[ir+2] = t[ii+2]; x[ir+3] = t[ii+3]; 177 x[ir+4] = t[ii+4]; 178 } 179 180 PetscCall(ISRestoreIndices(isrow,&rout)); 181 PetscCall(ISRestoreIndices(iscol,&cout)); 182 PetscCall(VecRestoreArrayRead(bb,&b)); 183 PetscCall(VecRestoreArray(xx,&x)); 184 PetscCall(PetscLogFlops(2.0*bs2*(a->nz) - bs*A->cmap->n)); 185 PetscFunctionReturn(0); 186 } 187