1*a4963045SJacob Faibussowitsch #pragma once 2a32e9c99SJunchao Zhang 3a32e9c99SJunchao Zhang #include <../src/mat/impls/hypre/mhypre.h> 4a32e9c99SJunchao Zhang 5a32e9c99SJunchao Zhang // Zero the specified n rows in rows[] of the hypre CSRMatrix (i, j, a) and replace the diagonal entry with diag 6a32e9c99SJunchao Zhang __global__ static void ZeroRows(PetscInt n, const PetscInt rows[], const HYPRE_Int i[], const HYPRE_Int j[], HYPRE_Complex a[], HYPRE_Complex diag) 7a32e9c99SJunchao Zhang { 8a32e9c99SJunchao Zhang PetscInt k = blockDim.x * blockIdx.x + threadIdx.x; // k-th entry in rows[] 9a32e9c99SJunchao Zhang PetscInt c = blockDim.y * blockIdx.y + threadIdx.y; // c-th nonzero in row rows[k] 10a32e9c99SJunchao Zhang PetscInt gridx = gridDim.x * blockDim.x; 11a32e9c99SJunchao Zhang PetscInt gridy = gridDim.y * blockDim.y; 12a32e9c99SJunchao Zhang for (; k < n; k += gridx) { 13a32e9c99SJunchao Zhang PetscInt r = rows[k]; // r-th row of the matrix 14a32e9c99SJunchao Zhang PetscInt nz = i[r + 1] - i[r]; 15a32e9c99SJunchao Zhang for (; c < nz; c += gridy) { 16a32e9c99SJunchao Zhang if (r == j[i[r] + c]) a[i[r] + c] = diag; 17a32e9c99SJunchao Zhang else a[i[r] + c] = 0.0; 18a32e9c99SJunchao Zhang } 19a32e9c99SJunchao Zhang } 20a32e9c99SJunchao Zhang } 21