xref: /petsc/src/mat/impls/baij/seq/ftn-kernels/fsolvebaij.F90 (revision f7243fda4f70acf8416ce2a83a22bcbfa57fb3c8)
1!
2!
3!    Fortran kernel for sparse triangular solve in the BAIJ matrix format
4! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
5! with MatSolve_SeqBAIJ_4_NaturalOrdering()
6!
7#include <petsc/finclude/petscsys.h>
8!
9
10pure subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
11  implicit none (type, external)
12  MatScalar, intent(in) :: a(0:*)
13  PetscScalar, intent(inout) :: x(0:*)
14  PetscScalar, intent(in) :: b(0:*)
15  PetscInt, intent(in) :: n
16  PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*)
17
18  PetscInt :: i,j,jstart,jend
19  PetscInt :: idx,ax,jdx
20  PetscScalar :: s(0:3)
21
22  PETSC_AssertAlignx(16,a(1))
23  PETSC_AssertAlignx(16,x(1))
24  PETSC_AssertAlignx(16,b(1))
25  PETSC_AssertAlignx(16,ai(1))
26  PETSC_AssertAlignx(16,aj(1))
27  PETSC_AssertAlignx(16,adiag(1))
28
29  !
30  ! Forward Solve
31  !
32  x(0:3) = b(0:3)
33  idx  = 0
34  do i=1,n-1
35    jstart = ai(i)
36    jend   = adiag(i) - 1
37    ax     = 16*jstart
38    idx    = idx + 4
39    s(0:3) = b(idx+0:idx+3)
40    do j=jstart,jend
41      jdx = 4*aj(j)
42
43      s(0) = s(0)-(a(ax+0)*x(jdx+0)+a(ax+4)*x(jdx+1)+a(ax+ 8)*x(jdx+2)+a(ax+12)*x(jdx+3))
44      s(1) = s(1)-(a(ax+1)*x(jdx+0)+a(ax+5)*x(jdx+1)+a(ax+ 9)*x(jdx+2)+a(ax+13)*x(jdx+3))
45      s(2) = s(2)-(a(ax+2)*x(jdx+0)+a(ax+6)*x(jdx+1)+a(ax+10)*x(jdx+2)+a(ax+14)*x(jdx+3))
46      s(3) = s(3)-(a(ax+3)*x(jdx+0)+a(ax+7)*x(jdx+1)+a(ax+11)*x(jdx+2)+a(ax+15)*x(jdx+3))
47      ax = ax + 16
48    end do
49    x(idx+0:idx+3) = s(0:3)
50  end do
51
52  !
53  ! Backward solve the upper triangular
54  !
55  do i=n-1,0,-1
56    jstart = adiag(i) + 1
57    jend   = ai(i+1) - 1
58    ax     = 16*jstart
59    s(0:3) = x(idx+0:idx+3)
60    do j=jstart,jend
61      jdx   = 4*aj(j)
62      s(0) = s(0)-(a(ax+0)*x(jdx+0)+a(ax+4)*x(jdx+1)+a(ax+ 8)*x(jdx+2)+a(ax+12)*x(jdx+3))
63      s(1) = s(1)-(a(ax+1)*x(jdx+0)+a(ax+5)*x(jdx+1)+a(ax+ 9)*x(jdx+2)+a(ax+13)*x(jdx+3))
64      s(2) = s(2)-(a(ax+2)*x(jdx+0)+a(ax+6)*x(jdx+1)+a(ax+10)*x(jdx+2)+a(ax+14)*x(jdx+3))
65      s(3) = s(3)-(a(ax+3)*x(jdx+0)+a(ax+7)*x(jdx+1)+a(ax+11)*x(jdx+2)+a(ax+15)*x(jdx+3))
66      ax = ax + 16
67    end do
68    ax      = 16*adiag(i)
69    x(idx+0) = a(ax+0)*s(0)+a(ax+4)*s(1)+a(ax+ 8)*s(2)+a(ax+12)*s(3)
70    x(idx+1) = a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+ 9)*s(2)+a(ax+13)*s(3)
71    x(idx+2) = a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
72    x(idx+3) = a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
73    idx      = idx - 4
74  end do
75end subroutine FortranSolveBAIJ4Unroll
76
77!   version that does not call BLAS 2 operation for each row block
78!
79pure subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
80  implicit none
81  MatScalar, intent(in) :: a(0:*)
82  PetscScalar, intent(inout) :: x(0:*),w(0:*)
83  PetscScalar, intent(in) :: b(0:*)
84  PetscInt, intent(in) :: n
85  PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*)
86
87  PetscInt :: ii,jj,i,j
88  PetscInt :: jstart,jend,idx,ax,jdx,kdx,nn
89  PetscScalar :: s(0:3)
90
91  PETSC_AssertAlignx(16,a(1))
92  PETSC_AssertAlignx(16,w(1))
93  PETSC_AssertAlignx(16,x(1))
94  PETSC_AssertAlignx(16,b(1))
95  PETSC_AssertAlignx(16,ai(1))
96  PETSC_AssertAlignx(16,aj(1))
97  PETSC_AssertAlignx(16,adiag(1))
98  !
99  !     Forward Solve
100  !
101  x(0:3) = b(0:3)
102  idx  = 0
103  do i=1,n-1
104    !
105    ! Pack required part of vector into work array
106    !
107    kdx    = 0
108    jstart = ai(i)
109    jend   = adiag(i) - 1
110
111    if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()'
112
113    do j=jstart,jend
114      jdx       = 4*aj(j)
115      w(kdx:kdx+3) = x(jdx:jdx+3)
116      kdx       = kdx + 4
117    end do
118
119    ax     = 16*jstart
120    idx    = idx + 4
121    s(0:3) = b(idx:idx+3)
122    !
123    !    s = s - a(ax:)*w
124    !
125    nn = 4*(jend - jstart + 1) - 1
126    do ii=0,3
127      do jj=0,nn
128        s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
129      end do
130    end do
131
132    x(idx:idx+3) = s(0:3)
133  end do
134  !
135  ! Backward solve the upper triangular
136  !
137  do i=n-1,0,-1
138     jstart = adiag(i) + 1
139     jend   = ai(i+1) - 1
140     ax     = 16*jstart
141     s(0:3) = x(idx:idx+3)
142     !
143     !   Pack each chunk of vector needed
144     !
145     kdx = 0
146     if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()'
147
148     do j=jstart,jend
149       jdx = 4*aj(j)
150       w(kdx:kdx+3) = x(jdx:jdx+3)
151       kdx = kdx + 4
152     end do
153     nn = 4*(jend - jstart + 1) - 1
154     do ii=0,3
155       do jj=0,nn
156         s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
157       end do
158     end do
159
160     ax      = 16*adiag(i)
161     x(idx)  = a(ax+0)*s(0)+a(ax+4)*s(1)+a(ax+ 8)*s(2)+a(ax+12)*s(3)
162     x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+ 9)*s(2)+a(ax+13)*s(3)
163     x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
164     x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
165     idx     = idx - 4
166  end do
167end subroutine FortranSolveBAIJ4
168