xref: /petsc/src/vec/is/sf/impls/basic/kokkos/sfkok.kokkos.cxx (revision 914b7a734abff03dc21d49d3871e1a6fe5fa59bb)
1*914b7a73SJunchao Zhang #include <../src/vec/is/sf/impls/basic/sfpack.h>
2*914b7a73SJunchao Zhang 
3*914b7a73SJunchao Zhang #include <Kokkos_Core.hpp>
4*914b7a73SJunchao Zhang 
5*914b7a73SJunchao Zhang using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace;
6*914b7a73SJunchao Zhang using DeviceMemorySpace    = typename DeviceExecutionSpace::memory_space;
7*914b7a73SJunchao Zhang using HostMemorySpace      = Kokkos::HostSpace;
8*914b7a73SJunchao Zhang 
9*914b7a73SJunchao Zhang typedef Kokkos::View<char*,DeviceMemorySpace>       deviceBuffer_t;
10*914b7a73SJunchao Zhang typedef Kokkos::View<char*,HostMemorySpace>         HostBuffer_t;
11*914b7a73SJunchao Zhang 
12*914b7a73SJunchao Zhang typedef Kokkos::View<const char*,DeviceMemorySpace> deviceConstBuffer_t;
13*914b7a73SJunchao Zhang typedef Kokkos::View<const char*,HostMemorySpace>   HostConstBuffer_t;
14*914b7a73SJunchao Zhang 
15*914b7a73SJunchao Zhang /*====================================================================================*/
16*914b7a73SJunchao Zhang /*                             Regular operations                           */
17*914b7a73SJunchao Zhang /*====================================================================================*/
18*914b7a73SJunchao Zhang template<typename Type> struct Insert{KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = y;             return old;}};
19*914b7a73SJunchao Zhang template<typename Type> struct Add   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x += y;             return old;}};
20*914b7a73SJunchao Zhang template<typename Type> struct Mult  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x *= y;             return old;}};
21*914b7a73SJunchao Zhang template<typename Type> struct Min   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = PetscMin(x,y); return old;}};
22*914b7a73SJunchao Zhang template<typename Type> struct Max   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = PetscMax(x,y); return old;}};
23*914b7a73SJunchao Zhang template<typename Type> struct LAND  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x && y;        return old;}};
24*914b7a73SJunchao Zhang template<typename Type> struct LOR   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x || y;        return old;}};
25*914b7a73SJunchao Zhang template<typename Type> struct LXOR  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = !x != !y;      return old;}};
26*914b7a73SJunchao Zhang template<typename Type> struct BAND  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x & y;         return old;}};
27*914b7a73SJunchao Zhang template<typename Type> struct BOR   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x | y;         return old;}};
28*914b7a73SJunchao Zhang template<typename Type> struct BXOR  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x ^ y;         return old;}};
29*914b7a73SJunchao Zhang template<typename PairType> struct Minloc {
30*914b7a73SJunchao Zhang   KOKKOS_INLINE_FUNCTION PairType operator()(PairType& x,PairType y) const {
31*914b7a73SJunchao Zhang     PairType old = x;
32*914b7a73SJunchao Zhang     if (y.first < x.first) x = y;
33*914b7a73SJunchao Zhang     else if (y.first == x.first) x.second = PetscMin(x.second,y.second);
34*914b7a73SJunchao Zhang     return old;
35*914b7a73SJunchao Zhang   }
36*914b7a73SJunchao Zhang };
37*914b7a73SJunchao Zhang template<typename PairType> struct Maxloc {
38*914b7a73SJunchao Zhang   KOKKOS_INLINE_FUNCTION PairType operator()(PairType& x,PairType y) const {
39*914b7a73SJunchao Zhang     PairType old = x;
40*914b7a73SJunchao Zhang     if (y.first > x.first) x = y;
41*914b7a73SJunchao Zhang     else if (y.first == x.first) x.second = PetscMin(x.second,y.second); /* See MPI MAXLOC */
42*914b7a73SJunchao Zhang     return old;
43*914b7a73SJunchao Zhang   }
44*914b7a73SJunchao Zhang };
45*914b7a73SJunchao Zhang 
46*914b7a73SJunchao Zhang /*====================================================================================*/
47*914b7a73SJunchao Zhang /*                             Atomic operations                            */
48*914b7a73SJunchao Zhang /*====================================================================================*/
49*914b7a73SJunchao Zhang template<typename Type> struct AtomicInsert  {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_assign(&x,y);}};
50*914b7a73SJunchao Zhang template<typename Type> struct AtomicAdd     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_add(&x,y);}};
51*914b7a73SJunchao Zhang template<typename Type> struct AtomicBAND    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_and(&x,y);}};
52*914b7a73SJunchao Zhang template<typename Type> struct AtomicBOR     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_or (&x,y);}};
53*914b7a73SJunchao Zhang template<typename Type> struct AtomicBXOR    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_xor(&x,y);}};
54*914b7a73SJunchao Zhang template<typename Type> struct AtomicLAND    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {const Type zero=0,one=~0; Kokkos::atomic_and(&x,y?one:zero);}};
55*914b7a73SJunchao Zhang template<typename Type> struct AtomicLOR     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {const Type zero=0,one=1;  Kokkos::atomic_or (&x,y?one:zero);}};
56*914b7a73SJunchao Zhang template<typename Type> struct AtomicMult    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_mul(&x,y);}};
57*914b7a73SJunchao Zhang template<typename Type> struct AtomicMin     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_min(&x,y);}};
58*914b7a73SJunchao Zhang template<typename Type> struct AtomicMax     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_max(&x,y);}};
59*914b7a73SJunchao Zhang /* TODO: struct AtomicLXOR  */
60*914b7a73SJunchao Zhang template<typename Type> struct AtomicFetchAdd{KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {return Kokkos::atomic_fetch_add(&x,y);}};
61*914b7a73SJunchao Zhang 
62*914b7a73SJunchao Zhang /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
63*914b7a73SJunchao Zhang static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt,PetscInt tid)
64*914b7a73SJunchao Zhang {
65*914b7a73SJunchao Zhang   PetscInt        i,j,k,m,n,r;
66*914b7a73SJunchao Zhang   const PetscInt  *offset,*start,*dx,*dy,*X,*Y;
67*914b7a73SJunchao Zhang 
68*914b7a73SJunchao Zhang   n      = opt[0];
69*914b7a73SJunchao Zhang   offset = opt + 1;
70*914b7a73SJunchao Zhang   start  = opt + n + 2;
71*914b7a73SJunchao Zhang   dx     = opt + 2*n + 2;
72*914b7a73SJunchao Zhang   dy     = opt + 3*n + 2;
73*914b7a73SJunchao Zhang   X      = opt + 5*n + 2;
74*914b7a73SJunchao Zhang   Y      = opt + 6*n + 2;
75*914b7a73SJunchao Zhang   for (r=0; r<n; r++) {if (tid < offset[r+1]) break;}
76*914b7a73SJunchao Zhang   m = (tid - offset[r]);
77*914b7a73SJunchao Zhang   k = m/(dx[r]*dy[r]);
78*914b7a73SJunchao Zhang   j = (m - k*dx[r]*dy[r])/dx[r];
79*914b7a73SJunchao Zhang   i = m - k*dx[r]*dy[r] - j*dx[r];
80*914b7a73SJunchao Zhang 
81*914b7a73SJunchao Zhang   return (start[r] + k*X[r]*Y[r] + j*X[r] + i);
82*914b7a73SJunchao Zhang }
83*914b7a73SJunchao Zhang 
84*914b7a73SJunchao Zhang /*====================================================================================*/
85*914b7a73SJunchao Zhang /*  Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link'         */
86*914b7a73SJunchao Zhang /*====================================================================================*/
87*914b7a73SJunchao Zhang 
88*914b7a73SJunchao Zhang /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
89*914b7a73SJunchao Zhang    <Type> is PetscReal, which is the primitive type we operate on.
90*914b7a73SJunchao Zhang    <bs>   is 16, which says <unit> contains 16 primitive types.
91*914b7a73SJunchao Zhang    <BS>   is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
92*914b7a73SJunchao Zhang    <EQ>   is 0, which is (bs == BS ? 1 : 0)
93*914b7a73SJunchao Zhang 
94*914b7a73SJunchao Zhang   If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
95*914b7a73SJunchao Zhang   For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
96*914b7a73SJunchao Zhang */
97*914b7a73SJunchao Zhang template<typename Type,PetscInt BS,PetscInt EQ>
98*914b7a73SJunchao Zhang static PetscErrorCode Pack(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,const void *data_,void *buf_)
99*914b7a73SJunchao Zhang {
100*914b7a73SJunchao Zhang   const PetscInt          *iopt = opt ? opt->array : NULL;
101*914b7a73SJunchao Zhang   const PetscInt          M = EQ ? 1 : link->bs/BS, MBS=M*BS; /* If EQ, then MBS will be a compile-time const */
102*914b7a73SJunchao Zhang   const Type              *data = static_cast<const Type*>(data_);
103*914b7a73SJunchao Zhang   Type                    *buf = static_cast<Type*>(buf_);
104*914b7a73SJunchao Zhang   DeviceExecutionSpace&   exec = *static_cast<DeviceExecutionSpace*>(link->sptr);
105*914b7a73SJunchao Zhang 
106*914b7a73SJunchao Zhang   PetscFunctionBegin;
107*914b7a73SJunchao Zhang   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
108*914b7a73SJunchao Zhang     /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous;
109*914b7a73SJunchao Zhang        iopt == NULL && idx == NULL ==> the indices are contiguous;
110*914b7a73SJunchao Zhang      */
111*914b7a73SJunchao Zhang     PetscInt t = (iopt? MapTidToIndex(iopt,tid) : (idx? idx[tid] : start+tid))*MBS;
112*914b7a73SJunchao Zhang     PetscInt s = tid*MBS;
113*914b7a73SJunchao Zhang     for (int i=0; i<MBS; i++) buf[s+i] = data[t+i];
114*914b7a73SJunchao Zhang   });
115*914b7a73SJunchao Zhang   PetscFunctionReturn(0);
116*914b7a73SJunchao Zhang }
117*914b7a73SJunchao Zhang 
118*914b7a73SJunchao Zhang template<typename Type,class Op,PetscInt BS,PetscInt EQ>
119*914b7a73SJunchao Zhang static PetscErrorCode UnpackAndOp(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,void *data_,const void *buf_)
120*914b7a73SJunchao Zhang {
121*914b7a73SJunchao Zhang   Op                      op;
122*914b7a73SJunchao Zhang   const PetscInt          *iopt = opt ? opt->array : NULL;
123*914b7a73SJunchao Zhang   const PetscInt          M = EQ ? 1 : link->bs/BS, MBS=M*BS;
124*914b7a73SJunchao Zhang   Type                    *data = static_cast<Type*>(data_);
125*914b7a73SJunchao Zhang   const Type              *buf = static_cast<const Type*>(buf_);
126*914b7a73SJunchao Zhang   DeviceExecutionSpace&   exec = *static_cast<DeviceExecutionSpace*>(link->sptr);
127*914b7a73SJunchao Zhang 
128*914b7a73SJunchao Zhang   PetscFunctionBegin;
129*914b7a73SJunchao Zhang   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
130*914b7a73SJunchao Zhang     PetscInt t = (iopt? MapTidToIndex(iopt,tid) : (idx? idx[tid] : start+tid))*MBS;
131*914b7a73SJunchao Zhang     PetscInt s = tid*MBS;
132*914b7a73SJunchao Zhang     for (int i=0; i<MBS; i++) op(data[t+i],buf[s+i]);
133*914b7a73SJunchao Zhang   });
134*914b7a73SJunchao Zhang   PetscFunctionReturn(0);
135*914b7a73SJunchao Zhang }
136*914b7a73SJunchao Zhang 
137*914b7a73SJunchao Zhang template<typename Type,class Op,PetscInt BS,PetscInt EQ>
138*914b7a73SJunchao Zhang static PetscErrorCode FetchAndOp(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,void *data,void *buf)
139*914b7a73SJunchao Zhang {
140*914b7a73SJunchao Zhang   Op                      op;
141*914b7a73SJunchao Zhang   const PetscInt          *ropt = opt ? opt->array : NULL;
142*914b7a73SJunchao Zhang   const PetscInt          M = EQ ? 1 : link->bs/BS, MBS=M*BS;
143*914b7a73SJunchao Zhang   Type                    *rootdata = static_cast<Type*>(data),*leafbuf=static_cast<Type*>(buf);
144*914b7a73SJunchao Zhang   DeviceExecutionSpace&   exec = *static_cast<DeviceExecutionSpace*>(link->sptr);
145*914b7a73SJunchao Zhang 
146*914b7a73SJunchao Zhang   PetscFunctionBegin;
147*914b7a73SJunchao Zhang   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
148*914b7a73SJunchao Zhang     PetscInt r = (ropt? MapTidToIndex(ropt,tid) : (idx? idx[tid] : start+tid))*MBS;
149*914b7a73SJunchao Zhang     PetscInt l = tid*MBS;
150*914b7a73SJunchao Zhang     for (int i=0; i<MBS; i++) leafbuf[l+i] = op(rootdata[r+i],leafbuf[l+i]);
151*914b7a73SJunchao Zhang   });
152*914b7a73SJunchao Zhang   PetscFunctionReturn(0);
153*914b7a73SJunchao Zhang }
154*914b7a73SJunchao Zhang 
155*914b7a73SJunchao Zhang template<typename Type,class Op,PetscInt BS,PetscInt EQ>
156*914b7a73SJunchao Zhang static PetscErrorCode ScatterAndOp(PetscSFLink link,PetscInt count,PetscInt srcStart,PetscSFPackOpt srcOpt,const PetscInt *srcIdx,const void *src_,PetscInt dstStart,PetscSFPackOpt dstOpt,const PetscInt *dstIdx,void *dst_)
157*914b7a73SJunchao Zhang {
158*914b7a73SJunchao Zhang   PetscInt                srcx=0,srcy=0,srcX=0,srcY=0,dstx=0,dsty=0,dstX=0,dstY=0;
159*914b7a73SJunchao Zhang   const PetscInt          M = (EQ) ? 1 : link->bs/BS, MBS=M*BS;
160*914b7a73SJunchao Zhang   const Type              *src = static_cast<const Type*>(src_);
161*914b7a73SJunchao Zhang   Type                    *dst = static_cast<Type*>(dst_);
162*914b7a73SJunchao Zhang   DeviceExecutionSpace&   exec = *static_cast<DeviceExecutionSpace*>(link->sptr);
163*914b7a73SJunchao Zhang 
164*914b7a73SJunchao Zhang   PetscFunctionBegin;
165*914b7a73SJunchao Zhang   /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */
166*914b7a73SJunchao Zhang   if (srcOpt)       {srcx = srcOpt->dx[0]; srcy = srcOpt->dy[0]; srcX = srcOpt->X[0]; srcY = srcOpt->Y[0]; srcStart = srcOpt->start[0]; srcIdx = NULL;}
167*914b7a73SJunchao Zhang   else if (!srcIdx) {srcx = srcX = count; srcy = srcY = 1;}
168*914b7a73SJunchao Zhang 
169*914b7a73SJunchao Zhang   if (dstOpt)       {dstx = dstOpt->dx[0]; dsty = dstOpt->dy[0]; dstX = dstOpt->X[0]; dstY = dstOpt->Y[0]; dstStart = dstOpt->start[0]; dstIdx = NULL;}
170*914b7a73SJunchao Zhang   else if (!dstIdx) {dstx = dstX = count; dsty = dstY = 1;}
171*914b7a73SJunchao Zhang 
172*914b7a73SJunchao Zhang   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
173*914b7a73SJunchao Zhang     PetscInt i,j,k,s,t;
174*914b7a73SJunchao Zhang     Op       op;
175*914b7a73SJunchao Zhang     if (!srcIdx) { /* src is in 3D */
176*914b7a73SJunchao Zhang       k = tid/(srcx*srcy);
177*914b7a73SJunchao Zhang       j = (tid - k*srcx*srcy)/srcx;
178*914b7a73SJunchao Zhang       i = tid - k*srcx*srcy - j*srcx;
179*914b7a73SJunchao Zhang       s = srcStart + k*srcX*srcY + j*srcX + i;
180*914b7a73SJunchao Zhang     } else { /* src is contiguous */
181*914b7a73SJunchao Zhang       s = srcIdx[tid];
182*914b7a73SJunchao Zhang     }
183*914b7a73SJunchao Zhang 
184*914b7a73SJunchao Zhang     if (!dstIdx) { /* 3D */
185*914b7a73SJunchao Zhang       k = tid/(dstx*dsty);
186*914b7a73SJunchao Zhang       j = (tid - k*dstx*dsty)/dstx;
187*914b7a73SJunchao Zhang       i = tid - k*dstx*dsty - j*dstx;
188*914b7a73SJunchao Zhang       t = dstStart + k*dstX*dstY + j*dstX + i;
189*914b7a73SJunchao Zhang     } else { /* contiguous */
190*914b7a73SJunchao Zhang       t = dstIdx[tid];
191*914b7a73SJunchao Zhang     }
192*914b7a73SJunchao Zhang 
193*914b7a73SJunchao Zhang     s *= MBS;
194*914b7a73SJunchao Zhang     t *= MBS;
195*914b7a73SJunchao Zhang     for (i=0; i<MBS; i++) op(dst[t+i],src[s+i]);
196*914b7a73SJunchao Zhang   });
197*914b7a73SJunchao Zhang   PetscFunctionReturn(0);
198*914b7a73SJunchao Zhang }
199*914b7a73SJunchao Zhang 
200*914b7a73SJunchao Zhang /* Specialization for Insert since we may use memcpy */
201*914b7a73SJunchao Zhang template<typename Type,PetscInt BS,PetscInt EQ>
202*914b7a73SJunchao Zhang static PetscErrorCode ScatterAndInsert(PetscSFLink link,PetscInt count,PetscInt srcStart,PetscSFPackOpt srcOpt,const PetscInt *srcIdx,const void *src_,PetscInt dstStart,PetscSFPackOpt dstOpt,const PetscInt *dstIdx,void *dst_)
203*914b7a73SJunchao Zhang {
204*914b7a73SJunchao Zhang   PetscErrorCode          ierr;
205*914b7a73SJunchao Zhang   const Type              *src = static_cast<const Type*>(src_);
206*914b7a73SJunchao Zhang   Type                    *dst = static_cast<Type*>(dst_);
207*914b7a73SJunchao Zhang   DeviceExecutionSpace&   exec = *static_cast<DeviceExecutionSpace*>(link->sptr);
208*914b7a73SJunchao Zhang 
209*914b7a73SJunchao Zhang   PetscFunctionBegin;
210*914b7a73SJunchao Zhang   if (!count) PetscFunctionReturn(0);
211*914b7a73SJunchao Zhang   /*src and dst are contiguous */
212*914b7a73SJunchao Zhang   if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
213*914b7a73SJunchao Zhang     size_t sz = count*link->unitbytes;
214*914b7a73SJunchao Zhang     deviceBuffer_t      dbuf(reinterpret_cast<char*>(dst+dstStart*link->bs),sz);
215*914b7a73SJunchao Zhang     deviceConstBuffer_t sbuf(reinterpret_cast<const char*>(src+srcStart*link->bs),sz);
216*914b7a73SJunchao Zhang     Kokkos::deep_copy(exec,dbuf,sbuf);
217*914b7a73SJunchao Zhang   } else {
218*914b7a73SJunchao Zhang     ierr = ScatterAndOp<Type,Insert<Type>,BS,EQ>(link,count,srcStart,srcOpt,srcIdx,src,dstStart,dstOpt,dstIdx,dst);CHKERRQ(ierr);
219*914b7a73SJunchao Zhang   }
220*914b7a73SJunchao Zhang   PetscFunctionReturn(0);
221*914b7a73SJunchao Zhang }
222*914b7a73SJunchao Zhang 
223*914b7a73SJunchao Zhang template<typename Type,class Op,PetscInt BS,PetscInt EQ>
224*914b7a73SJunchao Zhang static PetscErrorCode FetchAndOpLocal(PetscSFLink link,PetscInt count,PetscInt rootstart,PetscSFPackOpt rootopt,const PetscInt *rootidx,void *rootdata_,PetscInt leafstart,PetscSFPackOpt leafopt,const PetscInt *leafidx,const void *leafdata_,void *leafupdate_)
225*914b7a73SJunchao Zhang {
226*914b7a73SJunchao Zhang   Op                      op;
227*914b7a73SJunchao Zhang   const PetscInt          M = (EQ) ? 1 : link->bs/BS, MBS = M*BS;
228*914b7a73SJunchao Zhang   const PetscInt          *ropt = rootopt ? rootopt->array : NULL;
229*914b7a73SJunchao Zhang   const PetscInt          *lopt = leafopt ? leafopt->array : NULL;
230*914b7a73SJunchao Zhang   Type                    *rootdata = static_cast<Type*>(rootdata_),*leafupdate = static_cast<Type*>(leafupdate_);
231*914b7a73SJunchao Zhang   const Type              *leafdata = static_cast<const Type*>(leafdata_);
232*914b7a73SJunchao Zhang   DeviceExecutionSpace&   exec = *static_cast<DeviceExecutionSpace*>(link->sptr);
233*914b7a73SJunchao Zhang 
234*914b7a73SJunchao Zhang   PetscFunctionBegin;
235*914b7a73SJunchao Zhang   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
236*914b7a73SJunchao Zhang     PetscInt r = (ropt? MapTidToIndex(ropt,tid) : (rootidx? rootidx[tid] : rootstart+tid))*MBS;
237*914b7a73SJunchao Zhang     PetscInt l = (lopt? MapTidToIndex(lopt,tid) : (leafidx? leafidx[tid] : leafstart+tid))*MBS;
238*914b7a73SJunchao Zhang     for (int i=0; i<MBS; i++) leafupdate[l+i] = op(rootdata[r+i],leafdata[l+i]);
239*914b7a73SJunchao Zhang   });
240*914b7a73SJunchao Zhang   PetscFunctionReturn(0);
241*914b7a73SJunchao Zhang }
242*914b7a73SJunchao Zhang 
243*914b7a73SJunchao Zhang /*====================================================================================*/
244*914b7a73SJunchao Zhang /*  Init various types and instantiate pack/unpack function pointers                  */
245*914b7a73SJunchao Zhang /*====================================================================================*/
246*914b7a73SJunchao Zhang template<typename Type,PetscInt BS,PetscInt EQ>
247*914b7a73SJunchao Zhang static void PackInit_RealType(PetscSFLink link)
248*914b7a73SJunchao Zhang {
249*914b7a73SJunchao Zhang   /* Pack/unpack for remote communication */
250*914b7a73SJunchao Zhang   link->d_Pack              = Pack<Type,BS,EQ>;
251*914b7a73SJunchao Zhang   link->d_UnpackAndInsert   = UnpackAndOp<Type,Insert<Type>,BS,EQ>;
252*914b7a73SJunchao Zhang   link->d_UnpackAndAdd      = UnpackAndOp<Type,Add<Type>   ,BS,EQ>;
253*914b7a73SJunchao Zhang   link->d_UnpackAndMult     = UnpackAndOp<Type,Mult<Type>  ,BS,EQ>;
254*914b7a73SJunchao Zhang   link->d_UnpackAndMin      = UnpackAndOp<Type,Min<Type>   ,BS,EQ>;
255*914b7a73SJunchao Zhang   link->d_UnpackAndMax      = UnpackAndOp<Type,Max<Type>   ,BS,EQ>;
256*914b7a73SJunchao Zhang   link->d_FetchAndAdd       = FetchAndOp <Type,Add<Type>   ,BS,EQ>;
257*914b7a73SJunchao Zhang   /* Scatter for local communication */
258*914b7a73SJunchao Zhang   link->d_ScatterAndInsert  = ScatterAndInsert<Type,BS,EQ>; /* Has special optimizations */
259*914b7a73SJunchao Zhang   link->d_ScatterAndAdd     = ScatterAndOp<Type,Add<Type>    ,BS,EQ>;
260*914b7a73SJunchao Zhang   link->d_ScatterAndMult    = ScatterAndOp<Type,Mult<Type>   ,BS,EQ>;
261*914b7a73SJunchao Zhang   link->d_ScatterAndMin     = ScatterAndOp<Type,Min<Type>    ,BS,EQ>;
262*914b7a73SJunchao Zhang   link->d_ScatterAndMax     = ScatterAndOp<Type,Max<Type>    ,BS,EQ>;
263*914b7a73SJunchao Zhang   link->d_FetchAndAddLocal  = FetchAndOpLocal<Type,Add <Type>,BS,EQ>;
264*914b7a73SJunchao Zhang   /* Atomic versions when there are data-race possibilities */
265*914b7a73SJunchao Zhang   link->da_UnpackAndInsert  = UnpackAndOp<Type,AtomicInsert<Type>  ,BS,EQ>;
266*914b7a73SJunchao Zhang   link->da_UnpackAndAdd     = UnpackAndOp<Type,AtomicAdd<Type>     ,BS,EQ>;
267*914b7a73SJunchao Zhang   link->da_UnpackAndMult    = UnpackAndOp<Type,AtomicMult<Type>    ,BS,EQ>;
268*914b7a73SJunchao Zhang   link->da_UnpackAndMin     = UnpackAndOp<Type,AtomicMin<Type>     ,BS,EQ>;
269*914b7a73SJunchao Zhang   link->da_UnpackAndMax     = UnpackAndOp<Type,AtomicMax<Type>     ,BS,EQ>;
270*914b7a73SJunchao Zhang   link->da_FetchAndAdd      = FetchAndOp <Type,AtomicFetchAdd<Type>,BS,EQ>;
271*914b7a73SJunchao Zhang 
272*914b7a73SJunchao Zhang   link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>;
273*914b7a73SJunchao Zhang   link->da_ScatterAndAdd    = ScatterAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
274*914b7a73SJunchao Zhang   link->da_ScatterAndMult   = ScatterAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
275*914b7a73SJunchao Zhang   link->da_ScatterAndMin    = ScatterAndOp<Type,AtomicMin<Type>   ,BS,EQ>;
276*914b7a73SJunchao Zhang   link->da_ScatterAndMax    = ScatterAndOp<Type,AtomicMax<Type>   ,BS,EQ>;
277*914b7a73SJunchao Zhang   link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>;
278*914b7a73SJunchao Zhang }
279*914b7a73SJunchao Zhang 
280*914b7a73SJunchao Zhang template<typename Type,PetscInt BS,PetscInt EQ>
281*914b7a73SJunchao Zhang static void PackInit_IntegerType(PetscSFLink link)
282*914b7a73SJunchao Zhang {
283*914b7a73SJunchao Zhang   link->d_Pack              = Pack<Type,BS,EQ>;
284*914b7a73SJunchao Zhang   link->d_UnpackAndInsert   = UnpackAndOp<Type,Insert<Type> ,BS,EQ>;
285*914b7a73SJunchao Zhang   link->d_UnpackAndAdd      = UnpackAndOp<Type,Add<Type>    ,BS,EQ>;
286*914b7a73SJunchao Zhang   link->d_UnpackAndMult     = UnpackAndOp<Type,Mult<Type>   ,BS,EQ>;
287*914b7a73SJunchao Zhang   link->d_UnpackAndMin      = UnpackAndOp<Type,Min<Type>    ,BS,EQ>;
288*914b7a73SJunchao Zhang   link->d_UnpackAndMax      = UnpackAndOp<Type,Max<Type>    ,BS,EQ>;
289*914b7a73SJunchao Zhang   link->d_UnpackAndLAND     = UnpackAndOp<Type,LAND<Type>   ,BS,EQ>;
290*914b7a73SJunchao Zhang   link->d_UnpackAndLOR      = UnpackAndOp<Type,LOR<Type>    ,BS,EQ>;
291*914b7a73SJunchao Zhang   link->d_UnpackAndLXOR     = UnpackAndOp<Type,LXOR<Type>   ,BS,EQ>;
292*914b7a73SJunchao Zhang   link->d_UnpackAndBAND     = UnpackAndOp<Type,BAND<Type>   ,BS,EQ>;
293*914b7a73SJunchao Zhang   link->d_UnpackAndBOR      = UnpackAndOp<Type,BOR<Type>    ,BS,EQ>;
294*914b7a73SJunchao Zhang   link->d_UnpackAndBXOR     = UnpackAndOp<Type,BXOR<Type>   ,BS,EQ>;
295*914b7a73SJunchao Zhang   link->d_FetchAndAdd       = FetchAndOp <Type,Add<Type>    ,BS,EQ>;
296*914b7a73SJunchao Zhang 
297*914b7a73SJunchao Zhang   link->d_ScatterAndInsert  = ScatterAndInsert<Type,BS,EQ>;
298*914b7a73SJunchao Zhang   link->d_ScatterAndAdd     = ScatterAndOp<Type,Add<Type>   ,BS,EQ>;
299*914b7a73SJunchao Zhang   link->d_ScatterAndMult    = ScatterAndOp<Type,Mult<Type>  ,BS,EQ>;
300*914b7a73SJunchao Zhang   link->d_ScatterAndMin     = ScatterAndOp<Type,Min<Type>   ,BS,EQ>;
301*914b7a73SJunchao Zhang   link->d_ScatterAndMax     = ScatterAndOp<Type,Max<Type>   ,BS,EQ>;
302*914b7a73SJunchao Zhang   link->d_ScatterAndLAND    = ScatterAndOp<Type,LAND<Type>  ,BS,EQ>;
303*914b7a73SJunchao Zhang   link->d_ScatterAndLOR     = ScatterAndOp<Type,LOR<Type>   ,BS,EQ>;
304*914b7a73SJunchao Zhang   link->d_ScatterAndLXOR    = ScatterAndOp<Type,LXOR<Type>  ,BS,EQ>;
305*914b7a73SJunchao Zhang   link->d_ScatterAndBAND    = ScatterAndOp<Type,BAND<Type>  ,BS,EQ>;
306*914b7a73SJunchao Zhang   link->d_ScatterAndBOR     = ScatterAndOp<Type,BOR<Type>   ,BS,EQ>;
307*914b7a73SJunchao Zhang   link->d_ScatterAndBXOR    = ScatterAndOp<Type,BXOR<Type>  ,BS,EQ>;
308*914b7a73SJunchao Zhang   link->d_FetchAndAddLocal  = FetchAndOpLocal<Type,Add<Type>,BS,EQ>;
309*914b7a73SJunchao Zhang 
310*914b7a73SJunchao Zhang   link->da_UnpackAndInsert  = UnpackAndOp<Type,AtomicInsert<Type>,BS,EQ>;
311*914b7a73SJunchao Zhang   link->da_UnpackAndAdd     = UnpackAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
312*914b7a73SJunchao Zhang   link->da_UnpackAndMult    = UnpackAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
313*914b7a73SJunchao Zhang   link->da_UnpackAndMin     = UnpackAndOp<Type,AtomicMin<Type>   ,BS,EQ>;
314*914b7a73SJunchao Zhang   link->da_UnpackAndMax     = UnpackAndOp<Type,AtomicMax<Type>   ,BS,EQ>;
315*914b7a73SJunchao Zhang   link->da_UnpackAndLAND    = UnpackAndOp<Type,AtomicLAND<Type>  ,BS,EQ>;
316*914b7a73SJunchao Zhang   link->da_UnpackAndLOR     = UnpackAndOp<Type,AtomicLOR<Type>   ,BS,EQ>;
317*914b7a73SJunchao Zhang   link->da_UnpackAndBAND    = UnpackAndOp<Type,AtomicBAND<Type>  ,BS,EQ>;
318*914b7a73SJunchao Zhang   link->da_UnpackAndBOR     = UnpackAndOp<Type,AtomicBOR<Type>   ,BS,EQ>;
319*914b7a73SJunchao Zhang   link->da_UnpackAndBXOR    = UnpackAndOp<Type,AtomicBXOR<Type>  ,BS,EQ>;
320*914b7a73SJunchao Zhang   link->da_FetchAndAdd      = FetchAndOp <Type,AtomicFetchAdd<Type>,BS,EQ>;
321*914b7a73SJunchao Zhang 
322*914b7a73SJunchao Zhang   link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>;
323*914b7a73SJunchao Zhang   link->da_ScatterAndAdd    = ScatterAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
324*914b7a73SJunchao Zhang   link->da_ScatterAndMult   = ScatterAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
325*914b7a73SJunchao Zhang   link->da_ScatterAndMin    = ScatterAndOp<Type,AtomicMin<Type>   ,BS,EQ>;
326*914b7a73SJunchao Zhang   link->da_ScatterAndMax    = ScatterAndOp<Type,AtomicMax<Type>   ,BS,EQ>;
327*914b7a73SJunchao Zhang   link->da_ScatterAndLAND   = ScatterAndOp<Type,AtomicLAND<Type>  ,BS,EQ>;
328*914b7a73SJunchao Zhang   link->da_ScatterAndLOR    = ScatterAndOp<Type,AtomicLOR<Type>   ,BS,EQ>;
329*914b7a73SJunchao Zhang   link->da_ScatterAndBAND   = ScatterAndOp<Type,AtomicBAND<Type>  ,BS,EQ>;
330*914b7a73SJunchao Zhang   link->da_ScatterAndBOR    = ScatterAndOp<Type,AtomicBOR<Type>   ,BS,EQ>;
331*914b7a73SJunchao Zhang   link->da_ScatterAndBXOR   = ScatterAndOp<Type,AtomicBXOR<Type>  ,BS,EQ>;
332*914b7a73SJunchao Zhang   link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>;
333*914b7a73SJunchao Zhang }
334*914b7a73SJunchao Zhang 
335*914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX)
336*914b7a73SJunchao Zhang template<typename Type,PetscInt BS,PetscInt EQ>
337*914b7a73SJunchao Zhang static void PackInit_ComplexType(PetscSFLink link)
338*914b7a73SJunchao Zhang {
339*914b7a73SJunchao Zhang   link->d_Pack             = Pack<Type,BS,EQ>;
340*914b7a73SJunchao Zhang   link->d_UnpackAndInsert  = UnpackAndOp<Type,Insert<Type>,BS,EQ>;
341*914b7a73SJunchao Zhang   link->d_UnpackAndAdd     = UnpackAndOp<Type,Add<Type>   ,BS,EQ>;
342*914b7a73SJunchao Zhang   link->d_UnpackAndMult    = UnpackAndOp<Type,Mult<Type>  ,BS,EQ>;
343*914b7a73SJunchao Zhang   link->d_FetchAndAdd      = FetchAndOp <Type,Add<Type>   ,BS,EQ>;
344*914b7a73SJunchao Zhang 
345*914b7a73SJunchao Zhang   link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>;
346*914b7a73SJunchao Zhang   link->d_ScatterAndAdd    = ScatterAndOp<Type,Add<Type>   ,BS,EQ>;
347*914b7a73SJunchao Zhang   link->d_ScatterAndMult   = ScatterAndOp<Type,Mult<Type>  ,BS,EQ>;
348*914b7a73SJunchao Zhang   link->d_FetchAndAddLocal = FetchAndOpLocal<Type,Add<Type>,BS,EQ>;
349*914b7a73SJunchao Zhang 
350*914b7a73SJunchao Zhang   link->da_UnpackAndInsert = UnpackAndOp<Type,AtomicInsert<Type> ,BS,EQ>;
351*914b7a73SJunchao Zhang   link->da_UnpackAndAdd    = UnpackAndOp<Type,AtomicAdd<Type>    ,BS,EQ>;
352*914b7a73SJunchao Zhang   link->da_UnpackAndMult   = UnpackAndOp<Type,AtomicMult<Type>   ,BS,EQ>;
353*914b7a73SJunchao Zhang   link->da_FetchAndAdd     = FetchAndOp<Type,AtomicFetchAdd<Type>,BS,EQ>;
354*914b7a73SJunchao Zhang 
355*914b7a73SJunchao Zhang   link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>;
356*914b7a73SJunchao Zhang   link->da_ScatterAndAdd    = ScatterAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
357*914b7a73SJunchao Zhang   link->da_ScatterAndMult   = ScatterAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
358*914b7a73SJunchao Zhang   link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>;
359*914b7a73SJunchao Zhang }
360*914b7a73SJunchao Zhang #endif
361*914b7a73SJunchao Zhang 
362*914b7a73SJunchao Zhang template<typename Type>
363*914b7a73SJunchao Zhang static void PackInit_PairType(PetscSFLink link)
364*914b7a73SJunchao Zhang {
365*914b7a73SJunchao Zhang   link->d_Pack             = Pack<Type,1,1>;
366*914b7a73SJunchao Zhang   link->d_UnpackAndInsert  = UnpackAndOp<Type,Insert<Type>,1,1>;
367*914b7a73SJunchao Zhang   link->d_UnpackAndMaxloc  = UnpackAndOp<Type,Maxloc<Type>,1,1>;
368*914b7a73SJunchao Zhang   link->d_UnpackAndMinloc  = UnpackAndOp<Type,Minloc<Type>,1,1>;
369*914b7a73SJunchao Zhang 
370*914b7a73SJunchao Zhang   link->d_ScatterAndInsert = ScatterAndOp<Type,Insert<Type>,1,1>;
371*914b7a73SJunchao Zhang   link->d_ScatterAndMaxloc = ScatterAndOp<Type,Maxloc<Type>,1,1>;
372*914b7a73SJunchao Zhang   link->d_ScatterAndMinloc = ScatterAndOp<Type,Minloc<Type>,1,1>;
373*914b7a73SJunchao Zhang   /* Atomics for pair types are not implemented yet */
374*914b7a73SJunchao Zhang }
375*914b7a73SJunchao Zhang 
376*914b7a73SJunchao Zhang template<typename Type,PetscInt BS,PetscInt EQ>
377*914b7a73SJunchao Zhang static void PackInit_DumbType(PetscSFLink link)
378*914b7a73SJunchao Zhang {
379*914b7a73SJunchao Zhang   link->d_Pack             = Pack<Type,BS,EQ>;
380*914b7a73SJunchao Zhang   link->d_UnpackAndInsert  = UnpackAndOp<Type,Insert<Type>,BS,EQ>;
381*914b7a73SJunchao Zhang   link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>;
382*914b7a73SJunchao Zhang   /* Atomics for dumb types are not implemented yet */
383*914b7a73SJunchao Zhang }
384*914b7a73SJunchao Zhang 
385*914b7a73SJunchao Zhang static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link)
386*914b7a73SJunchao Zhang {
387*914b7a73SJunchao Zhang   PetscFunctionBegin;
388*914b7a73SJunchao Zhang   delete static_cast<DeviceExecutionSpace*>(link->sptr);
389*914b7a73SJunchao Zhang   PetscFunctionReturn(0);
390*914b7a73SJunchao Zhang }
391*914b7a73SJunchao Zhang 
392*914b7a73SJunchao Zhang /* Some device-specific utilities */
393*914b7a73SJunchao Zhang PetscErrorCode PetscSFLinkSyncDevice(PetscSF sf,PetscSFLink link)
394*914b7a73SJunchao Zhang {
395*914b7a73SJunchao Zhang   PetscFunctionBegin;
396*914b7a73SJunchao Zhang   Kokkos::fence();
397*914b7a73SJunchao Zhang   PetscFunctionReturn(0);
398*914b7a73SJunchao Zhang }
399*914b7a73SJunchao Zhang 
400*914b7a73SJunchao Zhang PetscErrorCode PetscSFLinkSyncStream(PetscSF sf,PetscSFLink link)
401*914b7a73SJunchao Zhang {
402*914b7a73SJunchao Zhang   DeviceExecutionSpace&  exec = *static_cast<DeviceExecutionSpace*>(link->sptr);
403*914b7a73SJunchao Zhang   PetscFunctionBegin;
404*914b7a73SJunchao Zhang   exec.fence();
405*914b7a73SJunchao Zhang   PetscFunctionReturn(0);
406*914b7a73SJunchao Zhang }
407*914b7a73SJunchao Zhang 
408*914b7a73SJunchao Zhang PetscErrorCode PetscSFLinkMemcpy(PetscSF sf,PetscSFLink link,PetscMemType dstmtype,void* dst,PetscMemType srcmtype,const void*src,size_t n)
409*914b7a73SJunchao Zhang {
410*914b7a73SJunchao Zhang   DeviceExecutionSpace&  exec = *static_cast<DeviceExecutionSpace*>(link->sptr);
411*914b7a73SJunchao Zhang 
412*914b7a73SJunchao Zhang   PetscFunctionBegin;
413*914b7a73SJunchao Zhang   if (!n) PetscFunctionReturn(0);
414*914b7a73SJunchao Zhang   if (dstmtype == PETSC_MEMTYPE_HOST && srcmtype == PETSC_MEMTYPE_HOST) {
415*914b7a73SJunchao Zhang     PetscErrorCode ierr = PetscMemcpy(dst,src,n);CHKERRQ(ierr);
416*914b7a73SJunchao Zhang   } else {
417*914b7a73SJunchao Zhang     if (dstmtype == PETSC_MEMTYPE_DEVICE && srcmtype == PETSC_MEMTYPE_HOST) {
418*914b7a73SJunchao Zhang       deviceBuffer_t       dbuf(static_cast<char*>(dst),n);
419*914b7a73SJunchao Zhang       HostConstBuffer_t    sbuf(static_cast<const char*>(src),n);
420*914b7a73SJunchao Zhang       Kokkos::deep_copy(exec,dbuf,sbuf);
421*914b7a73SJunchao Zhang       PetscErrorCode ierr = PetscLogCpuToGpu(n);CHKERRQ(ierr);
422*914b7a73SJunchao Zhang     } else if (dstmtype == PETSC_MEMTYPE_HOST && srcmtype == PETSC_MEMTYPE_DEVICE) {
423*914b7a73SJunchao Zhang       HostBuffer_t         dbuf(static_cast<char*>(dst),n);
424*914b7a73SJunchao Zhang       deviceConstBuffer_t  sbuf(static_cast<const char*>(src),n);
425*914b7a73SJunchao Zhang       Kokkos::deep_copy(exec,dbuf,sbuf);
426*914b7a73SJunchao Zhang       PetscErrorCode ierr = PetscLogGpuToCpu(n);CHKERRQ(ierr);
427*914b7a73SJunchao Zhang     } else if (dstmtype == PETSC_MEMTYPE_DEVICE && srcmtype == PETSC_MEMTYPE_DEVICE) {
428*914b7a73SJunchao Zhang       deviceBuffer_t       dbuf(static_cast<char*>(dst),n);
429*914b7a73SJunchao Zhang       deviceConstBuffer_t  sbuf(static_cast<const char*>(src),n);
430*914b7a73SJunchao Zhang       Kokkos::deep_copy(exec,dbuf,sbuf);
431*914b7a73SJunchao Zhang     }
432*914b7a73SJunchao Zhang   }
433*914b7a73SJunchao Zhang   PetscFunctionReturn(0);
434*914b7a73SJunchao Zhang }
435*914b7a73SJunchao Zhang 
436*914b7a73SJunchao Zhang PetscErrorCode PetscSFMalloc(PetscMemType mtype,size_t size,void** ptr)
437*914b7a73SJunchao Zhang {
438*914b7a73SJunchao Zhang   PetscFunctionBegin;
439*914b7a73SJunchao Zhang   if (mtype == PETSC_MEMTYPE_HOST) {PetscErrorCode ierr = PetscMalloc(size,ptr);CHKERRQ(ierr);}
440*914b7a73SJunchao Zhang   else if (mtype == PETSC_MEMTYPE_DEVICE) {*ptr = Kokkos::kokkos_malloc<DeviceMemorySpace>(size);}
441*914b7a73SJunchao Zhang   else SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"Wrong PetscMemType %d", (int)mtype);
442*914b7a73SJunchao Zhang   PetscFunctionReturn(0);
443*914b7a73SJunchao Zhang }
444*914b7a73SJunchao Zhang 
445*914b7a73SJunchao Zhang PetscErrorCode PetscSFFree_Private(PetscMemType mtype,void* ptr)
446*914b7a73SJunchao Zhang {
447*914b7a73SJunchao Zhang   PetscFunctionBegin;
448*914b7a73SJunchao Zhang   if (mtype == PETSC_MEMTYPE_HOST) {PetscErrorCode ierr = PetscFree(ptr);CHKERRQ(ierr);}
449*914b7a73SJunchao Zhang   else if (mtype == PETSC_MEMTYPE_DEVICE) {Kokkos::kokkos_free<DeviceMemorySpace>(ptr);}
450*914b7a73SJunchao Zhang   else SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"Wrong PetscMemType %d",(int)mtype);
451*914b7a73SJunchao Zhang   PetscFunctionReturn(0);
452*914b7a73SJunchao Zhang }
453*914b7a73SJunchao Zhang 
454*914b7a73SJunchao Zhang /*====================================================================================*/
455*914b7a73SJunchao Zhang /*                Main driver to init MPI datatype on device                          */
456*914b7a73SJunchao Zhang /*====================================================================================*/
457*914b7a73SJunchao Zhang 
458*914b7a73SJunchao Zhang /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
459*914b7a73SJunchao Zhang PetscErrorCode PetscSFLinkSetUp_Device(PetscSF sf,PetscSFLink link,MPI_Datatype unit)
460*914b7a73SJunchao Zhang {
461*914b7a73SJunchao Zhang   PetscErrorCode ierr;
462*914b7a73SJunchao Zhang   PetscInt       nSignedChar=0,nUnsignedChar=0,nInt=0,nPetscInt=0,nPetscReal=0;
463*914b7a73SJunchao Zhang   PetscBool      is2Int,is2PetscInt;
464*914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX)
465*914b7a73SJunchao Zhang   PetscInt       nPetscComplex=0;
466*914b7a73SJunchao Zhang #endif
467*914b7a73SJunchao Zhang 
468*914b7a73SJunchao Zhang   PetscFunctionBegin;
469*914b7a73SJunchao Zhang   if (link->deviceinited) PetscFunctionReturn(0);
470*914b7a73SJunchao Zhang   ierr = MPIPetsc_Type_compare_contig(unit,MPI_SIGNED_CHAR,  &nSignedChar);CHKERRQ(ierr);
471*914b7a73SJunchao Zhang   ierr = MPIPetsc_Type_compare_contig(unit,MPI_UNSIGNED_CHAR,&nUnsignedChar);CHKERRQ(ierr);
472*914b7a73SJunchao Zhang   /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
473*914b7a73SJunchao Zhang   ierr = MPIPetsc_Type_compare_contig(unit,MPI_INT,  &nInt);CHKERRQ(ierr);
474*914b7a73SJunchao Zhang   ierr = MPIPetsc_Type_compare_contig(unit,MPIU_INT, &nPetscInt);CHKERRQ(ierr);
475*914b7a73SJunchao Zhang   ierr = MPIPetsc_Type_compare_contig(unit,MPIU_REAL,&nPetscReal);CHKERRQ(ierr);
476*914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX)
477*914b7a73SJunchao Zhang   ierr = MPIPetsc_Type_compare_contig(unit,MPIU_COMPLEX,&nPetscComplex);CHKERRQ(ierr);
478*914b7a73SJunchao Zhang #endif
479*914b7a73SJunchao Zhang   ierr = MPIPetsc_Type_compare(unit,MPI_2INT,&is2Int);CHKERRQ(ierr);
480*914b7a73SJunchao Zhang   ierr = MPIPetsc_Type_compare(unit,MPIU_2INT,&is2PetscInt);CHKERRQ(ierr);
481*914b7a73SJunchao Zhang 
482*914b7a73SJunchao Zhang   if (is2Int) {
483*914b7a73SJunchao Zhang     PackInit_PairType<Kokkos::pair<int,int>>(link);
484*914b7a73SJunchao Zhang   } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
485*914b7a73SJunchao Zhang     PackInit_PairType<Kokkos::pair<PetscInt,PetscInt>>(link);
486*914b7a73SJunchao Zhang   } else if (nPetscReal) {
487*914b7a73SJunchao Zhang     if      (nPetscReal == 8) PackInit_RealType<PetscReal,8,1>(link); else if (nPetscReal%8 == 0) PackInit_RealType<PetscReal,8,0>(link);
488*914b7a73SJunchao Zhang     else if (nPetscReal == 4) PackInit_RealType<PetscReal,4,1>(link); else if (nPetscReal%4 == 0) PackInit_RealType<PetscReal,4,0>(link);
489*914b7a73SJunchao Zhang     else if (nPetscReal == 2) PackInit_RealType<PetscReal,2,1>(link); else if (nPetscReal%2 == 0) PackInit_RealType<PetscReal,2,0>(link);
490*914b7a73SJunchao Zhang     else if (nPetscReal == 1) PackInit_RealType<PetscReal,1,1>(link); else if (nPetscReal%1 == 0) PackInit_RealType<PetscReal,1,0>(link);
491*914b7a73SJunchao Zhang   } else if (nPetscInt) {
492*914b7a73SJunchao Zhang     if      (nPetscInt == 8) PackInit_IntegerType<PetscInt,8,1>(link); else if (nPetscInt%8 == 0) PackInit_IntegerType<PetscInt,8,0>(link);
493*914b7a73SJunchao Zhang     else if (nPetscInt == 4) PackInit_IntegerType<PetscInt,4,1>(link); else if (nPetscInt%4 == 0) PackInit_IntegerType<PetscInt,4,0>(link);
494*914b7a73SJunchao Zhang     else if (nPetscInt == 2) PackInit_IntegerType<PetscInt,2,1>(link); else if (nPetscInt%2 == 0) PackInit_IntegerType<PetscInt,2,0>(link);
495*914b7a73SJunchao Zhang     else if (nPetscInt == 1) PackInit_IntegerType<PetscInt,1,1>(link); else if (nPetscInt%1 == 0) PackInit_IntegerType<PetscInt,1,0>(link);
496*914b7a73SJunchao Zhang #if defined(PETSC_USE_64BIT_INDICES)
497*914b7a73SJunchao Zhang   } else if (nInt) {
498*914b7a73SJunchao Zhang     if      (nInt == 8) PackInit_IntegerType<int,8,1>(link); else if (nInt%8 == 0) PackInit_IntegerType<int,8,0>(link);
499*914b7a73SJunchao Zhang     else if (nInt == 4) PackInit_IntegerType<int,4,1>(link); else if (nInt%4 == 0) PackInit_IntegerType<int,4,0>(link);
500*914b7a73SJunchao Zhang     else if (nInt == 2) PackInit_IntegerType<int,2,1>(link); else if (nInt%2 == 0) PackInit_IntegerType<int,2,0>(link);
501*914b7a73SJunchao Zhang     else if (nInt == 1) PackInit_IntegerType<int,1,1>(link); else if (nInt%1 == 0) PackInit_IntegerType<int,1,0>(link);
502*914b7a73SJunchao Zhang #endif
503*914b7a73SJunchao Zhang   } else if (nSignedChar) {
504*914b7a73SJunchao Zhang     if      (nSignedChar == 8) PackInit_IntegerType<char,8,1>(link); else if (nSignedChar%8 == 0) PackInit_IntegerType<char,8,0>(link);
505*914b7a73SJunchao Zhang     else if (nSignedChar == 4) PackInit_IntegerType<char,4,1>(link); else if (nSignedChar%4 == 0) PackInit_IntegerType<char,4,0>(link);
506*914b7a73SJunchao Zhang     else if (nSignedChar == 2) PackInit_IntegerType<char,2,1>(link); else if (nSignedChar%2 == 0) PackInit_IntegerType<char,2,0>(link);
507*914b7a73SJunchao Zhang     else if (nSignedChar == 1) PackInit_IntegerType<char,1,1>(link); else if (nSignedChar%1 == 0) PackInit_IntegerType<char,1,0>(link);
508*914b7a73SJunchao Zhang   }  else if (nUnsignedChar) {
509*914b7a73SJunchao Zhang     if      (nUnsignedChar == 8) PackInit_IntegerType<unsigned char,8,1>(link); else if (nUnsignedChar%8 == 0) PackInit_IntegerType<unsigned char,8,0>(link);
510*914b7a73SJunchao Zhang     else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char,4,1>(link); else if (nUnsignedChar%4 == 0) PackInit_IntegerType<unsigned char,4,0>(link);
511*914b7a73SJunchao Zhang     else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char,2,1>(link); else if (nUnsignedChar%2 == 0) PackInit_IntegerType<unsigned char,2,0>(link);
512*914b7a73SJunchao Zhang     else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char,1,1>(link); else if (nUnsignedChar%1 == 0) PackInit_IntegerType<unsigned char,1,0>(link);
513*914b7a73SJunchao Zhang #if defined(PETSC_HAVE_COMPLEX)
514*914b7a73SJunchao Zhang   } else if (nPetscComplex) {
515*914b7a73SJunchao Zhang     if      (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>,8,1>(link); else if (nPetscComplex%8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,8,0>(link);
516*914b7a73SJunchao Zhang     else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>,4,1>(link); else if (nPetscComplex%4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,4,0>(link);
517*914b7a73SJunchao Zhang     else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>,2,1>(link); else if (nPetscComplex%2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,2,0>(link);
518*914b7a73SJunchao Zhang     else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>,1,1>(link); else if (nPetscComplex%1 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,1,0>(link);
519*914b7a73SJunchao Zhang #endif
520*914b7a73SJunchao Zhang   } else {
521*914b7a73SJunchao Zhang     MPI_Aint lb,nbyte;
522*914b7a73SJunchao Zhang     ierr = MPI_Type_get_extent(unit,&lb,&nbyte);CHKERRQ(ierr);
523*914b7a73SJunchao Zhang     if (lb != 0) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_SUP,"Datatype with nonzero lower bound %ld\n",(long)lb);
524*914b7a73SJunchao Zhang     if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
525*914b7a73SJunchao Zhang       if      (nbyte == 4) PackInit_DumbType<char,4,1>(link); else if (nbyte%4 == 0) PackInit_DumbType<char,4,0>(link);
526*914b7a73SJunchao Zhang       else if (nbyte == 2) PackInit_DumbType<char,2,1>(link); else if (nbyte%2 == 0) PackInit_DumbType<char,2,0>(link);
527*914b7a73SJunchao Zhang       else if (nbyte == 1) PackInit_DumbType<char,1,1>(link); else if (nbyte%1 == 0) PackInit_DumbType<char,1,0>(link);
528*914b7a73SJunchao Zhang     } else {
529*914b7a73SJunchao Zhang       nInt = nbyte / sizeof(int);
530*914b7a73SJunchao Zhang       if      (nInt == 8) PackInit_DumbType<int,8,1>(link); else if (nInt%8 == 0) PackInit_DumbType<int,8,0>(link);
531*914b7a73SJunchao Zhang       else if (nInt == 4) PackInit_DumbType<int,4,1>(link); else if (nInt%4 == 0) PackInit_DumbType<int,4,0>(link);
532*914b7a73SJunchao Zhang       else if (nInt == 2) PackInit_DumbType<int,2,1>(link); else if (nInt%2 == 0) PackInit_DumbType<int,2,0>(link);
533*914b7a73SJunchao Zhang       else if (nInt == 1) PackInit_DumbType<int,1,1>(link); else if (nInt%1 == 0) PackInit_DumbType<int,1,0>(link);
534*914b7a73SJunchao Zhang     }
535*914b7a73SJunchao Zhang   }
536*914b7a73SJunchao Zhang 
537*914b7a73SJunchao Zhang #if defined(KOKKOS_ENABLE_CUDA)
538*914b7a73SJunchao Zhang   if (!sf->use_default_stream) {cudaError_t cerr = cudaStreamCreate(&link->stream);CHKERRCUDA(cerr);}
539*914b7a73SJunchao Zhang   link->sptr         = new DeviceExecutionSpace(link->stream);
540*914b7a73SJunchao Zhang #elif defined(KOKKOS_ENABLE_HIP)
541*914b7a73SJunchao Zhang   if (!sf->use_default_stream) {hipError_t cerr = hipStreamCreate(&link->stream);CHKERRQ(cerr);}
542*914b7a73SJunchao Zhang   link->sptr         = new DeviceExecutionSpace(link->stream);
543*914b7a73SJunchao Zhang #endif
544*914b7a73SJunchao Zhang   link->Destroy      = PetscSFLinkDestroy_Kokkos;
545*914b7a73SJunchao Zhang   link->deviceinited = PETSC_TRUE;
546*914b7a73SJunchao Zhang   PetscFunctionReturn(0);
547*914b7a73SJunchao Zhang }
548