xref: /petsc/src/mat/impls/hypre/kokkos/hypre3.kokkos.cxx (revision a32e9c995d3c9cc14233efbb30d372fdb63ce962)
1*a32e9c99SJunchao Zhang #include <HYPRE_utilities.h>
2*a32e9c99SJunchao Zhang #include <Kokkos_Core.hpp>
3*a32e9c99SJunchao Zhang #include <../src/mat/impls/hypre/mhypre.h>
4*a32e9c99SJunchao Zhang 
5*a32e9c99SJunchao Zhang PetscErrorCode MatZeroRows_Kokkos(PetscInt n, const PetscInt rows[], const HYPRE_Int i[], const HYPRE_Int j[], HYPRE_Complex a[], HYPRE_Complex diag)
6*a32e9c99SJunchao Zhang {
7*a32e9c99SJunchao Zhang   PetscFunctionBegin;
8*a32e9c99SJunchao Zhang   if (!n) PetscFunctionReturn(PETSC_SUCCESS);
9*a32e9c99SJunchao Zhang   PetscCall(PetscKokkosInitializeCheck()); // As we might have not created any petsc/kokkos object yet
10*a32e9c99SJunchao Zhang   Kokkos::parallel_for(
11*a32e9c99SJunchao Zhang     Kokkos::TeamPolicy<>(n, Kokkos::AUTO()), KOKKOS_LAMBDA(const Kokkos::TeamPolicy<>::member_type &t) {
12*a32e9c99SJunchao Zhang       PetscInt r = rows[t.league_rank()]; // row r
13*a32e9c99SJunchao Zhang       Kokkos::parallel_for(Kokkos::TeamThreadRange(t, i[r + 1] - i[r]), [&](PetscInt c) {
14*a32e9c99SJunchao Zhang         if (r == j[i[r] + c]) a[i[r] + c] = diag;
15*a32e9c99SJunchao Zhang         else a[i[r] + c] = 0.0;
16*a32e9c99SJunchao Zhang       });
17*a32e9c99SJunchao Zhang     });
18*a32e9c99SJunchao Zhang   PetscFunctionReturn(PETSC_SUCCESS);
19*a32e9c99SJunchao Zhang }
20