xref: /petsc/src/mat/impls/baij/seq/ftn-kernels/fsolvebaij.F90 (revision 0113e7197f2fe148981cb35224c0458dc22f510d)
1!
2!
3!    Fortran kernel for sparse triangular solve in the BAIJ matrix format
4! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
5! with MatSolve_SeqBAIJ_4_NaturalOrdering()
6!
7#include <petsc/finclude/petscsys.h>
8!
9
10subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
11  implicit none
12  MatScalar   a(0:*)
13  PetscScalar x(0:*)
14  PetscScalar b(0:*)
15  PetscInt    n
16  PetscInt    ai(0:*)
17  PetscInt    aj(0:*)
18  PetscInt    adiag(0:*)
19
20  PetscInt    i,j,jstart,jend
21  PetscInt    idx,ax,jdx
22  PetscScalar s1,s2,s3,s4
23  PetscScalar x1,x2,x3,x4
24  !
25  ! Forward Solve
26  !
27  PETSC_AssertAlignx(16,a(1))
28  PETSC_AssertAlignx(16,x(1))
29  PETSC_AssertAlignx(16,b(1))
30  PETSC_AssertAlignx(16,ai(1))
31  PETSC_AssertAlignx(16,aj(1))
32  PETSC_AssertAlignx(16,adiag(1))
33
34  x(0) = b(0)
35  x(1) = b(1)
36  x(2) = b(2)
37  x(3) = b(3)
38  idx  = 0
39  do i=1,n-1
40    jstart = ai(i)
41    jend   = adiag(i) - 1
42    ax     = 16*jstart
43    idx    = idx + 4
44    s1     = b(idx)
45    s2     = b(idx+1)
46    s3     = b(idx+2)
47    s4     = b(idx+3)
48    do j=jstart,jend
49      jdx = 4*aj(j)
50
51      x1  = x(jdx)
52      x2  = x(jdx+1)
53      x3  = x(jdx+2)
54      x4  = x(jdx+3)
55      s1  = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
56      s2  = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
57      s3  = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
58      s4  = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
59      ax  = ax + 16
60    end do
61    x(idx)   = s1
62    x(idx+1) = s2
63    x(idx+2) = s3
64    x(idx+3) = s4
65  end do
66
67  !
68  ! Backward solve the upper triangular
69  !
70  do i=n-1,0,-1
71    jstart  = adiag(i) + 1
72    jend    = ai(i+1) - 1
73    ax     = 16*jstart
74    s1      = x(idx)
75    s2      = x(idx+1)
76    s3      = x(idx+2)
77    s4      = x(idx+3)
78    do j=jstart,jend
79      jdx   = 4*aj(j)
80      x1    = x(jdx)
81      x2    = x(jdx+1)
82      x3    = x(jdx+2)
83      x4    = x(jdx+3)
84      s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
85      s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
86      s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
87      s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
88      ax = ax + 16
89    end do
90    ax      = 16*adiag(i)
91    x(idx)   = a(ax)*s1  +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4
92    x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4
93    x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4
94    x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4
95    idx      = idx - 4
96  end do
97end subroutine FortranSolveBAIJ4Unroll
98
99!   version that does not call BLAS 2 operation for each row block
100!
101subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
102  implicit none
103  MatScalar   a(0:*)
104  PetscScalar x(0:*),b(0:*),w(0:*)
105  PetscInt  n,ai(0:*),aj(0:*),adiag(0:*)
106  PetscInt  ii,jj,i,j
107
108  PetscInt  jstart,jend,idx,ax,jdx,kdx,nn
109  PetscScalar s(0:3)
110
111  PETSC_AssertAlignx(16,a(1))
112  PETSC_AssertAlignx(16,w(1))
113  PETSC_AssertAlignx(16,x(1))
114  PETSC_AssertAlignx(16,b(1))
115  PETSC_AssertAlignx(16,ai(1))
116  PETSC_AssertAlignx(16,aj(1))
117  PETSC_AssertAlignx(16,adiag(1))
118  !
119  !     Forward Solve
120  !
121  x(0) = b(0)
122  x(1) = b(1)
123  x(2) = b(2)
124  x(3) = b(3)
125  idx  = 0
126  do i=1,n-1
127    !
128    ! Pack required part of vector into work array
129    !
130    kdx    = 0
131    jstart = ai(i)
132    jend   = adiag(i) - 1
133    if (jend - jstart .ge. 500) then
134      write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
135    endif
136    do j=jstart,jend
137
138      jdx       = 4*aj(j)
139
140      w(kdx)    = x(jdx)
141      w(kdx+1)  = x(jdx+1)
142      w(kdx+2)  = x(jdx+2)
143      w(kdx+3)  = x(jdx+3)
144      kdx       = kdx + 4
145    end do
146
147    ax       = 16*jstart
148    idx      = idx + 4
149    s(0)     = b(idx)
150    s(1)     = b(idx+1)
151    s(2)     = b(idx+2)
152    s(3)     = b(idx+3)
153    !
154    !    s = s - a(ax:)*w
155    !
156    nn = 4*(jend - jstart + 1) - 1
157    do ii=0,3
158      do jj=0,nn
159        s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
160      end do
161    end do
162
163    x(idx)   = s(0)
164    x(idx+1) = s(1)
165    x(idx+2) = s(2)
166    x(idx+3) = s(3)
167  end do
168  !
169  ! Backward solve the upper triangular
170  !
171  do i=n-1,0,-1
172     jstart    = adiag(i) + 1
173     jend      = ai(i+1) - 1
174     ax        = 16*jstart
175     s(0)      = x(idx)
176     s(1)      = x(idx+1)
177     s(2)      = x(idx+2)
178     s(3)      = x(idx+3)
179     !
180     !   Pack each chunk of vector needed
181     !
182     kdx = 0
183     if (jend - jstart .ge. 500) then
184       write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
185     endif
186     do j=jstart,jend
187       jdx      = 4*aj(j)
188       w(kdx)   = x(jdx)
189       w(kdx+1) = x(jdx+1)
190       w(kdx+2) = x(jdx+2)
191       w(kdx+3) = x(jdx+3)
192       kdx      = kdx + 4
193     end do
194     nn = 4*(jend - jstart + 1) - 1
195     do ii=0,3
196       do jj=0,nn
197         s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
198       end do
199     end do
200
201     ax      = 16*adiag(i)
202     x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
203     x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
204     x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
205     x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
206     idx     = idx - 4
207  end do
208
209end subroutine FortranSolveBAIJ4
210