1 #include <../src/mat/impls/baij/seq/baij.h> 2 #include <petsc/private/kernels/blockinvert.h> 3 4 PetscErrorCode MatSolveTranspose_SeqBAIJ_3_inplace(Mat A, Vec bb, Vec xx) 5 { 6 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 7 IS iscol = a->col, isrow = a->row; 8 const PetscInt *r, *c, *rout, *cout; 9 const PetscInt *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j; 10 PetscInt i, nz, idx, idt, ii, ic, ir, oidx; 11 const MatScalar *aa = a->a, *v; 12 PetscScalar s1, s2, s3, x1, x2, x3, *x, *t; 13 const PetscScalar *b; 14 15 PetscFunctionBegin; 16 PetscCall(VecGetArrayRead(bb, &b)); 17 PetscCall(VecGetArray(xx, &x)); 18 t = a->solve_work; 19 20 PetscCall(ISGetIndices(isrow, &rout)); 21 r = rout; 22 PetscCall(ISGetIndices(iscol, &cout)); 23 c = cout; 24 25 /* copy the b into temp work space according to permutation */ 26 ii = 0; 27 for (i = 0; i < n; i++) { 28 ic = 3 * c[i]; 29 t[ii] = b[ic]; 30 t[ii + 1] = b[ic + 1]; 31 t[ii + 2] = b[ic + 2]; 32 ii += 3; 33 } 34 35 /* forward solve the U^T */ 36 idx = 0; 37 for (i = 0; i < n; i++) { 38 v = aa + 9 * diag[i]; 39 /* multiply by the inverse of the block diagonal */ 40 x1 = t[idx]; 41 x2 = t[1 + idx]; 42 x3 = t[2 + idx]; 43 s1 = v[0] * x1 + v[1] * x2 + v[2] * x3; 44 s2 = v[3] * x1 + v[4] * x2 + v[5] * x3; 45 s3 = v[6] * x1 + v[7] * x2 + v[8] * x3; 46 v += 9; 47 48 vi = aj + diag[i] + 1; 49 nz = ai[i + 1] - diag[i] - 1; 50 while (nz--) { 51 oidx = 3 * (*vi++); 52 t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3; 53 t[oidx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3; 54 t[oidx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3; 55 v += 9; 56 } 57 t[idx] = s1; 58 t[1 + idx] = s2; 59 t[2 + idx] = s3; 60 idx += 3; 61 } 62 /* backward solve the L^T */ 63 for (i = n - 1; i >= 0; i--) { 64 v = aa + 9 * diag[i] - 9; 65 vi = aj + diag[i] - 1; 66 nz = diag[i] - ai[i]; 67 idt = 3 * i; 68 s1 = t[idt]; 69 s2 = t[1 + idt]; 70 s3 = t[2 + idt]; 71 while (nz--) { 72 idx = 3 * (*vi--); 73 t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3; 74 t[idx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3; 75 t[idx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3; 76 v -= 9; 77 } 78 } 79 80 /* copy t into x according to permutation */ 81 ii = 0; 82 for (i = 0; i < n; i++) { 83 ir = 3 * r[i]; 84 x[ir] = t[ii]; 85 x[ir + 1] = t[ii + 1]; 86 x[ir + 2] = t[ii + 2]; 87 ii += 3; 88 } 89 90 PetscCall(ISRestoreIndices(isrow, &rout)); 91 PetscCall(ISRestoreIndices(iscol, &cout)); 92 PetscCall(VecRestoreArrayRead(bb, &b)); 93 PetscCall(VecRestoreArray(xx, &x)); 94 PetscCall(PetscLogFlops(2.0 * 9 * (a->nz) - 3.0 * A->cmap->n)); 95 PetscFunctionReturn(0); 96 } 97 98 PetscErrorCode MatSolveTranspose_SeqBAIJ_3(Mat A, Vec bb, Vec xx) 99 { 100 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 101 IS iscol = a->col, isrow = a->row; 102 const PetscInt n = a->mbs, *vi, *ai = a->i, *aj = a->j, *diag = a->diag; 103 const PetscInt *r, *c, *rout, *cout; 104 PetscInt nz, idx, idt, j, i, oidx, ii, ic, ir; 105 const PetscInt bs = A->rmap->bs, bs2 = a->bs2; 106 const MatScalar *aa = a->a, *v; 107 PetscScalar s1, s2, s3, x1, x2, x3, *x, *t; 108 const PetscScalar *b; 109 110 PetscFunctionBegin; 111 PetscCall(VecGetArrayRead(bb, &b)); 112 PetscCall(VecGetArray(xx, &x)); 113 t = a->solve_work; 114 115 PetscCall(ISGetIndices(isrow, &rout)); 116 r = rout; 117 PetscCall(ISGetIndices(iscol, &cout)); 118 c = cout; 119 120 /* copy b into temp work space according to permutation */ 121 for (i = 0; i < n; i++) { 122 ii = bs * i; 123 ic = bs * c[i]; 124 t[ii] = b[ic]; 125 t[ii + 1] = b[ic + 1]; 126 t[ii + 2] = b[ic + 2]; 127 } 128 129 /* forward solve the U^T */ 130 idx = 0; 131 for (i = 0; i < n; i++) { 132 v = aa + bs2 * diag[i]; 133 /* multiply by the inverse of the block diagonal */ 134 x1 = t[idx]; 135 x2 = t[1 + idx]; 136 x3 = t[2 + idx]; 137 s1 = v[0] * x1 + v[1] * x2 + v[2] * x3; 138 s2 = v[3] * x1 + v[4] * x2 + v[5] * x3; 139 s3 = v[6] * x1 + v[7] * x2 + v[8] * x3; 140 v -= bs2; 141 142 vi = aj + diag[i] - 1; 143 nz = diag[i] - diag[i + 1] - 1; 144 for (j = 0; j > -nz; j--) { 145 oidx = bs * vi[j]; 146 t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3; 147 t[oidx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3; 148 t[oidx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3; 149 v -= bs2; 150 } 151 t[idx] = s1; 152 t[1 + idx] = s2; 153 t[2 + idx] = s3; 154 idx += bs; 155 } 156 /* backward solve the L^T */ 157 for (i = n - 1; i >= 0; i--) { 158 v = aa + bs2 * ai[i]; 159 vi = aj + ai[i]; 160 nz = ai[i + 1] - ai[i]; 161 idt = bs * i; 162 s1 = t[idt]; 163 s2 = t[1 + idt]; 164 s3 = t[2 + idt]; 165 for (j = 0; j < nz; j++) { 166 idx = bs * vi[j]; 167 t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3; 168 t[idx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3; 169 t[idx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3; 170 v += bs2; 171 } 172 } 173 174 /* copy t into x according to permutation */ 175 for (i = 0; i < n; i++) { 176 ii = bs * i; 177 ir = bs * r[i]; 178 x[ir] = t[ii]; 179 x[ir + 1] = t[ii + 1]; 180 x[ir + 2] = t[ii + 2]; 181 } 182 183 PetscCall(ISRestoreIndices(isrow, &rout)); 184 PetscCall(ISRestoreIndices(iscol, &cout)); 185 PetscCall(VecRestoreArrayRead(bb, &b)); 186 PetscCall(VecRestoreArray(xx, &x)); 187 PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n)); 188 PetscFunctionReturn(0); 189 } 190