xref: /petsc/src/mat/impls/baij/seq/ftn-kernels/fsolvebaij.F90 (revision 0ccf82ac36648ce52b79cfc8b55f689a1594b19a)
1!
2!
3!    Fortran kernel for sparse triangular solve in the BAIJ matrix format
4! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
5! with MatSolve_SeqBAIJ_4_NaturalOrdering()
6!
7#include <petsc/finclude/petscsys.h>
8!
9
10pure subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
11  implicit none (type, external)
12  MatScalar, intent(in) :: a(0:*)
13  PetscScalar, intent(inout) :: x(0:*)
14  PetscScalar, intent(in) :: b(0:*)
15  PetscInt, intent(in) :: n
16  PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*)
17
18  PetscInt :: i,j,jstart,jend
19  PetscInt :: idx,ax,jdx
20  PetscScalar :: s1,s2,s3,s4
21  PetscScalar :: x1,x2,x3,x4
22
23  PETSC_AssertAlignx(16,a(1))
24  PETSC_AssertAlignx(16,x(1))
25  PETSC_AssertAlignx(16,b(1))
26  PETSC_AssertAlignx(16,ai(1))
27  PETSC_AssertAlignx(16,aj(1))
28  PETSC_AssertAlignx(16,adiag(1))
29
30  !
31  ! Forward Solve
32  !
33  x(0) = b(0)
34  x(1) = b(1)
35  x(2) = b(2)
36  x(3) = b(3)
37  idx  = 0
38  do i=1,n-1
39    jstart = ai(i)
40    jend   = adiag(i) - 1
41    ax     = 16*jstart
42    idx    = idx + 4
43    s1     = b(idx)
44    s2     = b(idx+1)
45    s3     = b(idx+2)
46    s4     = b(idx+3)
47    do j=jstart,jend
48      jdx = 4*aj(j)
49
50      x1  = x(jdx)
51      x2  = x(jdx+1)
52      x3  = x(jdx+2)
53      x4  = x(jdx+3)
54      s1  = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
55      s2  = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
56      s3  = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
57      s4  = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
58      ax  = ax + 16
59    end do
60    x(idx)   = s1
61    x(idx+1) = s2
62    x(idx+2) = s3
63    x(idx+3) = s4
64  end do
65
66  !
67  ! Backward solve the upper triangular
68  !
69  do i=n-1,0,-1
70    jstart  = adiag(i) + 1
71    jend    = ai(i+1) - 1
72    ax     = 16*jstart
73    s1      = x(idx)
74    s2      = x(idx+1)
75    s3      = x(idx+2)
76    s4      = x(idx+3)
77    do j=jstart,jend
78      jdx   = 4*aj(j)
79      x1    = x(jdx)
80      x2    = x(jdx+1)
81      x3    = x(jdx+2)
82      x4    = x(jdx+3)
83      s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
84      s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
85      s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
86      s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
87      ax = ax + 16
88    end do
89    ax      = 16*adiag(i)
90    x(idx)   = a(ax)*s1  +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4
91    x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4
92    x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4
93    x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4
94    idx      = idx - 4
95  end do
96end subroutine FortranSolveBAIJ4Unroll
97
98!   version that does not call BLAS 2 operation for each row block
99!
100subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
101  implicit none
102  MatScalar, intent(in) :: a(0:*)
103  PetscScalar, intent(inout) :: x(0:*),w(0:*)
104  PetscScalar, intent(in) :: b(0:*)
105  PetscInt, intent(in) :: n
106  PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*)
107
108  PetscInt :: ii,jj,i,j
109  PetscInt :: jstart,jend,idx,ax,jdx,kdx,nn
110  PetscScalar :: s(0:3)
111
112  PETSC_AssertAlignx(16,a(1))
113  PETSC_AssertAlignx(16,w(1))
114  PETSC_AssertAlignx(16,x(1))
115  PETSC_AssertAlignx(16,b(1))
116  PETSC_AssertAlignx(16,ai(1))
117  PETSC_AssertAlignx(16,aj(1))
118  PETSC_AssertAlignx(16,adiag(1))
119  !
120  !     Forward Solve
121  !
122  x(0) = b(0)
123  x(1) = b(1)
124  x(2) = b(2)
125  x(3) = b(3)
126  idx  = 0
127  do i=1,n-1
128    !
129    ! Pack required part of vector into work array
130    !
131    kdx    = 0
132    jstart = ai(i)
133    jend   = adiag(i) - 1
134    if (jend - jstart .ge. 500) then
135      write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
136    endif
137    do j=jstart,jend
138
139      jdx       = 4*aj(j)
140
141      w(kdx)    = x(jdx)
142      w(kdx+1)  = x(jdx+1)
143      w(kdx+2)  = x(jdx+2)
144      w(kdx+3)  = x(jdx+3)
145      kdx       = kdx + 4
146    end do
147
148    ax       = 16*jstart
149    idx      = idx + 4
150    s(0)     = b(idx)
151    s(1)     = b(idx+1)
152    s(2)     = b(idx+2)
153    s(3)     = b(idx+3)
154    !
155    !    s = s - a(ax:)*w
156    !
157    nn = 4*(jend - jstart + 1) - 1
158    do ii=0,3
159      do jj=0,nn
160        s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
161      end do
162    end do
163
164    x(idx)   = s(0)
165    x(idx+1) = s(1)
166    x(idx+2) = s(2)
167    x(idx+3) = s(3)
168  end do
169  !
170  ! Backward solve the upper triangular
171  !
172  do i=n-1,0,-1
173     jstart    = adiag(i) + 1
174     jend      = ai(i+1) - 1
175     ax        = 16*jstart
176     s(0)      = x(idx)
177     s(1)      = x(idx+1)
178     s(2)      = x(idx+2)
179     s(3)      = x(idx+3)
180     !
181     !   Pack each chunk of vector needed
182     !
183     kdx = 0
184     if (jend - jstart .ge. 500) then
185       write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
186     endif
187     do j=jstart,jend
188       jdx      = 4*aj(j)
189       w(kdx)   = x(jdx)
190       w(kdx+1) = x(jdx+1)
191       w(kdx+2) = x(jdx+2)
192       w(kdx+3) = x(jdx+3)
193       kdx      = kdx + 4
194     end do
195     nn = 4*(jend - jstart + 1) - 1
196     do ii=0,3
197       do jj=0,nn
198         s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
199       end do
200     end do
201
202     ax      = 16*adiag(i)
203     x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
204     x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
205     x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
206     x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
207     idx     = idx - 4
208  end do
209end subroutine FortranSolveBAIJ4
210