1914b7a73SJunchao Zhang #include <../src/vec/is/sf/impls/basic/sfpack.h> 2914b7a73SJunchao Zhang 3914b7a73SJunchao Zhang #include <Kokkos_Core.hpp> 4914b7a73SJunchao Zhang 5914b7a73SJunchao Zhang using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace; 6914b7a73SJunchao Zhang using DeviceMemorySpace = typename DeviceExecutionSpace::memory_space; 7914b7a73SJunchao Zhang using HostMemorySpace = Kokkos::HostSpace; 8914b7a73SJunchao Zhang 9914b7a73SJunchao Zhang typedef Kokkos::View<char *, DeviceMemorySpace> deviceBuffer_t; 10914b7a73SJunchao Zhang typedef Kokkos::View<char *, HostMemorySpace> HostBuffer_t; 11914b7a73SJunchao Zhang 12914b7a73SJunchao Zhang typedef Kokkos::View<const char *, DeviceMemorySpace> deviceConstBuffer_t; 13914b7a73SJunchao Zhang typedef Kokkos::View<const char *, HostMemorySpace> HostConstBuffer_t; 14914b7a73SJunchao Zhang 15914b7a73SJunchao Zhang /*====================================================================================*/ 16914b7a73SJunchao Zhang /* Regular operations */ 17914b7a73SJunchao Zhang /*====================================================================================*/ 18*9371c9d4SSatish Balay template <typename Type> 19*9371c9d4SSatish Balay struct Insert { 20*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 21*9371c9d4SSatish Balay Type old = x; 22*9371c9d4SSatish Balay x = y; 23*9371c9d4SSatish Balay return old; 24*9371c9d4SSatish Balay } 25*9371c9d4SSatish Balay }; 26*9371c9d4SSatish Balay template <typename Type> 27*9371c9d4SSatish Balay struct Add { 28*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 29*9371c9d4SSatish Balay Type old = x; 30*9371c9d4SSatish Balay x += y; 31*9371c9d4SSatish Balay return old; 32*9371c9d4SSatish Balay } 33*9371c9d4SSatish Balay }; 34*9371c9d4SSatish Balay template <typename Type> 35*9371c9d4SSatish Balay struct Mult { 36*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 37*9371c9d4SSatish Balay Type old = x; 38*9371c9d4SSatish Balay x *= y; 39*9371c9d4SSatish Balay return old; 40*9371c9d4SSatish Balay } 41*9371c9d4SSatish Balay }; 42*9371c9d4SSatish Balay template <typename Type> 43*9371c9d4SSatish Balay struct Min { 44*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 45*9371c9d4SSatish Balay Type old = x; 46*9371c9d4SSatish Balay x = PetscMin(x, y); 47*9371c9d4SSatish Balay return old; 48*9371c9d4SSatish Balay } 49*9371c9d4SSatish Balay }; 50*9371c9d4SSatish Balay template <typename Type> 51*9371c9d4SSatish Balay struct Max { 52*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 53*9371c9d4SSatish Balay Type old = x; 54*9371c9d4SSatish Balay x = PetscMax(x, y); 55*9371c9d4SSatish Balay return old; 56*9371c9d4SSatish Balay } 57*9371c9d4SSatish Balay }; 58*9371c9d4SSatish Balay template <typename Type> 59*9371c9d4SSatish Balay struct LAND { 60*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 61*9371c9d4SSatish Balay Type old = x; 62*9371c9d4SSatish Balay x = x && y; 63*9371c9d4SSatish Balay return old; 64*9371c9d4SSatish Balay } 65*9371c9d4SSatish Balay }; 66*9371c9d4SSatish Balay template <typename Type> 67*9371c9d4SSatish Balay struct LOR { 68*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 69*9371c9d4SSatish Balay Type old = x; 70*9371c9d4SSatish Balay x = x || y; 71*9371c9d4SSatish Balay return old; 72*9371c9d4SSatish Balay } 73*9371c9d4SSatish Balay }; 74*9371c9d4SSatish Balay template <typename Type> 75*9371c9d4SSatish Balay struct LXOR { 76*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 77*9371c9d4SSatish Balay Type old = x; 78*9371c9d4SSatish Balay x = !x != !y; 79*9371c9d4SSatish Balay return old; 80*9371c9d4SSatish Balay } 81*9371c9d4SSatish Balay }; 82*9371c9d4SSatish Balay template <typename Type> 83*9371c9d4SSatish Balay struct BAND { 84*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 85*9371c9d4SSatish Balay Type old = x; 86*9371c9d4SSatish Balay x = x & y; 87*9371c9d4SSatish Balay return old; 88*9371c9d4SSatish Balay } 89*9371c9d4SSatish Balay }; 90*9371c9d4SSatish Balay template <typename Type> 91*9371c9d4SSatish Balay struct BOR { 92*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 93*9371c9d4SSatish Balay Type old = x; 94*9371c9d4SSatish Balay x = x | y; 95*9371c9d4SSatish Balay return old; 96*9371c9d4SSatish Balay } 97*9371c9d4SSatish Balay }; 98*9371c9d4SSatish Balay template <typename Type> 99*9371c9d4SSatish Balay struct BXOR { 100*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { 101*9371c9d4SSatish Balay Type old = x; 102*9371c9d4SSatish Balay x = x ^ y; 103*9371c9d4SSatish Balay return old; 104*9371c9d4SSatish Balay } 105*9371c9d4SSatish Balay }; 106*9371c9d4SSatish Balay template <typename PairType> 107*9371c9d4SSatish Balay struct Minloc { 108914b7a73SJunchao Zhang KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const { 109914b7a73SJunchao Zhang PairType old = x; 110914b7a73SJunchao Zhang if (y.first < x.first) x = y; 111914b7a73SJunchao Zhang else if (y.first == x.first) x.second = PetscMin(x.second, y.second); 112914b7a73SJunchao Zhang return old; 113914b7a73SJunchao Zhang } 114914b7a73SJunchao Zhang }; 115*9371c9d4SSatish Balay template <typename PairType> 116*9371c9d4SSatish Balay struct Maxloc { 117914b7a73SJunchao Zhang KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const { 118914b7a73SJunchao Zhang PairType old = x; 119914b7a73SJunchao Zhang if (y.first > x.first) x = y; 120914b7a73SJunchao Zhang else if (y.first == x.first) x.second = PetscMin(x.second, y.second); /* See MPI MAXLOC */ 121914b7a73SJunchao Zhang return old; 122914b7a73SJunchao Zhang } 123914b7a73SJunchao Zhang }; 124914b7a73SJunchao Zhang 125914b7a73SJunchao Zhang /*====================================================================================*/ 126914b7a73SJunchao Zhang /* Atomic operations */ 127914b7a73SJunchao Zhang /*====================================================================================*/ 128*9371c9d4SSatish Balay template <typename Type> 129*9371c9d4SSatish Balay struct AtomicInsert { 130*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_assign(&x, y); } 131*9371c9d4SSatish Balay }; 132*9371c9d4SSatish Balay template <typename Type> 133*9371c9d4SSatish Balay struct AtomicAdd { 134*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_add(&x, y); } 135*9371c9d4SSatish Balay }; 136*9371c9d4SSatish Balay template <typename Type> 137*9371c9d4SSatish Balay struct AtomicBAND { 138*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_and(&x, y); } 139*9371c9d4SSatish Balay }; 140*9371c9d4SSatish Balay template <typename Type> 141*9371c9d4SSatish Balay struct AtomicBOR { 142*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_or(&x, y); } 143*9371c9d4SSatish Balay }; 144*9371c9d4SSatish Balay template <typename Type> 145*9371c9d4SSatish Balay struct AtomicBXOR { 146*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_xor(&x, y); } 147*9371c9d4SSatish Balay }; 148*9371c9d4SSatish Balay template <typename Type> 149*9371c9d4SSatish Balay struct AtomicLAND { 150*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { 151*9371c9d4SSatish Balay const Type zero = 0, one = ~0; 152*9371c9d4SSatish Balay Kokkos::atomic_and(&x, y ? one : zero); 153*9371c9d4SSatish Balay } 154*9371c9d4SSatish Balay }; 155*9371c9d4SSatish Balay template <typename Type> 156*9371c9d4SSatish Balay struct AtomicLOR { 157*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { 158*9371c9d4SSatish Balay const Type zero = 0, one = 1; 159*9371c9d4SSatish Balay Kokkos::atomic_or(&x, y ? one : zero); 160*9371c9d4SSatish Balay } 161*9371c9d4SSatish Balay }; 162*9371c9d4SSatish Balay template <typename Type> 163*9371c9d4SSatish Balay struct AtomicMult { 164*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_mul(&x, y); } 165*9371c9d4SSatish Balay }; 166*9371c9d4SSatish Balay template <typename Type> 167*9371c9d4SSatish Balay struct AtomicMin { 168*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_min(&x, y); } 169*9371c9d4SSatish Balay }; 170*9371c9d4SSatish Balay template <typename Type> 171*9371c9d4SSatish Balay struct AtomicMax { 172*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_max(&x, y); } 173*9371c9d4SSatish Balay }; 174914b7a73SJunchao Zhang /* TODO: struct AtomicLXOR */ 175*9371c9d4SSatish Balay template <typename Type> 176*9371c9d4SSatish Balay struct AtomicFetchAdd { 177*9371c9d4SSatish Balay KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { return Kokkos::atomic_fetch_add(&x, y); } 178*9371c9d4SSatish Balay }; 179914b7a73SJunchao Zhang 180914b7a73SJunchao Zhang /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */ 181*9371c9d4SSatish Balay static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid) { 182914b7a73SJunchao Zhang PetscInt i, j, k, m, n, r; 183914b7a73SJunchao Zhang const PetscInt *offset, *start, *dx, *dy, *X, *Y; 184914b7a73SJunchao Zhang 185914b7a73SJunchao Zhang n = opt[0]; 186914b7a73SJunchao Zhang offset = opt + 1; 187914b7a73SJunchao Zhang start = opt + n + 2; 188914b7a73SJunchao Zhang dx = opt + 2 * n + 2; 189914b7a73SJunchao Zhang dy = opt + 3 * n + 2; 190914b7a73SJunchao Zhang X = opt + 5 * n + 2; 191914b7a73SJunchao Zhang Y = opt + 6 * n + 2; 192*9371c9d4SSatish Balay for (r = 0; r < n; r++) { 193*9371c9d4SSatish Balay if (tid < offset[r + 1]) break; 194*9371c9d4SSatish Balay } 195914b7a73SJunchao Zhang m = (tid - offset[r]); 196914b7a73SJunchao Zhang k = m / (dx[r] * dy[r]); 197914b7a73SJunchao Zhang j = (m - k * dx[r] * dy[r]) / dx[r]; 198914b7a73SJunchao Zhang i = m - k * dx[r] * dy[r] - j * dx[r]; 199914b7a73SJunchao Zhang 200914b7a73SJunchao Zhang return (start[r] + k * X[r] * Y[r] + j * X[r] + i); 201914b7a73SJunchao Zhang } 202914b7a73SJunchao Zhang 203914b7a73SJunchao Zhang /*====================================================================================*/ 204914b7a73SJunchao Zhang /* Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link' */ 205914b7a73SJunchao Zhang /*====================================================================================*/ 206914b7a73SJunchao Zhang 207914b7a73SJunchao Zhang /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then 208914b7a73SJunchao Zhang <Type> is PetscReal, which is the primitive type we operate on. 209914b7a73SJunchao Zhang <bs> is 16, which says <unit> contains 16 primitive types. 210914b7a73SJunchao Zhang <BS> is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>. 211914b7a73SJunchao Zhang <EQ> is 0, which is (bs == BS ? 1 : 0) 212914b7a73SJunchao Zhang 213914b7a73SJunchao Zhang If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant. 214914b7a73SJunchao Zhang For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled. 215914b7a73SJunchao Zhang */ 216914b7a73SJunchao Zhang template <typename Type, PetscInt BS, PetscInt EQ> 217*9371c9d4SSatish Balay static PetscErrorCode Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data_, void *buf_) { 218914b7a73SJunchao Zhang const PetscInt *iopt = opt ? opt->array : NULL; 219914b7a73SJunchao Zhang const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS; /* If EQ, then MBS will be a compile-time const */ 220914b7a73SJunchao Zhang const Type *data = static_cast<const Type *>(data_); 221914b7a73SJunchao Zhang Type *buf = static_cast<Type *>(buf_); 222f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 223914b7a73SJunchao Zhang 224914b7a73SJunchao Zhang PetscFunctionBegin; 225*9371c9d4SSatish Balay Kokkos::parallel_for( 226*9371c9d4SSatish Balay Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { 227914b7a73SJunchao Zhang /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous; 228914b7a73SJunchao Zhang iopt == NULL && idx == NULL ==> the indices are contiguous; 229914b7a73SJunchao Zhang */ 230914b7a73SJunchao Zhang PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS; 231914b7a73SJunchao Zhang PetscInt s = tid * MBS; 232914b7a73SJunchao Zhang for (int i = 0; i < MBS; i++) buf[s + i] = data[t + i]; 233914b7a73SJunchao Zhang }); 234914b7a73SJunchao Zhang PetscFunctionReturn(0); 235914b7a73SJunchao Zhang } 236914b7a73SJunchao Zhang 237914b7a73SJunchao Zhang template <typename Type, class Op, PetscInt BS, PetscInt EQ> 238*9371c9d4SSatish Balay static PetscErrorCode UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data_, const void *buf_) { 239914b7a73SJunchao Zhang Op op; 240914b7a73SJunchao Zhang const PetscInt *iopt = opt ? opt->array : NULL; 241914b7a73SJunchao Zhang const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS; 242914b7a73SJunchao Zhang Type *data = static_cast<Type *>(data_); 243914b7a73SJunchao Zhang const Type *buf = static_cast<const Type *>(buf_); 244f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 245914b7a73SJunchao Zhang 246914b7a73SJunchao Zhang PetscFunctionBegin; 247*9371c9d4SSatish Balay Kokkos::parallel_for( 248*9371c9d4SSatish Balay Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { 249914b7a73SJunchao Zhang PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS; 250914b7a73SJunchao Zhang PetscInt s = tid * MBS; 251914b7a73SJunchao Zhang for (int i = 0; i < MBS; i++) op(data[t + i], buf[s + i]); 252914b7a73SJunchao Zhang }); 253914b7a73SJunchao Zhang PetscFunctionReturn(0); 254914b7a73SJunchao Zhang } 255914b7a73SJunchao Zhang 256914b7a73SJunchao Zhang template <typename Type, class Op, PetscInt BS, PetscInt EQ> 257*9371c9d4SSatish Balay static PetscErrorCode FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf) { 258914b7a73SJunchao Zhang Op op; 259914b7a73SJunchao Zhang const PetscInt *ropt = opt ? opt->array : NULL; 260914b7a73SJunchao Zhang const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS; 261914b7a73SJunchao Zhang Type *rootdata = static_cast<Type *>(data), *leafbuf = static_cast<Type *>(buf); 262f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 263914b7a73SJunchao Zhang 264914b7a73SJunchao Zhang PetscFunctionBegin; 265*9371c9d4SSatish Balay Kokkos::parallel_for( 266*9371c9d4SSatish Balay Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { 267914b7a73SJunchao Zhang PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (idx ? idx[tid] : start + tid)) * MBS; 268914b7a73SJunchao Zhang PetscInt l = tid * MBS; 269914b7a73SJunchao Zhang for (int i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]); 270914b7a73SJunchao Zhang }); 271914b7a73SJunchao Zhang PetscFunctionReturn(0); 272914b7a73SJunchao Zhang } 273914b7a73SJunchao Zhang 274914b7a73SJunchao Zhang template <typename Type, class Op, PetscInt BS, PetscInt EQ> 275*9371c9d4SSatish Balay static PetscErrorCode ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_) { 276914b7a73SJunchao Zhang PetscInt srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0; 277914b7a73SJunchao Zhang const PetscInt M = (EQ) ? 1 : link->bs / BS, MBS = M * BS; 278914b7a73SJunchao Zhang const Type *src = static_cast<const Type *>(src_); 279914b7a73SJunchao Zhang Type *dst = static_cast<Type *>(dst_); 280f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 281914b7a73SJunchao Zhang 282914b7a73SJunchao Zhang PetscFunctionBegin; 283914b7a73SJunchao Zhang /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */ 284*9371c9d4SSatish Balay if (srcOpt) { 285*9371c9d4SSatish Balay srcx = srcOpt->dx[0]; 286*9371c9d4SSatish Balay srcy = srcOpt->dy[0]; 287*9371c9d4SSatish Balay srcX = srcOpt->X[0]; 288*9371c9d4SSatish Balay srcY = srcOpt->Y[0]; 289*9371c9d4SSatish Balay srcStart = srcOpt->start[0]; 290*9371c9d4SSatish Balay srcIdx = NULL; 291*9371c9d4SSatish Balay } else if (!srcIdx) { 292*9371c9d4SSatish Balay srcx = srcX = count; 293*9371c9d4SSatish Balay srcy = srcY = 1; 294*9371c9d4SSatish Balay } 295914b7a73SJunchao Zhang 296*9371c9d4SSatish Balay if (dstOpt) { 297*9371c9d4SSatish Balay dstx = dstOpt->dx[0]; 298*9371c9d4SSatish Balay dsty = dstOpt->dy[0]; 299*9371c9d4SSatish Balay dstX = dstOpt->X[0]; 300*9371c9d4SSatish Balay dstY = dstOpt->Y[0]; 301*9371c9d4SSatish Balay dstStart = dstOpt->start[0]; 302*9371c9d4SSatish Balay dstIdx = NULL; 303*9371c9d4SSatish Balay } else if (!dstIdx) { 304*9371c9d4SSatish Balay dstx = dstX = count; 305*9371c9d4SSatish Balay dsty = dstY = 1; 306*9371c9d4SSatish Balay } 307914b7a73SJunchao Zhang 308*9371c9d4SSatish Balay Kokkos::parallel_for( 309*9371c9d4SSatish Balay Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { 310914b7a73SJunchao Zhang PetscInt i, j, k, s, t; 311914b7a73SJunchao Zhang Op op; 312914b7a73SJunchao Zhang if (!srcIdx) { /* src is in 3D */ 313914b7a73SJunchao Zhang k = tid / (srcx * srcy); 314914b7a73SJunchao Zhang j = (tid - k * srcx * srcy) / srcx; 315914b7a73SJunchao Zhang i = tid - k * srcx * srcy - j * srcx; 316914b7a73SJunchao Zhang s = srcStart + k * srcX * srcY + j * srcX + i; 317914b7a73SJunchao Zhang } else { /* src is contiguous */ 318914b7a73SJunchao Zhang s = srcIdx[tid]; 319914b7a73SJunchao Zhang } 320914b7a73SJunchao Zhang 321914b7a73SJunchao Zhang if (!dstIdx) { /* 3D */ 322914b7a73SJunchao Zhang k = tid / (dstx * dsty); 323914b7a73SJunchao Zhang j = (tid - k * dstx * dsty) / dstx; 324914b7a73SJunchao Zhang i = tid - k * dstx * dsty - j * dstx; 325914b7a73SJunchao Zhang t = dstStart + k * dstX * dstY + j * dstX + i; 326914b7a73SJunchao Zhang } else { /* contiguous */ 327914b7a73SJunchao Zhang t = dstIdx[tid]; 328914b7a73SJunchao Zhang } 329914b7a73SJunchao Zhang 330914b7a73SJunchao Zhang s *= MBS; 331914b7a73SJunchao Zhang t *= MBS; 332914b7a73SJunchao Zhang for (i = 0; i < MBS; i++) op(dst[t + i], src[s + i]); 333914b7a73SJunchao Zhang }); 334914b7a73SJunchao Zhang PetscFunctionReturn(0); 335914b7a73SJunchao Zhang } 336914b7a73SJunchao Zhang 337914b7a73SJunchao Zhang /* Specialization for Insert since we may use memcpy */ 338914b7a73SJunchao Zhang template <typename Type, PetscInt BS, PetscInt EQ> 339*9371c9d4SSatish Balay static PetscErrorCode ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_) { 340914b7a73SJunchao Zhang const Type *src = static_cast<const Type *>(src_); 341914b7a73SJunchao Zhang Type *dst = static_cast<Type *>(dst_); 342f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 343914b7a73SJunchao Zhang 344914b7a73SJunchao Zhang PetscFunctionBegin; 345914b7a73SJunchao Zhang if (!count) PetscFunctionReturn(0); 346914b7a73SJunchao Zhang /*src and dst are contiguous */ 347914b7a73SJunchao Zhang if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) { 348914b7a73SJunchao Zhang size_t sz = count * link->unitbytes; 349914b7a73SJunchao Zhang deviceBuffer_t dbuf(reinterpret_cast<char *>(dst + dstStart * link->bs), sz); 350914b7a73SJunchao Zhang deviceConstBuffer_t sbuf(reinterpret_cast<const char *>(src + srcStart * link->bs), sz); 351914b7a73SJunchao Zhang Kokkos::deep_copy(exec, dbuf, sbuf); 352914b7a73SJunchao Zhang } else { 3539566063dSJacob Faibussowitsch PetscCall(ScatterAndOp<Type, Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst)); 354914b7a73SJunchao Zhang } 355914b7a73SJunchao Zhang PetscFunctionReturn(0); 356914b7a73SJunchao Zhang } 357914b7a73SJunchao Zhang 358914b7a73SJunchao Zhang template <typename Type, class Op, PetscInt BS, PetscInt EQ> 359*9371c9d4SSatish Balay static PetscErrorCode FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata_, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata_, void *leafupdate_) { 360914b7a73SJunchao Zhang Op op; 361914b7a73SJunchao Zhang const PetscInt M = (EQ) ? 1 : link->bs / BS, MBS = M * BS; 362914b7a73SJunchao Zhang const PetscInt *ropt = rootopt ? rootopt->array : NULL; 363914b7a73SJunchao Zhang const PetscInt *lopt = leafopt ? leafopt->array : NULL; 364914b7a73SJunchao Zhang Type *rootdata = static_cast<Type *>(rootdata_), *leafupdate = static_cast<Type *>(leafupdate_); 365914b7a73SJunchao Zhang const Type *leafdata = static_cast<const Type *>(leafdata_); 366f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 367914b7a73SJunchao Zhang 368914b7a73SJunchao Zhang PetscFunctionBegin; 369*9371c9d4SSatish Balay Kokkos::parallel_for( 370*9371c9d4SSatish Balay Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) { 371914b7a73SJunchao Zhang PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS; 372914b7a73SJunchao Zhang PetscInt l = (lopt ? MapTidToIndex(lopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS; 373914b7a73SJunchao Zhang for (int i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]); 374914b7a73SJunchao Zhang }); 375914b7a73SJunchao Zhang PetscFunctionReturn(0); 376914b7a73SJunchao Zhang } 377914b7a73SJunchao Zhang 378914b7a73SJunchao Zhang /*====================================================================================*/ 379914b7a73SJunchao Zhang /* Init various types and instantiate pack/unpack function pointers */ 380914b7a73SJunchao Zhang /*====================================================================================*/ 381914b7a73SJunchao Zhang template <typename Type, PetscInt BS, PetscInt EQ> 382*9371c9d4SSatish Balay static void PackInit_RealType(PetscSFLink link) { 383914b7a73SJunchao Zhang /* Pack/unpack for remote communication */ 384914b7a73SJunchao Zhang link->d_Pack = Pack<Type, BS, EQ>; 385914b7a73SJunchao Zhang link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>; 386914b7a73SJunchao Zhang link->d_UnpackAndAdd = UnpackAndOp<Type, Add<Type>, BS, EQ>; 387914b7a73SJunchao Zhang link->d_UnpackAndMult = UnpackAndOp<Type, Mult<Type>, BS, EQ>; 388914b7a73SJunchao Zhang link->d_UnpackAndMin = UnpackAndOp<Type, Min<Type>, BS, EQ>; 389914b7a73SJunchao Zhang link->d_UnpackAndMax = UnpackAndOp<Type, Max<Type>, BS, EQ>; 390914b7a73SJunchao Zhang link->d_FetchAndAdd = FetchAndOp<Type, Add<Type>, BS, EQ>; 391914b7a73SJunchao Zhang /* Scatter for local communication */ 392914b7a73SJunchao Zhang link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */ 393914b7a73SJunchao Zhang link->d_ScatterAndAdd = ScatterAndOp<Type, Add<Type>, BS, EQ>; 394914b7a73SJunchao Zhang link->d_ScatterAndMult = ScatterAndOp<Type, Mult<Type>, BS, EQ>; 395914b7a73SJunchao Zhang link->d_ScatterAndMin = ScatterAndOp<Type, Min<Type>, BS, EQ>; 396914b7a73SJunchao Zhang link->d_ScatterAndMax = ScatterAndOp<Type, Max<Type>, BS, EQ>; 397914b7a73SJunchao Zhang link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>; 398914b7a73SJunchao Zhang /* Atomic versions when there are data-race possibilities */ 399914b7a73SJunchao Zhang link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>; 400914b7a73SJunchao Zhang link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>; 401914b7a73SJunchao Zhang link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>; 402914b7a73SJunchao Zhang link->da_UnpackAndMin = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>; 403914b7a73SJunchao Zhang link->da_UnpackAndMax = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>; 404914b7a73SJunchao Zhang link->da_FetchAndAdd = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>; 405914b7a73SJunchao Zhang 406914b7a73SJunchao Zhang link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>; 407914b7a73SJunchao Zhang link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>; 408914b7a73SJunchao Zhang link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>; 409914b7a73SJunchao Zhang link->da_ScatterAndMin = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>; 410914b7a73SJunchao Zhang link->da_ScatterAndMax = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>; 411914b7a73SJunchao Zhang link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>; 412914b7a73SJunchao Zhang } 413914b7a73SJunchao Zhang 414914b7a73SJunchao Zhang template <typename Type, PetscInt BS, PetscInt EQ> 415*9371c9d4SSatish Balay static void PackInit_IntegerType(PetscSFLink link) { 416914b7a73SJunchao Zhang link->d_Pack = Pack<Type, BS, EQ>; 417914b7a73SJunchao Zhang link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>; 418914b7a73SJunchao Zhang link->d_UnpackAndAdd = UnpackAndOp<Type, Add<Type>, BS, EQ>; 419914b7a73SJunchao Zhang link->d_UnpackAndMult = UnpackAndOp<Type, Mult<Type>, BS, EQ>; 420914b7a73SJunchao Zhang link->d_UnpackAndMin = UnpackAndOp<Type, Min<Type>, BS, EQ>; 421914b7a73SJunchao Zhang link->d_UnpackAndMax = UnpackAndOp<Type, Max<Type>, BS, EQ>; 422914b7a73SJunchao Zhang link->d_UnpackAndLAND = UnpackAndOp<Type, LAND<Type>, BS, EQ>; 423914b7a73SJunchao Zhang link->d_UnpackAndLOR = UnpackAndOp<Type, LOR<Type>, BS, EQ>; 424914b7a73SJunchao Zhang link->d_UnpackAndLXOR = UnpackAndOp<Type, LXOR<Type>, BS, EQ>; 425914b7a73SJunchao Zhang link->d_UnpackAndBAND = UnpackAndOp<Type, BAND<Type>, BS, EQ>; 426914b7a73SJunchao Zhang link->d_UnpackAndBOR = UnpackAndOp<Type, BOR<Type>, BS, EQ>; 427914b7a73SJunchao Zhang link->d_UnpackAndBXOR = UnpackAndOp<Type, BXOR<Type>, BS, EQ>; 428914b7a73SJunchao Zhang link->d_FetchAndAdd = FetchAndOp<Type, Add<Type>, BS, EQ>; 429914b7a73SJunchao Zhang 430914b7a73SJunchao Zhang link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; 431914b7a73SJunchao Zhang link->d_ScatterAndAdd = ScatterAndOp<Type, Add<Type>, BS, EQ>; 432914b7a73SJunchao Zhang link->d_ScatterAndMult = ScatterAndOp<Type, Mult<Type>, BS, EQ>; 433914b7a73SJunchao Zhang link->d_ScatterAndMin = ScatterAndOp<Type, Min<Type>, BS, EQ>; 434914b7a73SJunchao Zhang link->d_ScatterAndMax = ScatterAndOp<Type, Max<Type>, BS, EQ>; 435914b7a73SJunchao Zhang link->d_ScatterAndLAND = ScatterAndOp<Type, LAND<Type>, BS, EQ>; 436914b7a73SJunchao Zhang link->d_ScatterAndLOR = ScatterAndOp<Type, LOR<Type>, BS, EQ>; 437914b7a73SJunchao Zhang link->d_ScatterAndLXOR = ScatterAndOp<Type, LXOR<Type>, BS, EQ>; 438914b7a73SJunchao Zhang link->d_ScatterAndBAND = ScatterAndOp<Type, BAND<Type>, BS, EQ>; 439914b7a73SJunchao Zhang link->d_ScatterAndBOR = ScatterAndOp<Type, BOR<Type>, BS, EQ>; 440914b7a73SJunchao Zhang link->d_ScatterAndBXOR = ScatterAndOp<Type, BXOR<Type>, BS, EQ>; 441914b7a73SJunchao Zhang link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>; 442914b7a73SJunchao Zhang 443914b7a73SJunchao Zhang link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>; 444914b7a73SJunchao Zhang link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>; 445914b7a73SJunchao Zhang link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>; 446914b7a73SJunchao Zhang link->da_UnpackAndMin = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>; 447914b7a73SJunchao Zhang link->da_UnpackAndMax = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>; 448914b7a73SJunchao Zhang link->da_UnpackAndLAND = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>; 449914b7a73SJunchao Zhang link->da_UnpackAndLOR = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>; 450914b7a73SJunchao Zhang link->da_UnpackAndBAND = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>; 451914b7a73SJunchao Zhang link->da_UnpackAndBOR = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>; 452914b7a73SJunchao Zhang link->da_UnpackAndBXOR = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>; 453914b7a73SJunchao Zhang link->da_FetchAndAdd = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>; 454914b7a73SJunchao Zhang 455914b7a73SJunchao Zhang link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>; 456914b7a73SJunchao Zhang link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>; 457914b7a73SJunchao Zhang link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>; 458914b7a73SJunchao Zhang link->da_ScatterAndMin = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>; 459914b7a73SJunchao Zhang link->da_ScatterAndMax = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>; 460914b7a73SJunchao Zhang link->da_ScatterAndLAND = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>; 461914b7a73SJunchao Zhang link->da_ScatterAndLOR = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>; 462914b7a73SJunchao Zhang link->da_ScatterAndBAND = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>; 463914b7a73SJunchao Zhang link->da_ScatterAndBOR = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>; 464914b7a73SJunchao Zhang link->da_ScatterAndBXOR = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>; 465914b7a73SJunchao Zhang link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>; 466914b7a73SJunchao Zhang } 467914b7a73SJunchao Zhang 468914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX) 469914b7a73SJunchao Zhang template <typename Type, PetscInt BS, PetscInt EQ> 470*9371c9d4SSatish Balay static void PackInit_ComplexType(PetscSFLink link) { 471914b7a73SJunchao Zhang link->d_Pack = Pack<Type, BS, EQ>; 472914b7a73SJunchao Zhang link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>; 473914b7a73SJunchao Zhang link->d_UnpackAndAdd = UnpackAndOp<Type, Add<Type>, BS, EQ>; 474914b7a73SJunchao Zhang link->d_UnpackAndMult = UnpackAndOp<Type, Mult<Type>, BS, EQ>; 475914b7a73SJunchao Zhang link->d_FetchAndAdd = FetchAndOp<Type, Add<Type>, BS, EQ>; 476914b7a73SJunchao Zhang 477914b7a73SJunchao Zhang link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; 478914b7a73SJunchao Zhang link->d_ScatterAndAdd = ScatterAndOp<Type, Add<Type>, BS, EQ>; 479914b7a73SJunchao Zhang link->d_ScatterAndMult = ScatterAndOp<Type, Mult<Type>, BS, EQ>; 480914b7a73SJunchao Zhang link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>; 481914b7a73SJunchao Zhang 482914b7a73SJunchao Zhang link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>; 483914b7a73SJunchao Zhang link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>; 484914b7a73SJunchao Zhang link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>; 485914b7a73SJunchao Zhang link->da_FetchAndAdd = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>; 486914b7a73SJunchao Zhang 487914b7a73SJunchao Zhang link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>; 488914b7a73SJunchao Zhang link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>; 489914b7a73SJunchao Zhang link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>; 490914b7a73SJunchao Zhang link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>; 491914b7a73SJunchao Zhang } 492914b7a73SJunchao Zhang #endif 493914b7a73SJunchao Zhang 494914b7a73SJunchao Zhang template <typename Type> 495*9371c9d4SSatish Balay static void PackInit_PairType(PetscSFLink link) { 496914b7a73SJunchao Zhang link->d_Pack = Pack<Type, 1, 1>; 497914b7a73SJunchao Zhang link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, 1, 1>; 498914b7a73SJunchao Zhang link->d_UnpackAndMaxloc = UnpackAndOp<Type, Maxloc<Type>, 1, 1>; 499914b7a73SJunchao Zhang link->d_UnpackAndMinloc = UnpackAndOp<Type, Minloc<Type>, 1, 1>; 500914b7a73SJunchao Zhang 501914b7a73SJunchao Zhang link->d_ScatterAndInsert = ScatterAndOp<Type, Insert<Type>, 1, 1>; 502914b7a73SJunchao Zhang link->d_ScatterAndMaxloc = ScatterAndOp<Type, Maxloc<Type>, 1, 1>; 503914b7a73SJunchao Zhang link->d_ScatterAndMinloc = ScatterAndOp<Type, Minloc<Type>, 1, 1>; 504914b7a73SJunchao Zhang /* Atomics for pair types are not implemented yet */ 505914b7a73SJunchao Zhang } 506914b7a73SJunchao Zhang 507914b7a73SJunchao Zhang template <typename Type, PetscInt BS, PetscInt EQ> 508*9371c9d4SSatish Balay static void PackInit_DumbType(PetscSFLink link) { 509914b7a73SJunchao Zhang link->d_Pack = Pack<Type, BS, EQ>; 510914b7a73SJunchao Zhang link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>; 511914b7a73SJunchao Zhang link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; 512914b7a73SJunchao Zhang /* Atomics for dumb types are not implemented yet */ 513914b7a73SJunchao Zhang } 514914b7a73SJunchao Zhang 515f4af43b4SJunchao Zhang /* 516f4af43b4SJunchao Zhang Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug 517f4af43b4SJunchao Zhang that one is not able to repeatedly create and destroy the object. SF's original design was each 518f4af43b4SJunchao Zhang SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from 519f4af43b4SJunchao Zhang destroying multiple SFLinks with NULL stream and the default execution space object. To avoid 520f4af43b4SJunchao Zhang memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos 521f4af43b4SJunchao Zhang does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton 522f4af43b4SJunchao Zhang object in Kokkos. 523f4af43b4SJunchao Zhang */ 524f4af43b4SJunchao Zhang /* 525914b7a73SJunchao Zhang static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link) 526914b7a73SJunchao Zhang { 527914b7a73SJunchao Zhang PetscFunctionBegin; 528914b7a73SJunchao Zhang PetscFunctionReturn(0); 529914b7a73SJunchao Zhang } 530f4af43b4SJunchao Zhang */ 531914b7a73SJunchao Zhang 532914b7a73SJunchao Zhang /* Some device-specific utilities */ 533*9371c9d4SSatish Balay static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link) { 534914b7a73SJunchao Zhang PetscFunctionBegin; 535914b7a73SJunchao Zhang Kokkos::fence(); 536914b7a73SJunchao Zhang PetscFunctionReturn(0); 537914b7a73SJunchao Zhang } 538914b7a73SJunchao Zhang 539*9371c9d4SSatish Balay static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link) { 540f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 541914b7a73SJunchao Zhang PetscFunctionBegin; 542914b7a73SJunchao Zhang exec.fence(); 543914b7a73SJunchao Zhang PetscFunctionReturn(0); 544914b7a73SJunchao Zhang } 545914b7a73SJunchao Zhang 546*9371c9d4SSatish Balay static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n) { 547f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 548914b7a73SJunchao Zhang 549914b7a73SJunchao Zhang PetscFunctionBegin; 550914b7a73SJunchao Zhang if (!n) PetscFunctionReturn(0); 55171438e86SJunchao Zhang if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) { 5529566063dSJacob Faibussowitsch PetscCall(PetscMemcpy(dst, src, n)); 553914b7a73SJunchao Zhang } else { 55471438e86SJunchao Zhang if (PetscMemTypeDevice(dstmtype) && PetscMemTypeHost(srcmtype)) { 555914b7a73SJunchao Zhang deviceBuffer_t dbuf(static_cast<char *>(dst), n); 556914b7a73SJunchao Zhang HostConstBuffer_t sbuf(static_cast<const char *>(src), n); 557914b7a73SJunchao Zhang Kokkos::deep_copy(exec, dbuf, sbuf); 5589566063dSJacob Faibussowitsch PetscCall(PetscLogCpuToGpu(n)); 55971438e86SJunchao Zhang } else if (PetscMemTypeHost(dstmtype) && PetscMemTypeDevice(srcmtype)) { 560914b7a73SJunchao Zhang HostBuffer_t dbuf(static_cast<char *>(dst), n); 561914b7a73SJunchao Zhang deviceConstBuffer_t sbuf(static_cast<const char *>(src), n); 562914b7a73SJunchao Zhang Kokkos::deep_copy(exec, dbuf, sbuf); 5639566063dSJacob Faibussowitsch PetscCall(PetscLogGpuToCpu(n)); 56471438e86SJunchao Zhang } else if (PetscMemTypeDevice(dstmtype) && PetscMemTypeDevice(srcmtype)) { 565914b7a73SJunchao Zhang deviceBuffer_t dbuf(static_cast<char *>(dst), n); 566914b7a73SJunchao Zhang deviceConstBuffer_t sbuf(static_cast<const char *>(src), n); 567914b7a73SJunchao Zhang Kokkos::deep_copy(exec, dbuf, sbuf); 568914b7a73SJunchao Zhang } 569914b7a73SJunchao Zhang } 570914b7a73SJunchao Zhang PetscFunctionReturn(0); 571914b7a73SJunchao Zhang } 572914b7a73SJunchao Zhang 573*9371c9d4SSatish Balay PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype, size_t size, void **ptr) { 574914b7a73SJunchao Zhang PetscFunctionBegin; 5759566063dSJacob Faibussowitsch if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr)); 57671438e86SJunchao Zhang else if (PetscMemTypeDevice(mtype)) { 5779566063dSJacob Faibussowitsch if (!PetscKokkosInitialized) PetscCall(PetscKokkosInitializeCheck()); 57845639126SStefano Zampini *ptr = Kokkos::kokkos_malloc<DeviceMemorySpace>(size); 57998921bdaSJacob Faibussowitsch } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype); 580914b7a73SJunchao Zhang PetscFunctionReturn(0); 581914b7a73SJunchao Zhang } 582914b7a73SJunchao Zhang 583*9371c9d4SSatish Balay PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype, void *ptr) { 584914b7a73SJunchao Zhang PetscFunctionBegin; 5859566063dSJacob Faibussowitsch if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr)); 586*9371c9d4SSatish Balay else if (PetscMemTypeDevice(mtype)) { 587*9371c9d4SSatish Balay Kokkos::kokkos_free<DeviceMemorySpace>(ptr); 588*9371c9d4SSatish Balay } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype); 589914b7a73SJunchao Zhang PetscFunctionReturn(0); 590914b7a73SJunchao Zhang } 591914b7a73SJunchao Zhang 59271438e86SJunchao Zhang /* Destructor when the link uses MPI for communication */ 593*9371c9d4SSatish Balay static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSF sf, PetscSFLink link) { 59471438e86SJunchao Zhang PetscFunctionBegin; 59571438e86SJunchao Zhang for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) { 5969566063dSJacob Faibussowitsch PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE])); 5979566063dSJacob Faibussowitsch PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE])); 59871438e86SJunchao Zhang } 59971438e86SJunchao Zhang PetscFunctionReturn(0); 60071438e86SJunchao Zhang } 601914b7a73SJunchao Zhang 602914b7a73SJunchao Zhang /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */ 603*9371c9d4SSatish Balay PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf, PetscSFLink link, MPI_Datatype unit) { 604914b7a73SJunchao Zhang PetscInt nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0; 605914b7a73SJunchao Zhang PetscBool is2Int, is2PetscInt; 606914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX) 607914b7a73SJunchao Zhang PetscInt nPetscComplex = 0; 608914b7a73SJunchao Zhang #endif 609914b7a73SJunchao Zhang 610914b7a73SJunchao Zhang PetscFunctionBegin; 611914b7a73SJunchao Zhang if (link->deviceinited) PetscFunctionReturn(0); 6129566063dSJacob Faibussowitsch PetscCall(PetscKokkosInitializeCheck()); 6139566063dSJacob Faibussowitsch PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar)); 6149566063dSJacob Faibussowitsch PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar)); 615914b7a73SJunchao Zhang /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */ 6169566063dSJacob Faibussowitsch PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt)); 6179566063dSJacob Faibussowitsch PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt)); 6189566063dSJacob Faibussowitsch PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal)); 619914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX) 6209566063dSJacob Faibussowitsch PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex)); 621914b7a73SJunchao Zhang #endif 6229566063dSJacob Faibussowitsch PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int)); 6239566063dSJacob Faibussowitsch PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt)); 624914b7a73SJunchao Zhang 625914b7a73SJunchao Zhang if (is2Int) { 626914b7a73SJunchao Zhang PackInit_PairType<Kokkos::pair<int, int>>(link); 627914b7a73SJunchao Zhang } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */ 628914b7a73SJunchao Zhang PackInit_PairType<Kokkos::pair<PetscInt, PetscInt>>(link); 629914b7a73SJunchao Zhang } else if (nPetscReal) { 630d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE) /* Skip the unimportant stuff to speed up SF device compilation time */ 631*9371c9d4SSatish Balay if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link); 632*9371c9d4SSatish Balay else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link); 633*9371c9d4SSatish Balay else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link); 634*9371c9d4SSatish Balay else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link); 635*9371c9d4SSatish Balay else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link); 636*9371c9d4SSatish Balay else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link); 637*9371c9d4SSatish Balay else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link); 638*9371c9d4SSatish Balay else if (nPetscReal % 1 == 0) 639eee4e20aSJunchao Zhang #endif 640d941a2f0SJunchao Zhang PackInit_RealType<PetscReal, 1, 0>(link); 641874d28e3SJunchao Zhang } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) { 642d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE) 643*9371c9d4SSatish Balay if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link); 644*9371c9d4SSatish Balay else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link); 645*9371c9d4SSatish Balay else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link); 646*9371c9d4SSatish Balay else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link); 647*9371c9d4SSatish Balay else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link); 648*9371c9d4SSatish Balay else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link); 649*9371c9d4SSatish Balay else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link); 650*9371c9d4SSatish Balay else if (nPetscInt % 1 == 0) 651eee4e20aSJunchao Zhang #endif 652d941a2f0SJunchao Zhang PackInit_IntegerType<llint, 1, 0>(link); 653914b7a73SJunchao Zhang } else if (nInt) { 654d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE) 655*9371c9d4SSatish Balay if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link); 656*9371c9d4SSatish Balay else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link); 657*9371c9d4SSatish Balay else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link); 658*9371c9d4SSatish Balay else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link); 659*9371c9d4SSatish Balay else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link); 660*9371c9d4SSatish Balay else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link); 661*9371c9d4SSatish Balay else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link); 662*9371c9d4SSatish Balay else if (nInt % 1 == 0) 663eee4e20aSJunchao Zhang #endif 664d941a2f0SJunchao Zhang PackInit_IntegerType<int, 1, 0>(link); 665914b7a73SJunchao Zhang } else if (nSignedChar) { 666d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE) 667*9371c9d4SSatish Balay if (nSignedChar == 8) PackInit_IntegerType<char, 8, 1>(link); 668*9371c9d4SSatish Balay else if (nSignedChar % 8 == 0) PackInit_IntegerType<char, 8, 0>(link); 669*9371c9d4SSatish Balay else if (nSignedChar == 4) PackInit_IntegerType<char, 4, 1>(link); 670*9371c9d4SSatish Balay else if (nSignedChar % 4 == 0) PackInit_IntegerType<char, 4, 0>(link); 671*9371c9d4SSatish Balay else if (nSignedChar == 2) PackInit_IntegerType<char, 2, 1>(link); 672*9371c9d4SSatish Balay else if (nSignedChar % 2 == 0) PackInit_IntegerType<char, 2, 0>(link); 673*9371c9d4SSatish Balay else if (nSignedChar == 1) PackInit_IntegerType<char, 1, 1>(link); 674*9371c9d4SSatish Balay else if (nSignedChar % 1 == 0) 675eee4e20aSJunchao Zhang #endif 676d941a2f0SJunchao Zhang PackInit_IntegerType<char, 1, 0>(link); 677914b7a73SJunchao Zhang } else if (nUnsignedChar) { 678d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE) 679*9371c9d4SSatish Balay if (nUnsignedChar == 8) PackInit_IntegerType<unsigned char, 8, 1>(link); 680*9371c9d4SSatish Balay else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<unsigned char, 8, 0>(link); 681*9371c9d4SSatish Balay else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char, 4, 1>(link); 682*9371c9d4SSatish Balay else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<unsigned char, 4, 0>(link); 683*9371c9d4SSatish Balay else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char, 2, 1>(link); 684*9371c9d4SSatish Balay else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<unsigned char, 2, 0>(link); 685*9371c9d4SSatish Balay else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char, 1, 1>(link); 686*9371c9d4SSatish Balay else if (nUnsignedChar % 1 == 0) 687eee4e20aSJunchao Zhang #endif 688d941a2f0SJunchao Zhang PackInit_IntegerType<unsigned char, 1, 0>(link); 689914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX) 690914b7a73SJunchao Zhang } else if (nPetscComplex) { 691d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE) 692*9371c9d4SSatish Balay if (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 1>(link); 693*9371c9d4SSatish Balay else if (nPetscComplex % 8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 0>(link); 694*9371c9d4SSatish Balay else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 1>(link); 695*9371c9d4SSatish Balay else if (nPetscComplex % 4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 0>(link); 696*9371c9d4SSatish Balay else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 1>(link); 697*9371c9d4SSatish Balay else if (nPetscComplex % 2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 0>(link); 698*9371c9d4SSatish Balay else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 1>(link); 699*9371c9d4SSatish Balay else if (nPetscComplex % 1 == 0) 700eee4e20aSJunchao Zhang #endif 701d941a2f0SJunchao Zhang PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 0>(link); 702914b7a73SJunchao Zhang #endif 703914b7a73SJunchao Zhang } else { 704914b7a73SJunchao Zhang MPI_Aint lb, nbyte; 7059566063dSJacob Faibussowitsch PetscCallMPI(MPI_Type_get_extent(unit, &lb, &nbyte)); 70608401ef6SPierre Jolivet PetscCheck(lb == 0, PETSC_COMM_SELF, PETSC_ERR_SUP, "Datatype with nonzero lower bound %ld", (long)lb); 707914b7a73SJunchao Zhang if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */ 708d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE) 709*9371c9d4SSatish Balay if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link); 710*9371c9d4SSatish Balay else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link); 711*9371c9d4SSatish Balay else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link); 712*9371c9d4SSatish Balay else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link); 713*9371c9d4SSatish Balay else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link); 714*9371c9d4SSatish Balay else if (nbyte % 1 == 0) 715eee4e20aSJunchao Zhang #endif 716d941a2f0SJunchao Zhang PackInit_DumbType<char, 1, 0>(link); 717914b7a73SJunchao Zhang } else { 718914b7a73SJunchao Zhang nInt = nbyte / sizeof(int); 719d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE) 720*9371c9d4SSatish Balay if (nInt == 8) PackInit_DumbType<int, 8, 1>(link); 721*9371c9d4SSatish Balay else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link); 722*9371c9d4SSatish Balay else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link); 723*9371c9d4SSatish Balay else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link); 724*9371c9d4SSatish Balay else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link); 725*9371c9d4SSatish Balay else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link); 726*9371c9d4SSatish Balay else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link); 727*9371c9d4SSatish Balay else if (nInt % 1 == 0) 728eee4e20aSJunchao Zhang #endif 729d941a2f0SJunchao Zhang PackInit_DumbType<int, 1, 0>(link); 730914b7a73SJunchao Zhang } 731914b7a73SJunchao Zhang } 732914b7a73SJunchao Zhang 73371438e86SJunchao Zhang link->SyncDevice = PetscSFLinkSyncDevice_Kokkos; 73471438e86SJunchao Zhang link->SyncStream = PetscSFLinkSyncStream_Kokkos; 73520c24465SJunchao Zhang link->Memcpy = PetscSFLinkMemcpy_Kokkos; 73671438e86SJunchao Zhang link->Destroy = PetscSFLinkDestroy_Kokkos; 737914b7a73SJunchao Zhang link->deviceinited = PETSC_TRUE; 738914b7a73SJunchao Zhang PetscFunctionReturn(0); 739914b7a73SJunchao Zhang } 740