1 #include <../src/vec/is/sf/impls/basic/sfpack.h> 2 3 #include <Kokkos_Core.hpp> 4 5 using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace; 6 using DeviceMemorySpace = typename DeviceExecutionSpace::memory_space; 7 using HostMemorySpace = Kokkos::HostSpace; 8 9 typedef Kokkos::View<char*,DeviceMemorySpace> deviceBuffer_t; 10 typedef Kokkos::View<char*,HostMemorySpace> HostBuffer_t; 11 12 typedef Kokkos::View<const char*,DeviceMemorySpace> deviceConstBuffer_t; 13 typedef Kokkos::View<const char*,HostMemorySpace> HostConstBuffer_t; 14 15 /*====================================================================================*/ 16 /* Regular operations */ 17 /*====================================================================================*/ 18 template<typename Type> struct Insert{KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = y; return old;}}; 19 template<typename Type> struct Add {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x += y; return old;}}; 20 template<typename Type> struct Mult {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x *= y; return old;}}; 21 template<typename Type> struct Min {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = PetscMin(x,y); return old;}}; 22 template<typename Type> struct Max {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = PetscMax(x,y); return old;}}; 23 template<typename Type> struct LAND {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = x && y; return old;}}; 24 template<typename Type> struct LOR {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = x || y; return old;}}; 25 template<typename Type> struct LXOR {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = !x != !y; return old;}}; 26 template<typename Type> struct BAND {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = x & y; return old;}}; 27 template<typename Type> struct BOR {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = x | y; return old;}}; 28 template<typename Type> struct BXOR {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x = x ^ y; return old;}}; 29 template<typename PairType> struct Minloc { 30 KOKKOS_INLINE_FUNCTION PairType operator()(PairType& x,PairType y) const { 31 PairType old = x; 32 if (y.first < x.first) x = y; 33 else if (y.first == x.first) x.second = PetscMin(x.second,y.second); 34 return old; 35 } 36 }; 37 template<typename PairType> struct Maxloc { 38 KOKKOS_INLINE_FUNCTION PairType operator()(PairType& x,PairType y) const { 39 PairType old = x; 40 if (y.first > x.first) x = y; 41 else if (y.first == x.first) x.second = PetscMin(x.second,y.second); /* See MPI MAXLOC */ 42 return old; 43 } 44 }; 45 46 /*====================================================================================*/ 47 /* Atomic operations */ 48 /*====================================================================================*/ 49 template<typename Type> struct AtomicInsert {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_assign(&x,y);}}; 50 template<typename Type> struct AtomicAdd {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_add(&x,y);}}; 51 template<typename Type> struct AtomicBAND {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_and(&x,y);}}; 52 template<typename Type> struct AtomicBOR {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_or (&x,y);}}; 53 template<typename Type> struct AtomicBXOR {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_xor(&x,y);}}; 54 template<typename Type> struct AtomicLAND {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {const Type zero=0,one=~0; Kokkos::atomic_and(&x,y?one:zero);}}; 55 template<typename Type> struct AtomicLOR {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {const Type zero=0,one=1; Kokkos::atomic_or (&x,y?one:zero);}}; 56 template<typename Type> struct AtomicMult {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_mul(&x,y);}}; 57 template<typename Type> struct AtomicMin {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_min(&x,y);}}; 58 template<typename Type> struct AtomicMax {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_max(&x,y);}}; 59 /* TODO: struct AtomicLXOR */ 60 template<typename Type> struct AtomicFetchAdd{KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {return Kokkos::atomic_fetch_add(&x,y);}}; 61 62 /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */ 63 static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt,PetscInt tid) 64 { 65 PetscInt i,j,k,m,n,r; 66 const PetscInt *offset,*start,*dx,*dy,*X,*Y; 67 68 n = opt[0]; 69 offset = opt + 1; 70 start = opt + n + 2; 71 dx = opt + 2*n + 2; 72 dy = opt + 3*n + 2; 73 X = opt + 5*n + 2; 74 Y = opt + 6*n + 2; 75 for (r=0; r<n; r++) {if (tid < offset[r+1]) break;} 76 m = (tid - offset[r]); 77 k = m/(dx[r]*dy[r]); 78 j = (m - k*dx[r]*dy[r])/dx[r]; 79 i = m - k*dx[r]*dy[r] - j*dx[r]; 80 81 return (start[r] + k*X[r]*Y[r] + j*X[r] + i); 82 } 83 84 /*====================================================================================*/ 85 /* Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link' */ 86 /*====================================================================================*/ 87 88 /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then 89 <Type> is PetscReal, which is the primitive type we operate on. 90 <bs> is 16, which says <unit> contains 16 primitive types. 91 <BS> is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>. 92 <EQ> is 0, which is (bs == BS ? 1 : 0) 93 94 If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant. 95 For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled. 96 */ 97 template<typename Type,PetscInt BS,PetscInt EQ> 98 static PetscErrorCode Pack(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,const void *data_,void *buf_) 99 { 100 const PetscInt *iopt = opt ? opt->array : NULL; 101 const PetscInt M = EQ ? 1 : link->bs/BS, MBS=M*BS; /* If EQ, then MBS will be a compile-time const */ 102 const Type *data = static_cast<const Type*>(data_); 103 Type *buf = static_cast<Type*>(buf_); 104 DeviceExecutionSpace& exec = *static_cast<DeviceExecutionSpace*>(link->sptr); 105 106 PetscFunctionBegin; 107 Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) { 108 /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous; 109 iopt == NULL && idx == NULL ==> the indices are contiguous; 110 */ 111 PetscInt t = (iopt? MapTidToIndex(iopt,tid) : (idx? idx[tid] : start+tid))*MBS; 112 PetscInt s = tid*MBS; 113 for (int i=0; i<MBS; i++) buf[s+i] = data[t+i]; 114 }); 115 PetscFunctionReturn(0); 116 } 117 118 template<typename Type,class Op,PetscInt BS,PetscInt EQ> 119 static PetscErrorCode UnpackAndOp(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,void *data_,const void *buf_) 120 { 121 Op op; 122 const PetscInt *iopt = opt ? opt->array : NULL; 123 const PetscInt M = EQ ? 1 : link->bs/BS, MBS=M*BS; 124 Type *data = static_cast<Type*>(data_); 125 const Type *buf = static_cast<const Type*>(buf_); 126 DeviceExecutionSpace& exec = *static_cast<DeviceExecutionSpace*>(link->sptr); 127 128 PetscFunctionBegin; 129 Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) { 130 PetscInt t = (iopt? MapTidToIndex(iopt,tid) : (idx? idx[tid] : start+tid))*MBS; 131 PetscInt s = tid*MBS; 132 for (int i=0; i<MBS; i++) op(data[t+i],buf[s+i]); 133 }); 134 PetscFunctionReturn(0); 135 } 136 137 template<typename Type,class Op,PetscInt BS,PetscInt EQ> 138 static PetscErrorCode FetchAndOp(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,void *data,void *buf) 139 { 140 Op op; 141 const PetscInt *ropt = opt ? opt->array : NULL; 142 const PetscInt M = EQ ? 1 : link->bs/BS, MBS=M*BS; 143 Type *rootdata = static_cast<Type*>(data),*leafbuf=static_cast<Type*>(buf); 144 DeviceExecutionSpace& exec = *static_cast<DeviceExecutionSpace*>(link->sptr); 145 146 PetscFunctionBegin; 147 Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) { 148 PetscInt r = (ropt? MapTidToIndex(ropt,tid) : (idx? idx[tid] : start+tid))*MBS; 149 PetscInt l = tid*MBS; 150 for (int i=0; i<MBS; i++) leafbuf[l+i] = op(rootdata[r+i],leafbuf[l+i]); 151 }); 152 PetscFunctionReturn(0); 153 } 154 155 template<typename Type,class Op,PetscInt BS,PetscInt EQ> 156 static PetscErrorCode ScatterAndOp(PetscSFLink link,PetscInt count,PetscInt srcStart,PetscSFPackOpt srcOpt,const PetscInt *srcIdx,const void *src_,PetscInt dstStart,PetscSFPackOpt dstOpt,const PetscInt *dstIdx,void *dst_) 157 { 158 PetscInt srcx=0,srcy=0,srcX=0,srcY=0,dstx=0,dsty=0,dstX=0,dstY=0; 159 const PetscInt M = (EQ) ? 1 : link->bs/BS, MBS=M*BS; 160 const Type *src = static_cast<const Type*>(src_); 161 Type *dst = static_cast<Type*>(dst_); 162 DeviceExecutionSpace& exec = *static_cast<DeviceExecutionSpace*>(link->sptr); 163 164 PetscFunctionBegin; 165 /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */ 166 if (srcOpt) {srcx = srcOpt->dx[0]; srcy = srcOpt->dy[0]; srcX = srcOpt->X[0]; srcY = srcOpt->Y[0]; srcStart = srcOpt->start[0]; srcIdx = NULL;} 167 else if (!srcIdx) {srcx = srcX = count; srcy = srcY = 1;} 168 169 if (dstOpt) {dstx = dstOpt->dx[0]; dsty = dstOpt->dy[0]; dstX = dstOpt->X[0]; dstY = dstOpt->Y[0]; dstStart = dstOpt->start[0]; dstIdx = NULL;} 170 else if (!dstIdx) {dstx = dstX = count; dsty = dstY = 1;} 171 172 Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) { 173 PetscInt i,j,k,s,t; 174 Op op; 175 if (!srcIdx) { /* src is in 3D */ 176 k = tid/(srcx*srcy); 177 j = (tid - k*srcx*srcy)/srcx; 178 i = tid - k*srcx*srcy - j*srcx; 179 s = srcStart + k*srcX*srcY + j*srcX + i; 180 } else { /* src is contiguous */ 181 s = srcIdx[tid]; 182 } 183 184 if (!dstIdx) { /* 3D */ 185 k = tid/(dstx*dsty); 186 j = (tid - k*dstx*dsty)/dstx; 187 i = tid - k*dstx*dsty - j*dstx; 188 t = dstStart + k*dstX*dstY + j*dstX + i; 189 } else { /* contiguous */ 190 t = dstIdx[tid]; 191 } 192 193 s *= MBS; 194 t *= MBS; 195 for (i=0; i<MBS; i++) op(dst[t+i],src[s+i]); 196 }); 197 PetscFunctionReturn(0); 198 } 199 200 /* Specialization for Insert since we may use memcpy */ 201 template<typename Type,PetscInt BS,PetscInt EQ> 202 static PetscErrorCode ScatterAndInsert(PetscSFLink link,PetscInt count,PetscInt srcStart,PetscSFPackOpt srcOpt,const PetscInt *srcIdx,const void *src_,PetscInt dstStart,PetscSFPackOpt dstOpt,const PetscInt *dstIdx,void *dst_) 203 { 204 PetscErrorCode ierr; 205 const Type *src = static_cast<const Type*>(src_); 206 Type *dst = static_cast<Type*>(dst_); 207 DeviceExecutionSpace& exec = *static_cast<DeviceExecutionSpace*>(link->sptr); 208 209 PetscFunctionBegin; 210 if (!count) PetscFunctionReturn(0); 211 /*src and dst are contiguous */ 212 if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) { 213 size_t sz = count*link->unitbytes; 214 deviceBuffer_t dbuf(reinterpret_cast<char*>(dst+dstStart*link->bs),sz); 215 deviceConstBuffer_t sbuf(reinterpret_cast<const char*>(src+srcStart*link->bs),sz); 216 Kokkos::deep_copy(exec,dbuf,sbuf); 217 } else { 218 ierr = ScatterAndOp<Type,Insert<Type>,BS,EQ>(link,count,srcStart,srcOpt,srcIdx,src,dstStart,dstOpt,dstIdx,dst);CHKERRQ(ierr); 219 } 220 PetscFunctionReturn(0); 221 } 222 223 template<typename Type,class Op,PetscInt BS,PetscInt EQ> 224 static PetscErrorCode FetchAndOpLocal(PetscSFLink link,PetscInt count,PetscInt rootstart,PetscSFPackOpt rootopt,const PetscInt *rootidx,void *rootdata_,PetscInt leafstart,PetscSFPackOpt leafopt,const PetscInt *leafidx,const void *leafdata_,void *leafupdate_) 225 { 226 Op op; 227 const PetscInt M = (EQ) ? 1 : link->bs/BS, MBS = M*BS; 228 const PetscInt *ropt = rootopt ? rootopt->array : NULL; 229 const PetscInt *lopt = leafopt ? leafopt->array : NULL; 230 Type *rootdata = static_cast<Type*>(rootdata_),*leafupdate = static_cast<Type*>(leafupdate_); 231 const Type *leafdata = static_cast<const Type*>(leafdata_); 232 DeviceExecutionSpace& exec = *static_cast<DeviceExecutionSpace*>(link->sptr); 233 234 PetscFunctionBegin; 235 Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) { 236 PetscInt r = (ropt? MapTidToIndex(ropt,tid) : (rootidx? rootidx[tid] : rootstart+tid))*MBS; 237 PetscInt l = (lopt? MapTidToIndex(lopt,tid) : (leafidx? leafidx[tid] : leafstart+tid))*MBS; 238 for (int i=0; i<MBS; i++) leafupdate[l+i] = op(rootdata[r+i],leafdata[l+i]); 239 }); 240 PetscFunctionReturn(0); 241 } 242 243 /*====================================================================================*/ 244 /* Init various types and instantiate pack/unpack function pointers */ 245 /*====================================================================================*/ 246 template<typename Type,PetscInt BS,PetscInt EQ> 247 static void PackInit_RealType(PetscSFLink link) 248 { 249 /* Pack/unpack for remote communication */ 250 link->d_Pack = Pack<Type,BS,EQ>; 251 link->d_UnpackAndInsert = UnpackAndOp<Type,Insert<Type>,BS,EQ>; 252 link->d_UnpackAndAdd = UnpackAndOp<Type,Add<Type> ,BS,EQ>; 253 link->d_UnpackAndMult = UnpackAndOp<Type,Mult<Type> ,BS,EQ>; 254 link->d_UnpackAndMin = UnpackAndOp<Type,Min<Type> ,BS,EQ>; 255 link->d_UnpackAndMax = UnpackAndOp<Type,Max<Type> ,BS,EQ>; 256 link->d_FetchAndAdd = FetchAndOp <Type,Add<Type> ,BS,EQ>; 257 /* Scatter for local communication */ 258 link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>; /* Has special optimizations */ 259 link->d_ScatterAndAdd = ScatterAndOp<Type,Add<Type> ,BS,EQ>; 260 link->d_ScatterAndMult = ScatterAndOp<Type,Mult<Type> ,BS,EQ>; 261 link->d_ScatterAndMin = ScatterAndOp<Type,Min<Type> ,BS,EQ>; 262 link->d_ScatterAndMax = ScatterAndOp<Type,Max<Type> ,BS,EQ>; 263 link->d_FetchAndAddLocal = FetchAndOpLocal<Type,Add <Type>,BS,EQ>; 264 /* Atomic versions when there are data-race possibilities */ 265 link->da_UnpackAndInsert = UnpackAndOp<Type,AtomicInsert<Type> ,BS,EQ>; 266 link->da_UnpackAndAdd = UnpackAndOp<Type,AtomicAdd<Type> ,BS,EQ>; 267 link->da_UnpackAndMult = UnpackAndOp<Type,AtomicMult<Type> ,BS,EQ>; 268 link->da_UnpackAndMin = UnpackAndOp<Type,AtomicMin<Type> ,BS,EQ>; 269 link->da_UnpackAndMax = UnpackAndOp<Type,AtomicMax<Type> ,BS,EQ>; 270 link->da_FetchAndAdd = FetchAndOp <Type,AtomicFetchAdd<Type>,BS,EQ>; 271 272 link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>; 273 link->da_ScatterAndAdd = ScatterAndOp<Type,AtomicAdd<Type> ,BS,EQ>; 274 link->da_ScatterAndMult = ScatterAndOp<Type,AtomicMult<Type> ,BS,EQ>; 275 link->da_ScatterAndMin = ScatterAndOp<Type,AtomicMin<Type> ,BS,EQ>; 276 link->da_ScatterAndMax = ScatterAndOp<Type,AtomicMax<Type> ,BS,EQ>; 277 link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>; 278 } 279 280 template<typename Type,PetscInt BS,PetscInt EQ> 281 static void PackInit_IntegerType(PetscSFLink link) 282 { 283 link->d_Pack = Pack<Type,BS,EQ>; 284 link->d_UnpackAndInsert = UnpackAndOp<Type,Insert<Type> ,BS,EQ>; 285 link->d_UnpackAndAdd = UnpackAndOp<Type,Add<Type> ,BS,EQ>; 286 link->d_UnpackAndMult = UnpackAndOp<Type,Mult<Type> ,BS,EQ>; 287 link->d_UnpackAndMin = UnpackAndOp<Type,Min<Type> ,BS,EQ>; 288 link->d_UnpackAndMax = UnpackAndOp<Type,Max<Type> ,BS,EQ>; 289 link->d_UnpackAndLAND = UnpackAndOp<Type,LAND<Type> ,BS,EQ>; 290 link->d_UnpackAndLOR = UnpackAndOp<Type,LOR<Type> ,BS,EQ>; 291 link->d_UnpackAndLXOR = UnpackAndOp<Type,LXOR<Type> ,BS,EQ>; 292 link->d_UnpackAndBAND = UnpackAndOp<Type,BAND<Type> ,BS,EQ>; 293 link->d_UnpackAndBOR = UnpackAndOp<Type,BOR<Type> ,BS,EQ>; 294 link->d_UnpackAndBXOR = UnpackAndOp<Type,BXOR<Type> ,BS,EQ>; 295 link->d_FetchAndAdd = FetchAndOp <Type,Add<Type> ,BS,EQ>; 296 297 link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>; 298 link->d_ScatterAndAdd = ScatterAndOp<Type,Add<Type> ,BS,EQ>; 299 link->d_ScatterAndMult = ScatterAndOp<Type,Mult<Type> ,BS,EQ>; 300 link->d_ScatterAndMin = ScatterAndOp<Type,Min<Type> ,BS,EQ>; 301 link->d_ScatterAndMax = ScatterAndOp<Type,Max<Type> ,BS,EQ>; 302 link->d_ScatterAndLAND = ScatterAndOp<Type,LAND<Type> ,BS,EQ>; 303 link->d_ScatterAndLOR = ScatterAndOp<Type,LOR<Type> ,BS,EQ>; 304 link->d_ScatterAndLXOR = ScatterAndOp<Type,LXOR<Type> ,BS,EQ>; 305 link->d_ScatterAndBAND = ScatterAndOp<Type,BAND<Type> ,BS,EQ>; 306 link->d_ScatterAndBOR = ScatterAndOp<Type,BOR<Type> ,BS,EQ>; 307 link->d_ScatterAndBXOR = ScatterAndOp<Type,BXOR<Type> ,BS,EQ>; 308 link->d_FetchAndAddLocal = FetchAndOpLocal<Type,Add<Type>,BS,EQ>; 309 310 link->da_UnpackAndInsert = UnpackAndOp<Type,AtomicInsert<Type>,BS,EQ>; 311 link->da_UnpackAndAdd = UnpackAndOp<Type,AtomicAdd<Type> ,BS,EQ>; 312 link->da_UnpackAndMult = UnpackAndOp<Type,AtomicMult<Type> ,BS,EQ>; 313 link->da_UnpackAndMin = UnpackAndOp<Type,AtomicMin<Type> ,BS,EQ>; 314 link->da_UnpackAndMax = UnpackAndOp<Type,AtomicMax<Type> ,BS,EQ>; 315 link->da_UnpackAndLAND = UnpackAndOp<Type,AtomicLAND<Type> ,BS,EQ>; 316 link->da_UnpackAndLOR = UnpackAndOp<Type,AtomicLOR<Type> ,BS,EQ>; 317 link->da_UnpackAndBAND = UnpackAndOp<Type,AtomicBAND<Type> ,BS,EQ>; 318 link->da_UnpackAndBOR = UnpackAndOp<Type,AtomicBOR<Type> ,BS,EQ>; 319 link->da_UnpackAndBXOR = UnpackAndOp<Type,AtomicBXOR<Type> ,BS,EQ>; 320 link->da_FetchAndAdd = FetchAndOp <Type,AtomicFetchAdd<Type>,BS,EQ>; 321 322 link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>; 323 link->da_ScatterAndAdd = ScatterAndOp<Type,AtomicAdd<Type> ,BS,EQ>; 324 link->da_ScatterAndMult = ScatterAndOp<Type,AtomicMult<Type> ,BS,EQ>; 325 link->da_ScatterAndMin = ScatterAndOp<Type,AtomicMin<Type> ,BS,EQ>; 326 link->da_ScatterAndMax = ScatterAndOp<Type,AtomicMax<Type> ,BS,EQ>; 327 link->da_ScatterAndLAND = ScatterAndOp<Type,AtomicLAND<Type> ,BS,EQ>; 328 link->da_ScatterAndLOR = ScatterAndOp<Type,AtomicLOR<Type> ,BS,EQ>; 329 link->da_ScatterAndBAND = ScatterAndOp<Type,AtomicBAND<Type> ,BS,EQ>; 330 link->da_ScatterAndBOR = ScatterAndOp<Type,AtomicBOR<Type> ,BS,EQ>; 331 link->da_ScatterAndBXOR = ScatterAndOp<Type,AtomicBXOR<Type> ,BS,EQ>; 332 link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>; 333 } 334 335 #if defined(PETSC_HAVE_COMPLEX) 336 template<typename Type,PetscInt BS,PetscInt EQ> 337 static void PackInit_ComplexType(PetscSFLink link) 338 { 339 link->d_Pack = Pack<Type,BS,EQ>; 340 link->d_UnpackAndInsert = UnpackAndOp<Type,Insert<Type>,BS,EQ>; 341 link->d_UnpackAndAdd = UnpackAndOp<Type,Add<Type> ,BS,EQ>; 342 link->d_UnpackAndMult = UnpackAndOp<Type,Mult<Type> ,BS,EQ>; 343 link->d_FetchAndAdd = FetchAndOp <Type,Add<Type> ,BS,EQ>; 344 345 link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>; 346 link->d_ScatterAndAdd = ScatterAndOp<Type,Add<Type> ,BS,EQ>; 347 link->d_ScatterAndMult = ScatterAndOp<Type,Mult<Type> ,BS,EQ>; 348 link->d_FetchAndAddLocal = FetchAndOpLocal<Type,Add<Type>,BS,EQ>; 349 350 link->da_UnpackAndInsert = UnpackAndOp<Type,AtomicInsert<Type> ,BS,EQ>; 351 link->da_UnpackAndAdd = UnpackAndOp<Type,AtomicAdd<Type> ,BS,EQ>; 352 link->da_UnpackAndMult = UnpackAndOp<Type,AtomicMult<Type> ,BS,EQ>; 353 link->da_FetchAndAdd = FetchAndOp<Type,AtomicFetchAdd<Type>,BS,EQ>; 354 355 link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>; 356 link->da_ScatterAndAdd = ScatterAndOp<Type,AtomicAdd<Type> ,BS,EQ>; 357 link->da_ScatterAndMult = ScatterAndOp<Type,AtomicMult<Type> ,BS,EQ>; 358 link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>; 359 } 360 #endif 361 362 template<typename Type> 363 static void PackInit_PairType(PetscSFLink link) 364 { 365 link->d_Pack = Pack<Type,1,1>; 366 link->d_UnpackAndInsert = UnpackAndOp<Type,Insert<Type>,1,1>; 367 link->d_UnpackAndMaxloc = UnpackAndOp<Type,Maxloc<Type>,1,1>; 368 link->d_UnpackAndMinloc = UnpackAndOp<Type,Minloc<Type>,1,1>; 369 370 link->d_ScatterAndInsert = ScatterAndOp<Type,Insert<Type>,1,1>; 371 link->d_ScatterAndMaxloc = ScatterAndOp<Type,Maxloc<Type>,1,1>; 372 link->d_ScatterAndMinloc = ScatterAndOp<Type,Minloc<Type>,1,1>; 373 /* Atomics for pair types are not implemented yet */ 374 } 375 376 template<typename Type,PetscInt BS,PetscInt EQ> 377 static void PackInit_DumbType(PetscSFLink link) 378 { 379 link->d_Pack = Pack<Type,BS,EQ>; 380 link->d_UnpackAndInsert = UnpackAndOp<Type,Insert<Type>,BS,EQ>; 381 link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>; 382 /* Atomics for dumb types are not implemented yet */ 383 } 384 385 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link) 386 { 387 PetscFunctionBegin; 388 delete static_cast<DeviceExecutionSpace*>(link->sptr); 389 PetscFunctionReturn(0); 390 } 391 392 /* Some device-specific utilities */ 393 static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink link) 394 { 395 PetscFunctionBegin; 396 Kokkos::fence(); 397 PetscFunctionReturn(0); 398 } 399 400 static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink link) 401 { 402 DeviceExecutionSpace& exec = *static_cast<DeviceExecutionSpace*>(link->sptr); 403 PetscFunctionBegin; 404 exec.fence(); 405 PetscFunctionReturn(0); 406 } 407 408 static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink link,PetscMemType dstmtype,void* dst,PetscMemType srcmtype,const void*src,size_t n) 409 { 410 DeviceExecutionSpace& exec = *static_cast<DeviceExecutionSpace*>(link->sptr); 411 412 PetscFunctionBegin; 413 if (!n) PetscFunctionReturn(0); 414 if (dstmtype == PETSC_MEMTYPE_HOST && srcmtype == PETSC_MEMTYPE_HOST) { 415 PetscErrorCode ierr = PetscMemcpy(dst,src,n);CHKERRQ(ierr); 416 } else { 417 if (dstmtype == PETSC_MEMTYPE_DEVICE && srcmtype == PETSC_MEMTYPE_HOST) { 418 deviceBuffer_t dbuf(static_cast<char*>(dst),n); 419 HostConstBuffer_t sbuf(static_cast<const char*>(src),n); 420 Kokkos::deep_copy(exec,dbuf,sbuf); 421 PetscErrorCode ierr = PetscLogCpuToGpu(n);CHKERRQ(ierr); 422 } else if (dstmtype == PETSC_MEMTYPE_HOST && srcmtype == PETSC_MEMTYPE_DEVICE) { 423 HostBuffer_t dbuf(static_cast<char*>(dst),n); 424 deviceConstBuffer_t sbuf(static_cast<const char*>(src),n); 425 Kokkos::deep_copy(exec,dbuf,sbuf); 426 PetscErrorCode ierr = PetscLogGpuToCpu(n);CHKERRQ(ierr); 427 } else if (dstmtype == PETSC_MEMTYPE_DEVICE && srcmtype == PETSC_MEMTYPE_DEVICE) { 428 deviceBuffer_t dbuf(static_cast<char*>(dst),n); 429 deviceConstBuffer_t sbuf(static_cast<const char*>(src),n); 430 Kokkos::deep_copy(exec,dbuf,sbuf); 431 } 432 } 433 PetscFunctionReturn(0); 434 } 435 436 PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype,size_t size,void** ptr) 437 { 438 PetscFunctionBegin; 439 if (mtype == PETSC_MEMTYPE_HOST) {PetscErrorCode ierr = PetscMalloc(size,ptr);CHKERRQ(ierr);} 440 else if (mtype == PETSC_MEMTYPE_DEVICE) {*ptr = Kokkos::kokkos_malloc<DeviceMemorySpace>(size);} 441 else SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"Wrong PetscMemType %d", (int)mtype); 442 PetscFunctionReturn(0); 443 } 444 445 PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype,void* ptr) 446 { 447 PetscFunctionBegin; 448 if (mtype == PETSC_MEMTYPE_HOST) {PetscErrorCode ierr = PetscFree(ptr);CHKERRQ(ierr);} 449 else if (mtype == PETSC_MEMTYPE_DEVICE) {Kokkos::kokkos_free<DeviceMemorySpace>(ptr);} 450 else SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"Wrong PetscMemType %d",(int)mtype); 451 PetscFunctionReturn(0); 452 } 453 454 /*====================================================================================*/ 455 /* Main driver to init MPI datatype on device */ 456 /*====================================================================================*/ 457 458 /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */ 459 PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF sf,PetscSFLink link,MPI_Datatype unit) 460 { 461 PetscErrorCode ierr; 462 PetscInt nSignedChar=0,nUnsignedChar=0,nInt=0,nPetscInt=0,nPetscReal=0; 463 PetscBool is2Int,is2PetscInt; 464 #if defined(PETSC_HAVE_COMPLEX) 465 PetscInt nPetscComplex=0; 466 #endif 467 468 PetscFunctionBegin; 469 if (link->deviceinited) PetscFunctionReturn(0); 470 ierr = MPIPetsc_Type_compare_contig(unit,MPI_SIGNED_CHAR, &nSignedChar);CHKERRQ(ierr); 471 ierr = MPIPetsc_Type_compare_contig(unit,MPI_UNSIGNED_CHAR,&nUnsignedChar);CHKERRQ(ierr); 472 /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */ 473 ierr = MPIPetsc_Type_compare_contig(unit,MPI_INT, &nInt);CHKERRQ(ierr); 474 ierr = MPIPetsc_Type_compare_contig(unit,MPIU_INT, &nPetscInt);CHKERRQ(ierr); 475 ierr = MPIPetsc_Type_compare_contig(unit,MPIU_REAL,&nPetscReal);CHKERRQ(ierr); 476 #if defined(PETSC_HAVE_COMPLEX) 477 ierr = MPIPetsc_Type_compare_contig(unit,MPIU_COMPLEX,&nPetscComplex);CHKERRQ(ierr); 478 #endif 479 ierr = MPIPetsc_Type_compare(unit,MPI_2INT,&is2Int);CHKERRQ(ierr); 480 ierr = MPIPetsc_Type_compare(unit,MPIU_2INT,&is2PetscInt);CHKERRQ(ierr); 481 482 if (is2Int) { 483 PackInit_PairType<Kokkos::pair<int,int>>(link); 484 } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */ 485 PackInit_PairType<Kokkos::pair<PetscInt,PetscInt>>(link); 486 } else if (nPetscReal) { 487 if (nPetscReal == 8) PackInit_RealType<PetscReal,8,1>(link); else if (nPetscReal%8 == 0) PackInit_RealType<PetscReal,8,0>(link); 488 else if (nPetscReal == 4) PackInit_RealType<PetscReal,4,1>(link); else if (nPetscReal%4 == 0) PackInit_RealType<PetscReal,4,0>(link); 489 else if (nPetscReal == 2) PackInit_RealType<PetscReal,2,1>(link); else if (nPetscReal%2 == 0) PackInit_RealType<PetscReal,2,0>(link); 490 else if (nPetscReal == 1) PackInit_RealType<PetscReal,1,1>(link); else if (nPetscReal%1 == 0) PackInit_RealType<PetscReal,1,0>(link); 491 } else if (nPetscInt) { 492 if (nPetscInt == 8) PackInit_IntegerType<PetscInt,8,1>(link); else if (nPetscInt%8 == 0) PackInit_IntegerType<PetscInt,8,0>(link); 493 else if (nPetscInt == 4) PackInit_IntegerType<PetscInt,4,1>(link); else if (nPetscInt%4 == 0) PackInit_IntegerType<PetscInt,4,0>(link); 494 else if (nPetscInt == 2) PackInit_IntegerType<PetscInt,2,1>(link); else if (nPetscInt%2 == 0) PackInit_IntegerType<PetscInt,2,0>(link); 495 else if (nPetscInt == 1) PackInit_IntegerType<PetscInt,1,1>(link); else if (nPetscInt%1 == 0) PackInit_IntegerType<PetscInt,1,0>(link); 496 #if defined(PETSC_USE_64BIT_INDICES) 497 } else if (nInt) { 498 if (nInt == 8) PackInit_IntegerType<int,8,1>(link); else if (nInt%8 == 0) PackInit_IntegerType<int,8,0>(link); 499 else if (nInt == 4) PackInit_IntegerType<int,4,1>(link); else if (nInt%4 == 0) PackInit_IntegerType<int,4,0>(link); 500 else if (nInt == 2) PackInit_IntegerType<int,2,1>(link); else if (nInt%2 == 0) PackInit_IntegerType<int,2,0>(link); 501 else if (nInt == 1) PackInit_IntegerType<int,1,1>(link); else if (nInt%1 == 0) PackInit_IntegerType<int,1,0>(link); 502 #endif 503 } else if (nSignedChar) { 504 if (nSignedChar == 8) PackInit_IntegerType<char,8,1>(link); else if (nSignedChar%8 == 0) PackInit_IntegerType<char,8,0>(link); 505 else if (nSignedChar == 4) PackInit_IntegerType<char,4,1>(link); else if (nSignedChar%4 == 0) PackInit_IntegerType<char,4,0>(link); 506 else if (nSignedChar == 2) PackInit_IntegerType<char,2,1>(link); else if (nSignedChar%2 == 0) PackInit_IntegerType<char,2,0>(link); 507 else if (nSignedChar == 1) PackInit_IntegerType<char,1,1>(link); else if (nSignedChar%1 == 0) PackInit_IntegerType<char,1,0>(link); 508 } else if (nUnsignedChar) { 509 if (nUnsignedChar == 8) PackInit_IntegerType<unsigned char,8,1>(link); else if (nUnsignedChar%8 == 0) PackInit_IntegerType<unsigned char,8,0>(link); 510 else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char,4,1>(link); else if (nUnsignedChar%4 == 0) PackInit_IntegerType<unsigned char,4,0>(link); 511 else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char,2,1>(link); else if (nUnsignedChar%2 == 0) PackInit_IntegerType<unsigned char,2,0>(link); 512 else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char,1,1>(link); else if (nUnsignedChar%1 == 0) PackInit_IntegerType<unsigned char,1,0>(link); 513 #if defined(PETSC_HAVE_COMPLEX) 514 } else if (nPetscComplex) { 515 if (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>,8,1>(link); else if (nPetscComplex%8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,8,0>(link); 516 else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>,4,1>(link); else if (nPetscComplex%4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,4,0>(link); 517 else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>,2,1>(link); else if (nPetscComplex%2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,2,0>(link); 518 else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>,1,1>(link); else if (nPetscComplex%1 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,1,0>(link); 519 #endif 520 } else { 521 MPI_Aint lb,nbyte; 522 ierr = MPI_Type_get_extent(unit,&lb,&nbyte);CHKERRQ(ierr); 523 if (lb != 0) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_SUP,"Datatype with nonzero lower bound %ld\n",(long)lb); 524 if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */ 525 if (nbyte == 4) PackInit_DumbType<char,4,1>(link); else if (nbyte%4 == 0) PackInit_DumbType<char,4,0>(link); 526 else if (nbyte == 2) PackInit_DumbType<char,2,1>(link); else if (nbyte%2 == 0) PackInit_DumbType<char,2,0>(link); 527 else if (nbyte == 1) PackInit_DumbType<char,1,1>(link); else if (nbyte%1 == 0) PackInit_DumbType<char,1,0>(link); 528 } else { 529 nInt = nbyte / sizeof(int); 530 if (nInt == 8) PackInit_DumbType<int,8,1>(link); else if (nInt%8 == 0) PackInit_DumbType<int,8,0>(link); 531 else if (nInt == 4) PackInit_DumbType<int,4,1>(link); else if (nInt%4 == 0) PackInit_DumbType<int,4,0>(link); 532 else if (nInt == 2) PackInit_DumbType<int,2,1>(link); else if (nInt%2 == 0) PackInit_DumbType<int,2,0>(link); 533 else if (nInt == 1) PackInit_DumbType<int,1,1>(link); else if (nInt%1 == 0) PackInit_DumbType<int,1,0>(link); 534 } 535 } 536 537 #if defined(KOKKOS_ENABLE_CUDA) 538 if (!sf->use_default_stream) {cudaError_t cerr = cudaStreamCreate(&link->stream);CHKERRCUDA(cerr);} 539 link->sptr = new DeviceExecutionSpace(link->stream); 540 #elif defined(KOKKOS_ENABLE_HIP) 541 if (!sf->use_default_stream) {hipError_t cerr = hipStreamCreate(&link->stream);CHKERRQ(cerr);} 542 link->sptr = new DeviceExecutionSpace(link->stream); 543 #endif 544 545 link->d_SyncDevice = PetscSFLinkSyncDevice_Kokkos; 546 link->d_SyncStream = PetscSFLinkSyncStream_Kokkos; 547 link->Memcpy = PetscSFLinkMemcpy_Kokkos; 548 link->Destroy = PetscSFLinkDestroy_Kokkos; 549 link->deviceinited = PETSC_TRUE; 550 PetscFunctionReturn(0); 551 } 552