xref: /petsc/src/vec/is/sf/impls/basic/kokkos/sfkok.kokkos.cxx (revision 503c0ea9b45bcfbcebbb1ea5341243bbc69f0bea)
1 #include <../src/vec/is/sf/impls/basic/sfpack.h>
2 
3 #include <Kokkos_Core.hpp>
4 
5 using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace;
6 using DeviceMemorySpace    = typename DeviceExecutionSpace::memory_space;
7 using HostMemorySpace      = Kokkos::HostSpace;
8 
9 typedef Kokkos::View<char*,DeviceMemorySpace>       deviceBuffer_t;
10 typedef Kokkos::View<char*,HostMemorySpace>         HostBuffer_t;
11 
12 typedef Kokkos::View<const char*,DeviceMemorySpace> deviceConstBuffer_t;
13 typedef Kokkos::View<const char*,HostMemorySpace>   HostConstBuffer_t;
14 
15 /*====================================================================================*/
16 /*                             Regular operations                           */
17 /*====================================================================================*/
18 template<typename Type> struct Insert{KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = y;             return old;}};
19 template<typename Type> struct Add   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x += y;             return old;}};
20 template<typename Type> struct Mult  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x *= y;             return old;}};
21 template<typename Type> struct Min   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = PetscMin(x,y); return old;}};
22 template<typename Type> struct Max   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = PetscMax(x,y); return old;}};
23 template<typename Type> struct LAND  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x && y;        return old;}};
24 template<typename Type> struct LOR   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x || y;        return old;}};
25 template<typename Type> struct LXOR  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = !x != !y;      return old;}};
26 template<typename Type> struct BAND  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x & y;         return old;}};
27 template<typename Type> struct BOR   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x | y;         return old;}};
28 template<typename Type> struct BXOR  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x ^ y;         return old;}};
29 template<typename PairType> struct Minloc {
30   KOKKOS_INLINE_FUNCTION PairType operator()(PairType& x,PairType y) const {
31     PairType old = x;
32     if (y.first < x.first) x = y;
33     else if (y.first == x.first) x.second = PetscMin(x.second,y.second);
34     return old;
35   }
36 };
37 template<typename PairType> struct Maxloc {
38   KOKKOS_INLINE_FUNCTION PairType operator()(PairType& x,PairType y) const {
39     PairType old = x;
40     if (y.first > x.first) x = y;
41     else if (y.first == x.first) x.second = PetscMin(x.second,y.second); /* See MPI MAXLOC */
42     return old;
43   }
44 };
45 
46 /*====================================================================================*/
47 /*                             Atomic operations                            */
48 /*====================================================================================*/
49 template<typename Type> struct AtomicInsert  {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_assign(&x,y);}};
50 template<typename Type> struct AtomicAdd     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_add(&x,y);}};
51 template<typename Type> struct AtomicBAND    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_and(&x,y);}};
52 template<typename Type> struct AtomicBOR     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_or (&x,y);}};
53 template<typename Type> struct AtomicBXOR    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_xor(&x,y);}};
54 template<typename Type> struct AtomicLAND    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {const Type zero=0,one=~0; Kokkos::atomic_and(&x,y?one:zero);}};
55 template<typename Type> struct AtomicLOR     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {const Type zero=0,one=1;  Kokkos::atomic_or (&x,y?one:zero);}};
56 template<typename Type> struct AtomicMult    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_mul(&x,y);}};
57 template<typename Type> struct AtomicMin     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_min(&x,y);}};
58 template<typename Type> struct AtomicMax     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_max(&x,y);}};
59 /* TODO: struct AtomicLXOR  */
60 template<typename Type> struct AtomicFetchAdd{KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {return Kokkos::atomic_fetch_add(&x,y);}};
61 
62 /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
63 static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt,PetscInt tid)
64 {
65   PetscInt        i,j,k,m,n,r;
66   const PetscInt  *offset,*start,*dx,*dy,*X,*Y;
67 
68   n      = opt[0];
69   offset = opt + 1;
70   start  = opt + n + 2;
71   dx     = opt + 2*n + 2;
72   dy     = opt + 3*n + 2;
73   X      = opt + 5*n + 2;
74   Y      = opt + 6*n + 2;
75   for (r=0; r<n; r++) {if (tid < offset[r+1]) break;}
76   m = (tid - offset[r]);
77   k = m/(dx[r]*dy[r]);
78   j = (m - k*dx[r]*dy[r])/dx[r];
79   i = m - k*dx[r]*dy[r] - j*dx[r];
80 
81   return (start[r] + k*X[r]*Y[r] + j*X[r] + i);
82 }
83 
84 /*====================================================================================*/
85 /*  Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link'         */
86 /*====================================================================================*/
87 
88 /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
89    <Type> is PetscReal, which is the primitive type we operate on.
90    <bs>   is 16, which says <unit> contains 16 primitive types.
91    <BS>   is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
92    <EQ>   is 0, which is (bs == BS ? 1 : 0)
93 
94   If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
95   For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
96 */
97 template<typename Type,PetscInt BS,PetscInt EQ>
98 static PetscErrorCode Pack(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,const void *data_,void *buf_)
99 {
100   const PetscInt          *iopt = opt ? opt->array : NULL;
101   const PetscInt          M = EQ ? 1 : link->bs/BS, MBS=M*BS; /* If EQ, then MBS will be a compile-time const */
102   const Type              *data = static_cast<const Type*>(data_);
103   Type                    *buf = static_cast<Type*>(buf_);
104   DeviceExecutionSpace    exec;
105 
106   PetscFunctionBegin;
107   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
108     /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous;
109        iopt == NULL && idx == NULL ==> the indices are contiguous;
110      */
111     PetscInt t = (iopt? MapTidToIndex(iopt,tid) : (idx? idx[tid] : start+tid))*MBS;
112     PetscInt s = tid*MBS;
113     for (int i=0; i<MBS; i++) buf[s+i] = data[t+i];
114   });
115   PetscFunctionReturn(0);
116 }
117 
118 template<typename Type,class Op,PetscInt BS,PetscInt EQ>
119 static PetscErrorCode UnpackAndOp(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,void *data_,const void *buf_)
120 {
121   Op                      op;
122   const PetscInt          *iopt = opt ? opt->array : NULL;
123   const PetscInt          M = EQ ? 1 : link->bs/BS, MBS=M*BS;
124   Type                    *data = static_cast<Type*>(data_);
125   const Type              *buf = static_cast<const Type*>(buf_);
126   DeviceExecutionSpace    exec;
127 
128   PetscFunctionBegin;
129   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
130     PetscInt t = (iopt? MapTidToIndex(iopt,tid) : (idx? idx[tid] : start+tid))*MBS;
131     PetscInt s = tid*MBS;
132     for (int i=0; i<MBS; i++) op(data[t+i],buf[s+i]);
133   });
134   PetscFunctionReturn(0);
135 }
136 
137 template<typename Type,class Op,PetscInt BS,PetscInt EQ>
138 static PetscErrorCode FetchAndOp(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,void *data,void *buf)
139 {
140   Op                      op;
141   const PetscInt          *ropt = opt ? opt->array : NULL;
142   const PetscInt          M = EQ ? 1 : link->bs/BS, MBS=M*BS;
143   Type                    *rootdata = static_cast<Type*>(data),*leafbuf=static_cast<Type*>(buf);
144   DeviceExecutionSpace    exec;
145 
146   PetscFunctionBegin;
147   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
148     PetscInt r = (ropt? MapTidToIndex(ropt,tid) : (idx? idx[tid] : start+tid))*MBS;
149     PetscInt l = tid*MBS;
150     for (int i=0; i<MBS; i++) leafbuf[l+i] = op(rootdata[r+i],leafbuf[l+i]);
151   });
152   PetscFunctionReturn(0);
153 }
154 
155 template<typename Type,class Op,PetscInt BS,PetscInt EQ>
156 static PetscErrorCode ScatterAndOp(PetscSFLink link,PetscInt count,PetscInt srcStart,PetscSFPackOpt srcOpt,const PetscInt *srcIdx,const void *src_,PetscInt dstStart,PetscSFPackOpt dstOpt,const PetscInt *dstIdx,void *dst_)
157 {
158   PetscInt                srcx=0,srcy=0,srcX=0,srcY=0,dstx=0,dsty=0,dstX=0,dstY=0;
159   const PetscInt          M = (EQ) ? 1 : link->bs/BS, MBS=M*BS;
160   const Type              *src = static_cast<const Type*>(src_);
161   Type                    *dst = static_cast<Type*>(dst_);
162   DeviceExecutionSpace    exec;
163 
164   PetscFunctionBegin;
165   /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */
166   if (srcOpt)       {srcx = srcOpt->dx[0]; srcy = srcOpt->dy[0]; srcX = srcOpt->X[0]; srcY = srcOpt->Y[0]; srcStart = srcOpt->start[0]; srcIdx = NULL;}
167   else if (!srcIdx) {srcx = srcX = count; srcy = srcY = 1;}
168 
169   if (dstOpt)       {dstx = dstOpt->dx[0]; dsty = dstOpt->dy[0]; dstX = dstOpt->X[0]; dstY = dstOpt->Y[0]; dstStart = dstOpt->start[0]; dstIdx = NULL;}
170   else if (!dstIdx) {dstx = dstX = count; dsty = dstY = 1;}
171 
172   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
173     PetscInt i,j,k,s,t;
174     Op       op;
175     if (!srcIdx) { /* src is in 3D */
176       k = tid/(srcx*srcy);
177       j = (tid - k*srcx*srcy)/srcx;
178       i = tid - k*srcx*srcy - j*srcx;
179       s = srcStart + k*srcX*srcY + j*srcX + i;
180     } else { /* src is contiguous */
181       s = srcIdx[tid];
182     }
183 
184     if (!dstIdx) { /* 3D */
185       k = tid/(dstx*dsty);
186       j = (tid - k*dstx*dsty)/dstx;
187       i = tid - k*dstx*dsty - j*dstx;
188       t = dstStart + k*dstX*dstY + j*dstX + i;
189     } else { /* contiguous */
190       t = dstIdx[tid];
191     }
192 
193     s *= MBS;
194     t *= MBS;
195     for (i=0; i<MBS; i++) op(dst[t+i],src[s+i]);
196   });
197   PetscFunctionReturn(0);
198 }
199 
200 /* Specialization for Insert since we may use memcpy */
201 template<typename Type,PetscInt BS,PetscInt EQ>
202 static PetscErrorCode ScatterAndInsert(PetscSFLink link,PetscInt count,PetscInt srcStart,PetscSFPackOpt srcOpt,const PetscInt *srcIdx,const void *src_,PetscInt dstStart,PetscSFPackOpt dstOpt,const PetscInt *dstIdx,void *dst_)
203 {
204   const Type              *src = static_cast<const Type*>(src_);
205   Type                    *dst = static_cast<Type*>(dst_);
206   DeviceExecutionSpace    exec;
207 
208   PetscFunctionBegin;
209   if (!count) PetscFunctionReturn(0);
210   /*src and dst are contiguous */
211   if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
212     size_t sz = count*link->unitbytes;
213     deviceBuffer_t      dbuf(reinterpret_cast<char*>(dst+dstStart*link->bs),sz);
214     deviceConstBuffer_t sbuf(reinterpret_cast<const char*>(src+srcStart*link->bs),sz);
215     Kokkos::deep_copy(exec,dbuf,sbuf);
216   } else {
217     PetscCall(ScatterAndOp<Type,Insert<Type>,BS,EQ>(link,count,srcStart,srcOpt,srcIdx,src,dstStart,dstOpt,dstIdx,dst));
218   }
219   PetscFunctionReturn(0);
220 }
221 
222 template<typename Type,class Op,PetscInt BS,PetscInt EQ>
223 static PetscErrorCode FetchAndOpLocal(PetscSFLink link,PetscInt count,PetscInt rootstart,PetscSFPackOpt rootopt,const PetscInt *rootidx,void *rootdata_,PetscInt leafstart,PetscSFPackOpt leafopt,const PetscInt *leafidx,const void *leafdata_,void *leafupdate_)
224 {
225   Op                      op;
226   const PetscInt          M = (EQ) ? 1 : link->bs/BS, MBS = M*BS;
227   const PetscInt          *ropt = rootopt ? rootopt->array : NULL;
228   const PetscInt          *lopt = leafopt ? leafopt->array : NULL;
229   Type                    *rootdata = static_cast<Type*>(rootdata_),*leafupdate = static_cast<Type*>(leafupdate_);
230   const Type              *leafdata = static_cast<const Type*>(leafdata_);
231   DeviceExecutionSpace    exec;
232 
233   PetscFunctionBegin;
234   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
235     PetscInt r = (ropt? MapTidToIndex(ropt,tid) : (rootidx? rootidx[tid] : rootstart+tid))*MBS;
236     PetscInt l = (lopt? MapTidToIndex(lopt,tid) : (leafidx? leafidx[tid] : leafstart+tid))*MBS;
237     for (int i=0; i<MBS; i++) leafupdate[l+i] = op(rootdata[r+i],leafdata[l+i]);
238   });
239   PetscFunctionReturn(0);
240 }
241 
242 /*====================================================================================*/
243 /*  Init various types and instantiate pack/unpack function pointers                  */
244 /*====================================================================================*/
245 template<typename Type,PetscInt BS,PetscInt EQ>
246 static void PackInit_RealType(PetscSFLink link)
247 {
248   /* Pack/unpack for remote communication */
249   link->d_Pack              = Pack<Type,BS,EQ>;
250   link->d_UnpackAndInsert   = UnpackAndOp<Type,Insert<Type>,BS,EQ>;
251   link->d_UnpackAndAdd      = UnpackAndOp<Type,Add<Type>   ,BS,EQ>;
252   link->d_UnpackAndMult     = UnpackAndOp<Type,Mult<Type>  ,BS,EQ>;
253   link->d_UnpackAndMin      = UnpackAndOp<Type,Min<Type>   ,BS,EQ>;
254   link->d_UnpackAndMax      = UnpackAndOp<Type,Max<Type>   ,BS,EQ>;
255   link->d_FetchAndAdd       = FetchAndOp <Type,Add<Type>   ,BS,EQ>;
256   /* Scatter for local communication */
257   link->d_ScatterAndInsert  = ScatterAndInsert<Type,BS,EQ>; /* Has special optimizations */
258   link->d_ScatterAndAdd     = ScatterAndOp<Type,Add<Type>    ,BS,EQ>;
259   link->d_ScatterAndMult    = ScatterAndOp<Type,Mult<Type>   ,BS,EQ>;
260   link->d_ScatterAndMin     = ScatterAndOp<Type,Min<Type>    ,BS,EQ>;
261   link->d_ScatterAndMax     = ScatterAndOp<Type,Max<Type>    ,BS,EQ>;
262   link->d_FetchAndAddLocal  = FetchAndOpLocal<Type,Add <Type>,BS,EQ>;
263   /* Atomic versions when there are data-race possibilities */
264   link->da_UnpackAndInsert  = UnpackAndOp<Type,AtomicInsert<Type>  ,BS,EQ>;
265   link->da_UnpackAndAdd     = UnpackAndOp<Type,AtomicAdd<Type>     ,BS,EQ>;
266   link->da_UnpackAndMult    = UnpackAndOp<Type,AtomicMult<Type>    ,BS,EQ>;
267   link->da_UnpackAndMin     = UnpackAndOp<Type,AtomicMin<Type>     ,BS,EQ>;
268   link->da_UnpackAndMax     = UnpackAndOp<Type,AtomicMax<Type>     ,BS,EQ>;
269   link->da_FetchAndAdd      = FetchAndOp <Type,AtomicFetchAdd<Type>,BS,EQ>;
270 
271   link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>;
272   link->da_ScatterAndAdd    = ScatterAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
273   link->da_ScatterAndMult   = ScatterAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
274   link->da_ScatterAndMin    = ScatterAndOp<Type,AtomicMin<Type>   ,BS,EQ>;
275   link->da_ScatterAndMax    = ScatterAndOp<Type,AtomicMax<Type>   ,BS,EQ>;
276   link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>;
277 }
278 
279 template<typename Type,PetscInt BS,PetscInt EQ>
280 static void PackInit_IntegerType(PetscSFLink link)
281 {
282   link->d_Pack              = Pack<Type,BS,EQ>;
283   link->d_UnpackAndInsert   = UnpackAndOp<Type,Insert<Type> ,BS,EQ>;
284   link->d_UnpackAndAdd      = UnpackAndOp<Type,Add<Type>    ,BS,EQ>;
285   link->d_UnpackAndMult     = UnpackAndOp<Type,Mult<Type>   ,BS,EQ>;
286   link->d_UnpackAndMin      = UnpackAndOp<Type,Min<Type>    ,BS,EQ>;
287   link->d_UnpackAndMax      = UnpackAndOp<Type,Max<Type>    ,BS,EQ>;
288   link->d_UnpackAndLAND     = UnpackAndOp<Type,LAND<Type>   ,BS,EQ>;
289   link->d_UnpackAndLOR      = UnpackAndOp<Type,LOR<Type>    ,BS,EQ>;
290   link->d_UnpackAndLXOR     = UnpackAndOp<Type,LXOR<Type>   ,BS,EQ>;
291   link->d_UnpackAndBAND     = UnpackAndOp<Type,BAND<Type>   ,BS,EQ>;
292   link->d_UnpackAndBOR      = UnpackAndOp<Type,BOR<Type>    ,BS,EQ>;
293   link->d_UnpackAndBXOR     = UnpackAndOp<Type,BXOR<Type>   ,BS,EQ>;
294   link->d_FetchAndAdd       = FetchAndOp <Type,Add<Type>    ,BS,EQ>;
295 
296   link->d_ScatterAndInsert  = ScatterAndInsert<Type,BS,EQ>;
297   link->d_ScatterAndAdd     = ScatterAndOp<Type,Add<Type>   ,BS,EQ>;
298   link->d_ScatterAndMult    = ScatterAndOp<Type,Mult<Type>  ,BS,EQ>;
299   link->d_ScatterAndMin     = ScatterAndOp<Type,Min<Type>   ,BS,EQ>;
300   link->d_ScatterAndMax     = ScatterAndOp<Type,Max<Type>   ,BS,EQ>;
301   link->d_ScatterAndLAND    = ScatterAndOp<Type,LAND<Type>  ,BS,EQ>;
302   link->d_ScatterAndLOR     = ScatterAndOp<Type,LOR<Type>   ,BS,EQ>;
303   link->d_ScatterAndLXOR    = ScatterAndOp<Type,LXOR<Type>  ,BS,EQ>;
304   link->d_ScatterAndBAND    = ScatterAndOp<Type,BAND<Type>  ,BS,EQ>;
305   link->d_ScatterAndBOR     = ScatterAndOp<Type,BOR<Type>   ,BS,EQ>;
306   link->d_ScatterAndBXOR    = ScatterAndOp<Type,BXOR<Type>  ,BS,EQ>;
307   link->d_FetchAndAddLocal  = FetchAndOpLocal<Type,Add<Type>,BS,EQ>;
308 
309   link->da_UnpackAndInsert  = UnpackAndOp<Type,AtomicInsert<Type>,BS,EQ>;
310   link->da_UnpackAndAdd     = UnpackAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
311   link->da_UnpackAndMult    = UnpackAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
312   link->da_UnpackAndMin     = UnpackAndOp<Type,AtomicMin<Type>   ,BS,EQ>;
313   link->da_UnpackAndMax     = UnpackAndOp<Type,AtomicMax<Type>   ,BS,EQ>;
314   link->da_UnpackAndLAND    = UnpackAndOp<Type,AtomicLAND<Type>  ,BS,EQ>;
315   link->da_UnpackAndLOR     = UnpackAndOp<Type,AtomicLOR<Type>   ,BS,EQ>;
316   link->da_UnpackAndBAND    = UnpackAndOp<Type,AtomicBAND<Type>  ,BS,EQ>;
317   link->da_UnpackAndBOR     = UnpackAndOp<Type,AtomicBOR<Type>   ,BS,EQ>;
318   link->da_UnpackAndBXOR    = UnpackAndOp<Type,AtomicBXOR<Type>  ,BS,EQ>;
319   link->da_FetchAndAdd      = FetchAndOp <Type,AtomicFetchAdd<Type>,BS,EQ>;
320 
321   link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>;
322   link->da_ScatterAndAdd    = ScatterAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
323   link->da_ScatterAndMult   = ScatterAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
324   link->da_ScatterAndMin    = ScatterAndOp<Type,AtomicMin<Type>   ,BS,EQ>;
325   link->da_ScatterAndMax    = ScatterAndOp<Type,AtomicMax<Type>   ,BS,EQ>;
326   link->da_ScatterAndLAND   = ScatterAndOp<Type,AtomicLAND<Type>  ,BS,EQ>;
327   link->da_ScatterAndLOR    = ScatterAndOp<Type,AtomicLOR<Type>   ,BS,EQ>;
328   link->da_ScatterAndBAND   = ScatterAndOp<Type,AtomicBAND<Type>  ,BS,EQ>;
329   link->da_ScatterAndBOR    = ScatterAndOp<Type,AtomicBOR<Type>   ,BS,EQ>;
330   link->da_ScatterAndBXOR   = ScatterAndOp<Type,AtomicBXOR<Type>  ,BS,EQ>;
331   link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>;
332 }
333 
334 #if defined(PETSC_HAVE_COMPLEX)
335 template<typename Type,PetscInt BS,PetscInt EQ>
336 static void PackInit_ComplexType(PetscSFLink link)
337 {
338   link->d_Pack             = Pack<Type,BS,EQ>;
339   link->d_UnpackAndInsert  = UnpackAndOp<Type,Insert<Type>,BS,EQ>;
340   link->d_UnpackAndAdd     = UnpackAndOp<Type,Add<Type>   ,BS,EQ>;
341   link->d_UnpackAndMult    = UnpackAndOp<Type,Mult<Type>  ,BS,EQ>;
342   link->d_FetchAndAdd      = FetchAndOp <Type,Add<Type>   ,BS,EQ>;
343 
344   link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>;
345   link->d_ScatterAndAdd    = ScatterAndOp<Type,Add<Type>   ,BS,EQ>;
346   link->d_ScatterAndMult   = ScatterAndOp<Type,Mult<Type>  ,BS,EQ>;
347   link->d_FetchAndAddLocal = FetchAndOpLocal<Type,Add<Type>,BS,EQ>;
348 
349   link->da_UnpackAndInsert = UnpackAndOp<Type,AtomicInsert<Type> ,BS,EQ>;
350   link->da_UnpackAndAdd    = UnpackAndOp<Type,AtomicAdd<Type>    ,BS,EQ>;
351   link->da_UnpackAndMult   = UnpackAndOp<Type,AtomicMult<Type>   ,BS,EQ>;
352   link->da_FetchAndAdd     = FetchAndOp<Type,AtomicFetchAdd<Type>,BS,EQ>;
353 
354   link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>;
355   link->da_ScatterAndAdd    = ScatterAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
356   link->da_ScatterAndMult   = ScatterAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
357   link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>;
358 }
359 #endif
360 
361 template<typename Type>
362 static void PackInit_PairType(PetscSFLink link)
363 {
364   link->d_Pack             = Pack<Type,1,1>;
365   link->d_UnpackAndInsert  = UnpackAndOp<Type,Insert<Type>,1,1>;
366   link->d_UnpackAndMaxloc  = UnpackAndOp<Type,Maxloc<Type>,1,1>;
367   link->d_UnpackAndMinloc  = UnpackAndOp<Type,Minloc<Type>,1,1>;
368 
369   link->d_ScatterAndInsert = ScatterAndOp<Type,Insert<Type>,1,1>;
370   link->d_ScatterAndMaxloc = ScatterAndOp<Type,Maxloc<Type>,1,1>;
371   link->d_ScatterAndMinloc = ScatterAndOp<Type,Minloc<Type>,1,1>;
372   /* Atomics for pair types are not implemented yet */
373 }
374 
375 template<typename Type,PetscInt BS,PetscInt EQ>
376 static void PackInit_DumbType(PetscSFLink link)
377 {
378   link->d_Pack             = Pack<Type,BS,EQ>;
379   link->d_UnpackAndInsert  = UnpackAndOp<Type,Insert<Type>,BS,EQ>;
380   link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>;
381   /* Atomics for dumb types are not implemented yet */
382 }
383 
384 /*
385   Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug
386   that one is not able to repeatedly create and destroy the object. SF's original design was each
387   SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from
388   destroying multiple SFLinks with NULL stream and the default execution space object. To avoid
389   memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos
390   does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton
391   object in Kokkos.
392 */
393 /*
394 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link)
395 {
396   PetscFunctionBegin;
397   PetscFunctionReturn(0);
398 }
399 */
400 
401 /* Some device-specific utilities */
402 static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link)
403 {
404   PetscFunctionBegin;
405   Kokkos::fence();
406   PetscFunctionReturn(0);
407 }
408 
409 static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link)
410 {
411   DeviceExecutionSpace    exec;
412   PetscFunctionBegin;
413   exec.fence();
414   PetscFunctionReturn(0);
415 }
416 
417 static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link,PetscMemType dstmtype,void* dst,PetscMemType srcmtype,const void*src,size_t n)
418 {
419   DeviceExecutionSpace    exec;
420 
421   PetscFunctionBegin;
422   if (!n) PetscFunctionReturn(0);
423   if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) {
424     PetscCall(PetscMemcpy(dst,src,n));
425   } else {
426     if (PetscMemTypeDevice(dstmtype) && PetscMemTypeHost(srcmtype)) {
427       deviceBuffer_t       dbuf(static_cast<char*>(dst),n);
428       HostConstBuffer_t    sbuf(static_cast<const char*>(src),n);
429       Kokkos::deep_copy(exec,dbuf,sbuf);
430       PetscCall(PetscLogCpuToGpu(n));
431     } else if (PetscMemTypeHost(dstmtype) && PetscMemTypeDevice(srcmtype)) {
432       HostBuffer_t         dbuf(static_cast<char*>(dst),n);
433       deviceConstBuffer_t  sbuf(static_cast<const char*>(src),n);
434       Kokkos::deep_copy(exec,dbuf,sbuf);
435       PetscCall(PetscLogGpuToCpu(n));
436     } else if (PetscMemTypeDevice(dstmtype) && PetscMemTypeDevice(srcmtype)) {
437       deviceBuffer_t       dbuf(static_cast<char*>(dst),n);
438       deviceConstBuffer_t  sbuf(static_cast<const char*>(src),n);
439       Kokkos::deep_copy(exec,dbuf,sbuf);
440     }
441   }
442   PetscFunctionReturn(0);
443 }
444 
445 PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype,size_t size,void** ptr)
446 {
447   PetscFunctionBegin;
448   if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size,ptr));
449   else if (PetscMemTypeDevice(mtype)) {
450     if (!PetscKokkosInitialized) PetscCall(PetscKokkosInitializeCheck());
451     *ptr = Kokkos::kokkos_malloc<DeviceMemorySpace>(size);
452   } else SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"Wrong PetscMemType %d", (int)mtype);
453   PetscFunctionReturn(0);
454 }
455 
456 PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype,void* ptr)
457 {
458   PetscFunctionBegin;
459   if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr));
460   else if (PetscMemTypeDevice(mtype)) {Kokkos::kokkos_free<DeviceMemorySpace>(ptr);}
461   else SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"Wrong PetscMemType %d",(int)mtype);
462   PetscFunctionReturn(0);
463 }
464 
465 /* Destructor when the link uses MPI for communication */
466 static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSF sf,PetscSFLink link)
467 {
468   PetscFunctionBegin;
469   for (int i=PETSCSF_LOCAL; i<=PETSCSF_REMOTE; i++) {
470     PetscCall(PetscSFFree(sf,PETSC_MEMTYPE_DEVICE,link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
471     PetscCall(PetscSFFree(sf,PETSC_MEMTYPE_DEVICE,link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
472   }
473   PetscFunctionReturn(0);
474 }
475 
476 /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
477 PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf,PetscSFLink link,MPI_Datatype unit)
478 {
479   PetscInt           nSignedChar=0,nUnsignedChar=0,nInt=0,nPetscInt=0,nPetscReal=0;
480   PetscBool          is2Int,is2PetscInt;
481 #if defined(PETSC_HAVE_COMPLEX)
482   PetscInt           nPetscComplex=0;
483 #endif
484 
485   PetscFunctionBegin;
486   if (link->deviceinited) PetscFunctionReturn(0);
487   PetscCall(PetscKokkosInitializeCheck());
488   PetscCall(MPIPetsc_Type_compare_contig(unit,MPI_SIGNED_CHAR,  &nSignedChar));
489   PetscCall(MPIPetsc_Type_compare_contig(unit,MPI_UNSIGNED_CHAR,&nUnsignedChar));
490   /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
491   PetscCall(MPIPetsc_Type_compare_contig(unit,MPI_INT,  &nInt));
492   PetscCall(MPIPetsc_Type_compare_contig(unit,MPIU_INT, &nPetscInt));
493   PetscCall(MPIPetsc_Type_compare_contig(unit,MPIU_REAL,&nPetscReal));
494 #if defined(PETSC_HAVE_COMPLEX)
495   PetscCall(MPIPetsc_Type_compare_contig(unit,MPIU_COMPLEX,&nPetscComplex));
496 #endif
497   PetscCall(MPIPetsc_Type_compare(unit,MPI_2INT,&is2Int));
498   PetscCall(MPIPetsc_Type_compare(unit,MPIU_2INT,&is2PetscInt));
499 
500   if (is2Int) {
501     PackInit_PairType<Kokkos::pair<int,int>>(link);
502   } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
503     PackInit_PairType<Kokkos::pair<PetscInt,PetscInt>>(link);
504   } else if (nPetscReal) {
505    #if !defined(PETSC_HAVE_DEVICE)  /* Skip the unimportant stuff to speed up SF device compilation time */
506     if      (nPetscReal == 8) PackInit_RealType<PetscReal,8,1>(link); else if (nPetscReal%8 == 0) PackInit_RealType<PetscReal,8,0>(link);
507     else if (nPetscReal == 4) PackInit_RealType<PetscReal,4,1>(link); else if (nPetscReal%4 == 0) PackInit_RealType<PetscReal,4,0>(link);
508     else if (nPetscReal == 2) PackInit_RealType<PetscReal,2,1>(link); else if (nPetscReal%2 == 0) PackInit_RealType<PetscReal,2,0>(link);
509     else if (nPetscReal == 1) PackInit_RealType<PetscReal,1,1>(link); else if (nPetscReal%1 == 0)
510    #endif
511     PackInit_RealType<PetscReal,1,0>(link);
512   } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) {
513    #if !defined(PETSC_HAVE_DEVICE)
514     if      (nPetscInt == 8) PackInit_IntegerType<llint,8,1>(link); else if (nPetscInt%8 == 0) PackInit_IntegerType<llint,8,0>(link);
515     else if (nPetscInt == 4) PackInit_IntegerType<llint,4,1>(link); else if (nPetscInt%4 == 0) PackInit_IntegerType<llint,4,0>(link);
516     else if (nPetscInt == 2) PackInit_IntegerType<llint,2,1>(link); else if (nPetscInt%2 == 0) PackInit_IntegerType<llint,2,0>(link);
517     else if (nPetscInt == 1) PackInit_IntegerType<llint,1,1>(link); else if (nPetscInt%1 == 0)
518    #endif
519     PackInit_IntegerType<llint,1,0>(link);
520   } else if (nInt) {
521    #if !defined(PETSC_HAVE_DEVICE)
522     if      (nInt == 8) PackInit_IntegerType<int,8,1>(link); else if (nInt%8 == 0) PackInit_IntegerType<int,8,0>(link);
523     else if (nInt == 4) PackInit_IntegerType<int,4,1>(link); else if (nInt%4 == 0) PackInit_IntegerType<int,4,0>(link);
524     else if (nInt == 2) PackInit_IntegerType<int,2,1>(link); else if (nInt%2 == 0) PackInit_IntegerType<int,2,0>(link);
525     else if (nInt == 1) PackInit_IntegerType<int,1,1>(link); else if (nInt%1 == 0)
526    #endif
527     PackInit_IntegerType<int,1,0>(link);
528   } else if (nSignedChar) {
529    #if !defined(PETSC_HAVE_DEVICE)
530     if      (nSignedChar == 8) PackInit_IntegerType<char,8,1>(link); else if (nSignedChar%8 == 0) PackInit_IntegerType<char,8,0>(link);
531     else if (nSignedChar == 4) PackInit_IntegerType<char,4,1>(link); else if (nSignedChar%4 == 0) PackInit_IntegerType<char,4,0>(link);
532     else if (nSignedChar == 2) PackInit_IntegerType<char,2,1>(link); else if (nSignedChar%2 == 0) PackInit_IntegerType<char,2,0>(link);
533     else if (nSignedChar == 1) PackInit_IntegerType<char,1,1>(link); else if (nSignedChar%1 == 0)
534    #endif
535     PackInit_IntegerType<char,1,0>(link);
536   }  else if (nUnsignedChar) {
537    #if !defined(PETSC_HAVE_DEVICE)
538     if      (nUnsignedChar == 8) PackInit_IntegerType<unsigned char,8,1>(link); else if (nUnsignedChar%8 == 0) PackInit_IntegerType<unsigned char,8,0>(link);
539     else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char,4,1>(link); else if (nUnsignedChar%4 == 0) PackInit_IntegerType<unsigned char,4,0>(link);
540     else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char,2,1>(link); else if (nUnsignedChar%2 == 0) PackInit_IntegerType<unsigned char,2,0>(link);
541     else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char,1,1>(link); else if (nUnsignedChar%1 == 0)
542    #endif
543     PackInit_IntegerType<unsigned char,1,0>(link);
544 #if defined(PETSC_HAVE_COMPLEX)
545   } else if (nPetscComplex) {
546    #if !defined(PETSC_HAVE_DEVICE)
547     if      (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>,8,1>(link); else if (nPetscComplex%8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,8,0>(link);
548     else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>,4,1>(link); else if (nPetscComplex%4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,4,0>(link);
549     else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>,2,1>(link); else if (nPetscComplex%2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,2,0>(link);
550     else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>,1,1>(link); else if (nPetscComplex%1 == 0)
551    #endif
552     PackInit_ComplexType<Kokkos::complex<PetscReal>,1,0>(link);
553 #endif
554   } else {
555     MPI_Aint lb,nbyte;
556     PetscCallMPI(MPI_Type_get_extent(unit,&lb,&nbyte));
557     PetscCheckFalse(lb != 0,PETSC_COMM_SELF,PETSC_ERR_SUP,"Datatype with nonzero lower bound %ld",(long)lb);
558     if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
559      #if !defined(PETSC_HAVE_DEVICE)
560       if      (nbyte == 4) PackInit_DumbType<char,4,1>(link); else if (nbyte%4 == 0) PackInit_DumbType<char,4,0>(link);
561       else if (nbyte == 2) PackInit_DumbType<char,2,1>(link); else if (nbyte%2 == 0) PackInit_DumbType<char,2,0>(link);
562       else if (nbyte == 1) PackInit_DumbType<char,1,1>(link); else if (nbyte%1 == 0)
563      #endif
564       PackInit_DumbType<char,1,0>(link);
565     } else {
566       nInt = nbyte / sizeof(int);
567      #if !defined(PETSC_HAVE_DEVICE)
568       if      (nInt == 8) PackInit_DumbType<int,8,1>(link); else if (nInt%8 == 0) PackInit_DumbType<int,8,0>(link);
569       else if (nInt == 4) PackInit_DumbType<int,4,1>(link); else if (nInt%4 == 0) PackInit_DumbType<int,4,0>(link);
570       else if (nInt == 2) PackInit_DumbType<int,2,1>(link); else if (nInt%2 == 0) PackInit_DumbType<int,2,0>(link);
571       else if (nInt == 1) PackInit_DumbType<int,1,1>(link); else if (nInt%1 == 0)
572      #endif
573       PackInit_DumbType<int,1,0>(link);
574     }
575   }
576 
577   link->SyncDevice   = PetscSFLinkSyncDevice_Kokkos;
578   link->SyncStream   = PetscSFLinkSyncStream_Kokkos;
579   link->Memcpy       = PetscSFLinkMemcpy_Kokkos;
580   link->Destroy      = PetscSFLinkDestroy_Kokkos;
581   link->deviceinited = PETSC_TRUE;
582   PetscFunctionReturn(0);
583 }
584