1! 2! 3! Fortran kernel for sparse triangular solve in the BAIJ matrix format 4! This ONLY works for factorizations in the NATURAL ORDERING, i.e. 5! with MatSolve_SeqBAIJ_4_NaturalOrdering() 6! 7#include <petsc/finclude/petscsys.h> 8! 9 10pure subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b) 11 implicit none (type, external) 12 MatScalar, intent(in) :: a(0:*) 13 PetscScalar, intent(inout) :: x(0:*) 14 PetscScalar, intent(in) :: b(0:*) 15 PetscInt, intent(in) :: n 16 PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*) 17 18 PetscInt :: i,j,jstart,jend 19 PetscInt :: idx,ax,jdx 20 PetscScalar :: s1,s2,s3,s4 21 PetscScalar :: x1,x2,x3,x4 22 23 PETSC_AssertAlignx(16,a(1)) 24 PETSC_AssertAlignx(16,x(1)) 25 PETSC_AssertAlignx(16,b(1)) 26 PETSC_AssertAlignx(16,ai(1)) 27 PETSC_AssertAlignx(16,aj(1)) 28 PETSC_AssertAlignx(16,adiag(1)) 29 30 ! 31 ! Forward Solve 32 ! 33 x(0) = b(0) 34 x(1) = b(1) 35 x(2) = b(2) 36 x(3) = b(3) 37 idx = 0 38 do i=1,n-1 39 jstart = ai(i) 40 jend = adiag(i) - 1 41 ax = 16*jstart 42 idx = idx + 4 43 s1 = b(idx) 44 s2 = b(idx+1) 45 s3 = b(idx+2) 46 s4 = b(idx+3) 47 do j=jstart,jend 48 jdx = 4*aj(j) 49 50 x1 = x(jdx) 51 x2 = x(jdx+1) 52 x3 = x(jdx+2) 53 x4 = x(jdx+3) 54 s1 = s1-(a(ax)*x1 +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4) 55 s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4) 56 s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4) 57 s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4) 58 ax = ax + 16 59 end do 60 x(idx) = s1 61 x(idx+1) = s2 62 x(idx+2) = s3 63 x(idx+3) = s4 64 end do 65 66 ! 67 ! Backward solve the upper triangular 68 ! 69 do i=n-1,0,-1 70 jstart = adiag(i) + 1 71 jend = ai(i+1) - 1 72 ax = 16*jstart 73 s1 = x(idx) 74 s2 = x(idx+1) 75 s3 = x(idx+2) 76 s4 = x(idx+3) 77 do j=jstart,jend 78 jdx = 4*aj(j) 79 x1 = x(jdx) 80 x2 = x(jdx+1) 81 x3 = x(jdx+2) 82 x4 = x(jdx+3) 83 s1 = s1-(a(ax)*x1 +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4) 84 s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4) 85 s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4) 86 s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4) 87 ax = ax + 16 88 end do 89 ax = 16*adiag(i) 90 x(idx) = a(ax)*s1 +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4 91 x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4 92 x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4 93 x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4 94 idx = idx - 4 95 end do 96end subroutine FortranSolveBAIJ4Unroll 97 98! version that does not call BLAS 2 operation for each row block 99! 100subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w) 101 implicit none 102 MatScalar, intent(in) :: a(0:*) 103 PetscScalar, intent(inout) :: x(0:*),w(0:*) 104 PetscScalar, intent(in) :: b(0:*) 105 PetscInt, intent(in) :: n 106 PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*) 107 108 PetscInt :: ii,jj,i,j 109 PetscInt :: jstart,jend,idx,ax,jdx,kdx,nn 110 PetscScalar :: s(0:3) 111 112 PETSC_AssertAlignx(16,a(1)) 113 PETSC_AssertAlignx(16,w(1)) 114 PETSC_AssertAlignx(16,x(1)) 115 PETSC_AssertAlignx(16,b(1)) 116 PETSC_AssertAlignx(16,ai(1)) 117 PETSC_AssertAlignx(16,aj(1)) 118 PETSC_AssertAlignx(16,adiag(1)) 119 ! 120 ! Forward Solve 121 ! 122 x(0) = b(0) 123 x(1) = b(1) 124 x(2) = b(2) 125 x(3) = b(3) 126 idx = 0 127 do i=1,n-1 128 ! 129 ! Pack required part of vector into work array 130 ! 131 kdx = 0 132 jstart = ai(i) 133 jend = adiag(i) - 1 134 if (jend - jstart .ge. 500) then 135 write(6,*) 'Overflowing vector FortranSolveBAIJ4()' 136 endif 137 do j=jstart,jend 138 139 jdx = 4*aj(j) 140 141 w(kdx) = x(jdx) 142 w(kdx+1) = x(jdx+1) 143 w(kdx+2) = x(jdx+2) 144 w(kdx+3) = x(jdx+3) 145 kdx = kdx + 4 146 end do 147 148 ax = 16*jstart 149 idx = idx + 4 150 s(0) = b(idx) 151 s(1) = b(idx+1) 152 s(2) = b(idx+2) 153 s(3) = b(idx+3) 154 ! 155 ! s = s - a(ax:)*w 156 ! 157 nn = 4*(jend - jstart + 1) - 1 158 do ii=0,3 159 do jj=0,nn 160 s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj) 161 end do 162 end do 163 164 x(idx) = s(0) 165 x(idx+1) = s(1) 166 x(idx+2) = s(2) 167 x(idx+3) = s(3) 168 end do 169 ! 170 ! Backward solve the upper triangular 171 ! 172 do i=n-1,0,-1 173 jstart = adiag(i) + 1 174 jend = ai(i+1) - 1 175 ax = 16*jstart 176 s(0) = x(idx) 177 s(1) = x(idx+1) 178 s(2) = x(idx+2) 179 s(3) = x(idx+3) 180 ! 181 ! Pack each chunk of vector needed 182 ! 183 kdx = 0 184 if (jend - jstart .ge. 500) then 185 write(6,*) 'Overflowing vector FortranSolveBAIJ4()' 186 endif 187 do j=jstart,jend 188 jdx = 4*aj(j) 189 w(kdx) = x(jdx) 190 w(kdx+1) = x(jdx+1) 191 w(kdx+2) = x(jdx+2) 192 w(kdx+3) = x(jdx+3) 193 kdx = kdx + 4 194 end do 195 nn = 4*(jend - jstart + 1) - 1 196 do ii=0,3 197 do jj=0,nn 198 s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj) 199 end do 200 end do 201 202 ax = 16*adiag(i) 203 x(idx) = a(ax)*s(0) +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3) 204 x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3) 205 x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3) 206 x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3) 207 idx = idx - 4 208 end do 209end subroutine FortranSolveBAIJ4 210