1 #include <../src/mat/impls/baij/seq/baij.h> 2 #include <petsc/private/kernels/blockinvert.h> 3 4 PetscErrorCode MatSolveTranspose_SeqBAIJ_6_inplace(Mat A, Vec bb, Vec xx) 5 { 6 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 7 IS iscol = a->col, isrow = a->row; 8 const PetscInt *r, *c, *rout, *cout; 9 const PetscInt *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j; 10 PetscInt i, nz, idx, idt, ii, ic, ir, oidx; 11 const MatScalar *aa = a->a, *v; 12 PetscScalar s1, s2, s3, s4, s5, s6, x1, x2, x3, x4, x5, x6, *x, *t; 13 const PetscScalar *b; 14 15 PetscFunctionBegin; 16 PetscCall(VecGetArrayRead(bb, &b)); 17 PetscCall(VecGetArray(xx, &x)); 18 t = a->solve_work; 19 20 PetscCall(ISGetIndices(isrow, &rout)); 21 r = rout; 22 PetscCall(ISGetIndices(iscol, &cout)); 23 c = cout; 24 25 /* copy the b into temp work space according to permutation */ 26 ii = 0; 27 for (i = 0; i < n; i++) { 28 ic = 6 * c[i]; 29 t[ii] = b[ic]; 30 t[ii + 1] = b[ic + 1]; 31 t[ii + 2] = b[ic + 2]; 32 t[ii + 3] = b[ic + 3]; 33 t[ii + 4] = b[ic + 4]; 34 t[ii + 5] = b[ic + 5]; 35 ii += 6; 36 } 37 38 /* forward solve the U^T */ 39 idx = 0; 40 for (i = 0; i < n; i++) { 41 v = aa + 36 * diag[i]; 42 /* multiply by the inverse of the block diagonal */ 43 x1 = t[idx]; 44 x2 = t[1 + idx]; 45 x3 = t[2 + idx]; 46 x4 = t[3 + idx]; 47 x5 = t[4 + idx]; 48 x6 = t[5 + idx]; 49 s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4 + v[4] * x5 + v[5] * x6; 50 s2 = v[6] * x1 + v[7] * x2 + v[8] * x3 + v[9] * x4 + v[10] * x5 + v[11] * x6; 51 s3 = v[12] * x1 + v[13] * x2 + v[14] * x3 + v[15] * x4 + v[16] * x5 + v[17] * x6; 52 s4 = v[18] * x1 + v[19] * x2 + v[20] * x3 + v[21] * x4 + v[22] * x5 + v[23] * x6; 53 s5 = v[24] * x1 + v[25] * x2 + v[26] * x3 + v[27] * x4 + v[28] * x5 + v[29] * x6; 54 s6 = v[30] * x1 + v[31] * x2 + v[32] * x3 + v[33] * x4 + v[34] * x5 + v[35] * x6; 55 v += 36; 56 57 vi = aj + diag[i] + 1; 58 nz = ai[i + 1] - diag[i] - 1; 59 while (nz--) { 60 oidx = 6 * (*vi++); 61 t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5 + v[5] * s6; 62 t[oidx + 1] -= v[6] * s1 + v[7] * s2 + v[8] * s3 + v[9] * s4 + v[10] * s5 + v[11] * s6; 63 t[oidx + 2] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4 + v[16] * s5 + v[17] * s6; 64 t[oidx + 3] -= v[18] * s1 + v[19] * s2 + v[20] * s3 + v[21] * s4 + v[22] * s5 + v[23] * s6; 65 t[oidx + 4] -= v[24] * s1 + v[25] * s2 + v[26] * s3 + v[27] * s4 + v[28] * s5 + v[29] * s6; 66 t[oidx + 5] -= v[30] * s1 + v[31] * s2 + v[32] * s3 + v[33] * s4 + v[34] * s5 + v[35] * s6; 67 v += 36; 68 } 69 t[idx] = s1; 70 t[1 + idx] = s2; 71 t[2 + idx] = s3; 72 t[3 + idx] = s4; 73 t[4 + idx] = s5; 74 t[5 + idx] = s6; 75 idx += 6; 76 } 77 /* backward solve the L^T */ 78 for (i = n - 1; i >= 0; i--) { 79 v = aa + 36 * diag[i] - 36; 80 vi = aj + diag[i] - 1; 81 nz = diag[i] - ai[i]; 82 idt = 6 * i; 83 s1 = t[idt]; 84 s2 = t[1 + idt]; 85 s3 = t[2 + idt]; 86 s4 = t[3 + idt]; 87 s5 = t[4 + idt]; 88 s6 = t[5 + idt]; 89 while (nz--) { 90 idx = 6 * (*vi--); 91 t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5 + v[5] * s6; 92 t[idx + 1] -= v[6] * s1 + v[7] * s2 + v[8] * s3 + v[9] * s4 + v[10] * s5 + v[11] * s6; 93 t[idx + 2] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4 + v[16] * s5 + v[17] * s6; 94 t[idx + 3] -= v[18] * s1 + v[19] * s2 + v[20] * s3 + v[21] * s4 + v[22] * s5 + v[23] * s6; 95 t[idx + 4] -= v[24] * s1 + v[25] * s2 + v[26] * s3 + v[27] * s4 + v[28] * s5 + v[29] * s6; 96 t[idx + 5] -= v[30] * s1 + v[31] * s2 + v[32] * s3 + v[33] * s4 + v[34] * s5 + v[35] * s6; 97 v -= 36; 98 } 99 } 100 101 /* copy t into x according to permutation */ 102 ii = 0; 103 for (i = 0; i < n; i++) { 104 ir = 6 * r[i]; 105 x[ir] = t[ii]; 106 x[ir + 1] = t[ii + 1]; 107 x[ir + 2] = t[ii + 2]; 108 x[ir + 3] = t[ii + 3]; 109 x[ir + 4] = t[ii + 4]; 110 x[ir + 5] = t[ii + 5]; 111 ii += 6; 112 } 113 114 PetscCall(ISRestoreIndices(isrow, &rout)); 115 PetscCall(ISRestoreIndices(iscol, &cout)); 116 PetscCall(VecRestoreArrayRead(bb, &b)); 117 PetscCall(VecRestoreArray(xx, &x)); 118 PetscCall(PetscLogFlops(2.0 * 36 * (a->nz) - 6.0 * A->cmap->n)); 119 PetscFunctionReturn(PETSC_SUCCESS); 120 } 121 122 PetscErrorCode MatSolveTranspose_SeqBAIJ_6(Mat A, Vec bb, Vec xx) 123 { 124 Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data; 125 IS iscol = a->col, isrow = a->row; 126 const PetscInt n = a->mbs, *vi, *ai = a->i, *aj = a->j, *diag = a->diag; 127 const PetscInt *r, *c, *rout, *cout; 128 PetscInt nz, idx, idt, j, i, oidx, ii, ic, ir; 129 const PetscInt bs = A->rmap->bs, bs2 = a->bs2; 130 const MatScalar *aa = a->a, *v; 131 PetscScalar s1, s2, s3, s4, s5, s6, x1, x2, x3, x4, x5, x6, *x, *t; 132 const PetscScalar *b; 133 134 PetscFunctionBegin; 135 PetscCall(VecGetArrayRead(bb, &b)); 136 PetscCall(VecGetArray(xx, &x)); 137 t = a->solve_work; 138 139 PetscCall(ISGetIndices(isrow, &rout)); 140 r = rout; 141 PetscCall(ISGetIndices(iscol, &cout)); 142 c = cout; 143 144 /* copy b into temp work space according to permutation */ 145 for (i = 0; i < n; i++) { 146 ii = bs * i; 147 ic = bs * c[i]; 148 t[ii] = b[ic]; 149 t[ii + 1] = b[ic + 1]; 150 t[ii + 2] = b[ic + 2]; 151 t[ii + 3] = b[ic + 3]; 152 t[ii + 4] = b[ic + 4]; 153 t[ii + 5] = b[ic + 5]; 154 } 155 156 /* forward solve the U^T */ 157 idx = 0; 158 for (i = 0; i < n; i++) { 159 v = aa + bs2 * diag[i]; 160 /* multiply by the inverse of the block diagonal */ 161 x1 = t[idx]; 162 x2 = t[1 + idx]; 163 x3 = t[2 + idx]; 164 x4 = t[3 + idx]; 165 x5 = t[4 + idx]; 166 x6 = t[5 + idx]; 167 s1 = v[0] * x1 + v[1] * x2 + v[2] * x3 + v[3] * x4 + v[4] * x5 + v[5] * x6; 168 s2 = v[6] * x1 + v[7] * x2 + v[8] * x3 + v[9] * x4 + v[10] * x5 + v[11] * x6; 169 s3 = v[12] * x1 + v[13] * x2 + v[14] * x3 + v[15] * x4 + v[16] * x5 + v[17] * x6; 170 s4 = v[18] * x1 + v[19] * x2 + v[20] * x3 + v[21] * x4 + v[22] * x5 + v[23] * x6; 171 s5 = v[24] * x1 + v[25] * x2 + v[26] * x3 + v[27] * x4 + v[28] * x5 + v[29] * x6; 172 s6 = v[30] * x1 + v[31] * x2 + v[32] * x3 + v[33] * x4 + v[34] * x5 + v[35] * x6; 173 v -= bs2; 174 175 vi = aj + diag[i] - 1; 176 nz = diag[i] - diag[i + 1] - 1; 177 for (j = 0; j > -nz; j--) { 178 oidx = bs * vi[j]; 179 t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5 + v[5] * s6; 180 t[oidx + 1] -= v[6] * s1 + v[7] * s2 + v[8] * s3 + v[9] * s4 + v[10] * s5 + v[11] * s6; 181 t[oidx + 2] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4 + v[16] * s5 + v[17] * s6; 182 t[oidx + 3] -= v[18] * s1 + v[19] * s2 + v[20] * s3 + v[21] * s4 + v[22] * s5 + v[23] * s6; 183 t[oidx + 4] -= v[24] * s1 + v[25] * s2 + v[26] * s3 + v[27] * s4 + v[28] * s5 + v[29] * s6; 184 t[oidx + 5] -= v[30] * s1 + v[31] * s2 + v[32] * s3 + v[33] * s4 + v[34] * s5 + v[35] * s6; 185 v -= bs2; 186 } 187 t[idx] = s1; 188 t[1 + idx] = s2; 189 t[2 + idx] = s3; 190 t[3 + idx] = s4; 191 t[4 + idx] = s5; 192 t[5 + idx] = s6; 193 idx += bs; 194 } 195 /* backward solve the L^T */ 196 for (i = n - 1; i >= 0; i--) { 197 v = aa + bs2 * ai[i]; 198 vi = aj + ai[i]; 199 nz = ai[i + 1] - ai[i]; 200 idt = bs * i; 201 s1 = t[idt]; 202 s2 = t[1 + idt]; 203 s3 = t[2 + idt]; 204 s4 = t[3 + idt]; 205 s5 = t[4 + idt]; 206 s6 = t[5 + idt]; 207 for (j = 0; j < nz; j++) { 208 idx = bs * vi[j]; 209 t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3 + v[3] * s4 + v[4] * s5 + v[5] * s6; 210 t[idx + 1] -= v[6] * s1 + v[7] * s2 + v[8] * s3 + v[9] * s4 + v[10] * s5 + v[11] * s6; 211 t[idx + 2] -= v[12] * s1 + v[13] * s2 + v[14] * s3 + v[15] * s4 + v[16] * s5 + v[17] * s6; 212 t[idx + 3] -= v[18] * s1 + v[19] * s2 + v[20] * s3 + v[21] * s4 + v[22] * s5 + v[23] * s6; 213 t[idx + 4] -= v[24] * s1 + v[25] * s2 + v[26] * s3 + v[27] * s4 + v[28] * s5 + v[29] * s6; 214 t[idx + 5] -= v[30] * s1 + v[31] * s2 + v[32] * s3 + v[33] * s4 + v[34] * s5 + v[35] * s6; 215 v += bs2; 216 } 217 } 218 219 /* copy t into x according to permutation */ 220 for (i = 0; i < n; i++) { 221 ii = bs * i; 222 ir = bs * r[i]; 223 x[ir] = t[ii]; 224 x[ir + 1] = t[ii + 1]; 225 x[ir + 2] = t[ii + 2]; 226 x[ir + 3] = t[ii + 3]; 227 x[ir + 4] = t[ii + 4]; 228 x[ir + 5] = t[ii + 5]; 229 } 230 231 PetscCall(ISRestoreIndices(isrow, &rout)); 232 PetscCall(ISRestoreIndices(iscol, &cout)); 233 PetscCall(VecRestoreArrayRead(bb, &b)); 234 PetscCall(VecRestoreArray(xx, &x)); 235 PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n)); 236 PetscFunctionReturn(PETSC_SUCCESS); 237 } 238