xref: /petsc/src/mat/impls/hypre/mhypre_kernels.hpp (revision 3ea99036a5fedea4d39e7e77471d0ab500c249d7)
1 #ifndef MHYPRE_KERNELS_HPP
2 #define MHYPRE_KERNELS_HPP
3 
4 #include <../src/mat/impls/hypre/mhypre.h>
5 
6 // Zero the specified n rows in rows[] of the hypre CSRMatrix (i, j, a) and replace the diagonal entry with diag
7 __global__ static void ZeroRows(PetscInt n, const PetscInt rows[], const HYPRE_Int i[], const HYPRE_Int j[], HYPRE_Complex a[], HYPRE_Complex diag)
8 {
9   PetscInt k     = blockDim.x * blockIdx.x + threadIdx.x; // k-th entry in rows[]
10   PetscInt c     = blockDim.y * blockIdx.y + threadIdx.y; // c-th nonzero in row rows[k]
11   PetscInt gridx = gridDim.x * blockDim.x;
12   PetscInt gridy = gridDim.y * blockDim.y;
13   for (; k < n; k += gridx) {
14     PetscInt r  = rows[k]; // r-th row of the matrix
15     PetscInt nz = i[r + 1] - i[r];
16     for (; c < nz; c += gridy) {
17       if (r == j[i[r] + c]) a[i[r] + c] = diag;
18       else a[i[r] + c] = 0.0;
19     }
20   }
21 }
22 
23 #endif
24