1914b7a73SJunchao Zhang #include <../src/vec/is/sf/impls/basic/sfpack.h> 2914b7a73SJunchao Zhang 3914b7a73SJunchao Zhang #include <Kokkos_Core.hpp> 4914b7a73SJunchao Zhang 5914b7a73SJunchao Zhang using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace; 6914b7a73SJunchao Zhang using DeviceMemorySpace = typename DeviceExecutionSpace::memory_space; 7914b7a73SJunchao Zhang using HostMemorySpace = Kokkos::HostSpace; 8914b7a73SJunchao Zhang 9914b7a73SJunchao Zhang typedef Kokkos::View<char*,DeviceMemorySpace> deviceBuffer_t; 10914b7a73SJunchao Zhang typedef Kokkos::View<char*,HostMemorySpace> HostBuffer_t; 11914b7a73SJunchao Zhang 12914b7a73SJunchao Zhang typedef Kokkos::View<const char*,DeviceMemorySpace> deviceConstBuffer_t; 13914b7a73SJunchao Zhang typedef Kokkos::View<const char*,HostMemorySpace> HostConstBuffer_t; 14914b7a73SJunchao Zhang 15914b7a73SJunchao Zhang /*====================================================================================*/ 16914b7a73SJunchao Zhang /* Regular operations */ 17914b7a73SJunchao Zhang /*====================================================================================*/ 18914b7a73SJunchao Zhang template<typename Type> struct Insert{KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = y; return old;}}; 19914b7a73SJunchao Zhang template<typename Type> struct Add {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x += y; return old;}}; 20914b7a73SJunchao Zhang template<typename Type> struct Mult {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x *= y; return old;}}; 21914b7a73SJunchao Zhang template<typename Type> struct Min {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = PetscMin(x,y); return old;}}; 22914b7a73SJunchao Zhang template<typename Type> struct Max {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = PetscMax(x,y); return old;}}; 23914b7a73SJunchao Zhang template<typename Type> struct LAND {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = x && y; return old;}}; 24914b7a73SJunchao Zhang template<typename Type> struct LOR {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = x || y; return old;}}; 25914b7a73SJunchao Zhang template<typename Type> struct LXOR {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = !x != !y; return old;}}; 26914b7a73SJunchao Zhang template<typename Type> struct BAND {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = x & y; return old;}}; 27914b7a73SJunchao Zhang template<typename Type> struct BOR {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = x | y; return old;}}; 28914b7a73SJunchao Zhang template<typename Type> struct BXOR {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = x ^ y; return old;}}; 29914b7a73SJunchao Zhang template<typename PairType> struct Minloc { 30914b7a73SJunchao Zhang KOKKOS_INLINE_FUNCTION PairType operator()(PairType& x,PairType y) const { 31914b7a73SJunchao Zhang PairType old = x; 32914b7a73SJunchao Zhang if (y.first < x.first) x = y; 33914b7a73SJunchao Zhang else if (y.first == x.first) x.second = PetscMin(x.second,y.second); 34914b7a73SJunchao Zhang return old; 35914b7a73SJunchao Zhang } 36914b7a73SJunchao Zhang }; 37914b7a73SJunchao Zhang template<typename PairType> struct Maxloc { 38914b7a73SJunchao Zhang KOKKOS_INLINE_FUNCTION PairType operator()(PairType& x,PairType y) const { 39914b7a73SJunchao Zhang PairType old = x; 40914b7a73SJunchao Zhang if (y.first > x.first) x = y; 41914b7a73SJunchao Zhang else if (y.first == x.first) x.second = PetscMin(x.second,y.second); /* See MPI MAXLOC */ 42914b7a73SJunchao Zhang return old; 43914b7a73SJunchao Zhang } 44914b7a73SJunchao Zhang }; 45914b7a73SJunchao Zhang 46914b7a73SJunchao Zhang /*====================================================================================*/ 47914b7a73SJunchao Zhang /* Atomic operations */ 48914b7a73SJunchao Zhang /*====================================================================================*/ 49914b7a73SJunchao Zhang template<typename Type> struct AtomicInsert {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_assign(&x,y);}}; 50914b7a73SJunchao Zhang template<typename Type> struct AtomicAdd {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_add(&x,y);}}; 51914b7a73SJunchao Zhang template<typename Type> struct AtomicBAND {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_and(&x,y);}}; 52914b7a73SJunchao Zhang template<typename Type> struct AtomicBOR {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_or (&x,y);}}; 53914b7a73SJunchao Zhang template<typename Type> struct AtomicBXOR {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_xor(&x,y);}}; 54914b7a73SJunchao Zhang template<typename Type> struct AtomicLAND {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {const Type zero=0,one=~0; Kokkos::atomic_and(&x,y?one:zero);}}; 55914b7a73SJunchao Zhang template<typename Type> struct AtomicLOR {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {const Type zero=0,one=1; Kokkos::atomic_or (&x,y?one:zero);}}; 56914b7a73SJunchao Zhang template<typename Type> struct AtomicMult {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_mul(&x,y);}}; 57914b7a73SJunchao Zhang template<typename Type> struct AtomicMin {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_min(&x,y);}}; 58914b7a73SJunchao Zhang template<typename Type> struct AtomicMax {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_max(&x,y);}}; 59914b7a73SJunchao Zhang /* TODO: struct AtomicLXOR */ 60914b7a73SJunchao Zhang template<typename Type> struct AtomicFetchAdd{KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {return Kokkos::atomic_fetch_add(&x,y);}}; 61914b7a73SJunchao Zhang 62914b7a73SJunchao Zhang /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */ 63914b7a73SJunchao Zhang static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt,PetscInt tid) 64914b7a73SJunchao Zhang { 65914b7a73SJunchao Zhang PetscInt i,j,k,m,n,r; 66914b7a73SJunchao Zhang const PetscInt *offset,*start,*dx,*dy,*X,*Y; 67914b7a73SJunchao Zhang 68914b7a73SJunchao Zhang n = opt[0]; 69914b7a73SJunchao Zhang offset = opt + 1; 70914b7a73SJunchao Zhang start = opt + n + 2; 71914b7a73SJunchao Zhang dx = opt + 2*n + 2; 72914b7a73SJunchao Zhang dy = opt + 3*n + 2; 73914b7a73SJunchao Zhang X = opt + 5*n + 2; 74914b7a73SJunchao Zhang Y = opt + 6*n + 2; 75914b7a73SJunchao Zhang for (r=0; r<n; r++) {if (tid < offset[r+1]) break;} 76914b7a73SJunchao Zhang m = (tid - offset[r]); 77914b7a73SJunchao Zhang k = m/(dx[r]*dy[r]); 78914b7a73SJunchao Zhang j = (m - k*dx[r]*dy[r])/dx[r]; 79914b7a73SJunchao Zhang i = m - k*dx[r]*dy[r] - j*dx[r]; 80914b7a73SJunchao Zhang 81914b7a73SJunchao Zhang return (start[r] + k*X[r]*Y[r] + j*X[r] + i); 82914b7a73SJunchao Zhang } 83914b7a73SJunchao Zhang 84914b7a73SJunchao Zhang /*====================================================================================*/ 85914b7a73SJunchao Zhang /* Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link' */ 86914b7a73SJunchao Zhang /*====================================================================================*/ 87914b7a73SJunchao Zhang 88914b7a73SJunchao Zhang /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then 89914b7a73SJunchao Zhang <Type> is PetscReal, which is the primitive type we operate on. 90914b7a73SJunchao Zhang <bs> is 16, which says <unit> contains 16 primitive types. 91914b7a73SJunchao Zhang <BS> is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>. 92914b7a73SJunchao Zhang <EQ> is 0, which is (bs == BS ? 1 : 0) 93914b7a73SJunchao Zhang 94914b7a73SJunchao Zhang If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant. 95914b7a73SJunchao Zhang For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled. 96914b7a73SJunchao Zhang */ 97914b7a73SJunchao Zhang template<typename Type,PetscInt BS,PetscInt EQ> 98914b7a73SJunchao Zhang static PetscErrorCode Pack(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,const void *data_,void *buf_) 99914b7a73SJunchao Zhang { 100914b7a73SJunchao Zhang const PetscInt *iopt = opt ? opt->array : NULL; 101914b7a73SJunchao Zhang const PetscInt M = EQ ? 1 : link->bs/BS, MBS=M*BS; /* If EQ, then MBS will be a compile-time const */ 102914b7a73SJunchao Zhang const Type *data = static_cast<const Type*>(data_); 103914b7a73SJunchao Zhang Type *buf = static_cast<Type*>(buf_); 104*f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 105914b7a73SJunchao Zhang 106914b7a73SJunchao Zhang PetscFunctionBegin; 107914b7a73SJunchao Zhang Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) { 108914b7a73SJunchao Zhang /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous; 109914b7a73SJunchao Zhang iopt == NULL && idx == NULL ==> the indices are contiguous; 110914b7a73SJunchao Zhang */ 111914b7a73SJunchao Zhang PetscInt t = (iopt? MapTidToIndex(iopt,tid) : (idx? idx[tid] : start+tid))*MBS; 112914b7a73SJunchao Zhang PetscInt s = tid*MBS; 113914b7a73SJunchao Zhang for (int i=0; i<MBS; i++) buf[s+i] = data[t+i]; 114914b7a73SJunchao Zhang }); 115914b7a73SJunchao Zhang PetscFunctionReturn(0); 116914b7a73SJunchao Zhang } 117914b7a73SJunchao Zhang 118914b7a73SJunchao Zhang template<typename Type,class Op,PetscInt BS,PetscInt EQ> 119914b7a73SJunchao Zhang static PetscErrorCode UnpackAndOp(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,void *data_,const void *buf_) 120914b7a73SJunchao Zhang { 121914b7a73SJunchao Zhang Op op; 122914b7a73SJunchao Zhang const PetscInt *iopt = opt ? opt->array : NULL; 123914b7a73SJunchao Zhang const PetscInt M = EQ ? 1 : link->bs/BS, MBS=M*BS; 124914b7a73SJunchao Zhang Type *data = static_cast<Type*>(data_); 125914b7a73SJunchao Zhang const Type *buf = static_cast<const Type*>(buf_); 126*f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 127914b7a73SJunchao Zhang 128914b7a73SJunchao Zhang PetscFunctionBegin; 129914b7a73SJunchao Zhang Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) { 130914b7a73SJunchao Zhang PetscInt t = (iopt? MapTidToIndex(iopt,tid) : (idx? idx[tid] : start+tid))*MBS; 131914b7a73SJunchao Zhang PetscInt s = tid*MBS; 132914b7a73SJunchao Zhang for (int i=0; i<MBS; i++) op(data[t+i],buf[s+i]); 133914b7a73SJunchao Zhang }); 134914b7a73SJunchao Zhang PetscFunctionReturn(0); 135914b7a73SJunchao Zhang } 136914b7a73SJunchao Zhang 137914b7a73SJunchao Zhang template<typename Type,class Op,PetscInt BS,PetscInt EQ> 138914b7a73SJunchao Zhang static PetscErrorCode FetchAndOp(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,void *data,void *buf) 139914b7a73SJunchao Zhang { 140914b7a73SJunchao Zhang Op op; 141914b7a73SJunchao Zhang const PetscInt *ropt = opt ? opt->array : NULL; 142914b7a73SJunchao Zhang const PetscInt M = EQ ? 1 : link->bs/BS, MBS=M*BS; 143914b7a73SJunchao Zhang Type *rootdata = static_cast<Type*>(data),*leafbuf=static_cast<Type*>(buf); 144*f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 145914b7a73SJunchao Zhang 146914b7a73SJunchao Zhang PetscFunctionBegin; 147914b7a73SJunchao Zhang Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) { 148914b7a73SJunchao Zhang PetscInt r = (ropt? MapTidToIndex(ropt,tid) : (idx? idx[tid] : start+tid))*MBS; 149914b7a73SJunchao Zhang PetscInt l = tid*MBS; 150914b7a73SJunchao Zhang for (int i=0; i<MBS; i++) leafbuf[l+i] = op(rootdata[r+i],leafbuf[l+i]); 151914b7a73SJunchao Zhang }); 152914b7a73SJunchao Zhang PetscFunctionReturn(0); 153914b7a73SJunchao Zhang } 154914b7a73SJunchao Zhang 155914b7a73SJunchao Zhang template<typename Type,class Op,PetscInt BS,PetscInt EQ> 156914b7a73SJunchao Zhang static PetscErrorCode ScatterAndOp(PetscSFLink link,PetscInt count,PetscInt srcStart,PetscSFPackOpt srcOpt,const PetscInt *srcIdx,const void *src_,PetscInt dstStart,PetscSFPackOpt dstOpt,const PetscInt *dstIdx,void *dst_) 157914b7a73SJunchao Zhang { 158914b7a73SJunchao Zhang PetscInt srcx=0,srcy=0,srcX=0,srcY=0,dstx=0,dsty=0,dstX=0,dstY=0; 159914b7a73SJunchao Zhang const PetscInt M = (EQ) ? 1 : link->bs/BS, MBS=M*BS; 160914b7a73SJunchao Zhang const Type *src = static_cast<const Type*>(src_); 161914b7a73SJunchao Zhang Type *dst = static_cast<Type*>(dst_); 162*f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 163914b7a73SJunchao Zhang 164914b7a73SJunchao Zhang PetscFunctionBegin; 165914b7a73SJunchao Zhang /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */ 166914b7a73SJunchao Zhang if (srcOpt) {srcx = srcOpt->dx[0]; srcy = srcOpt->dy[0]; srcX = srcOpt->X[0]; srcY = srcOpt->Y[0]; srcStart = srcOpt->start[0]; srcIdx = NULL;} 167914b7a73SJunchao Zhang else if (!srcIdx) {srcx = srcX = count; srcy = srcY = 1;} 168914b7a73SJunchao Zhang 169914b7a73SJunchao Zhang if (dstOpt) {dstx = dstOpt->dx[0]; dsty = dstOpt->dy[0]; dstX = dstOpt->X[0]; dstY = dstOpt->Y[0]; dstStart = dstOpt->start[0]; dstIdx = NULL;} 170914b7a73SJunchao Zhang else if (!dstIdx) {dstx = dstX = count; dsty = dstY = 1;} 171914b7a73SJunchao Zhang 172914b7a73SJunchao Zhang Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) { 173914b7a73SJunchao Zhang PetscInt i,j,k,s,t; 174914b7a73SJunchao Zhang Op op; 175914b7a73SJunchao Zhang if (!srcIdx) { /* src is in 3D */ 176914b7a73SJunchao Zhang k = tid/(srcx*srcy); 177914b7a73SJunchao Zhang j = (tid - k*srcx*srcy)/srcx; 178914b7a73SJunchao Zhang i = tid - k*srcx*srcy - j*srcx; 179914b7a73SJunchao Zhang s = srcStart + k*srcX*srcY + j*srcX + i; 180914b7a73SJunchao Zhang } else { /* src is contiguous */ 181914b7a73SJunchao Zhang s = srcIdx[tid]; 182914b7a73SJunchao Zhang } 183914b7a73SJunchao Zhang 184914b7a73SJunchao Zhang if (!dstIdx) { /* 3D */ 185914b7a73SJunchao Zhang k = tid/(dstx*dsty); 186914b7a73SJunchao Zhang j = (tid - k*dstx*dsty)/dstx; 187914b7a73SJunchao Zhang i = tid - k*dstx*dsty - j*dstx; 188914b7a73SJunchao Zhang t = dstStart + k*dstX*dstY + j*dstX + i; 189914b7a73SJunchao Zhang } else { /* contiguous */ 190914b7a73SJunchao Zhang t = dstIdx[tid]; 191914b7a73SJunchao Zhang } 192914b7a73SJunchao Zhang 193914b7a73SJunchao Zhang s *= MBS; 194914b7a73SJunchao Zhang t *= MBS; 195914b7a73SJunchao Zhang for (i=0; i<MBS; i++) op(dst[t+i],src[s+i]); 196914b7a73SJunchao Zhang }); 197914b7a73SJunchao Zhang PetscFunctionReturn(0); 198914b7a73SJunchao Zhang } 199914b7a73SJunchao Zhang 200914b7a73SJunchao Zhang /* Specialization for Insert since we may use memcpy */ 201914b7a73SJunchao Zhang template<typename Type,PetscInt BS,PetscInt EQ> 202914b7a73SJunchao Zhang static PetscErrorCode ScatterAndInsert(PetscSFLink link,PetscInt count,PetscInt srcStart,PetscSFPackOpt srcOpt,const PetscInt *srcIdx,const void *src_,PetscInt dstStart,PetscSFPackOpt dstOpt,const PetscInt *dstIdx,void *dst_) 203914b7a73SJunchao Zhang { 204914b7a73SJunchao Zhang PetscErrorCode ierr; 205914b7a73SJunchao Zhang const Type *src = static_cast<const Type*>(src_); 206914b7a73SJunchao Zhang Type *dst = static_cast<Type*>(dst_); 207*f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 208914b7a73SJunchao Zhang 209914b7a73SJunchao Zhang PetscFunctionBegin; 210914b7a73SJunchao Zhang if (!count) PetscFunctionReturn(0); 211914b7a73SJunchao Zhang /*src and dst are contiguous */ 212914b7a73SJunchao Zhang if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) { 213914b7a73SJunchao Zhang size_t sz = count*link->unitbytes; 214914b7a73SJunchao Zhang deviceBuffer_t dbuf(reinterpret_cast<char*>(dst+dstStart*link->bs),sz); 215914b7a73SJunchao Zhang deviceConstBuffer_t sbuf(reinterpret_cast<const char*>(src+srcStart*link->bs),sz); 216914b7a73SJunchao Zhang Kokkos::deep_copy(exec,dbuf,sbuf); 217914b7a73SJunchao Zhang } else { 218914b7a73SJunchao Zhang ierr = ScatterAndOp<Type,Insert<Type>,BS,EQ>(link,count,srcStart,srcOpt,srcIdx,src,dstStart,dstOpt,dstIdx,dst);CHKERRQ(ierr); 219914b7a73SJunchao Zhang } 220914b7a73SJunchao Zhang PetscFunctionReturn(0); 221914b7a73SJunchao Zhang } 222914b7a73SJunchao Zhang 223914b7a73SJunchao Zhang template<typename Type,class Op,PetscInt BS,PetscInt EQ> 224914b7a73SJunchao Zhang static PetscErrorCode FetchAndOpLocal(PetscSFLink link,PetscInt count,PetscInt rootstart,PetscSFPackOpt rootopt,const PetscInt *rootidx,void *rootdata_,PetscInt leafstart,PetscSFPackOpt leafopt,const PetscInt *leafidx,const void *leafdata_,void *leafupdate_) 225914b7a73SJunchao Zhang { 226914b7a73SJunchao Zhang Op op; 227914b7a73SJunchao Zhang const PetscInt M = (EQ) ? 1 : link->bs/BS, MBS = M*BS; 228914b7a73SJunchao Zhang const PetscInt *ropt = rootopt ? rootopt->array : NULL; 229914b7a73SJunchao Zhang const PetscInt *lopt = leafopt ? leafopt->array : NULL; 230914b7a73SJunchao Zhang Type *rootdata = static_cast<Type*>(rootdata_),*leafupdate = static_cast<Type*>(leafupdate_); 231914b7a73SJunchao Zhang const Type *leafdata = static_cast<const Type*>(leafdata_); 232*f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 233914b7a73SJunchao Zhang 234914b7a73SJunchao Zhang PetscFunctionBegin; 235914b7a73SJunchao Zhang Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) { 236914b7a73SJunchao Zhang PetscInt r = (ropt? MapTidToIndex(ropt,tid) : (rootidx? rootidx[tid] : rootstart+tid))*MBS; 237914b7a73SJunchao Zhang PetscInt l = (lopt? MapTidToIndex(lopt,tid) : (leafidx? leafidx[tid] : leafstart+tid))*MBS; 238914b7a73SJunchao Zhang for (int i=0; i<MBS; i++) leafupdate[l+i] = op(rootdata[r+i],leafdata[l+i]); 239914b7a73SJunchao Zhang }); 240914b7a73SJunchao Zhang PetscFunctionReturn(0); 241914b7a73SJunchao Zhang } 242914b7a73SJunchao Zhang 243914b7a73SJunchao Zhang /*====================================================================================*/ 244914b7a73SJunchao Zhang /* Init various types and instantiate pack/unpack function pointers */ 245914b7a73SJunchao Zhang /*====================================================================================*/ 246914b7a73SJunchao Zhang template<typename Type,PetscInt BS,PetscInt EQ> 247914b7a73SJunchao Zhang static void PackInit_RealType(PetscSFLink link) 248914b7a73SJunchao Zhang { 249914b7a73SJunchao Zhang /* Pack/unpack for remote communication */ 250914b7a73SJunchao Zhang link->d_Pack = Pack<Type,BS,EQ>; 251914b7a73SJunchao Zhang link->d_UnpackAndInsert = UnpackAndOp<Type,Insert<Type>,BS,EQ>; 252914b7a73SJunchao Zhang link->d_UnpackAndAdd = UnpackAndOp<Type,Add<Type> ,BS,EQ>; 253914b7a73SJunchao Zhang link->d_UnpackAndMult = UnpackAndOp<Type,Mult<Type> ,BS,EQ>; 254914b7a73SJunchao Zhang link->d_UnpackAndMin = UnpackAndOp<Type,Min<Type> ,BS,EQ>; 255914b7a73SJunchao Zhang link->d_UnpackAndMax = UnpackAndOp<Type,Max<Type> ,BS,EQ>; 256914b7a73SJunchao Zhang link->d_FetchAndAdd = FetchAndOp <Type,Add<Type> ,BS,EQ>; 257914b7a73SJunchao Zhang /* Scatter for local communication */ 258914b7a73SJunchao Zhang link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>; /* Has special optimizations */ 259914b7a73SJunchao Zhang link->d_ScatterAndAdd = ScatterAndOp<Type,Add<Type> ,BS,EQ>; 260914b7a73SJunchao Zhang link->d_ScatterAndMult = ScatterAndOp<Type,Mult<Type> ,BS,EQ>; 261914b7a73SJunchao Zhang link->d_ScatterAndMin = ScatterAndOp<Type,Min<Type> ,BS,EQ>; 262914b7a73SJunchao Zhang link->d_ScatterAndMax = ScatterAndOp<Type,Max<Type> ,BS,EQ>; 263914b7a73SJunchao Zhang link->d_FetchAndAddLocal = FetchAndOpLocal<Type,Add <Type>,BS,EQ>; 264914b7a73SJunchao Zhang /* Atomic versions when there are data-race possibilities */ 265914b7a73SJunchao Zhang link->da_UnpackAndInsert = UnpackAndOp<Type,AtomicInsert<Type> ,BS,EQ>; 266914b7a73SJunchao Zhang link->da_UnpackAndAdd = UnpackAndOp<Type,AtomicAdd<Type> ,BS,EQ>; 267914b7a73SJunchao Zhang link->da_UnpackAndMult = UnpackAndOp<Type,AtomicMult<Type> ,BS,EQ>; 268914b7a73SJunchao Zhang link->da_UnpackAndMin = UnpackAndOp<Type,AtomicMin<Type> ,BS,EQ>; 269914b7a73SJunchao Zhang link->da_UnpackAndMax = UnpackAndOp<Type,AtomicMax<Type> ,BS,EQ>; 270914b7a73SJunchao Zhang link->da_FetchAndAdd = FetchAndOp <Type,AtomicFetchAdd<Type>,BS,EQ>; 271914b7a73SJunchao Zhang 272914b7a73SJunchao Zhang link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>; 273914b7a73SJunchao Zhang link->da_ScatterAndAdd = ScatterAndOp<Type,AtomicAdd<Type> ,BS,EQ>; 274914b7a73SJunchao Zhang link->da_ScatterAndMult = ScatterAndOp<Type,AtomicMult<Type> ,BS,EQ>; 275914b7a73SJunchao Zhang link->da_ScatterAndMin = ScatterAndOp<Type,AtomicMin<Type> ,BS,EQ>; 276914b7a73SJunchao Zhang link->da_ScatterAndMax = ScatterAndOp<Type,AtomicMax<Type> ,BS,EQ>; 277914b7a73SJunchao Zhang link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>; 278914b7a73SJunchao Zhang } 279914b7a73SJunchao Zhang 280914b7a73SJunchao Zhang template<typename Type,PetscInt BS,PetscInt EQ> 281914b7a73SJunchao Zhang static void PackInit_IntegerType(PetscSFLink link) 282914b7a73SJunchao Zhang { 283914b7a73SJunchao Zhang link->d_Pack = Pack<Type,BS,EQ>; 284914b7a73SJunchao Zhang link->d_UnpackAndInsert = UnpackAndOp<Type,Insert<Type> ,BS,EQ>; 285914b7a73SJunchao Zhang link->d_UnpackAndAdd = UnpackAndOp<Type,Add<Type> ,BS,EQ>; 286914b7a73SJunchao Zhang link->d_UnpackAndMult = UnpackAndOp<Type,Mult<Type> ,BS,EQ>; 287914b7a73SJunchao Zhang link->d_UnpackAndMin = UnpackAndOp<Type,Min<Type> ,BS,EQ>; 288914b7a73SJunchao Zhang link->d_UnpackAndMax = UnpackAndOp<Type,Max<Type> ,BS,EQ>; 289914b7a73SJunchao Zhang link->d_UnpackAndLAND = UnpackAndOp<Type,LAND<Type> ,BS,EQ>; 290914b7a73SJunchao Zhang link->d_UnpackAndLOR = UnpackAndOp<Type,LOR<Type> ,BS,EQ>; 291914b7a73SJunchao Zhang link->d_UnpackAndLXOR = UnpackAndOp<Type,LXOR<Type> ,BS,EQ>; 292914b7a73SJunchao Zhang link->d_UnpackAndBAND = UnpackAndOp<Type,BAND<Type> ,BS,EQ>; 293914b7a73SJunchao Zhang link->d_UnpackAndBOR = UnpackAndOp<Type,BOR<Type> ,BS,EQ>; 294914b7a73SJunchao Zhang link->d_UnpackAndBXOR = UnpackAndOp<Type,BXOR<Type> ,BS,EQ>; 295914b7a73SJunchao Zhang link->d_FetchAndAdd = FetchAndOp <Type,Add<Type> ,BS,EQ>; 296914b7a73SJunchao Zhang 297914b7a73SJunchao Zhang link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>; 298914b7a73SJunchao Zhang link->d_ScatterAndAdd = ScatterAndOp<Type,Add<Type> ,BS,EQ>; 299914b7a73SJunchao Zhang link->d_ScatterAndMult = ScatterAndOp<Type,Mult<Type> ,BS,EQ>; 300914b7a73SJunchao Zhang link->d_ScatterAndMin = ScatterAndOp<Type,Min<Type> ,BS,EQ>; 301914b7a73SJunchao Zhang link->d_ScatterAndMax = ScatterAndOp<Type,Max<Type> ,BS,EQ>; 302914b7a73SJunchao Zhang link->d_ScatterAndLAND = ScatterAndOp<Type,LAND<Type> ,BS,EQ>; 303914b7a73SJunchao Zhang link->d_ScatterAndLOR = ScatterAndOp<Type,LOR<Type> ,BS,EQ>; 304914b7a73SJunchao Zhang link->d_ScatterAndLXOR = ScatterAndOp<Type,LXOR<Type> ,BS,EQ>; 305914b7a73SJunchao Zhang link->d_ScatterAndBAND = ScatterAndOp<Type,BAND<Type> ,BS,EQ>; 306914b7a73SJunchao Zhang link->d_ScatterAndBOR = ScatterAndOp<Type,BOR<Type> ,BS,EQ>; 307914b7a73SJunchao Zhang link->d_ScatterAndBXOR = ScatterAndOp<Type,BXOR<Type> ,BS,EQ>; 308914b7a73SJunchao Zhang link->d_FetchAndAddLocal = FetchAndOpLocal<Type,Add<Type>,BS,EQ>; 309914b7a73SJunchao Zhang 310914b7a73SJunchao Zhang link->da_UnpackAndInsert = UnpackAndOp<Type,AtomicInsert<Type>,BS,EQ>; 311914b7a73SJunchao Zhang link->da_UnpackAndAdd = UnpackAndOp<Type,AtomicAdd<Type> ,BS,EQ>; 312914b7a73SJunchao Zhang link->da_UnpackAndMult = UnpackAndOp<Type,AtomicMult<Type> ,BS,EQ>; 313914b7a73SJunchao Zhang link->da_UnpackAndMin = UnpackAndOp<Type,AtomicMin<Type> ,BS,EQ>; 314914b7a73SJunchao Zhang link->da_UnpackAndMax = UnpackAndOp<Type,AtomicMax<Type> ,BS,EQ>; 315914b7a73SJunchao Zhang link->da_UnpackAndLAND = UnpackAndOp<Type,AtomicLAND<Type> ,BS,EQ>; 316914b7a73SJunchao Zhang link->da_UnpackAndLOR = UnpackAndOp<Type,AtomicLOR<Type> ,BS,EQ>; 317914b7a73SJunchao Zhang link->da_UnpackAndBAND = UnpackAndOp<Type,AtomicBAND<Type> ,BS,EQ>; 318914b7a73SJunchao Zhang link->da_UnpackAndBOR = UnpackAndOp<Type,AtomicBOR<Type> ,BS,EQ>; 319914b7a73SJunchao Zhang link->da_UnpackAndBXOR = UnpackAndOp<Type,AtomicBXOR<Type> ,BS,EQ>; 320914b7a73SJunchao Zhang link->da_FetchAndAdd = FetchAndOp <Type,AtomicFetchAdd<Type>,BS,EQ>; 321914b7a73SJunchao Zhang 322914b7a73SJunchao Zhang link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>; 323914b7a73SJunchao Zhang link->da_ScatterAndAdd = ScatterAndOp<Type,AtomicAdd<Type> ,BS,EQ>; 324914b7a73SJunchao Zhang link->da_ScatterAndMult = ScatterAndOp<Type,AtomicMult<Type> ,BS,EQ>; 325914b7a73SJunchao Zhang link->da_ScatterAndMin = ScatterAndOp<Type,AtomicMin<Type> ,BS,EQ>; 326914b7a73SJunchao Zhang link->da_ScatterAndMax = ScatterAndOp<Type,AtomicMax<Type> ,BS,EQ>; 327914b7a73SJunchao Zhang link->da_ScatterAndLAND = ScatterAndOp<Type,AtomicLAND<Type> ,BS,EQ>; 328914b7a73SJunchao Zhang link->da_ScatterAndLOR = ScatterAndOp<Type,AtomicLOR<Type> ,BS,EQ>; 329914b7a73SJunchao Zhang link->da_ScatterAndBAND = ScatterAndOp<Type,AtomicBAND<Type> ,BS,EQ>; 330914b7a73SJunchao Zhang link->da_ScatterAndBOR = ScatterAndOp<Type,AtomicBOR<Type> ,BS,EQ>; 331914b7a73SJunchao Zhang link->da_ScatterAndBXOR = ScatterAndOp<Type,AtomicBXOR<Type> ,BS,EQ>; 332914b7a73SJunchao Zhang link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>; 333914b7a73SJunchao Zhang } 334914b7a73SJunchao Zhang 335914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX) 336914b7a73SJunchao Zhang template<typename Type,PetscInt BS,PetscInt EQ> 337914b7a73SJunchao Zhang static void PackInit_ComplexType(PetscSFLink link) 338914b7a73SJunchao Zhang { 339914b7a73SJunchao Zhang link->d_Pack = Pack<Type,BS,EQ>; 340914b7a73SJunchao Zhang link->d_UnpackAndInsert = UnpackAndOp<Type,Insert<Type>,BS,EQ>; 341914b7a73SJunchao Zhang link->d_UnpackAndAdd = UnpackAndOp<Type,Add<Type> ,BS,EQ>; 342914b7a73SJunchao Zhang link->d_UnpackAndMult = UnpackAndOp<Type,Mult<Type> ,BS,EQ>; 343914b7a73SJunchao Zhang link->d_FetchAndAdd = FetchAndOp <Type,Add<Type> ,BS,EQ>; 344914b7a73SJunchao Zhang 345914b7a73SJunchao Zhang link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>; 346914b7a73SJunchao Zhang link->d_ScatterAndAdd = ScatterAndOp<Type,Add<Type> ,BS,EQ>; 347914b7a73SJunchao Zhang link->d_ScatterAndMult = ScatterAndOp<Type,Mult<Type> ,BS,EQ>; 348914b7a73SJunchao Zhang link->d_FetchAndAddLocal = FetchAndOpLocal<Type,Add<Type>,BS,EQ>; 349914b7a73SJunchao Zhang 350914b7a73SJunchao Zhang link->da_UnpackAndInsert = UnpackAndOp<Type,AtomicInsert<Type> ,BS,EQ>; 351914b7a73SJunchao Zhang link->da_UnpackAndAdd = UnpackAndOp<Type,AtomicAdd<Type> ,BS,EQ>; 352914b7a73SJunchao Zhang link->da_UnpackAndMult = UnpackAndOp<Type,AtomicMult<Type> ,BS,EQ>; 353914b7a73SJunchao Zhang link->da_FetchAndAdd = FetchAndOp<Type,AtomicFetchAdd<Type>,BS,EQ>; 354914b7a73SJunchao Zhang 355914b7a73SJunchao Zhang link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>; 356914b7a73SJunchao Zhang link->da_ScatterAndAdd = ScatterAndOp<Type,AtomicAdd<Type> ,BS,EQ>; 357914b7a73SJunchao Zhang link->da_ScatterAndMult = ScatterAndOp<Type,AtomicMult<Type> ,BS,EQ>; 358914b7a73SJunchao Zhang link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>; 359914b7a73SJunchao Zhang } 360914b7a73SJunchao Zhang #endif 361914b7a73SJunchao Zhang 362914b7a73SJunchao Zhang template<typename Type> 363914b7a73SJunchao Zhang static void PackInit_PairType(PetscSFLink link) 364914b7a73SJunchao Zhang { 365914b7a73SJunchao Zhang link->d_Pack = Pack<Type,1,1>; 366914b7a73SJunchao Zhang link->d_UnpackAndInsert = UnpackAndOp<Type,Insert<Type>,1,1>; 367914b7a73SJunchao Zhang link->d_UnpackAndMaxloc = UnpackAndOp<Type,Maxloc<Type>,1,1>; 368914b7a73SJunchao Zhang link->d_UnpackAndMinloc = UnpackAndOp<Type,Minloc<Type>,1,1>; 369914b7a73SJunchao Zhang 370914b7a73SJunchao Zhang link->d_ScatterAndInsert = ScatterAndOp<Type,Insert<Type>,1,1>; 371914b7a73SJunchao Zhang link->d_ScatterAndMaxloc = ScatterAndOp<Type,Maxloc<Type>,1,1>; 372914b7a73SJunchao Zhang link->d_ScatterAndMinloc = ScatterAndOp<Type,Minloc<Type>,1,1>; 373914b7a73SJunchao Zhang /* Atomics for pair types are not implemented yet */ 374914b7a73SJunchao Zhang } 375914b7a73SJunchao Zhang 376914b7a73SJunchao Zhang template<typename Type,PetscInt BS,PetscInt EQ> 377914b7a73SJunchao Zhang static void PackInit_DumbType(PetscSFLink link) 378914b7a73SJunchao Zhang { 379914b7a73SJunchao Zhang link->d_Pack = Pack<Type,BS,EQ>; 380914b7a73SJunchao Zhang link->d_UnpackAndInsert = UnpackAndOp<Type,Insert<Type>,BS,EQ>; 381914b7a73SJunchao Zhang link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>; 382914b7a73SJunchao Zhang /* Atomics for dumb types are not implemented yet */ 383914b7a73SJunchao Zhang } 384914b7a73SJunchao Zhang 385*f4af43b4SJunchao Zhang /* 386*f4af43b4SJunchao Zhang Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug 387*f4af43b4SJunchao Zhang that one is not able to repeatedly create and destroy the object. SF's original design was each 388*f4af43b4SJunchao Zhang SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from 389*f4af43b4SJunchao Zhang destroying multiple SFLinks with NULL stream and the default execution space object. To avoid 390*f4af43b4SJunchao Zhang memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos 391*f4af43b4SJunchao Zhang does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton 392*f4af43b4SJunchao Zhang object in Kokkos. 393*f4af43b4SJunchao Zhang */ 394*f4af43b4SJunchao Zhang /* 395914b7a73SJunchao Zhang static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link) 396914b7a73SJunchao Zhang { 397914b7a73SJunchao Zhang PetscFunctionBegin; 398914b7a73SJunchao Zhang PetscFunctionReturn(0); 399914b7a73SJunchao Zhang } 400*f4af43b4SJunchao Zhang */ 401914b7a73SJunchao Zhang 402914b7a73SJunchao Zhang /* Some device-specific utilities */ 40320c24465SJunchao Zhang static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink link) 404914b7a73SJunchao Zhang { 405914b7a73SJunchao Zhang PetscFunctionBegin; 406914b7a73SJunchao Zhang Kokkos::fence(); 407914b7a73SJunchao Zhang PetscFunctionReturn(0); 408914b7a73SJunchao Zhang } 409914b7a73SJunchao Zhang 41020c24465SJunchao Zhang static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink link) 411914b7a73SJunchao Zhang { 412*f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 413914b7a73SJunchao Zhang PetscFunctionBegin; 414914b7a73SJunchao Zhang exec.fence(); 415914b7a73SJunchao Zhang PetscFunctionReturn(0); 416914b7a73SJunchao Zhang } 417914b7a73SJunchao Zhang 41820c24465SJunchao Zhang static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink link,PetscMemType dstmtype,void* dst,PetscMemType srcmtype,const void*src,size_t n) 419914b7a73SJunchao Zhang { 420*f4af43b4SJunchao Zhang DeviceExecutionSpace exec; 421914b7a73SJunchao Zhang 422914b7a73SJunchao Zhang PetscFunctionBegin; 423914b7a73SJunchao Zhang if (!n) PetscFunctionReturn(0); 424914b7a73SJunchao Zhang if (dstmtype == PETSC_MEMTYPE_HOST && srcmtype == PETSC_MEMTYPE_HOST) { 425914b7a73SJunchao Zhang PetscErrorCode ierr = PetscMemcpy(dst,src,n);CHKERRQ(ierr); 426914b7a73SJunchao Zhang } else { 427914b7a73SJunchao Zhang if (dstmtype == PETSC_MEMTYPE_DEVICE && srcmtype == PETSC_MEMTYPE_HOST) { 428914b7a73SJunchao Zhang deviceBuffer_t dbuf(static_cast<char*>(dst),n); 429914b7a73SJunchao Zhang HostConstBuffer_t sbuf(static_cast<const char*>(src),n); 430914b7a73SJunchao Zhang Kokkos::deep_copy(exec,dbuf,sbuf); 431914b7a73SJunchao Zhang PetscErrorCode ierr = PetscLogCpuToGpu(n);CHKERRQ(ierr); 432914b7a73SJunchao Zhang } else if (dstmtype == PETSC_MEMTYPE_HOST && srcmtype == PETSC_MEMTYPE_DEVICE) { 433914b7a73SJunchao Zhang HostBuffer_t dbuf(static_cast<char*>(dst),n); 434914b7a73SJunchao Zhang deviceConstBuffer_t sbuf(static_cast<const char*>(src),n); 435914b7a73SJunchao Zhang Kokkos::deep_copy(exec,dbuf,sbuf); 436914b7a73SJunchao Zhang PetscErrorCode ierr = PetscLogGpuToCpu(n);CHKERRQ(ierr); 437914b7a73SJunchao Zhang } else if (dstmtype == PETSC_MEMTYPE_DEVICE && srcmtype == PETSC_MEMTYPE_DEVICE) { 438914b7a73SJunchao Zhang deviceBuffer_t dbuf(static_cast<char*>(dst),n); 439914b7a73SJunchao Zhang deviceConstBuffer_t sbuf(static_cast<const char*>(src),n); 440914b7a73SJunchao Zhang Kokkos::deep_copy(exec,dbuf,sbuf); 441914b7a73SJunchao Zhang } 442914b7a73SJunchao Zhang } 443914b7a73SJunchao Zhang PetscFunctionReturn(0); 444914b7a73SJunchao Zhang } 445914b7a73SJunchao Zhang 44620c24465SJunchao Zhang PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype,size_t size,void** ptr) 447914b7a73SJunchao Zhang { 448914b7a73SJunchao Zhang PetscFunctionBegin; 449914b7a73SJunchao Zhang if (mtype == PETSC_MEMTYPE_HOST) {PetscErrorCode ierr = PetscMalloc(size,ptr);CHKERRQ(ierr);} 450914b7a73SJunchao Zhang else if (mtype == PETSC_MEMTYPE_DEVICE) {*ptr = Kokkos::kokkos_malloc<DeviceMemorySpace>(size);} 451914b7a73SJunchao Zhang else SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"Wrong PetscMemType %d", (int)mtype); 452914b7a73SJunchao Zhang PetscFunctionReturn(0); 453914b7a73SJunchao Zhang } 454914b7a73SJunchao Zhang 45520c24465SJunchao Zhang PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype,void* ptr) 456914b7a73SJunchao Zhang { 457914b7a73SJunchao Zhang PetscFunctionBegin; 458914b7a73SJunchao Zhang if (mtype == PETSC_MEMTYPE_HOST) {PetscErrorCode ierr = PetscFree(ptr);CHKERRQ(ierr);} 459914b7a73SJunchao Zhang else if (mtype == PETSC_MEMTYPE_DEVICE) {Kokkos::kokkos_free<DeviceMemorySpace>(ptr);} 460914b7a73SJunchao Zhang else SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"Wrong PetscMemType %d",(int)mtype); 461914b7a73SJunchao Zhang PetscFunctionReturn(0); 462914b7a73SJunchao Zhang } 463914b7a73SJunchao Zhang 464914b7a73SJunchao Zhang /*====================================================================================*/ 465914b7a73SJunchao Zhang /* Main driver to init MPI datatype on device */ 466914b7a73SJunchao Zhang /*====================================================================================*/ 467914b7a73SJunchao Zhang 468914b7a73SJunchao Zhang /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */ 46920c24465SJunchao Zhang PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF sf,PetscSFLink link,MPI_Datatype unit) 470914b7a73SJunchao Zhang { 471914b7a73SJunchao Zhang PetscErrorCode ierr; 472914b7a73SJunchao Zhang PetscInt nSignedChar=0,nUnsignedChar=0,nInt=0,nPetscInt=0,nPetscReal=0; 473914b7a73SJunchao Zhang PetscBool is2Int,is2PetscInt; 474914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX) 475914b7a73SJunchao Zhang PetscInt nPetscComplex=0; 476914b7a73SJunchao Zhang #endif 477914b7a73SJunchao Zhang 478914b7a73SJunchao Zhang PetscFunctionBegin; 479914b7a73SJunchao Zhang if (link->deviceinited) PetscFunctionReturn(0); 480914b7a73SJunchao Zhang ierr = MPIPetsc_Type_compare_contig(unit,MPI_SIGNED_CHAR, &nSignedChar);CHKERRQ(ierr); 481914b7a73SJunchao Zhang ierr = MPIPetsc_Type_compare_contig(unit,MPI_UNSIGNED_CHAR,&nUnsignedChar);CHKERRQ(ierr); 482914b7a73SJunchao Zhang /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */ 483914b7a73SJunchao Zhang ierr = MPIPetsc_Type_compare_contig(unit,MPI_INT, &nInt);CHKERRQ(ierr); 484914b7a73SJunchao Zhang ierr = MPIPetsc_Type_compare_contig(unit,MPIU_INT, &nPetscInt);CHKERRQ(ierr); 485914b7a73SJunchao Zhang ierr = MPIPetsc_Type_compare_contig(unit,MPIU_REAL,&nPetscReal);CHKERRQ(ierr); 486914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX) 487914b7a73SJunchao Zhang ierr = MPIPetsc_Type_compare_contig(unit,MPIU_COMPLEX,&nPetscComplex);CHKERRQ(ierr); 488914b7a73SJunchao Zhang #endif 489914b7a73SJunchao Zhang ierr = MPIPetsc_Type_compare(unit,MPI_2INT,&is2Int);CHKERRQ(ierr); 490914b7a73SJunchao Zhang ierr = MPIPetsc_Type_compare(unit,MPIU_2INT,&is2PetscInt);CHKERRQ(ierr); 491914b7a73SJunchao Zhang 492914b7a73SJunchao Zhang if (is2Int) { 493914b7a73SJunchao Zhang PackInit_PairType<Kokkos::pair<int,int>>(link); 494914b7a73SJunchao Zhang } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */ 495914b7a73SJunchao Zhang PackInit_PairType<Kokkos::pair<PetscInt,PetscInt>>(link); 496914b7a73SJunchao Zhang } else if (nPetscReal) { 497914b7a73SJunchao Zhang if (nPetscReal == 8) PackInit_RealType<PetscReal,8,1>(link); else if (nPetscReal%8 == 0) PackInit_RealType<PetscReal,8,0>(link); 498914b7a73SJunchao Zhang else if (nPetscReal == 4) PackInit_RealType<PetscReal,4,1>(link); else if (nPetscReal%4 == 0) PackInit_RealType<PetscReal,4,0>(link); 499914b7a73SJunchao Zhang else if (nPetscReal == 2) PackInit_RealType<PetscReal,2,1>(link); else if (nPetscReal%2 == 0) PackInit_RealType<PetscReal,2,0>(link); 500914b7a73SJunchao Zhang else if (nPetscReal == 1) PackInit_RealType<PetscReal,1,1>(link); else if (nPetscReal%1 == 0) PackInit_RealType<PetscReal,1,0>(link); 501914b7a73SJunchao Zhang } else if (nPetscInt) { 502914b7a73SJunchao Zhang if (nPetscInt == 8) PackInit_IntegerType<PetscInt,8,1>(link); else if (nPetscInt%8 == 0) PackInit_IntegerType<PetscInt,8,0>(link); 503914b7a73SJunchao Zhang else if (nPetscInt == 4) PackInit_IntegerType<PetscInt,4,1>(link); else if (nPetscInt%4 == 0) PackInit_IntegerType<PetscInt,4,0>(link); 504914b7a73SJunchao Zhang else if (nPetscInt == 2) PackInit_IntegerType<PetscInt,2,1>(link); else if (nPetscInt%2 == 0) PackInit_IntegerType<PetscInt,2,0>(link); 505914b7a73SJunchao Zhang else if (nPetscInt == 1) PackInit_IntegerType<PetscInt,1,1>(link); else if (nPetscInt%1 == 0) PackInit_IntegerType<PetscInt,1,0>(link); 506914b7a73SJunchao Zhang #if defined(PETSC_USE_64BIT_INDICES) 507914b7a73SJunchao Zhang } else if (nInt) { 508914b7a73SJunchao Zhang if (nInt == 8) PackInit_IntegerType<int,8,1>(link); else if (nInt%8 == 0) PackInit_IntegerType<int,8,0>(link); 509914b7a73SJunchao Zhang else if (nInt == 4) PackInit_IntegerType<int,4,1>(link); else if (nInt%4 == 0) PackInit_IntegerType<int,4,0>(link); 510914b7a73SJunchao Zhang else if (nInt == 2) PackInit_IntegerType<int,2,1>(link); else if (nInt%2 == 0) PackInit_IntegerType<int,2,0>(link); 511914b7a73SJunchao Zhang else if (nInt == 1) PackInit_IntegerType<int,1,1>(link); else if (nInt%1 == 0) PackInit_IntegerType<int,1,0>(link); 512914b7a73SJunchao Zhang #endif 513914b7a73SJunchao Zhang } else if (nSignedChar) { 514914b7a73SJunchao Zhang if (nSignedChar == 8) PackInit_IntegerType<char,8,1>(link); else if (nSignedChar%8 == 0) PackInit_IntegerType<char,8,0>(link); 515914b7a73SJunchao Zhang else if (nSignedChar == 4) PackInit_IntegerType<char,4,1>(link); else if (nSignedChar%4 == 0) PackInit_IntegerType<char,4,0>(link); 516914b7a73SJunchao Zhang else if (nSignedChar == 2) PackInit_IntegerType<char,2,1>(link); else if (nSignedChar%2 == 0) PackInit_IntegerType<char,2,0>(link); 517914b7a73SJunchao Zhang else if (nSignedChar == 1) PackInit_IntegerType<char,1,1>(link); else if (nSignedChar%1 == 0) PackInit_IntegerType<char,1,0>(link); 518914b7a73SJunchao Zhang } else if (nUnsignedChar) { 519914b7a73SJunchao Zhang if (nUnsignedChar == 8) PackInit_IntegerType<unsigned char,8,1>(link); else if (nUnsignedChar%8 == 0) PackInit_IntegerType<unsigned char,8,0>(link); 520914b7a73SJunchao Zhang else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char,4,1>(link); else if (nUnsignedChar%4 == 0) PackInit_IntegerType<unsigned char,4,0>(link); 521914b7a73SJunchao Zhang else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char,2,1>(link); else if (nUnsignedChar%2 == 0) PackInit_IntegerType<unsigned char,2,0>(link); 522914b7a73SJunchao Zhang else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char,1,1>(link); else if (nUnsignedChar%1 == 0) PackInit_IntegerType<unsigned char,1,0>(link); 523914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX) 524914b7a73SJunchao Zhang } else if (nPetscComplex) { 525914b7a73SJunchao Zhang if (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>,8,1>(link); else if (nPetscComplex%8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,8,0>(link); 526914b7a73SJunchao Zhang else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>,4,1>(link); else if (nPetscComplex%4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,4,0>(link); 527914b7a73SJunchao Zhang else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>,2,1>(link); else if (nPetscComplex%2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,2,0>(link); 528914b7a73SJunchao Zhang else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>,1,1>(link); else if (nPetscComplex%1 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,1,0>(link); 529914b7a73SJunchao Zhang #endif 530914b7a73SJunchao Zhang } else { 531914b7a73SJunchao Zhang MPI_Aint lb,nbyte; 532914b7a73SJunchao Zhang ierr = MPI_Type_get_extent(unit,&lb,&nbyte);CHKERRQ(ierr); 533914b7a73SJunchao Zhang if (lb != 0) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_SUP,"Datatype with nonzero lower bound %ld\n",(long)lb); 534914b7a73SJunchao Zhang if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */ 535914b7a73SJunchao Zhang if (nbyte == 4) PackInit_DumbType<char,4,1>(link); else if (nbyte%4 == 0) PackInit_DumbType<char,4,0>(link); 536914b7a73SJunchao Zhang else if (nbyte == 2) PackInit_DumbType<char,2,1>(link); else if (nbyte%2 == 0) PackInit_DumbType<char,2,0>(link); 537914b7a73SJunchao Zhang else if (nbyte == 1) PackInit_DumbType<char,1,1>(link); else if (nbyte%1 == 0) PackInit_DumbType<char,1,0>(link); 538914b7a73SJunchao Zhang } else { 539914b7a73SJunchao Zhang nInt = nbyte / sizeof(int); 540914b7a73SJunchao Zhang if (nInt == 8) PackInit_DumbType<int,8,1>(link); else if (nInt%8 == 0) PackInit_DumbType<int,8,0>(link); 541914b7a73SJunchao Zhang else if (nInt == 4) PackInit_DumbType<int,4,1>(link); else if (nInt%4 == 0) PackInit_DumbType<int,4,0>(link); 542914b7a73SJunchao Zhang else if (nInt == 2) PackInit_DumbType<int,2,1>(link); else if (nInt%2 == 0) PackInit_DumbType<int,2,0>(link); 543914b7a73SJunchao Zhang else if (nInt == 1) PackInit_DumbType<int,1,1>(link); else if (nInt%1 == 0) PackInit_DumbType<int,1,0>(link); 544914b7a73SJunchao Zhang } 545914b7a73SJunchao Zhang } 546914b7a73SJunchao Zhang 547*f4af43b4SJunchao Zhang if (!sf->use_default_stream) { 548*f4af43b4SJunchao Zhang #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) 549*f4af43b4SJunchao Zhang SETERRQ(PETSC_COMM_SELF,PETSC_ERR_SUP,"Non-default cuda/hip streams are not supported by the SF Kokkos backend. If it is cuda, use -sf_backend cuda instead"); 550914b7a73SJunchao Zhang #endif 551*f4af43b4SJunchao Zhang } 55220c24465SJunchao Zhang 55320c24465SJunchao Zhang link->d_SyncDevice = PetscSFLinkSyncDevice_Kokkos; 55420c24465SJunchao Zhang link->d_SyncStream = PetscSFLinkSyncStream_Kokkos; 55520c24465SJunchao Zhang link->Memcpy = PetscSFLinkMemcpy_Kokkos; 556*f4af43b4SJunchao Zhang link->spptr = NULL; /* Unused now */ 557*f4af43b4SJunchao Zhang link->Destroy = NULL; /* PetscSFLinkDestroy_Kokkos; */ 558914b7a73SJunchao Zhang link->deviceinited = PETSC_TRUE; 559914b7a73SJunchao Zhang PetscFunctionReturn(0); 560914b7a73SJunchao Zhang } 561