1! 2! 3! Fortran kernel for sparse triangular solve in the BAIJ matrix format 4! This ONLY works for factorizations in the NATURAL ORDERING, i.e. 5! with MatSolve_SeqBAIJ_4_NaturalOrdering() 6! 7#include <petsc/finclude/petscsys.h> 8! 9 10pure subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b) 11 implicit none (type, external) 12 MatScalar, intent(in) :: a(0:*) 13 PetscScalar, intent(inout) :: x(0:*) 14 PetscScalar, intent(in) :: b(0:*) 15 PetscInt, intent(in) :: n 16 PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*) 17 18 PetscInt :: i,j,jstart,jend 19 PetscInt :: idx,ax,jdx 20 PetscScalar :: s(0:3) 21 22 PETSC_AssertAlignx(16,a(1)) 23 PETSC_AssertAlignx(16,x(1)) 24 PETSC_AssertAlignx(16,b(1)) 25 PETSC_AssertAlignx(16,ai(1)) 26 PETSC_AssertAlignx(16,aj(1)) 27 PETSC_AssertAlignx(16,adiag(1)) 28 29 ! 30 ! Forward Solve 31 ! 32 x(0:3) = b(0:3) 33 idx = 0 34 do i=1,n-1 35 jstart = ai(i) 36 jend = adiag(i) - 1 37 ax = 16*jstart 38 idx = idx + 4 39 s(0:3) = b(idx+0:idx+3) 40 do j=jstart,jend 41 jdx = 4*aj(j) 42 43 s(0) = s(0)-(a(ax+0)*x(jdx+0)+a(ax+4)*x(jdx+1)+a(ax+ 8)*x(jdx+2)+a(ax+12)*x(jdx+3)) 44 s(1) = s(1)-(a(ax+1)*x(jdx+0)+a(ax+5)*x(jdx+1)+a(ax+ 9)*x(jdx+2)+a(ax+13)*x(jdx+3)) 45 s(2) = s(2)-(a(ax+2)*x(jdx+0)+a(ax+6)*x(jdx+1)+a(ax+10)*x(jdx+2)+a(ax+14)*x(jdx+3)) 46 s(3) = s(3)-(a(ax+3)*x(jdx+0)+a(ax+7)*x(jdx+1)+a(ax+11)*x(jdx+2)+a(ax+15)*x(jdx+3)) 47 ax = ax + 16 48 end do 49 x(idx+0:idx+3) = s(0:3) 50 end do 51 52 ! 53 ! Backward solve the upper triangular 54 ! 55 do i=n-1,0,-1 56 jstart = adiag(i) + 1 57 jend = ai(i+1) - 1 58 ax = 16*jstart 59 s(0:3) = x(idx+0:idx+3) 60 do j=jstart,jend 61 jdx = 4*aj(j) 62 s(0) = s(0)-(a(ax+0)*x(jdx+0)+a(ax+4)*x(jdx+1)+a(ax+ 8)*x(jdx+2)+a(ax+12)*x(jdx+3)) 63 s(1) = s(1)-(a(ax+1)*x(jdx+0)+a(ax+5)*x(jdx+1)+a(ax+ 9)*x(jdx+2)+a(ax+13)*x(jdx+3)) 64 s(2) = s(2)-(a(ax+2)*x(jdx+0)+a(ax+6)*x(jdx+1)+a(ax+10)*x(jdx+2)+a(ax+14)*x(jdx+3)) 65 s(3) = s(3)-(a(ax+3)*x(jdx+0)+a(ax+7)*x(jdx+1)+a(ax+11)*x(jdx+2)+a(ax+15)*x(jdx+3)) 66 ax = ax + 16 67 end do 68 ax = 16*adiag(i) 69 x(idx+0) = a(ax+0)*s(0)+a(ax+4)*s(1)+a(ax+ 8)*s(2)+a(ax+12)*s(3) 70 x(idx+1) = a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+ 9)*s(2)+a(ax+13)*s(3) 71 x(idx+2) = a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3) 72 x(idx+3) = a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3) 73 idx = idx - 4 74 end do 75end subroutine FortranSolveBAIJ4Unroll 76 77! version that does not call BLAS 2 operation for each row block 78! 79pure subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w) 80 implicit none 81 MatScalar, intent(in) :: a(0:*) 82 PetscScalar, intent(inout) :: x(0:*),w(0:*) 83 PetscScalar, intent(in) :: b(0:*) 84 PetscInt, intent(in) :: n 85 PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*) 86 87 PetscInt :: ii,jj,i,j 88 PetscInt :: jstart,jend,idx,ax,jdx,kdx,nn 89 PetscScalar :: s(0:3) 90 91 PETSC_AssertAlignx(16,a(1)) 92 PETSC_AssertAlignx(16,w(1)) 93 PETSC_AssertAlignx(16,x(1)) 94 PETSC_AssertAlignx(16,b(1)) 95 PETSC_AssertAlignx(16,ai(1)) 96 PETSC_AssertAlignx(16,aj(1)) 97 PETSC_AssertAlignx(16,adiag(1)) 98 ! 99 ! Forward Solve 100 ! 101 x(0:3) = b(0:3) 102 idx = 0 103 do i=1,n-1 104 ! 105 ! Pack required part of vector into work array 106 ! 107 kdx = 0 108 jstart = ai(i) 109 jend = adiag(i) - 1 110 111 if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()' 112 113 do j=jstart,jend 114 jdx = 4*aj(j) 115 w(kdx:kdx+3) = x(jdx:jdx+3) 116 kdx = kdx + 4 117 end do 118 119 ax = 16*jstart 120 idx = idx + 4 121 s(0:3) = b(idx:idx+3) 122 ! 123 ! s = s - a(ax:)*w 124 ! 125 nn = 4*(jend - jstart + 1) - 1 126 do ii=0,3 127 do jj=0,nn 128 s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj) 129 end do 130 end do 131 132 x(idx:idx+3) = s(0:3) 133 end do 134 ! 135 ! Backward solve the upper triangular 136 ! 137 do i=n-1,0,-1 138 jstart = adiag(i) + 1 139 jend = ai(i+1) - 1 140 ax = 16*jstart 141 s(0:3) = x(idx:idx+3) 142 ! 143 ! Pack each chunk of vector needed 144 ! 145 kdx = 0 146 if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()' 147 148 do j=jstart,jend 149 jdx = 4*aj(j) 150 w(kdx:kdx+3) = x(jdx:jdx+3) 151 kdx = kdx + 4 152 end do 153 nn = 4*(jend - jstart + 1) - 1 154 do ii=0,3 155 do jj=0,nn 156 s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj) 157 end do 158 end do 159 160 ax = 16*adiag(i) 161 x(idx) = a(ax+0)*s(0)+a(ax+4)*s(1)+a(ax+ 8)*s(2)+a(ax+12)*s(3) 162 x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+ 9)*s(2)+a(ax+13)*s(3) 163 x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3) 164 x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3) 165 idx = idx - 4 166 end do 167end subroutine FortranSolveBAIJ4 168