xref: /petsc/src/mat/impls/baij/seq/ftn-kernels/fsolvebaij.F90 (revision d66e387e1b17bf1eafe50b0bb7df00ccc9053b5a)
1!
2!
3!    Fortran kernel for sparse triangular solve in the BAIJ matrix format
4! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
5! with MatSolve_SeqBAIJ_4_NaturalOrdering()
6!
7#include <petsc/finclude/petscsys.h>
8!
9
10pure subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
11  implicit none (type, external)
12  MatScalar, intent(in) :: a(0:*)
13  PetscScalar, intent(inout) :: x(0:*)
14  PetscScalar, intent(in) :: b(0:*)
15  PetscInt, intent(in) :: n
16  PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*)
17
18  PetscInt :: i,j,jstart,jend
19  PetscInt :: idx,ax,jdx
20  PetscScalar :: s1,s2,s3,s4
21  PetscScalar :: x1,x2,x3,x4
22
23  PETSC_AssertAlignx(16,a(1))
24  PETSC_AssertAlignx(16,x(1))
25  PETSC_AssertAlignx(16,b(1))
26  PETSC_AssertAlignx(16,ai(1))
27  PETSC_AssertAlignx(16,aj(1))
28  PETSC_AssertAlignx(16,adiag(1))
29
30  !
31  ! Forward Solve
32  !
33  x(0:3) = b(0:3)
34  idx  = 0
35  do i=1,n-1
36    jstart = ai(i)
37    jend   = adiag(i) - 1
38    ax     = 16*jstart
39    idx    = idx + 4
40    s1     = b(idx+0)
41    s2     = b(idx+1)
42    s3     = b(idx+2)
43    s4     = b(idx+3)
44    do j=jstart,jend
45      jdx = 4*aj(j)
46
47      x1  = x(jdx+0)
48      x2  = x(jdx+1)
49      x3  = x(jdx+2)
50      x4  = x(jdx+3)
51      s1  = s1-(a(ax+0)*x1+a(ax+4)*x2+a(ax+ 8)*x3+a(ax+12)*x4)
52      s2  = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
53      s3  = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
54      s4  = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
55      ax  = ax + 16
56    end do
57    x(idx+0) = s1
58    x(idx+1) = s2
59    x(idx+2) = s3
60    x(idx+3) = s4
61  end do
62
63  !
64  ! Backward solve the upper triangular
65  !
66  do i=n-1,0,-1
67    jstart = adiag(i) + 1
68    jend   = ai(i+1) - 1
69    ax     = 16*jstart
70    s1     = x(idx+0)
71    s2     = x(idx+1)
72    s3     = x(idx+2)
73    s4     = x(idx+3)
74    do j=jstart,jend
75      jdx   = 4*aj(j)
76      x1    = x(jdx+0)
77      x2    = x(jdx+1)
78      x3    = x(jdx+2)
79      x4    = x(jdx+3)
80      s1 = s1-(a(ax+0)*x1+a(ax+4)*x2+a(ax+ 8)*x3+a(ax+12)*x4)
81      s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
82      s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
83      s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
84      ax = ax + 16
85    end do
86    ax      = 16*adiag(i)
87    x(idx+0) = a(ax+0)*s1+a(ax+4)*s2+a(ax+ 8)*s3+a(ax+12)*s4
88    x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4
89    x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4
90    x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4
91    idx      = idx - 4
92  end do
93end subroutine FortranSolveBAIJ4Unroll
94
95!   version that does not call BLAS 2 operation for each row block
96!
97subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
98  implicit none
99  MatScalar, intent(in) :: a(0:*)
100  PetscScalar, intent(inout) :: x(0:*),w(0:*)
101  PetscScalar, intent(in) :: b(0:*)
102  PetscInt, intent(in) :: n
103  PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*)
104
105  PetscInt :: ii,jj,i,j
106  PetscInt :: jstart,jend,idx,ax,jdx,kdx,nn
107  PetscScalar :: s(0:3)
108
109  PETSC_AssertAlignx(16,a(1))
110  PETSC_AssertAlignx(16,w(1))
111  PETSC_AssertAlignx(16,x(1))
112  PETSC_AssertAlignx(16,b(1))
113  PETSC_AssertAlignx(16,ai(1))
114  PETSC_AssertAlignx(16,aj(1))
115  PETSC_AssertAlignx(16,adiag(1))
116  !
117  !     Forward Solve
118  !
119  x(0:3) = b(0:3)
120  idx  = 0
121  do i=1,n-1
122    !
123    ! Pack required part of vector into work array
124    !
125    kdx    = 0
126    jstart = ai(i)
127    jend   = adiag(i) - 1
128
129    if (jend - jstart >= 500) write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
130
131    do j=jstart,jend
132
133      jdx       = 4*aj(j)
134      w(kdx:kdx+3) = x(jdx:jdx+3)
135      kdx       = kdx + 4
136    end do
137
138    ax       = 16*jstart
139    idx      = idx + 4
140    s(0:3) = b(idx:idx+3)
141    !
142    !    s = s - a(ax:)*w
143    !
144    nn = 4*(jend - jstart + 1) - 1
145    do ii=0,3
146      do jj=0,nn
147        s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
148      end do
149    end do
150
151    x(idx:idx+3) = s(0:3)
152  end do
153  !
154  ! Backward solve the upper triangular
155  !
156  do i=n-1,0,-1
157     jstart    = adiag(i) + 1
158     jend      = ai(i+1) - 1
159     ax        = 16*jstart
160     s(0:3) = x(idx:idx+3)
161     !
162     !   Pack each chunk of vector needed
163     !
164     kdx = 0
165     if (jend - jstart >= 500) write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
166
167     do j=jstart,jend
168       jdx      = 4*aj(j)
169       w(kdx:kdx+3) = x(jdx:jdx+3)
170       kdx      = kdx + 4
171     end do
172     nn = 4*(jend - jstart + 1) - 1
173     do ii=0,3
174       do jj=0,nn
175         s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
176       end do
177     end do
178
179     ax      = 16*adiag(i)
180     x(idx)  = a(ax+0)*s(0)+a(ax+4)*s(1)+a(ax+ 8)*s(2)+a(ax+12)*s(3)
181     x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
182     x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
183     x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
184     idx     = idx - 4
185  end do
186end subroutine FortranSolveBAIJ4
187