xref: /petsc/src/vec/is/sf/impls/basic/kokkos/sfkok.kokkos.cxx (revision 9371c9d470a9602b6d10a8bf50c9b2280a79e45a)
1914b7a73SJunchao Zhang #include <../src/vec/is/sf/impls/basic/sfpack.h>
2914b7a73SJunchao Zhang 
3914b7a73SJunchao Zhang #include <Kokkos_Core.hpp>
4914b7a73SJunchao Zhang 
5914b7a73SJunchao Zhang using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace;
6914b7a73SJunchao Zhang using DeviceMemorySpace    = typename DeviceExecutionSpace::memory_space;
7914b7a73SJunchao Zhang using HostMemorySpace      = Kokkos::HostSpace;
8914b7a73SJunchao Zhang 
9914b7a73SJunchao Zhang typedef Kokkos::View<char *, DeviceMemorySpace> deviceBuffer_t;
10914b7a73SJunchao Zhang typedef Kokkos::View<char *, HostMemorySpace>   HostBuffer_t;
11914b7a73SJunchao Zhang 
12914b7a73SJunchao Zhang typedef Kokkos::View<const char *, DeviceMemorySpace> deviceConstBuffer_t;
13914b7a73SJunchao Zhang typedef Kokkos::View<const char *, HostMemorySpace>   HostConstBuffer_t;
14914b7a73SJunchao Zhang 
15914b7a73SJunchao Zhang /*====================================================================================*/
16914b7a73SJunchao Zhang /*                             Regular operations                           */
17914b7a73SJunchao Zhang /*====================================================================================*/
18*9371c9d4SSatish Balay template <typename Type>
19*9371c9d4SSatish Balay struct Insert {
20*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
21*9371c9d4SSatish Balay     Type old = x;
22*9371c9d4SSatish Balay     x        = y;
23*9371c9d4SSatish Balay     return old;
24*9371c9d4SSatish Balay   }
25*9371c9d4SSatish Balay };
26*9371c9d4SSatish Balay template <typename Type>
27*9371c9d4SSatish Balay struct Add {
28*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
29*9371c9d4SSatish Balay     Type old = x;
30*9371c9d4SSatish Balay     x += y;
31*9371c9d4SSatish Balay     return old;
32*9371c9d4SSatish Balay   }
33*9371c9d4SSatish Balay };
34*9371c9d4SSatish Balay template <typename Type>
35*9371c9d4SSatish Balay struct Mult {
36*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
37*9371c9d4SSatish Balay     Type old = x;
38*9371c9d4SSatish Balay     x *= y;
39*9371c9d4SSatish Balay     return old;
40*9371c9d4SSatish Balay   }
41*9371c9d4SSatish Balay };
42*9371c9d4SSatish Balay template <typename Type>
43*9371c9d4SSatish Balay struct Min {
44*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
45*9371c9d4SSatish Balay     Type old = x;
46*9371c9d4SSatish Balay     x        = PetscMin(x, y);
47*9371c9d4SSatish Balay     return old;
48*9371c9d4SSatish Balay   }
49*9371c9d4SSatish Balay };
50*9371c9d4SSatish Balay template <typename Type>
51*9371c9d4SSatish Balay struct Max {
52*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
53*9371c9d4SSatish Balay     Type old = x;
54*9371c9d4SSatish Balay     x        = PetscMax(x, y);
55*9371c9d4SSatish Balay     return old;
56*9371c9d4SSatish Balay   }
57*9371c9d4SSatish Balay };
58*9371c9d4SSatish Balay template <typename Type>
59*9371c9d4SSatish Balay struct LAND {
60*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
61*9371c9d4SSatish Balay     Type old = x;
62*9371c9d4SSatish Balay     x        = x && y;
63*9371c9d4SSatish Balay     return old;
64*9371c9d4SSatish Balay   }
65*9371c9d4SSatish Balay };
66*9371c9d4SSatish Balay template <typename Type>
67*9371c9d4SSatish Balay struct LOR {
68*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
69*9371c9d4SSatish Balay     Type old = x;
70*9371c9d4SSatish Balay     x        = x || y;
71*9371c9d4SSatish Balay     return old;
72*9371c9d4SSatish Balay   }
73*9371c9d4SSatish Balay };
74*9371c9d4SSatish Balay template <typename Type>
75*9371c9d4SSatish Balay struct LXOR {
76*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
77*9371c9d4SSatish Balay     Type old = x;
78*9371c9d4SSatish Balay     x        = !x != !y;
79*9371c9d4SSatish Balay     return old;
80*9371c9d4SSatish Balay   }
81*9371c9d4SSatish Balay };
82*9371c9d4SSatish Balay template <typename Type>
83*9371c9d4SSatish Balay struct BAND {
84*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
85*9371c9d4SSatish Balay     Type old = x;
86*9371c9d4SSatish Balay     x        = x & y;
87*9371c9d4SSatish Balay     return old;
88*9371c9d4SSatish Balay   }
89*9371c9d4SSatish Balay };
90*9371c9d4SSatish Balay template <typename Type>
91*9371c9d4SSatish Balay struct BOR {
92*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
93*9371c9d4SSatish Balay     Type old = x;
94*9371c9d4SSatish Balay     x        = x | y;
95*9371c9d4SSatish Balay     return old;
96*9371c9d4SSatish Balay   }
97*9371c9d4SSatish Balay };
98*9371c9d4SSatish Balay template <typename Type>
99*9371c9d4SSatish Balay struct BXOR {
100*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const {
101*9371c9d4SSatish Balay     Type old = x;
102*9371c9d4SSatish Balay     x        = x ^ y;
103*9371c9d4SSatish Balay     return old;
104*9371c9d4SSatish Balay   }
105*9371c9d4SSatish Balay };
106*9371c9d4SSatish Balay template <typename PairType>
107*9371c9d4SSatish Balay struct Minloc {
108914b7a73SJunchao Zhang   KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const {
109914b7a73SJunchao Zhang     PairType old = x;
110914b7a73SJunchao Zhang     if (y.first < x.first) x = y;
111914b7a73SJunchao Zhang     else if (y.first == x.first) x.second = PetscMin(x.second, y.second);
112914b7a73SJunchao Zhang     return old;
113914b7a73SJunchao Zhang   }
114914b7a73SJunchao Zhang };
115*9371c9d4SSatish Balay template <typename PairType>
116*9371c9d4SSatish Balay struct Maxloc {
117914b7a73SJunchao Zhang   KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const {
118914b7a73SJunchao Zhang     PairType old = x;
119914b7a73SJunchao Zhang     if (y.first > x.first) x = y;
120914b7a73SJunchao Zhang     else if (y.first == x.first) x.second = PetscMin(x.second, y.second); /* See MPI MAXLOC */
121914b7a73SJunchao Zhang     return old;
122914b7a73SJunchao Zhang   }
123914b7a73SJunchao Zhang };
124914b7a73SJunchao Zhang 
125914b7a73SJunchao Zhang /*====================================================================================*/
126914b7a73SJunchao Zhang /*                             Atomic operations                            */
127914b7a73SJunchao Zhang /*====================================================================================*/
128*9371c9d4SSatish Balay template <typename Type>
129*9371c9d4SSatish Balay struct AtomicInsert {
130*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_assign(&x, y); }
131*9371c9d4SSatish Balay };
132*9371c9d4SSatish Balay template <typename Type>
133*9371c9d4SSatish Balay struct AtomicAdd {
134*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_add(&x, y); }
135*9371c9d4SSatish Balay };
136*9371c9d4SSatish Balay template <typename Type>
137*9371c9d4SSatish Balay struct AtomicBAND {
138*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_and(&x, y); }
139*9371c9d4SSatish Balay };
140*9371c9d4SSatish Balay template <typename Type>
141*9371c9d4SSatish Balay struct AtomicBOR {
142*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_or(&x, y); }
143*9371c9d4SSatish Balay };
144*9371c9d4SSatish Balay template <typename Type>
145*9371c9d4SSatish Balay struct AtomicBXOR {
146*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_xor(&x, y); }
147*9371c9d4SSatish Balay };
148*9371c9d4SSatish Balay template <typename Type>
149*9371c9d4SSatish Balay struct AtomicLAND {
150*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const {
151*9371c9d4SSatish Balay     const Type zero = 0, one = ~0;
152*9371c9d4SSatish Balay     Kokkos::atomic_and(&x, y ? one : zero);
153*9371c9d4SSatish Balay   }
154*9371c9d4SSatish Balay };
155*9371c9d4SSatish Balay template <typename Type>
156*9371c9d4SSatish Balay struct AtomicLOR {
157*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const {
158*9371c9d4SSatish Balay     const Type zero = 0, one = 1;
159*9371c9d4SSatish Balay     Kokkos::atomic_or(&x, y ? one : zero);
160*9371c9d4SSatish Balay   }
161*9371c9d4SSatish Balay };
162*9371c9d4SSatish Balay template <typename Type>
163*9371c9d4SSatish Balay struct AtomicMult {
164*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_mul(&x, y); }
165*9371c9d4SSatish Balay };
166*9371c9d4SSatish Balay template <typename Type>
167*9371c9d4SSatish Balay struct AtomicMin {
168*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_min(&x, y); }
169*9371c9d4SSatish Balay };
170*9371c9d4SSatish Balay template <typename Type>
171*9371c9d4SSatish Balay struct AtomicMax {
172*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_max(&x, y); }
173*9371c9d4SSatish Balay };
174914b7a73SJunchao Zhang /* TODO: struct AtomicLXOR  */
175*9371c9d4SSatish Balay template <typename Type>
176*9371c9d4SSatish Balay struct AtomicFetchAdd {
177*9371c9d4SSatish Balay   KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { return Kokkos::atomic_fetch_add(&x, y); }
178*9371c9d4SSatish Balay };
179914b7a73SJunchao Zhang 
180914b7a73SJunchao Zhang /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
181*9371c9d4SSatish Balay static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid) {
182914b7a73SJunchao Zhang   PetscInt        i, j, k, m, n, r;
183914b7a73SJunchao Zhang   const PetscInt *offset, *start, *dx, *dy, *X, *Y;
184914b7a73SJunchao Zhang 
185914b7a73SJunchao Zhang   n      = opt[0];
186914b7a73SJunchao Zhang   offset = opt + 1;
187914b7a73SJunchao Zhang   start  = opt + n + 2;
188914b7a73SJunchao Zhang   dx     = opt + 2 * n + 2;
189914b7a73SJunchao Zhang   dy     = opt + 3 * n + 2;
190914b7a73SJunchao Zhang   X      = opt + 5 * n + 2;
191914b7a73SJunchao Zhang   Y      = opt + 6 * n + 2;
192*9371c9d4SSatish Balay   for (r = 0; r < n; r++) {
193*9371c9d4SSatish Balay     if (tid < offset[r + 1]) break;
194*9371c9d4SSatish Balay   }
195914b7a73SJunchao Zhang   m = (tid - offset[r]);
196914b7a73SJunchao Zhang   k = m / (dx[r] * dy[r]);
197914b7a73SJunchao Zhang   j = (m - k * dx[r] * dy[r]) / dx[r];
198914b7a73SJunchao Zhang   i = m - k * dx[r] * dy[r] - j * dx[r];
199914b7a73SJunchao Zhang 
200914b7a73SJunchao Zhang   return (start[r] + k * X[r] * Y[r] + j * X[r] + i);
201914b7a73SJunchao Zhang }
202914b7a73SJunchao Zhang 
203914b7a73SJunchao Zhang /*====================================================================================*/
204914b7a73SJunchao Zhang /*  Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link'         */
205914b7a73SJunchao Zhang /*====================================================================================*/
206914b7a73SJunchao Zhang 
207914b7a73SJunchao Zhang /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
208914b7a73SJunchao Zhang    <Type> is PetscReal, which is the primitive type we operate on.
209914b7a73SJunchao Zhang    <bs>   is 16, which says <unit> contains 16 primitive types.
210914b7a73SJunchao Zhang    <BS>   is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
211914b7a73SJunchao Zhang    <EQ>   is 0, which is (bs == BS ? 1 : 0)
212914b7a73SJunchao Zhang 
213914b7a73SJunchao Zhang   If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
214914b7a73SJunchao Zhang   For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
215914b7a73SJunchao Zhang */
216914b7a73SJunchao Zhang template <typename Type, PetscInt BS, PetscInt EQ>
217*9371c9d4SSatish Balay static PetscErrorCode Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data_, void *buf_) {
218914b7a73SJunchao Zhang   const PetscInt      *iopt = opt ? opt->array : NULL;
219914b7a73SJunchao Zhang   const PetscInt       M = EQ ? 1 : link->bs / BS, MBS = M * BS; /* If EQ, then MBS will be a compile-time const */
220914b7a73SJunchao Zhang   const Type          *data = static_cast<const Type *>(data_);
221914b7a73SJunchao Zhang   Type                *buf  = static_cast<Type *>(buf_);
222f4af43b4SJunchao Zhang   DeviceExecutionSpace exec;
223914b7a73SJunchao Zhang 
224914b7a73SJunchao Zhang   PetscFunctionBegin;
225*9371c9d4SSatish Balay   Kokkos::parallel_for(
226*9371c9d4SSatish Balay     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
227914b7a73SJunchao Zhang       /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous;
228914b7a73SJunchao Zhang        iopt == NULL && idx == NULL ==> the indices are contiguous;
229914b7a73SJunchao Zhang      */
230914b7a73SJunchao Zhang       PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
231914b7a73SJunchao Zhang       PetscInt s = tid * MBS;
232914b7a73SJunchao Zhang       for (int i = 0; i < MBS; i++) buf[s + i] = data[t + i];
233914b7a73SJunchao Zhang     });
234914b7a73SJunchao Zhang   PetscFunctionReturn(0);
235914b7a73SJunchao Zhang }
236914b7a73SJunchao Zhang 
237914b7a73SJunchao Zhang template <typename Type, class Op, PetscInt BS, PetscInt EQ>
238*9371c9d4SSatish Balay static PetscErrorCode UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data_, const void *buf_) {
239914b7a73SJunchao Zhang   Op                   op;
240914b7a73SJunchao Zhang   const PetscInt      *iopt = opt ? opt->array : NULL;
241914b7a73SJunchao Zhang   const PetscInt       M = EQ ? 1 : link->bs / BS, MBS = M * BS;
242914b7a73SJunchao Zhang   Type                *data = static_cast<Type *>(data_);
243914b7a73SJunchao Zhang   const Type          *buf  = static_cast<const Type *>(buf_);
244f4af43b4SJunchao Zhang   DeviceExecutionSpace exec;
245914b7a73SJunchao Zhang 
246914b7a73SJunchao Zhang   PetscFunctionBegin;
247*9371c9d4SSatish Balay   Kokkos::parallel_for(
248*9371c9d4SSatish Balay     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
249914b7a73SJunchao Zhang       PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
250914b7a73SJunchao Zhang       PetscInt s = tid * MBS;
251914b7a73SJunchao Zhang       for (int i = 0; i < MBS; i++) op(data[t + i], buf[s + i]);
252914b7a73SJunchao Zhang     });
253914b7a73SJunchao Zhang   PetscFunctionReturn(0);
254914b7a73SJunchao Zhang }
255914b7a73SJunchao Zhang 
256914b7a73SJunchao Zhang template <typename Type, class Op, PetscInt BS, PetscInt EQ>
257*9371c9d4SSatish Balay static PetscErrorCode FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf) {
258914b7a73SJunchao Zhang   Op                   op;
259914b7a73SJunchao Zhang   const PetscInt      *ropt = opt ? opt->array : NULL;
260914b7a73SJunchao Zhang   const PetscInt       M = EQ ? 1 : link->bs / BS, MBS = M * BS;
261914b7a73SJunchao Zhang   Type                *rootdata = static_cast<Type *>(data), *leafbuf = static_cast<Type *>(buf);
262f4af43b4SJunchao Zhang   DeviceExecutionSpace exec;
263914b7a73SJunchao Zhang 
264914b7a73SJunchao Zhang   PetscFunctionBegin;
265*9371c9d4SSatish Balay   Kokkos::parallel_for(
266*9371c9d4SSatish Balay     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
267914b7a73SJunchao Zhang       PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
268914b7a73SJunchao Zhang       PetscInt l = tid * MBS;
269914b7a73SJunchao Zhang       for (int i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]);
270914b7a73SJunchao Zhang     });
271914b7a73SJunchao Zhang   PetscFunctionReturn(0);
272914b7a73SJunchao Zhang }
273914b7a73SJunchao Zhang 
274914b7a73SJunchao Zhang template <typename Type, class Op, PetscInt BS, PetscInt EQ>
275*9371c9d4SSatish Balay static PetscErrorCode ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_) {
276914b7a73SJunchao Zhang   PetscInt             srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0;
277914b7a73SJunchao Zhang   const PetscInt       M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
278914b7a73SJunchao Zhang   const Type          *src = static_cast<const Type *>(src_);
279914b7a73SJunchao Zhang   Type                *dst = static_cast<Type *>(dst_);
280f4af43b4SJunchao Zhang   DeviceExecutionSpace exec;
281914b7a73SJunchao Zhang 
282914b7a73SJunchao Zhang   PetscFunctionBegin;
283914b7a73SJunchao Zhang   /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */
284*9371c9d4SSatish Balay   if (srcOpt) {
285*9371c9d4SSatish Balay     srcx     = srcOpt->dx[0];
286*9371c9d4SSatish Balay     srcy     = srcOpt->dy[0];
287*9371c9d4SSatish Balay     srcX     = srcOpt->X[0];
288*9371c9d4SSatish Balay     srcY     = srcOpt->Y[0];
289*9371c9d4SSatish Balay     srcStart = srcOpt->start[0];
290*9371c9d4SSatish Balay     srcIdx   = NULL;
291*9371c9d4SSatish Balay   } else if (!srcIdx) {
292*9371c9d4SSatish Balay     srcx = srcX = count;
293*9371c9d4SSatish Balay     srcy = srcY = 1;
294*9371c9d4SSatish Balay   }
295914b7a73SJunchao Zhang 
296*9371c9d4SSatish Balay   if (dstOpt) {
297*9371c9d4SSatish Balay     dstx     = dstOpt->dx[0];
298*9371c9d4SSatish Balay     dsty     = dstOpt->dy[0];
299*9371c9d4SSatish Balay     dstX     = dstOpt->X[0];
300*9371c9d4SSatish Balay     dstY     = dstOpt->Y[0];
301*9371c9d4SSatish Balay     dstStart = dstOpt->start[0];
302*9371c9d4SSatish Balay     dstIdx   = NULL;
303*9371c9d4SSatish Balay   } else if (!dstIdx) {
304*9371c9d4SSatish Balay     dstx = dstX = count;
305*9371c9d4SSatish Balay     dsty = dstY = 1;
306*9371c9d4SSatish Balay   }
307914b7a73SJunchao Zhang 
308*9371c9d4SSatish Balay   Kokkos::parallel_for(
309*9371c9d4SSatish Balay     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
310914b7a73SJunchao Zhang       PetscInt i, j, k, s, t;
311914b7a73SJunchao Zhang       Op       op;
312914b7a73SJunchao Zhang       if (!srcIdx) { /* src is in 3D */
313914b7a73SJunchao Zhang         k = tid / (srcx * srcy);
314914b7a73SJunchao Zhang         j = (tid - k * srcx * srcy) / srcx;
315914b7a73SJunchao Zhang         i = tid - k * srcx * srcy - j * srcx;
316914b7a73SJunchao Zhang         s = srcStart + k * srcX * srcY + j * srcX + i;
317914b7a73SJunchao Zhang       } else { /* src is contiguous */
318914b7a73SJunchao Zhang         s = srcIdx[tid];
319914b7a73SJunchao Zhang       }
320914b7a73SJunchao Zhang 
321914b7a73SJunchao Zhang       if (!dstIdx) { /* 3D */
322914b7a73SJunchao Zhang         k = tid / (dstx * dsty);
323914b7a73SJunchao Zhang         j = (tid - k * dstx * dsty) / dstx;
324914b7a73SJunchao Zhang         i = tid - k * dstx * dsty - j * dstx;
325914b7a73SJunchao Zhang         t = dstStart + k * dstX * dstY + j * dstX + i;
326914b7a73SJunchao Zhang       } else { /* contiguous */
327914b7a73SJunchao Zhang         t = dstIdx[tid];
328914b7a73SJunchao Zhang       }
329914b7a73SJunchao Zhang 
330914b7a73SJunchao Zhang       s *= MBS;
331914b7a73SJunchao Zhang       t *= MBS;
332914b7a73SJunchao Zhang       for (i = 0; i < MBS; i++) op(dst[t + i], src[s + i]);
333914b7a73SJunchao Zhang     });
334914b7a73SJunchao Zhang   PetscFunctionReturn(0);
335914b7a73SJunchao Zhang }
336914b7a73SJunchao Zhang 
337914b7a73SJunchao Zhang /* Specialization for Insert since we may use memcpy */
338914b7a73SJunchao Zhang template <typename Type, PetscInt BS, PetscInt EQ>
339*9371c9d4SSatish Balay static PetscErrorCode ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_) {
340914b7a73SJunchao Zhang   const Type          *src = static_cast<const Type *>(src_);
341914b7a73SJunchao Zhang   Type                *dst = static_cast<Type *>(dst_);
342f4af43b4SJunchao Zhang   DeviceExecutionSpace exec;
343914b7a73SJunchao Zhang 
344914b7a73SJunchao Zhang   PetscFunctionBegin;
345914b7a73SJunchao Zhang   if (!count) PetscFunctionReturn(0);
346914b7a73SJunchao Zhang   /*src and dst are contiguous */
347914b7a73SJunchao Zhang   if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
348914b7a73SJunchao Zhang     size_t              sz = count * link->unitbytes;
349914b7a73SJunchao Zhang     deviceBuffer_t      dbuf(reinterpret_cast<char *>(dst + dstStart * link->bs), sz);
350914b7a73SJunchao Zhang     deviceConstBuffer_t sbuf(reinterpret_cast<const char *>(src + srcStart * link->bs), sz);
351914b7a73SJunchao Zhang     Kokkos::deep_copy(exec, dbuf, sbuf);
352914b7a73SJunchao Zhang   } else {
3539566063dSJacob Faibussowitsch     PetscCall(ScatterAndOp<Type, Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst));
354914b7a73SJunchao Zhang   }
355914b7a73SJunchao Zhang   PetscFunctionReturn(0);
356914b7a73SJunchao Zhang }
357914b7a73SJunchao Zhang 
358914b7a73SJunchao Zhang template <typename Type, class Op, PetscInt BS, PetscInt EQ>
359*9371c9d4SSatish Balay static PetscErrorCode FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata_, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata_, void *leafupdate_) {
360914b7a73SJunchao Zhang   Op                   op;
361914b7a73SJunchao Zhang   const PetscInt       M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
362914b7a73SJunchao Zhang   const PetscInt      *ropt     = rootopt ? rootopt->array : NULL;
363914b7a73SJunchao Zhang   const PetscInt      *lopt     = leafopt ? leafopt->array : NULL;
364914b7a73SJunchao Zhang   Type                *rootdata = static_cast<Type *>(rootdata_), *leafupdate = static_cast<Type *>(leafupdate_);
365914b7a73SJunchao Zhang   const Type          *leafdata = static_cast<const Type *>(leafdata_);
366f4af43b4SJunchao Zhang   DeviceExecutionSpace exec;
367914b7a73SJunchao Zhang 
368914b7a73SJunchao Zhang   PetscFunctionBegin;
369*9371c9d4SSatish Balay   Kokkos::parallel_for(
370*9371c9d4SSatish Balay     Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
371914b7a73SJunchao Zhang       PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
372914b7a73SJunchao Zhang       PetscInt l = (lopt ? MapTidToIndex(lopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS;
373914b7a73SJunchao Zhang       for (int i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]);
374914b7a73SJunchao Zhang     });
375914b7a73SJunchao Zhang   PetscFunctionReturn(0);
376914b7a73SJunchao Zhang }
377914b7a73SJunchao Zhang 
378914b7a73SJunchao Zhang /*====================================================================================*/
379914b7a73SJunchao Zhang /*  Init various types and instantiate pack/unpack function pointers                  */
380914b7a73SJunchao Zhang /*====================================================================================*/
381914b7a73SJunchao Zhang template <typename Type, PetscInt BS, PetscInt EQ>
382*9371c9d4SSatish Balay static void PackInit_RealType(PetscSFLink link) {
383914b7a73SJunchao Zhang   /* Pack/unpack for remote communication */
384914b7a73SJunchao Zhang   link->d_Pack             = Pack<Type, BS, EQ>;
385914b7a73SJunchao Zhang   link->d_UnpackAndInsert  = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
386914b7a73SJunchao Zhang   link->d_UnpackAndAdd     = UnpackAndOp<Type, Add<Type>, BS, EQ>;
387914b7a73SJunchao Zhang   link->d_UnpackAndMult    = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
388914b7a73SJunchao Zhang   link->d_UnpackAndMin     = UnpackAndOp<Type, Min<Type>, BS, EQ>;
389914b7a73SJunchao Zhang   link->d_UnpackAndMax     = UnpackAndOp<Type, Max<Type>, BS, EQ>;
390914b7a73SJunchao Zhang   link->d_FetchAndAdd      = FetchAndOp<Type, Add<Type>, BS, EQ>;
391914b7a73SJunchao Zhang   /* Scatter for local communication */
392914b7a73SJunchao Zhang   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */
393914b7a73SJunchao Zhang   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
394914b7a73SJunchao Zhang   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
395914b7a73SJunchao Zhang   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
396914b7a73SJunchao Zhang   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
397914b7a73SJunchao Zhang   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
398914b7a73SJunchao Zhang   /* Atomic versions when there are data-race possibilities */
399914b7a73SJunchao Zhang   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
400914b7a73SJunchao Zhang   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
401914b7a73SJunchao Zhang   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
402914b7a73SJunchao Zhang   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
403914b7a73SJunchao Zhang   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
404914b7a73SJunchao Zhang   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
405914b7a73SJunchao Zhang 
406914b7a73SJunchao Zhang   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
407914b7a73SJunchao Zhang   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
408914b7a73SJunchao Zhang   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
409914b7a73SJunchao Zhang   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
410914b7a73SJunchao Zhang   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
411914b7a73SJunchao Zhang   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
412914b7a73SJunchao Zhang }
413914b7a73SJunchao Zhang 
414914b7a73SJunchao Zhang template <typename Type, PetscInt BS, PetscInt EQ>
415*9371c9d4SSatish Balay static void PackInit_IntegerType(PetscSFLink link) {
416914b7a73SJunchao Zhang   link->d_Pack            = Pack<Type, BS, EQ>;
417914b7a73SJunchao Zhang   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
418914b7a73SJunchao Zhang   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
419914b7a73SJunchao Zhang   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
420914b7a73SJunchao Zhang   link->d_UnpackAndMin    = UnpackAndOp<Type, Min<Type>, BS, EQ>;
421914b7a73SJunchao Zhang   link->d_UnpackAndMax    = UnpackAndOp<Type, Max<Type>, BS, EQ>;
422914b7a73SJunchao Zhang   link->d_UnpackAndLAND   = UnpackAndOp<Type, LAND<Type>, BS, EQ>;
423914b7a73SJunchao Zhang   link->d_UnpackAndLOR    = UnpackAndOp<Type, LOR<Type>, BS, EQ>;
424914b7a73SJunchao Zhang   link->d_UnpackAndLXOR   = UnpackAndOp<Type, LXOR<Type>, BS, EQ>;
425914b7a73SJunchao Zhang   link->d_UnpackAndBAND   = UnpackAndOp<Type, BAND<Type>, BS, EQ>;
426914b7a73SJunchao Zhang   link->d_UnpackAndBOR    = UnpackAndOp<Type, BOR<Type>, BS, EQ>;
427914b7a73SJunchao Zhang   link->d_UnpackAndBXOR   = UnpackAndOp<Type, BXOR<Type>, BS, EQ>;
428914b7a73SJunchao Zhang   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;
429914b7a73SJunchao Zhang 
430914b7a73SJunchao Zhang   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
431914b7a73SJunchao Zhang   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
432914b7a73SJunchao Zhang   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
433914b7a73SJunchao Zhang   link->d_ScatterAndMin    = ScatterAndOp<Type, Min<Type>, BS, EQ>;
434914b7a73SJunchao Zhang   link->d_ScatterAndMax    = ScatterAndOp<Type, Max<Type>, BS, EQ>;
435914b7a73SJunchao Zhang   link->d_ScatterAndLAND   = ScatterAndOp<Type, LAND<Type>, BS, EQ>;
436914b7a73SJunchao Zhang   link->d_ScatterAndLOR    = ScatterAndOp<Type, LOR<Type>, BS, EQ>;
437914b7a73SJunchao Zhang   link->d_ScatterAndLXOR   = ScatterAndOp<Type, LXOR<Type>, BS, EQ>;
438914b7a73SJunchao Zhang   link->d_ScatterAndBAND   = ScatterAndOp<Type, BAND<Type>, BS, EQ>;
439914b7a73SJunchao Zhang   link->d_ScatterAndBOR    = ScatterAndOp<Type, BOR<Type>, BS, EQ>;
440914b7a73SJunchao Zhang   link->d_ScatterAndBXOR   = ScatterAndOp<Type, BXOR<Type>, BS, EQ>;
441914b7a73SJunchao Zhang   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
442914b7a73SJunchao Zhang 
443914b7a73SJunchao Zhang   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
444914b7a73SJunchao Zhang   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
445914b7a73SJunchao Zhang   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
446914b7a73SJunchao Zhang   link->da_UnpackAndMin    = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
447914b7a73SJunchao Zhang   link->da_UnpackAndMax    = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
448914b7a73SJunchao Zhang   link->da_UnpackAndLAND   = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>;
449914b7a73SJunchao Zhang   link->da_UnpackAndLOR    = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>;
450914b7a73SJunchao Zhang   link->da_UnpackAndBAND   = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>;
451914b7a73SJunchao Zhang   link->da_UnpackAndBOR    = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>;
452914b7a73SJunchao Zhang   link->da_UnpackAndBXOR   = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
453914b7a73SJunchao Zhang   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
454914b7a73SJunchao Zhang 
455914b7a73SJunchao Zhang   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
456914b7a73SJunchao Zhang   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
457914b7a73SJunchao Zhang   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
458914b7a73SJunchao Zhang   link->da_ScatterAndMin    = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
459914b7a73SJunchao Zhang   link->da_ScatterAndMax    = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
460914b7a73SJunchao Zhang   link->da_ScatterAndLAND   = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>;
461914b7a73SJunchao Zhang   link->da_ScatterAndLOR    = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>;
462914b7a73SJunchao Zhang   link->da_ScatterAndBAND   = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>;
463914b7a73SJunchao Zhang   link->da_ScatterAndBOR    = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>;
464914b7a73SJunchao Zhang   link->da_ScatterAndBXOR   = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
465914b7a73SJunchao Zhang   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
466914b7a73SJunchao Zhang }
467914b7a73SJunchao Zhang 
468914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX)
469914b7a73SJunchao Zhang template <typename Type, PetscInt BS, PetscInt EQ>
470*9371c9d4SSatish Balay static void PackInit_ComplexType(PetscSFLink link) {
471914b7a73SJunchao Zhang   link->d_Pack            = Pack<Type, BS, EQ>;
472914b7a73SJunchao Zhang   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
473914b7a73SJunchao Zhang   link->d_UnpackAndAdd    = UnpackAndOp<Type, Add<Type>, BS, EQ>;
474914b7a73SJunchao Zhang   link->d_UnpackAndMult   = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
475914b7a73SJunchao Zhang   link->d_FetchAndAdd     = FetchAndOp<Type, Add<Type>, BS, EQ>;
476914b7a73SJunchao Zhang 
477914b7a73SJunchao Zhang   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
478914b7a73SJunchao Zhang   link->d_ScatterAndAdd    = ScatterAndOp<Type, Add<Type>, BS, EQ>;
479914b7a73SJunchao Zhang   link->d_ScatterAndMult   = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
480914b7a73SJunchao Zhang   link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
481914b7a73SJunchao Zhang 
482914b7a73SJunchao Zhang   link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
483914b7a73SJunchao Zhang   link->da_UnpackAndAdd    = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
484914b7a73SJunchao Zhang   link->da_UnpackAndMult   = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
485914b7a73SJunchao Zhang   link->da_FetchAndAdd     = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
486914b7a73SJunchao Zhang 
487914b7a73SJunchao Zhang   link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
488914b7a73SJunchao Zhang   link->da_ScatterAndAdd    = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
489914b7a73SJunchao Zhang   link->da_ScatterAndMult   = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
490914b7a73SJunchao Zhang   link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
491914b7a73SJunchao Zhang }
492914b7a73SJunchao Zhang #endif
493914b7a73SJunchao Zhang 
494914b7a73SJunchao Zhang template <typename Type>
495*9371c9d4SSatish Balay static void PackInit_PairType(PetscSFLink link) {
496914b7a73SJunchao Zhang   link->d_Pack            = Pack<Type, 1, 1>;
497914b7a73SJunchao Zhang   link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, 1, 1>;
498914b7a73SJunchao Zhang   link->d_UnpackAndMaxloc = UnpackAndOp<Type, Maxloc<Type>, 1, 1>;
499914b7a73SJunchao Zhang   link->d_UnpackAndMinloc = UnpackAndOp<Type, Minloc<Type>, 1, 1>;
500914b7a73SJunchao Zhang 
501914b7a73SJunchao Zhang   link->d_ScatterAndInsert = ScatterAndOp<Type, Insert<Type>, 1, 1>;
502914b7a73SJunchao Zhang   link->d_ScatterAndMaxloc = ScatterAndOp<Type, Maxloc<Type>, 1, 1>;
503914b7a73SJunchao Zhang   link->d_ScatterAndMinloc = ScatterAndOp<Type, Minloc<Type>, 1, 1>;
504914b7a73SJunchao Zhang   /* Atomics for pair types are not implemented yet */
505914b7a73SJunchao Zhang }
506914b7a73SJunchao Zhang 
507914b7a73SJunchao Zhang template <typename Type, PetscInt BS, PetscInt EQ>
508*9371c9d4SSatish Balay static void PackInit_DumbType(PetscSFLink link) {
509914b7a73SJunchao Zhang   link->d_Pack             = Pack<Type, BS, EQ>;
510914b7a73SJunchao Zhang   link->d_UnpackAndInsert  = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
511914b7a73SJunchao Zhang   link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
512914b7a73SJunchao Zhang   /* Atomics for dumb types are not implemented yet */
513914b7a73SJunchao Zhang }
514914b7a73SJunchao Zhang 
515f4af43b4SJunchao Zhang /*
516f4af43b4SJunchao Zhang   Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug
517f4af43b4SJunchao Zhang   that one is not able to repeatedly create and destroy the object. SF's original design was each
518f4af43b4SJunchao Zhang   SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from
519f4af43b4SJunchao Zhang   destroying multiple SFLinks with NULL stream and the default execution space object. To avoid
520f4af43b4SJunchao Zhang   memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos
521f4af43b4SJunchao Zhang   does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton
522f4af43b4SJunchao Zhang   object in Kokkos.
523f4af43b4SJunchao Zhang */
524f4af43b4SJunchao Zhang /*
525914b7a73SJunchao Zhang static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link)
526914b7a73SJunchao Zhang {
527914b7a73SJunchao Zhang   PetscFunctionBegin;
528914b7a73SJunchao Zhang   PetscFunctionReturn(0);
529914b7a73SJunchao Zhang }
530f4af43b4SJunchao Zhang */
531914b7a73SJunchao Zhang 
532914b7a73SJunchao Zhang /* Some device-specific utilities */
533*9371c9d4SSatish Balay static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link) {
534914b7a73SJunchao Zhang   PetscFunctionBegin;
535914b7a73SJunchao Zhang   Kokkos::fence();
536914b7a73SJunchao Zhang   PetscFunctionReturn(0);
537914b7a73SJunchao Zhang }
538914b7a73SJunchao Zhang 
539*9371c9d4SSatish Balay static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link) {
540f4af43b4SJunchao Zhang   DeviceExecutionSpace exec;
541914b7a73SJunchao Zhang   PetscFunctionBegin;
542914b7a73SJunchao Zhang   exec.fence();
543914b7a73SJunchao Zhang   PetscFunctionReturn(0);
544914b7a73SJunchao Zhang }
545914b7a73SJunchao Zhang 
546*9371c9d4SSatish Balay static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n) {
547f4af43b4SJunchao Zhang   DeviceExecutionSpace exec;
548914b7a73SJunchao Zhang 
549914b7a73SJunchao Zhang   PetscFunctionBegin;
550914b7a73SJunchao Zhang   if (!n) PetscFunctionReturn(0);
55171438e86SJunchao Zhang   if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) {
5529566063dSJacob Faibussowitsch     PetscCall(PetscMemcpy(dst, src, n));
553914b7a73SJunchao Zhang   } else {
55471438e86SJunchao Zhang     if (PetscMemTypeDevice(dstmtype) && PetscMemTypeHost(srcmtype)) {
555914b7a73SJunchao Zhang       deviceBuffer_t    dbuf(static_cast<char *>(dst), n);
556914b7a73SJunchao Zhang       HostConstBuffer_t sbuf(static_cast<const char *>(src), n);
557914b7a73SJunchao Zhang       Kokkos::deep_copy(exec, dbuf, sbuf);
5589566063dSJacob Faibussowitsch       PetscCall(PetscLogCpuToGpu(n));
55971438e86SJunchao Zhang     } else if (PetscMemTypeHost(dstmtype) && PetscMemTypeDevice(srcmtype)) {
560914b7a73SJunchao Zhang       HostBuffer_t        dbuf(static_cast<char *>(dst), n);
561914b7a73SJunchao Zhang       deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
562914b7a73SJunchao Zhang       Kokkos::deep_copy(exec, dbuf, sbuf);
5639566063dSJacob Faibussowitsch       PetscCall(PetscLogGpuToCpu(n));
56471438e86SJunchao Zhang     } else if (PetscMemTypeDevice(dstmtype) && PetscMemTypeDevice(srcmtype)) {
565914b7a73SJunchao Zhang       deviceBuffer_t      dbuf(static_cast<char *>(dst), n);
566914b7a73SJunchao Zhang       deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
567914b7a73SJunchao Zhang       Kokkos::deep_copy(exec, dbuf, sbuf);
568914b7a73SJunchao Zhang     }
569914b7a73SJunchao Zhang   }
570914b7a73SJunchao Zhang   PetscFunctionReturn(0);
571914b7a73SJunchao Zhang }
572914b7a73SJunchao Zhang 
573*9371c9d4SSatish Balay PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype, size_t size, void **ptr) {
574914b7a73SJunchao Zhang   PetscFunctionBegin;
5759566063dSJacob Faibussowitsch   if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr));
57671438e86SJunchao Zhang   else if (PetscMemTypeDevice(mtype)) {
5779566063dSJacob Faibussowitsch     if (!PetscKokkosInitialized) PetscCall(PetscKokkosInitializeCheck());
57845639126SStefano Zampini     *ptr = Kokkos::kokkos_malloc<DeviceMemorySpace>(size);
57998921bdaSJacob Faibussowitsch   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
580914b7a73SJunchao Zhang   PetscFunctionReturn(0);
581914b7a73SJunchao Zhang }
582914b7a73SJunchao Zhang 
583*9371c9d4SSatish Balay PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype, void *ptr) {
584914b7a73SJunchao Zhang   PetscFunctionBegin;
5859566063dSJacob Faibussowitsch   if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr));
586*9371c9d4SSatish Balay   else if (PetscMemTypeDevice(mtype)) {
587*9371c9d4SSatish Balay     Kokkos::kokkos_free<DeviceMemorySpace>(ptr);
588*9371c9d4SSatish Balay   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
589914b7a73SJunchao Zhang   PetscFunctionReturn(0);
590914b7a73SJunchao Zhang }
591914b7a73SJunchao Zhang 
59271438e86SJunchao Zhang /* Destructor when the link uses MPI for communication */
593*9371c9d4SSatish Balay static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSF sf, PetscSFLink link) {
59471438e86SJunchao Zhang   PetscFunctionBegin;
59571438e86SJunchao Zhang   for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) {
5969566063dSJacob Faibussowitsch     PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
5979566063dSJacob Faibussowitsch     PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
59871438e86SJunchao Zhang   }
59971438e86SJunchao Zhang   PetscFunctionReturn(0);
60071438e86SJunchao Zhang }
601914b7a73SJunchao Zhang 
602914b7a73SJunchao Zhang /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
603*9371c9d4SSatish Balay PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf, PetscSFLink link, MPI_Datatype unit) {
604914b7a73SJunchao Zhang   PetscInt  nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0;
605914b7a73SJunchao Zhang   PetscBool is2Int, is2PetscInt;
606914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX)
607914b7a73SJunchao Zhang   PetscInt nPetscComplex = 0;
608914b7a73SJunchao Zhang #endif
609914b7a73SJunchao Zhang 
610914b7a73SJunchao Zhang   PetscFunctionBegin;
611914b7a73SJunchao Zhang   if (link->deviceinited) PetscFunctionReturn(0);
6129566063dSJacob Faibussowitsch   PetscCall(PetscKokkosInitializeCheck());
6139566063dSJacob Faibussowitsch   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar));
6149566063dSJacob Faibussowitsch   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar));
615914b7a73SJunchao Zhang   /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
6169566063dSJacob Faibussowitsch   PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt));
6179566063dSJacob Faibussowitsch   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt));
6189566063dSJacob Faibussowitsch   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal));
619914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX)
6209566063dSJacob Faibussowitsch   PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex));
621914b7a73SJunchao Zhang #endif
6229566063dSJacob Faibussowitsch   PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int));
6239566063dSJacob Faibussowitsch   PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt));
624914b7a73SJunchao Zhang 
625914b7a73SJunchao Zhang   if (is2Int) {
626914b7a73SJunchao Zhang     PackInit_PairType<Kokkos::pair<int, int>>(link);
627914b7a73SJunchao Zhang   } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
628914b7a73SJunchao Zhang     PackInit_PairType<Kokkos::pair<PetscInt, PetscInt>>(link);
629914b7a73SJunchao Zhang   } else if (nPetscReal) {
630d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE) /* Skip the unimportant stuff to speed up SF device compilation time */
631*9371c9d4SSatish Balay     if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link);
632*9371c9d4SSatish Balay     else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link);
633*9371c9d4SSatish Balay     else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link);
634*9371c9d4SSatish Balay     else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link);
635*9371c9d4SSatish Balay     else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link);
636*9371c9d4SSatish Balay     else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link);
637*9371c9d4SSatish Balay     else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link);
638*9371c9d4SSatish Balay     else if (nPetscReal % 1 == 0)
639eee4e20aSJunchao Zhang #endif
640d941a2f0SJunchao Zhang       PackInit_RealType<PetscReal, 1, 0>(link);
641874d28e3SJunchao Zhang   } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) {
642d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE)
643*9371c9d4SSatish Balay     if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link);
644*9371c9d4SSatish Balay     else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link);
645*9371c9d4SSatish Balay     else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link);
646*9371c9d4SSatish Balay     else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link);
647*9371c9d4SSatish Balay     else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link);
648*9371c9d4SSatish Balay     else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link);
649*9371c9d4SSatish Balay     else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link);
650*9371c9d4SSatish Balay     else if (nPetscInt % 1 == 0)
651eee4e20aSJunchao Zhang #endif
652d941a2f0SJunchao Zhang       PackInit_IntegerType<llint, 1, 0>(link);
653914b7a73SJunchao Zhang   } else if (nInt) {
654d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE)
655*9371c9d4SSatish Balay     if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link);
656*9371c9d4SSatish Balay     else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link);
657*9371c9d4SSatish Balay     else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link);
658*9371c9d4SSatish Balay     else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link);
659*9371c9d4SSatish Balay     else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link);
660*9371c9d4SSatish Balay     else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link);
661*9371c9d4SSatish Balay     else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link);
662*9371c9d4SSatish Balay     else if (nInt % 1 == 0)
663eee4e20aSJunchao Zhang #endif
664d941a2f0SJunchao Zhang       PackInit_IntegerType<int, 1, 0>(link);
665914b7a73SJunchao Zhang   } else if (nSignedChar) {
666d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE)
667*9371c9d4SSatish Balay     if (nSignedChar == 8) PackInit_IntegerType<char, 8, 1>(link);
668*9371c9d4SSatish Balay     else if (nSignedChar % 8 == 0) PackInit_IntegerType<char, 8, 0>(link);
669*9371c9d4SSatish Balay     else if (nSignedChar == 4) PackInit_IntegerType<char, 4, 1>(link);
670*9371c9d4SSatish Balay     else if (nSignedChar % 4 == 0) PackInit_IntegerType<char, 4, 0>(link);
671*9371c9d4SSatish Balay     else if (nSignedChar == 2) PackInit_IntegerType<char, 2, 1>(link);
672*9371c9d4SSatish Balay     else if (nSignedChar % 2 == 0) PackInit_IntegerType<char, 2, 0>(link);
673*9371c9d4SSatish Balay     else if (nSignedChar == 1) PackInit_IntegerType<char, 1, 1>(link);
674*9371c9d4SSatish Balay     else if (nSignedChar % 1 == 0)
675eee4e20aSJunchao Zhang #endif
676d941a2f0SJunchao Zhang       PackInit_IntegerType<char, 1, 0>(link);
677914b7a73SJunchao Zhang   } else if (nUnsignedChar) {
678d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE)
679*9371c9d4SSatish Balay     if (nUnsignedChar == 8) PackInit_IntegerType<unsigned char, 8, 1>(link);
680*9371c9d4SSatish Balay     else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<unsigned char, 8, 0>(link);
681*9371c9d4SSatish Balay     else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char, 4, 1>(link);
682*9371c9d4SSatish Balay     else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<unsigned char, 4, 0>(link);
683*9371c9d4SSatish Balay     else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char, 2, 1>(link);
684*9371c9d4SSatish Balay     else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<unsigned char, 2, 0>(link);
685*9371c9d4SSatish Balay     else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char, 1, 1>(link);
686*9371c9d4SSatish Balay     else if (nUnsignedChar % 1 == 0)
687eee4e20aSJunchao Zhang #endif
688d941a2f0SJunchao Zhang       PackInit_IntegerType<unsigned char, 1, 0>(link);
689914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX)
690914b7a73SJunchao Zhang   } else if (nPetscComplex) {
691d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE)
692*9371c9d4SSatish Balay     if (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 1>(link);
693*9371c9d4SSatish Balay     else if (nPetscComplex % 8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 0>(link);
694*9371c9d4SSatish Balay     else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 1>(link);
695*9371c9d4SSatish Balay     else if (nPetscComplex % 4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 0>(link);
696*9371c9d4SSatish Balay     else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 1>(link);
697*9371c9d4SSatish Balay     else if (nPetscComplex % 2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 0>(link);
698*9371c9d4SSatish Balay     else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 1>(link);
699*9371c9d4SSatish Balay     else if (nPetscComplex % 1 == 0)
700eee4e20aSJunchao Zhang #endif
701d941a2f0SJunchao Zhang       PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 0>(link);
702914b7a73SJunchao Zhang #endif
703914b7a73SJunchao Zhang   } else {
704914b7a73SJunchao Zhang     MPI_Aint lb, nbyte;
7059566063dSJacob Faibussowitsch     PetscCallMPI(MPI_Type_get_extent(unit, &lb, &nbyte));
70608401ef6SPierre Jolivet     PetscCheck(lb == 0, PETSC_COMM_SELF, PETSC_ERR_SUP, "Datatype with nonzero lower bound %ld", (long)lb);
707914b7a73SJunchao Zhang     if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
708d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE)
709*9371c9d4SSatish Balay       if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link);
710*9371c9d4SSatish Balay       else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link);
711*9371c9d4SSatish Balay       else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link);
712*9371c9d4SSatish Balay       else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link);
713*9371c9d4SSatish Balay       else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link);
714*9371c9d4SSatish Balay       else if (nbyte % 1 == 0)
715eee4e20aSJunchao Zhang #endif
716d941a2f0SJunchao Zhang         PackInit_DumbType<char, 1, 0>(link);
717914b7a73SJunchao Zhang     } else {
718914b7a73SJunchao Zhang       nInt = nbyte / sizeof(int);
719d941a2f0SJunchao Zhang #if !defined(PETSC_HAVE_DEVICE)
720*9371c9d4SSatish Balay       if (nInt == 8) PackInit_DumbType<int, 8, 1>(link);
721*9371c9d4SSatish Balay       else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link);
722*9371c9d4SSatish Balay       else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link);
723*9371c9d4SSatish Balay       else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link);
724*9371c9d4SSatish Balay       else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link);
725*9371c9d4SSatish Balay       else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link);
726*9371c9d4SSatish Balay       else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link);
727*9371c9d4SSatish Balay       else if (nInt % 1 == 0)
728eee4e20aSJunchao Zhang #endif
729d941a2f0SJunchao Zhang         PackInit_DumbType<int, 1, 0>(link);
730914b7a73SJunchao Zhang     }
731914b7a73SJunchao Zhang   }
732914b7a73SJunchao Zhang 
73371438e86SJunchao Zhang   link->SyncDevice   = PetscSFLinkSyncDevice_Kokkos;
73471438e86SJunchao Zhang   link->SyncStream   = PetscSFLinkSyncStream_Kokkos;
73520c24465SJunchao Zhang   link->Memcpy       = PetscSFLinkMemcpy_Kokkos;
73671438e86SJunchao Zhang   link->Destroy      = PetscSFLinkDestroy_Kokkos;
737914b7a73SJunchao Zhang   link->deviceinited = PETSC_TRUE;
738914b7a73SJunchao Zhang   PetscFunctionReturn(0);
739914b7a73SJunchao Zhang }
740