1 #include <../src/mat/impls/baij/seq/baij.h> 2 #include <petsc/private/kernels/blockinvert.h> 3 4 PetscErrorCode MatSolveTranspose_SeqBAIJ_4_inplace(Mat A, Vec bb, Vec xx) 5 { 6 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 7 IS iscol = a->col, isrow = a->row; 8 const PetscInt *r, *c, *rout, *cout; 9 const PetscInt *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j; 10 PetscInt i, nz, idx, idt, ii, ic, ir, oidx; 11 const MatScalar *aa = a->a, *v; 12 PetscScalar s1, s2, s3, s4, x1, x2, x3, x4, *x, *t; 13 const PetscScalar *b; 14 15 PetscFunctionBegin; 16 PetscCall(VecGetArrayRead(bb, &b)); 17 PetscCall(VecGetArray(xx, &x)); 18 t = a->solve_work; 19 20 PetscCall(ISGetIndices(isrow, &rout)); 21 r = rout; 22 PetscCall(ISGetIndices(iscol, &cout)); 23 c = cout; 24 25 /* copy the b into temp work space according to permutation */ 26 ii = 0; 27 for (i = 0; i < n; i++) { 28 ic = 4 * c[i]; 29 t[ii] = b[ic]; 30 t[ii + 1] = b[ic + 1]; 31 t[ii + 2] = b[ic + 2]; 32 t[ii + 3] = b[ic + 3]; 33 ii += 4; 34 } 35 36 /* forward solve the U^T */ 37 idx = 0; 38 for (i = 0; i < n; i++) { 39 v = aa + 16 * diag[i]; 40 /* multiply by the inverse of the block diagonal */ 41 x1 = t[idx]; 42 x2 = t[1 + idx]; 43 x3 = t[2 + idx]; 44 x4 = t[3 + idx]; 45 s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4; 46 s2 = v[4] * x1 + v[5] * x2 + v[6] * x3 + v[7] * x4; 47 s3 = v[8] * x1 + v[9] * x2 + v[10] * x3 + v[11] * x4; 48 s4 = v[12] * x1 + v[13] * x2 + v[14] * x3 + v[15] * x4; 49 v += 16; 50 51 vi = aj + diag[i] + 1; 52 nz = ai[i + 1] - diag[i] - 1; 53 while (nz--) { 54 oidx = 4 * (*vi++); 55 t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4; 56 t[oidx + 1] -= v[4] * s1 + v[5] * s2 + v[6] * s3 + v[7] * s4; 57 t[oidx + 2] -= v[8] * s1 + v[9] * s2 + v[10] * s3 + v[11] * s4; 58 t[oidx + 3] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4; 59 v += 16; 60 } 61 t[idx] = s1; 62 t[1 + idx] = s2; 63 t[2 + idx] = s3; 64 t[3 + idx] = s4; 65 idx += 4; 66 } 67 /* backward solve the L^T */ 68 for (i = n - 1; i >= 0; i--) { 69 v = aa + 16 * diag[i] - 16; 70 vi = aj + diag[i] - 1; 71 nz = diag[i] - ai[i]; 72 idt = 4 * i; 73 s1 = t[idt]; 74 s2 = t[1 + idt]; 75 s3 = t[2 + idt]; 76 s4 = t[3 + idt]; 77 while (nz--) { 78 idx = 4 * (*vi--); 79 t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4; 80 t[idx + 1] -= v[4] * s1 + v[5] * s2 + v[6] * s3 + v[7] * s4; 81 t[idx + 2] -= v[8] * s1 + v[9] * s2 + v[10] * s3 + v[11] * s4; 82 t[idx + 3] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4; 83 v -= 16; 84 } 85 } 86 87 /* copy t into x according to permutation */ 88 ii = 0; 89 for (i = 0; i < n; i++) { 90 ir = 4 * r[i]; 91 x[ir] = t[ii]; 92 x[ir + 1] = t[ii + 1]; 93 x[ir + 2] = t[ii + 2]; 94 x[ir + 3] = t[ii + 3]; 95 ii += 4; 96 } 97 98 PetscCall(ISRestoreIndices(isrow, &rout)); 99 PetscCall(ISRestoreIndices(iscol, &cout)); 100 PetscCall(VecRestoreArrayRead(bb, &b)); 101 PetscCall(VecRestoreArray(xx, &x)); 102 PetscCall(PetscLogFlops(2.0 * 16 * (a->nz) - 4.0 * A->cmap->n)); 103 PetscFunctionReturn(PETSC_SUCCESS); 104 } 105 106 PetscErrorCode MatSolveTranspose_SeqBAIJ_4(Mat A, Vec bb, Vec xx) 107 { 108 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 109 IS iscol = a->col, isrow = a->row; 110 const PetscInt n = a->mbs, *vi, *ai = a->i, *aj = a->j, *diag = a->diag; 111 const PetscInt *r, *c, *rout, *cout; 112 PetscInt nz, idx, idt, j, i, oidx, ii, ic, ir; 113 const PetscInt bs = A->rmap->bs, bs2 = a->bs2; 114 const MatScalar *aa = a->a, *v; 115 PetscScalar s1, s2, s3, s4, x1, x2, x3, x4, *x, *t; 116 const PetscScalar *b; 117 118 PetscFunctionBegin; 119 PetscCall(VecGetArrayRead(bb, &b)); 120 PetscCall(VecGetArray(xx, &x)); 121 t = a->solve_work; 122 123 PetscCall(ISGetIndices(isrow, &rout)); 124 r = rout; 125 PetscCall(ISGetIndices(iscol, &cout)); 126 c = cout; 127 128 /* copy b into temp work space according to permutation */ 129 for (i = 0; i < n; i++) { 130 ii = bs * i; 131 ic = bs * c[i]; 132 t[ii] = b[ic]; 133 t[ii + 1] = b[ic + 1]; 134 t[ii + 2] = b[ic + 2]; 135 t[ii + 3] = b[ic + 3]; 136 } 137 138 /* forward solve the U^T */ 139 idx = 0; 140 for (i = 0; i < n; i++) { 141 v = aa + bs2 * diag[i]; 142 /* multiply by the inverse of the block diagonal */ 143 x1 = t[idx]; 144 x2 = t[1 + idx]; 145 x3 = t[2 + idx]; 146 x4 = t[3 + idx]; 147 s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4; 148 s2 = v[4] * x1 + v[5] * x2 + v[6] * x3 + v[7] * x4; 149 s3 = v[8] * x1 + v[9] * x2 + v[10] * x3 + v[11] * x4; 150 s4 = v[12] * x1 + v[13] * x2 + v[14] * x3 + v[15] * x4; 151 v -= bs2; 152 153 vi = aj + diag[i] - 1; 154 nz = diag[i] - diag[i + 1] - 1; 155 for (j = 0; j > -nz; j--) { 156 oidx = bs * vi[j]; 157 t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4; 158 t[oidx + 1] -= v[4] * s1 + v[5] * s2 + v[6] * s3 + v[7] * s4; 159 t[oidx + 2] -= v[8] * s1 + v[9] * s2 + v[10] * s3 + v[11] * s4; 160 t[oidx + 3] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4; 161 v -= bs2; 162 } 163 t[idx] = s1; 164 t[1 + idx] = s2; 165 t[2 + idx] = s3; 166 t[3 + idx] = s4; 167 idx += bs; 168 } 169 /* backward solve the L^T */ 170 for (i = n - 1; i >= 0; i--) { 171 v = aa + bs2 * ai[i]; 172 vi = aj + ai[i]; 173 nz = ai[i + 1] - ai[i]; 174 idt = bs * i; 175 s1 = t[idt]; 176 s2 = t[1 + idt]; 177 s3 = t[2 + idt]; 178 s4 = t[3 + idt]; 179 for (j = 0; j < nz; j++) { 180 idx = bs * vi[j]; 181 t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4; 182 t[idx + 1] -= v[4] * s1 + v[5] * s2 + v[6] * s3 + v[7] * s4; 183 t[idx + 2] -= v[8] * s1 + v[9] * s2 + v[10] * s3 + v[11] * s4; 184 t[idx + 3] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4; 185 v += bs2; 186 } 187 } 188 189 /* copy t into x according to permutation */ 190 for (i = 0; i < n; i++) { 191 ii = bs * i; 192 ir = bs * r[i]; 193 x[ir] = t[ii]; 194 x[ir + 1] = t[ii + 1]; 195 x[ir + 2] = t[ii + 2]; 196 x[ir + 3] = t[ii + 3]; 197 } 198 199 PetscCall(ISRestoreIndices(isrow, &rout)); 200 PetscCall(ISRestoreIndices(iscol, &cout)); 201 PetscCall(VecRestoreArrayRead(bb, &b)); 202 PetscCall(VecRestoreArray(xx, &x)); 203 PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n)); 204 PetscFunctionReturn(PETSC_SUCCESS); 205 } 206