1! 2! 3! Fortran kernel for sparse triangular solve in the BAIJ matrix format 4! This ONLY works for factorizations in the NATURAL ORDERING, i.e. 5! with MatSolve_SeqBAIJ_4_NaturalOrdering() 6! 7#include <petsc/finclude/petscsys.h> 8! 9 10pure subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b) 11 implicit none (type, external) 12 MatScalar, intent(in) :: a(0:*) 13 PetscScalar, intent(inout) :: x(0:*) 14 PetscScalar, intent(in) :: b(0:*) 15 PetscInt, intent(in) :: n 16 PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*) 17 18 PetscInt :: i,j,jstart,jend 19 PetscInt :: idx,ax,jdx 20 PetscScalar :: s1,s2,s3,s4 21 PetscScalar :: x1,x2,x3,x4 22 23 PETSC_AssertAlignx(16,a(1)) 24 PETSC_AssertAlignx(16,x(1)) 25 PETSC_AssertAlignx(16,b(1)) 26 PETSC_AssertAlignx(16,ai(1)) 27 PETSC_AssertAlignx(16,aj(1)) 28 PETSC_AssertAlignx(16,adiag(1)) 29 30 ! 31 ! Forward Solve 32 ! 33 x(0:3) = b(0:3) 34 idx = 0 35 do i=1,n-1 36 jstart = ai(i) 37 jend = adiag(i) - 1 38 ax = 16*jstart 39 idx = idx + 4 40 s1 = b(idx+0) 41 s2 = b(idx+1) 42 s3 = b(idx+2) 43 s4 = b(idx+3) 44 do j=jstart,jend 45 jdx = 4*aj(j) 46 47 x1 = x(jdx+0) 48 x2 = x(jdx+1) 49 x3 = x(jdx+2) 50 x4 = x(jdx+3) 51 s1 = s1-(a(ax+0)*x1+a(ax+4)*x2+a(ax+ 8)*x3+a(ax+12)*x4) 52 s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4) 53 s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4) 54 s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4) 55 ax = ax + 16 56 end do 57 x(idx+0) = s1 58 x(idx+1) = s2 59 x(idx+2) = s3 60 x(idx+3) = s4 61 end do 62 63 ! 64 ! Backward solve the upper triangular 65 ! 66 do i=n-1,0,-1 67 jstart = adiag(i) + 1 68 jend = ai(i+1) - 1 69 ax = 16*jstart 70 s1 = x(idx+0) 71 s2 = x(idx+1) 72 s3 = x(idx+2) 73 s4 = x(idx+3) 74 do j=jstart,jend 75 jdx = 4*aj(j) 76 x1 = x(jdx+0) 77 x2 = x(jdx+1) 78 x3 = x(jdx+2) 79 x4 = x(jdx+3) 80 s1 = s1-(a(ax+0)*x1+a(ax+4)*x2+a(ax+ 8)*x3+a(ax+12)*x4) 81 s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4) 82 s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4) 83 s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4) 84 ax = ax + 16 85 end do 86 ax = 16*adiag(i) 87 x(idx+0) = a(ax+0)*s1+a(ax+4)*s2+a(ax+ 8)*s3+a(ax+12)*s4 88 x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4 89 x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4 90 x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4 91 idx = idx - 4 92 end do 93end subroutine FortranSolveBAIJ4Unroll 94 95! version that does not call BLAS 2 operation for each row block 96! 97subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w) 98 implicit none 99 MatScalar, intent(in) :: a(0:*) 100 PetscScalar, intent(inout) :: x(0:*),w(0:*) 101 PetscScalar, intent(in) :: b(0:*) 102 PetscInt, intent(in) :: n 103 PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*) 104 105 PetscInt :: ii,jj,i,j 106 PetscInt :: jstart,jend,idx,ax,jdx,kdx,nn 107 PetscScalar :: s(0:3) 108 109 PETSC_AssertAlignx(16,a(1)) 110 PETSC_AssertAlignx(16,w(1)) 111 PETSC_AssertAlignx(16,x(1)) 112 PETSC_AssertAlignx(16,b(1)) 113 PETSC_AssertAlignx(16,ai(1)) 114 PETSC_AssertAlignx(16,aj(1)) 115 PETSC_AssertAlignx(16,adiag(1)) 116 ! 117 ! Forward Solve 118 ! 119 x(0:3) = b(0:3) 120 idx = 0 121 do i=1,n-1 122 ! 123 ! Pack required part of vector into work array 124 ! 125 kdx = 0 126 jstart = ai(i) 127 jend = adiag(i) - 1 128 129 if (jend - jstart >= 500) write(6,*) 'Overflowing vector FortranSolveBAIJ4()' 130 131 do j=jstart,jend 132 133 jdx = 4*aj(j) 134 w(kdx:kdx+3) = x(jdx:jdx+3) 135 kdx = kdx + 4 136 end do 137 138 ax = 16*jstart 139 idx = idx + 4 140 s(0:3) = b(idx:idx+3) 141 ! 142 ! s = s - a(ax:)*w 143 ! 144 nn = 4*(jend - jstart + 1) - 1 145 do ii=0,3 146 do jj=0,nn 147 s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj) 148 end do 149 end do 150 151 x(idx:idx+3) = s(0:3) 152 end do 153 ! 154 ! Backward solve the upper triangular 155 ! 156 do i=n-1,0,-1 157 jstart = adiag(i) + 1 158 jend = ai(i+1) - 1 159 ax = 16*jstart 160 s(0:3) = x(idx:idx+3) 161 ! 162 ! Pack each chunk of vector needed 163 ! 164 kdx = 0 165 if (jend - jstart >= 500) write(6,*) 'Overflowing vector FortranSolveBAIJ4()' 166 167 do j=jstart,jend 168 jdx = 4*aj(j) 169 w(kdx:kdx+3) = x(jdx:jdx+3) 170 kdx = kdx + 4 171 end do 172 nn = 4*(jend - jstart + 1) - 1 173 do ii=0,3 174 do jj=0,nn 175 s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj) 176 end do 177 end do 178 179 ax = 16*adiag(i) 180 x(idx) = a(ax+0)*s(0)+a(ax+4)*s(1)+a(ax+ 8)*s(2)+a(ax+12)*s(3) 181 x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3) 182 x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3) 183 x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3) 184 idx = idx - 4 185 end do 186end subroutine FortranSolveBAIJ4 187