1 #include <../src/mat/impls/baij/seq/baij.h> 2 #include <petsc/private/kernels/blockinvert.h> 3 4 /* ----------------------------------------------------------- */ 5 PetscErrorCode MatSolveTranspose_SeqBAIJ_N_inplace(Mat A,Vec bb,Vec xx) 6 { 7 Mat_SeqBAIJ *a =(Mat_SeqBAIJ*)A->data; 8 IS iscol=a->col,isrow=a->row; 9 const PetscInt *r,*c,*rout,*cout,*ai=a->i,*aj=a->j,*vi; 10 PetscInt i,nz,j; 11 const PetscInt n =a->mbs,bs=A->rmap->bs,bs2=a->bs2; 12 const MatScalar *aa=a->a,*v; 13 PetscScalar *x,*t,*ls; 14 const PetscScalar *b; 15 16 PetscFunctionBegin; 17 PetscCall(VecGetArrayRead(bb,&b)); 18 PetscCall(VecGetArray(xx,&x)); 19 t = a->solve_work; 20 21 PetscCall(ISGetIndices(isrow,&rout)); r = rout; 22 PetscCall(ISGetIndices(iscol,&cout)); c = cout; 23 24 /* copy the b into temp work space according to permutation */ 25 for (i=0; i<n; i++) { 26 for (j=0; j<bs; j++) { 27 t[i*bs+j] = b[c[i]*bs+j]; 28 } 29 } 30 31 /* forward solve the upper triangular transpose */ 32 ls = a->solve_work + A->cmap->n; 33 for (i=0; i<n; i++) { 34 PetscCall(PetscArraycpy(ls,t+i*bs,bs)); 35 PetscKernel_w_gets_transA_times_v(bs,ls,aa+bs2*a->diag[i],t+i*bs); 36 v = aa + bs2*(a->diag[i] + 1); 37 vi = aj + a->diag[i] + 1; 38 nz = ai[i+1] - a->diag[i] - 1; 39 while (nz--) { 40 PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(*vi++),v,t+i*bs); 41 v += bs2; 42 } 43 } 44 45 /* backward solve the lower triangular transpose */ 46 for (i=n-1; i>=0; i--) { 47 v = aa + bs2*ai[i]; 48 vi = aj + ai[i]; 49 nz = a->diag[i] - ai[i]; 50 while (nz--) { 51 PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(*vi++),v,t+i*bs); 52 v += bs2; 53 } 54 } 55 56 /* copy t into x according to permutation */ 57 for (i=0; i<n; i++) { 58 for (j=0; j<bs; j++) { 59 x[bs*r[i]+j] = t[bs*i+j]; 60 } 61 } 62 63 PetscCall(ISRestoreIndices(isrow,&rout)); 64 PetscCall(ISRestoreIndices(iscol,&cout)); 65 PetscCall(VecRestoreArrayRead(bb,&b)); 66 PetscCall(VecRestoreArray(xx,&x)); 67 PetscCall(PetscLogFlops(2.0*(a->bs2)*(a->nz) - A->rmap->bs*A->cmap->n)); 68 PetscFunctionReturn(0); 69 } 70 71 PetscErrorCode MatSolveTranspose_SeqBAIJ_N(Mat A,Vec bb,Vec xx) 72 { 73 Mat_SeqBAIJ *a =(Mat_SeqBAIJ*)A->data; 74 IS iscol=a->col,isrow=a->row; 75 const PetscInt *r,*c,*rout,*cout; 76 const PetscInt n=a->mbs,*ai=a->i,*aj=a->j,*vi,*diag=a->diag; 77 PetscInt i,j,nz; 78 const PetscInt bs =A->rmap->bs,bs2=a->bs2; 79 const MatScalar *aa=a->a,*v; 80 PetscScalar *x,*t,*ls; 81 const PetscScalar *b; 82 83 PetscFunctionBegin; 84 PetscCall(VecGetArrayRead(bb,&b)); 85 PetscCall(VecGetArray(xx,&x)); 86 t = a->solve_work; 87 88 PetscCall(ISGetIndices(isrow,&rout)); r = rout; 89 PetscCall(ISGetIndices(iscol,&cout)); c = cout; 90 91 /* copy the b into temp work space according to permutation */ 92 for (i=0; i<n; i++) { 93 for (j=0; j<bs; j++) { 94 t[i*bs+j] = b[c[i]*bs+j]; 95 } 96 } 97 98 /* forward solve the upper triangular transpose */ 99 ls = a->solve_work + A->cmap->n; 100 for (i=0; i<n; i++) { 101 PetscCall(PetscArraycpy(ls,t+i*bs,bs)); 102 PetscKernel_w_gets_transA_times_v(bs,ls,aa+bs2*diag[i],t+i*bs); 103 v = aa + bs2*(diag[i] - 1); 104 vi = aj + diag[i] - 1; 105 nz = diag[i] - diag[i+1] - 1; 106 for (j=0; j>-nz; j--) { 107 PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(vi[j]),v,t+i*bs); 108 v -= bs2; 109 } 110 } 111 112 /* backward solve the lower triangular transpose */ 113 for (i=n-1; i>=0; i--) { 114 v = aa + bs2*ai[i]; 115 vi = aj + ai[i]; 116 nz = ai[i+1] - ai[i]; 117 for (j=0; j<nz; j++) { 118 PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(vi[j]),v,t+i*bs); 119 v += bs2; 120 } 121 } 122 123 /* copy t into x according to permutation */ 124 for (i=0; i<n; i++) { 125 for (j=0; j<bs; j++) { 126 x[bs*r[i]+j] = t[bs*i+j]; 127 } 128 } 129 130 PetscCall(ISRestoreIndices(isrow,&rout)); 131 PetscCall(ISRestoreIndices(iscol,&cout)); 132 PetscCall(VecRestoreArrayRead(bb,&b)); 133 PetscCall(VecRestoreArray(xx,&x)); 134 PetscCall(PetscLogFlops(2.0*(a->bs2)*(a->nz) - A->rmap->bs*A->cmap->n)); 135 PetscFunctionReturn(0); 136 } 137