xref: /petsc/src/vec/is/sf/impls/basic/nvshmem/sfnvshmem.cu (revision 71438e86cf8dea9f708c08733b37fef8eb68dc06)
1*71438e86SJunchao Zhang #include <petsc/private/cudavecimpl.h>
2*71438e86SJunchao Zhang #include <petsccublas.h>
3*71438e86SJunchao Zhang #include <../src/vec/is/sf/impls/basic/sfpack.h>
4*71438e86SJunchao Zhang #include <mpi.h>
5*71438e86SJunchao Zhang #include <nvshmem.h>
6*71438e86SJunchao Zhang #include <nvshmemx.h>
7*71438e86SJunchao Zhang 
8*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemInitializeCheck(void)
9*71438e86SJunchao Zhang {
10*71438e86SJunchao Zhang   PetscErrorCode   ierr;
11*71438e86SJunchao Zhang 
12*71438e86SJunchao Zhang   PetscFunctionBegin;
13*71438e86SJunchao Zhang   if (!PetscNvshmemInitialized) { /* Note NVSHMEM does not provide a routine to check whether it is initialized */
14*71438e86SJunchao Zhang     nvshmemx_init_attr_t attr;
15*71438e86SJunchao Zhang     attr.mpi_comm = &PETSC_COMM_WORLD;
16*71438e86SJunchao Zhang     ierr = PetscCUDAInitializeCheck();CHKERRQ(ierr);
17*71438e86SJunchao Zhang     ierr = nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM,&attr);CHKERRQ(ierr);
18*71438e86SJunchao Zhang     PetscNvshmemInitialized = PETSC_TRUE;
19*71438e86SJunchao Zhang     PetscBeganNvshmem       = PETSC_TRUE;
20*71438e86SJunchao Zhang   }
21*71438e86SJunchao Zhang   PetscFunctionReturn(0);
22*71438e86SJunchao Zhang }
23*71438e86SJunchao Zhang 
24*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemMalloc(size_t size, void** ptr)
25*71438e86SJunchao Zhang {
26*71438e86SJunchao Zhang   PetscErrorCode ierr;
27*71438e86SJunchao Zhang 
28*71438e86SJunchao Zhang   PetscFunctionBegin;
29*71438e86SJunchao Zhang   ierr = PetscNvshmemInitializeCheck();CHKERRQ(ierr);
30*71438e86SJunchao Zhang   *ptr = nvshmem_malloc(size);
31*71438e86SJunchao Zhang   if (!*ptr) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"nvshmem_malloc() failed to allocate %zu bytes",size);
32*71438e86SJunchao Zhang   PetscFunctionReturn(0);
33*71438e86SJunchao Zhang }
34*71438e86SJunchao Zhang 
35*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemCalloc(size_t size, void**ptr)
36*71438e86SJunchao Zhang {
37*71438e86SJunchao Zhang   PetscErrorCode ierr;
38*71438e86SJunchao Zhang 
39*71438e86SJunchao Zhang   PetscFunctionBegin;
40*71438e86SJunchao Zhang   ierr = PetscNvshmemInitializeCheck();CHKERRQ(ierr);
41*71438e86SJunchao Zhang   *ptr = nvshmem_calloc(size,1);
42*71438e86SJunchao Zhang   if (!*ptr) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"nvshmem_calloc() failed to allocate %zu bytes",size);
43*71438e86SJunchao Zhang   PetscFunctionReturn(0);
44*71438e86SJunchao Zhang }
45*71438e86SJunchao Zhang 
46*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemFree_Private(void* ptr)
47*71438e86SJunchao Zhang {
48*71438e86SJunchao Zhang   PetscFunctionBegin;
49*71438e86SJunchao Zhang   nvshmem_free(ptr);
50*71438e86SJunchao Zhang   PetscFunctionReturn(0);
51*71438e86SJunchao Zhang }
52*71438e86SJunchao Zhang 
53*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemFinalize(void)
54*71438e86SJunchao Zhang {
55*71438e86SJunchao Zhang   PetscFunctionBegin;
56*71438e86SJunchao Zhang   nvshmem_finalize();
57*71438e86SJunchao Zhang   PetscFunctionReturn(0);
58*71438e86SJunchao Zhang }
59*71438e86SJunchao Zhang 
60*71438e86SJunchao Zhang /* Free nvshmem related fields in the SF */
61*71438e86SJunchao Zhang PetscErrorCode PetscSFReset_Basic_NVSHMEM(PetscSF sf)
62*71438e86SJunchao Zhang {
63*71438e86SJunchao Zhang   PetscErrorCode    ierr;
64*71438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
65*71438e86SJunchao Zhang 
66*71438e86SJunchao Zhang   PetscFunctionBegin;
67*71438e86SJunchao Zhang   ierr = PetscFree2(bas->leafsigdisp,bas->leafbufdisp);CHKERRQ(ierr);
68*71438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->leafbufdisp_d);CHKERRQ(ierr);
69*71438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->leafsigdisp_d);CHKERRQ(ierr);
70*71438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->iranks_d);CHKERRQ(ierr);
71*71438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->ioffset_d);CHKERRQ(ierr);
72*71438e86SJunchao Zhang 
73*71438e86SJunchao Zhang   ierr = PetscFree2(sf->rootsigdisp,sf->rootbufdisp);CHKERRQ(ierr);
74*71438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->rootbufdisp_d);CHKERRQ(ierr);
75*71438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->rootsigdisp_d);CHKERRQ(ierr);
76*71438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->ranks_d);CHKERRQ(ierr);
77*71438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->roffset_d);CHKERRQ(ierr);
78*71438e86SJunchao Zhang   PetscFunctionReturn(0);
79*71438e86SJunchao Zhang }
80*71438e86SJunchao Zhang 
81*71438e86SJunchao Zhang /* Set up NVSHMEM related fields for an SF of type SFBASIC (only after PetscSFSetup_Basic() already set up dependant fields */
82*71438e86SJunchao Zhang static PetscErrorCode PetscSFSetUp_Basic_NVSHMEM(PetscSF sf)
83*71438e86SJunchao Zhang {
84*71438e86SJunchao Zhang   PetscErrorCode ierr;
85*71438e86SJunchao Zhang   cudaError_t    cerr;
86*71438e86SJunchao Zhang   PetscSF_Basic  *bas = (PetscSF_Basic*)sf->data;
87*71438e86SJunchao Zhang   PetscInt       i,nRemoteRootRanks,nRemoteLeafRanks;
88*71438e86SJunchao Zhang   PetscMPIInt    tag;
89*71438e86SJunchao Zhang   MPI_Comm       comm;
90*71438e86SJunchao Zhang   MPI_Request    *rootreqs,*leafreqs;
91*71438e86SJunchao Zhang   PetscInt       tmp,stmp[4],rtmp[4]; /* tmps for send/recv buffers */
92*71438e86SJunchao Zhang 
93*71438e86SJunchao Zhang   PetscFunctionBegin;
94*71438e86SJunchao Zhang   ierr = PetscObjectGetComm((PetscObject)sf,&comm);CHKERRQ(ierr);
95*71438e86SJunchao Zhang   ierr = PetscObjectGetNewTag((PetscObject)sf,&tag);CHKERRQ(ierr);
96*71438e86SJunchao Zhang 
97*71438e86SJunchao Zhang   nRemoteRootRanks      = sf->nranks-sf->ndranks;
98*71438e86SJunchao Zhang   nRemoteLeafRanks      = bas->niranks-bas->ndiranks;
99*71438e86SJunchao Zhang   sf->nRemoteRootRanks  = nRemoteRootRanks;
100*71438e86SJunchao Zhang   bas->nRemoteLeafRanks = nRemoteLeafRanks;
101*71438e86SJunchao Zhang 
102*71438e86SJunchao Zhang   ierr = PetscMalloc2(nRemoteLeafRanks,&rootreqs,nRemoteRootRanks,&leafreqs);CHKERRQ(ierr);
103*71438e86SJunchao Zhang 
104*71438e86SJunchao Zhang   stmp[0] = nRemoteRootRanks;
105*71438e86SJunchao Zhang   stmp[1] = sf->leafbuflen[PETSCSF_REMOTE];
106*71438e86SJunchao Zhang   stmp[2] = nRemoteLeafRanks;
107*71438e86SJunchao Zhang   stmp[3] = bas->rootbuflen[PETSCSF_REMOTE];
108*71438e86SJunchao Zhang 
109*71438e86SJunchao Zhang   ierr = MPIU_Allreduce(stmp,rtmp,4,MPIU_INT,MPI_MAX,comm);CHKERRMPI(ierr);
110*71438e86SJunchao Zhang 
111*71438e86SJunchao Zhang   sf->nRemoteRootRanksMax   = rtmp[0];
112*71438e86SJunchao Zhang   sf->leafbuflen_rmax       = rtmp[1];
113*71438e86SJunchao Zhang   bas->nRemoteLeafRanksMax  = rtmp[2];
114*71438e86SJunchao Zhang   bas->rootbuflen_rmax      = rtmp[3];
115*71438e86SJunchao Zhang 
116*71438e86SJunchao Zhang   /* Total four rounds of MPI communications to set up the nvshmem fields */
117*71438e86SJunchao Zhang 
118*71438e86SJunchao Zhang   /* Root ranks to leaf ranks: send info about rootsigdisp[] and rootbufdisp[] */
119*71438e86SJunchao Zhang   ierr = PetscMalloc2(nRemoteRootRanks,&sf->rootsigdisp,nRemoteRootRanks,&sf->rootbufdisp);CHKERRQ(ierr);
120*71438e86SJunchao Zhang   for (i=0; i<nRemoteRootRanks; i++) {ierr = MPI_Irecv(&sf->rootsigdisp[i],1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm,&leafreqs[i]);CHKERRMPI(ierr);} /* Leaves recv */
121*71438e86SJunchao Zhang   for (i=0; i<nRemoteLeafRanks; i++) {ierr = MPI_Send(&i,1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm);CHKERRMPI(ierr);} /* Roots send. Note i changes, so we use MPI_Send. */
122*71438e86SJunchao Zhang   ierr = MPI_Waitall(nRemoteRootRanks,leafreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
123*71438e86SJunchao Zhang 
124*71438e86SJunchao Zhang   for (i=0; i<nRemoteRootRanks; i++) {ierr = MPI_Irecv(&sf->rootbufdisp[i],1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm,&leafreqs[i]);CHKERRMPI(ierr);} /* Leaves recv */
125*71438e86SJunchao Zhang   for (i=0; i<nRemoteLeafRanks; i++) {
126*71438e86SJunchao Zhang     tmp  = bas->ioffset[i+bas->ndiranks] - bas->ioffset[bas->ndiranks];
127*71438e86SJunchao Zhang     ierr = MPI_Send(&tmp,1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm);CHKERRMPI(ierr);  /* Roots send. Note tmp changes, so we use MPI_Send. */
128*71438e86SJunchao Zhang   }
129*71438e86SJunchao Zhang   ierr = MPI_Waitall(nRemoteRootRanks,leafreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
130*71438e86SJunchao Zhang 
131*71438e86SJunchao Zhang   cerr = cudaMalloc((void**)&sf->rootbufdisp_d,nRemoteRootRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
132*71438e86SJunchao Zhang   cerr = cudaMalloc((void**)&sf->rootsigdisp_d,nRemoteRootRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
133*71438e86SJunchao Zhang   cerr = cudaMalloc((void**)&sf->ranks_d,nRemoteRootRanks*sizeof(PetscMPIInt));CHKERRCUDA(cerr);
134*71438e86SJunchao Zhang   cerr = cudaMalloc((void**)&sf->roffset_d,(nRemoteRootRanks+1)*sizeof(PetscInt));CHKERRCUDA(cerr);
135*71438e86SJunchao Zhang 
136*71438e86SJunchao Zhang   cerr = cudaMemcpyAsync(sf->rootbufdisp_d,sf->rootbufdisp,nRemoteRootRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
137*71438e86SJunchao Zhang   cerr = cudaMemcpyAsync(sf->rootsigdisp_d,sf->rootsigdisp,nRemoteRootRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
138*71438e86SJunchao Zhang   cerr = cudaMemcpyAsync(sf->ranks_d,sf->ranks+sf->ndranks,nRemoteRootRanks*sizeof(PetscMPIInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
139*71438e86SJunchao Zhang   cerr = cudaMemcpyAsync(sf->roffset_d,sf->roffset+sf->ndranks,(nRemoteRootRanks+1)*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
140*71438e86SJunchao Zhang 
141*71438e86SJunchao Zhang   /* Leaf ranks to root ranks: send info about leafsigdisp[] and leafbufdisp[] */
142*71438e86SJunchao Zhang   ierr = PetscMalloc2(nRemoteLeafRanks,&bas->leafsigdisp,nRemoteLeafRanks,&bas->leafbufdisp);CHKERRQ(ierr);
143*71438e86SJunchao Zhang   for (i=0; i<nRemoteLeafRanks; i++) {ierr = MPI_Irecv(&bas->leafsigdisp[i],1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm,&rootreqs[i]);CHKERRMPI(ierr);}
144*71438e86SJunchao Zhang   for (i=0; i<nRemoteRootRanks; i++) {ierr = MPI_Send(&i,1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm);CHKERRMPI(ierr);}
145*71438e86SJunchao Zhang   ierr = MPI_Waitall(nRemoteLeafRanks,rootreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
146*71438e86SJunchao Zhang 
147*71438e86SJunchao Zhang   for (i=0; i<nRemoteLeafRanks; i++) {ierr = MPI_Irecv(&bas->leafbufdisp[i],1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm,&rootreqs[i]);CHKERRMPI(ierr);}
148*71438e86SJunchao Zhang   for (i=0; i<nRemoteRootRanks; i++) {
149*71438e86SJunchao Zhang     tmp  = sf->roffset[i+sf->ndranks] - sf->roffset[sf->ndranks];
150*71438e86SJunchao Zhang     ierr = MPI_Send(&tmp,1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm);CHKERRMPI(ierr);
151*71438e86SJunchao Zhang   }
152*71438e86SJunchao Zhang   ierr = MPI_Waitall(nRemoteLeafRanks,rootreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
153*71438e86SJunchao Zhang 
154*71438e86SJunchao Zhang   cerr = cudaMalloc((void**)&bas->leafbufdisp_d,nRemoteLeafRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
155*71438e86SJunchao Zhang   cerr = cudaMalloc((void**)&bas->leafsigdisp_d,nRemoteLeafRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
156*71438e86SJunchao Zhang   cerr = cudaMalloc((void**)&bas->iranks_d,nRemoteLeafRanks*sizeof(PetscMPIInt));CHKERRCUDA(cerr);
157*71438e86SJunchao Zhang   cerr = cudaMalloc((void**)&bas->ioffset_d,(nRemoteLeafRanks+1)*sizeof(PetscInt));CHKERRCUDA(cerr);
158*71438e86SJunchao Zhang 
159*71438e86SJunchao Zhang   cerr = cudaMemcpyAsync(bas->leafbufdisp_d,bas->leafbufdisp,nRemoteLeafRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
160*71438e86SJunchao Zhang   cerr = cudaMemcpyAsync(bas->leafsigdisp_d,bas->leafsigdisp,nRemoteLeafRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
161*71438e86SJunchao Zhang   cerr = cudaMemcpyAsync(bas->iranks_d,bas->iranks+bas->ndiranks,nRemoteLeafRanks*sizeof(PetscMPIInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
162*71438e86SJunchao Zhang   cerr = cudaMemcpyAsync(bas->ioffset_d,bas->ioffset+bas->ndiranks,(nRemoteLeafRanks+1)*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
163*71438e86SJunchao Zhang 
164*71438e86SJunchao Zhang   ierr = PetscFree2(rootreqs,leafreqs);CHKERRQ(ierr);
165*71438e86SJunchao Zhang   PetscFunctionReturn(0);
166*71438e86SJunchao Zhang }
167*71438e86SJunchao Zhang 
168*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkNvshmemCheck(PetscSF sf,PetscMemType rootmtype,const void *rootdata,PetscMemType leafmtype,const void *leafdata,PetscBool *use_nvshmem)
169*71438e86SJunchao Zhang {
170*71438e86SJunchao Zhang   PetscErrorCode   ierr;
171*71438e86SJunchao Zhang   MPI_Comm         comm;
172*71438e86SJunchao Zhang   PetscBool        isBasic;
173*71438e86SJunchao Zhang   PetscMPIInt      result = MPI_UNEQUAL;
174*71438e86SJunchao Zhang 
175*71438e86SJunchao Zhang   PetscFunctionBegin;
176*71438e86SJunchao Zhang   ierr = PetscObjectGetComm((PetscObject)sf,&comm);CHKERRQ(ierr);
177*71438e86SJunchao Zhang   /* Check if the sf is eligible for NVSHMEM, if we have not checked yet.
178*71438e86SJunchao Zhang      Note the check result <use_nvshmem> must be the same over comm, since an SFLink must be collectively either NVSHMEM or MPI.
179*71438e86SJunchao Zhang   */
180*71438e86SJunchao Zhang   sf->checked_nvshmem_eligibility = PETSC_TRUE;
181*71438e86SJunchao Zhang   if (sf->use_nvshmem && !sf->checked_nvshmem_eligibility) {
182*71438e86SJunchao Zhang     /* Only use NVSHMEM for SFBASIC on PETSC_COMM_WORLD  */
183*71438e86SJunchao Zhang     ierr = PetscObjectTypeCompare((PetscObject)sf,PETSCSFBASIC,&isBasic);CHKERRQ(ierr);
184*71438e86SJunchao Zhang     if (isBasic) {ierr = MPI_Comm_compare(PETSC_COMM_WORLD,comm,&result);CHKERRMPI(ierr);}
185*71438e86SJunchao Zhang     if (!isBasic || (result != MPI_IDENT && result != MPI_CONGRUENT)) sf->use_nvshmem = PETSC_FALSE; /* If not eligible, clear the flag so that we don't try again */
186*71438e86SJunchao Zhang 
187*71438e86SJunchao Zhang     /* Do further check: If on a rank, both rootdata and leafdata are NULL, we might think they are PETSC_MEMTYPE_CUDA (or HOST)
188*71438e86SJunchao Zhang        and then use NVSHMEM. But if root/leafmtypes on other ranks are PETSC_MEMTYPE_HOST (or DEVICE), this would lead to
189*71438e86SJunchao Zhang        inconsistency on the return value <use_nvshmem>. To be safe, we simply disable nvshmem on these rare SFs.
190*71438e86SJunchao Zhang     */
191*71438e86SJunchao Zhang     if (sf->use_nvshmem) {
192*71438e86SJunchao Zhang       PetscInt hasNullRank = (!rootdata && !leafdata) ? 1 : 0;
193*71438e86SJunchao Zhang       ierr = MPI_Allreduce(MPI_IN_PLACE,&hasNullRank,1,MPIU_INT,MPI_LOR,comm);CHKERRMPI(ierr);
194*71438e86SJunchao Zhang       if (hasNullRank) sf->use_nvshmem = PETSC_FALSE;
195*71438e86SJunchao Zhang     }
196*71438e86SJunchao Zhang     sf->checked_nvshmem_eligibility = PETSC_TRUE; /* If eligible, don't do above check again */
197*71438e86SJunchao Zhang   }
198*71438e86SJunchao Zhang 
199*71438e86SJunchao Zhang   /* Check if rootmtype and leafmtype collectively are PETSC_MEMTYPE_CUDA */
200*71438e86SJunchao Zhang   if (sf->use_nvshmem) {
201*71438e86SJunchao Zhang     PetscInt oneCuda = (!rootdata || PetscMemTypeCUDA(rootmtype)) && (!leafdata || PetscMemTypeCUDA(leafmtype)) ? 1 : 0; /* Do I use cuda for both root&leafmtype? */
202*71438e86SJunchao Zhang     PetscInt allCuda = oneCuda; /* Assume the same for all ranks. But if not, in opt mode, return value <use_nvshmem> won't be collective! */
203*71438e86SJunchao Zhang    #if defined(PETSC_USE_DEBUG)  /* Check in debug mode. Note MPI_Allreduce is expensive, so only in debug mode */
204*71438e86SJunchao Zhang     ierr = MPI_Allreduce(&oneCuda,&allCuda,1,MPIU_INT,MPI_LAND,comm);CHKERRMPI(ierr);
205*71438e86SJunchao Zhang     if (allCuda != oneCuda) SETERRQ(comm,PETSC_ERR_SUP,"root/leaf mtypes are inconsistent among ranks, which may lead to SF nvshmem failure in opt mode. Add -use_nvshmem 0 to disable it.");
206*71438e86SJunchao Zhang    #endif
207*71438e86SJunchao Zhang     if (allCuda) {
208*71438e86SJunchao Zhang       ierr = PetscNvshmemInitializeCheck();CHKERRQ(ierr);
209*71438e86SJunchao Zhang       if (!sf->setup_nvshmem) { /* Set up nvshmem related fields on this SF on-demand */
210*71438e86SJunchao Zhang         ierr = PetscSFSetUp_Basic_NVSHMEM(sf);CHKERRQ(ierr);
211*71438e86SJunchao Zhang         sf->setup_nvshmem = PETSC_TRUE;
212*71438e86SJunchao Zhang       }
213*71438e86SJunchao Zhang       *use_nvshmem = PETSC_TRUE;
214*71438e86SJunchao Zhang     } else {
215*71438e86SJunchao Zhang       *use_nvshmem = PETSC_FALSE;
216*71438e86SJunchao Zhang     }
217*71438e86SJunchao Zhang   } else {
218*71438e86SJunchao Zhang     *use_nvshmem = PETSC_FALSE;
219*71438e86SJunchao Zhang   }
220*71438e86SJunchao Zhang   PetscFunctionReturn(0);
221*71438e86SJunchao Zhang }
222*71438e86SJunchao Zhang 
223*71438e86SJunchao Zhang /* Build dependence between <stream> and <remoteCommStream> at the entry of NVSHMEM communication */
224*71438e86SJunchao Zhang static PetscErrorCode PetscSFLinkBuildDependenceBegin(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
225*71438e86SJunchao Zhang {
226*71438e86SJunchao Zhang   cudaError_t    cerr;
227*71438e86SJunchao Zhang   PetscSF_Basic  *bas = (PetscSF_Basic *)sf->data;
228*71438e86SJunchao Zhang   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF)? bas->rootbuflen[PETSCSF_REMOTE] : sf->leafbuflen[PETSCSF_REMOTE];
229*71438e86SJunchao Zhang 
230*71438e86SJunchao Zhang   PetscFunctionBegin;
231*71438e86SJunchao Zhang   if (buflen) {
232*71438e86SJunchao Zhang     cerr = cudaEventRecord(link->dataReady,link->stream);CHKERRCUDA(cerr);
233*71438e86SJunchao Zhang     cerr = cudaStreamWaitEvent(link->remoteCommStream,link->dataReady,0);CHKERRCUDA(cerr);
234*71438e86SJunchao Zhang   }
235*71438e86SJunchao Zhang   PetscFunctionReturn(0);
236*71438e86SJunchao Zhang }
237*71438e86SJunchao Zhang 
238*71438e86SJunchao Zhang /* Build dependence between <stream> and <remoteCommStream> at the exit of NVSHMEM communication */
239*71438e86SJunchao Zhang static PetscErrorCode PetscSFLinkBuildDependenceEnd(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
240*71438e86SJunchao Zhang {
241*71438e86SJunchao Zhang   cudaError_t    cerr;
242*71438e86SJunchao Zhang   PetscSF_Basic  *bas = (PetscSF_Basic *)sf->data;
243*71438e86SJunchao Zhang   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF)? sf->leafbuflen[PETSCSF_REMOTE] : bas->rootbuflen[PETSCSF_REMOTE];
244*71438e86SJunchao Zhang 
245*71438e86SJunchao Zhang   PetscFunctionBegin;
246*71438e86SJunchao Zhang   /* If unpack to non-null device buffer, build the endRemoteComm dependance */
247*71438e86SJunchao Zhang   if (buflen) {
248*71438e86SJunchao Zhang     cerr = cudaEventRecord(link->endRemoteComm,link->remoteCommStream);CHKERRCUDA(cerr);
249*71438e86SJunchao Zhang     cerr = cudaStreamWaitEvent(link->stream,link->endRemoteComm,0);CHKERRCUDA(cerr);
250*71438e86SJunchao Zhang   }
251*71438e86SJunchao Zhang   PetscFunctionReturn(0);
252*71438e86SJunchao Zhang }
253*71438e86SJunchao Zhang 
254*71438e86SJunchao Zhang /* Send/Put signals to remote ranks
255*71438e86SJunchao Zhang 
256*71438e86SJunchao Zhang  Input parameters:
257*71438e86SJunchao Zhang   + n        - Number of remote ranks
258*71438e86SJunchao Zhang   . sig      - Signal address in symmetric heap
259*71438e86SJunchao Zhang   . sigdisp  - To i-th rank, use its signal at offset sigdisp[i]
260*71438e86SJunchao Zhang   . ranks    - remote ranks
261*71438e86SJunchao Zhang   - newval   - Set signals to this value
262*71438e86SJunchao Zhang */
263*71438e86SJunchao Zhang __global__ static void NvshmemSendSignals(PetscInt n,uint64_t *sig,PetscInt *sigdisp,PetscMPIInt *ranks,uint64_t newval)
264*71438e86SJunchao Zhang {
265*71438e86SJunchao Zhang   int i = blockIdx.x*blockDim.x + threadIdx.x;
266*71438e86SJunchao Zhang 
267*71438e86SJunchao Zhang   /* Each thread puts one remote signal */
268*71438e86SJunchao Zhang   if (i < n) nvshmemx_uint64_signal(sig+sigdisp[i],newval,ranks[i]);
269*71438e86SJunchao Zhang }
270*71438e86SJunchao Zhang 
271*71438e86SJunchao Zhang /* Wait until local signals equal to the expected value and then set them to a new value
272*71438e86SJunchao Zhang 
273*71438e86SJunchao Zhang  Input parameters:
274*71438e86SJunchao Zhang   + n        - Number of signals
275*71438e86SJunchao Zhang   . sig      - Local signal address
276*71438e86SJunchao Zhang   . expval   - expected value
277*71438e86SJunchao Zhang   - newval   - Set signals to this new value
278*71438e86SJunchao Zhang */
279*71438e86SJunchao Zhang __global__ static void NvshmemWaitSignals(PetscInt n,uint64_t *sig,uint64_t expval,uint64_t newval)
280*71438e86SJunchao Zhang {
281*71438e86SJunchao Zhang #if 0
282*71438e86SJunchao Zhang   /* Akhil Langer@NVIDIA said using 1 thread and nvshmem_uint64_wait_until_all is better */
283*71438e86SJunchao Zhang   int i = blockIdx.x*blockDim.x + threadIdx.x;
284*71438e86SJunchao Zhang   if (i < n) {
285*71438e86SJunchao Zhang     nvshmem_signal_wait_until(sig+i,NVSHMEM_CMP_EQ,expval);
286*71438e86SJunchao Zhang     sig[i] = newval;
287*71438e86SJunchao Zhang   }
288*71438e86SJunchao Zhang #else
289*71438e86SJunchao Zhang   nvshmem_uint64_wait_until_all(sig,n,NULL/*no mask*/,NVSHMEM_CMP_EQ,expval);
290*71438e86SJunchao Zhang   for (int i=0; i<n; i++) sig[i] = newval;
291*71438e86SJunchao Zhang #endif
292*71438e86SJunchao Zhang }
293*71438e86SJunchao Zhang 
294*71438e86SJunchao Zhang /* ===========================================================================================================
295*71438e86SJunchao Zhang 
296*71438e86SJunchao Zhang    A set of routines to support receiver initiated communication using the get method
297*71438e86SJunchao Zhang 
298*71438e86SJunchao Zhang     The getting protocol is:
299*71438e86SJunchao Zhang 
300*71438e86SJunchao Zhang     Sender has a send buf (sbuf) and a signal variable (ssig);  Receiver has a recv buf (rbuf) and a signal variable (rsig);
301*71438e86SJunchao Zhang     All signal variables have an initial value 0.
302*71438e86SJunchao Zhang 
303*71438e86SJunchao Zhang     Sender:                                 |  Receiver:
304*71438e86SJunchao Zhang   1.  Wait ssig be 0, then set it to 1
305*71438e86SJunchao Zhang   2.  Pack data into stand alone sbuf       |
306*71438e86SJunchao Zhang   3.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
307*71438e86SJunchao Zhang                                             |   2. Get data from remote sbuf to local rbuf
308*71438e86SJunchao Zhang                                             |   3. Put 1 to sender's ssig
309*71438e86SJunchao Zhang                                             |   4. Unpack data from local rbuf
310*71438e86SJunchao Zhang    ===========================================================================================================*/
311*71438e86SJunchao Zhang /* PrePack operation -- since sender will overwrite the send buffer which the receiver might be getting data from.
312*71438e86SJunchao Zhang    Sender waits for signals (from receivers) indicating receivers have finished getting data
313*71438e86SJunchao Zhang */
314*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
315*71438e86SJunchao Zhang {
316*71438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
317*71438e86SJunchao Zhang   uint64_t          *sig;
318*71438e86SJunchao Zhang   PetscInt          n;
319*71438e86SJunchao Zhang 
320*71438e86SJunchao Zhang   PetscFunctionBegin;
321*71438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
322*71438e86SJunchao Zhang     sig = link->rootSendSig;            /* leaf ranks set my rootSendsig */
323*71438e86SJunchao Zhang     n   = bas->nRemoteLeafRanks;
324*71438e86SJunchao Zhang   } else { /* LEAF2ROOT */
325*71438e86SJunchao Zhang     sig = link->leafSendSig;
326*71438e86SJunchao Zhang     n   = sf->nRemoteRootRanks;
327*71438e86SJunchao Zhang   }
328*71438e86SJunchao Zhang 
329*71438e86SJunchao Zhang   if (n) {
330*71438e86SJunchao Zhang     NvshmemWaitSignals<<<1,1,0,link->remoteCommStream>>>(n,sig,0,1); /* wait the signals to be 0, then set them to 1 */
331*71438e86SJunchao Zhang     cudaError_t cerr = cudaGetLastError();CHKERRCUDA(cerr);
332*71438e86SJunchao Zhang   }
333*71438e86SJunchao Zhang   PetscFunctionReturn(0);
334*71438e86SJunchao Zhang }
335*71438e86SJunchao Zhang 
336*71438e86SJunchao Zhang /* n thread blocks. Each takes in charge one remote rank */
337*71438e86SJunchao Zhang __global__ static void GetDataFromRemotelyAccessible(PetscInt nsrcranks,PetscMPIInt *srcranks,const char *src,PetscInt *srcdisp,char *dst,PetscInt *dstdisp,PetscInt unitbytes)
338*71438e86SJunchao Zhang {
339*71438e86SJunchao Zhang   int               bid = blockIdx.x;
340*71438e86SJunchao Zhang   PetscMPIInt       pe  = srcranks[bid];
341*71438e86SJunchao Zhang 
342*71438e86SJunchao Zhang   if (!nvshmem_ptr(src,pe)) {
343*71438e86SJunchao Zhang     PetscInt nelems = (dstdisp[bid+1]-dstdisp[bid])*unitbytes;
344*71438e86SJunchao Zhang     nvshmem_getmem_nbi(dst+(dstdisp[bid]-dstdisp[0])*unitbytes,src+srcdisp[bid]*unitbytes,nelems,pe);
345*71438e86SJunchao Zhang   }
346*71438e86SJunchao Zhang }
347*71438e86SJunchao Zhang 
348*71438e86SJunchao Zhang /* Start communication -- Get data in the given direction */
349*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkGetDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
350*71438e86SJunchao Zhang {
351*71438e86SJunchao Zhang   PetscErrorCode    ierr;
352*71438e86SJunchao Zhang   cudaError_t       cerr;
353*71438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
354*71438e86SJunchao Zhang 
355*71438e86SJunchao Zhang   PetscInt          nsrcranks,ndstranks,nLocallyAccessible = 0;
356*71438e86SJunchao Zhang 
357*71438e86SJunchao Zhang   char              *src,*dst;
358*71438e86SJunchao Zhang   PetscInt          *srcdisp_h,*dstdisp_h;
359*71438e86SJunchao Zhang   PetscInt          *srcdisp_d,*dstdisp_d;
360*71438e86SJunchao Zhang   PetscMPIInt       *srcranks_h;
361*71438e86SJunchao Zhang   PetscMPIInt       *srcranks_d,*dstranks_d;
362*71438e86SJunchao Zhang   uint64_t          *dstsig;
363*71438e86SJunchao Zhang   PetscInt          *dstsigdisp_d;
364*71438e86SJunchao Zhang 
365*71438e86SJunchao Zhang   PetscFunctionBegin;
366*71438e86SJunchao Zhang   ierr = PetscSFLinkBuildDependenceBegin(sf,link,direction);CHKERRQ(ierr);
367*71438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* src is root, dst is leaf; we will move data from src to dst */
368*71438e86SJunchao Zhang     nsrcranks    = sf->nRemoteRootRanks;
369*71438e86SJunchao Zhang     src          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* root buf is the send buf; it is in symmetric heap */
370*71438e86SJunchao Zhang 
371*71438e86SJunchao Zhang     srcdisp_h    = sf->rootbufdisp;       /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
372*71438e86SJunchao Zhang     srcdisp_d    = sf->rootbufdisp_d;
373*71438e86SJunchao Zhang     srcranks_h   = sf->ranks+sf->ndranks; /* my (remote) root ranks */
374*71438e86SJunchao Zhang     srcranks_d   = sf->ranks_d;
375*71438e86SJunchao Zhang 
376*71438e86SJunchao Zhang     ndstranks    = bas->nRemoteLeafRanks;
377*71438e86SJunchao Zhang     dst          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* recv buf is the local leaf buf, also in symmetric heap */
378*71438e86SJunchao Zhang 
379*71438e86SJunchao Zhang     dstdisp_h    = sf->roffset+sf->ndranks; /* offsets of the local leaf buf. Note dstdisp[0] is not necessarily 0 */
380*71438e86SJunchao Zhang     dstdisp_d    = sf->roffset_d;
381*71438e86SJunchao Zhang     dstranks_d   = bas->iranks_d; /* my (remote) leaf ranks */
382*71438e86SJunchao Zhang 
383*71438e86SJunchao Zhang     dstsig       = link->leafRecvSig;
384*71438e86SJunchao Zhang     dstsigdisp_d = bas->leafsigdisp_d;
385*71438e86SJunchao Zhang   } else { /* src is leaf, dst is root; we will move data from src to dst */
386*71438e86SJunchao Zhang     nsrcranks    = bas->nRemoteLeafRanks;
387*71438e86SJunchao Zhang     src          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* leaf buf is the send buf */
388*71438e86SJunchao Zhang 
389*71438e86SJunchao Zhang     srcdisp_h    = bas->leafbufdisp;       /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
390*71438e86SJunchao Zhang     srcdisp_d    = bas->leafbufdisp_d;
391*71438e86SJunchao Zhang     srcranks_h   = bas->iranks+bas->ndiranks; /* my (remote) root ranks */
392*71438e86SJunchao Zhang     srcranks_d   = bas->iranks_d;
393*71438e86SJunchao Zhang 
394*71438e86SJunchao Zhang     ndstranks    = sf->nRemoteRootRanks;
395*71438e86SJunchao Zhang     dst          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* the local root buf is the recv buf */
396*71438e86SJunchao Zhang 
397*71438e86SJunchao Zhang     dstdisp_h    = bas->ioffset+bas->ndiranks; /* offsets of the local root buf. Note dstdisp[0] is not necessarily 0 */
398*71438e86SJunchao Zhang     dstdisp_d    = bas->ioffset_d;
399*71438e86SJunchao Zhang     dstranks_d   = sf->ranks_d; /* my (remote) root ranks */
400*71438e86SJunchao Zhang 
401*71438e86SJunchao Zhang     dstsig       = link->rootRecvSig;
402*71438e86SJunchao Zhang     dstsigdisp_d = sf->rootsigdisp_d;
403*71438e86SJunchao Zhang   }
404*71438e86SJunchao Zhang 
405*71438e86SJunchao Zhang   /* After Pack operation -- src tells dst ranks that they are allowed to get data */
406*71438e86SJunchao Zhang   if (ndstranks) {
407*71438e86SJunchao Zhang     NvshmemSendSignals<<<(ndstranks+255)/256,256,0,link->remoteCommStream>>>(ndstranks,dstsig,dstsigdisp_d,dstranks_d,1); /* set signals to 1 */
408*71438e86SJunchao Zhang     cerr = cudaGetLastError();CHKERRCUDA(cerr);
409*71438e86SJunchao Zhang   }
410*71438e86SJunchao Zhang 
411*71438e86SJunchao Zhang   /* dst waits for signals (permissions) from src ranks to start getting data */
412*71438e86SJunchao Zhang   if (nsrcranks) {
413*71438e86SJunchao Zhang     NvshmemWaitSignals<<<1,1,0,link->remoteCommStream>>>(nsrcranks,dstsig,1,0); /* wait the signals to be 1, then set them to 0 */
414*71438e86SJunchao Zhang     cerr = cudaGetLastError();CHKERRCUDA(cerr);
415*71438e86SJunchao Zhang   }
416*71438e86SJunchao Zhang 
417*71438e86SJunchao Zhang   /* dst gets data from src ranks using non-blocking nvshmem_gets, which are finished in PetscSFLinkGetDataEnd_NVSHMEM() */
418*71438e86SJunchao Zhang 
419*71438e86SJunchao Zhang   /* Count number of locally accessible src ranks, which should be a small number */
420*71438e86SJunchao Zhang   for (int i=0; i<nsrcranks; i++) {if (nvshmem_ptr(src,srcranks_h[i])) nLocallyAccessible++;}
421*71438e86SJunchao Zhang 
422*71438e86SJunchao Zhang   /* Get data from remotely accessible PEs */
423*71438e86SJunchao Zhang   if (nLocallyAccessible < nsrcranks) {
424*71438e86SJunchao Zhang     GetDataFromRemotelyAccessible<<<nsrcranks,1,0,link->remoteCommStream>>>(nsrcranks,srcranks_d,src,srcdisp_d,dst,dstdisp_d,link->unitbytes);
425*71438e86SJunchao Zhang     cerr = cudaGetLastError();CHKERRCUDA(cerr);
426*71438e86SJunchao Zhang   }
427*71438e86SJunchao Zhang 
428*71438e86SJunchao Zhang   /* Get data from locally accessible PEs */
429*71438e86SJunchao Zhang   if (nLocallyAccessible) {
430*71438e86SJunchao Zhang     for (int i=0; i<nsrcranks; i++) {
431*71438e86SJunchao Zhang       int pe = srcranks_h[i];
432*71438e86SJunchao Zhang       if (nvshmem_ptr(src,pe)) {
433*71438e86SJunchao Zhang         size_t nelems = (dstdisp_h[i+1]-dstdisp_h[i])*link->unitbytes;
434*71438e86SJunchao Zhang         nvshmemx_getmem_nbi_on_stream(dst+(dstdisp_h[i]-dstdisp_h[0])*link->unitbytes,src+srcdisp_h[i]*link->unitbytes,nelems,pe,link->remoteCommStream);
435*71438e86SJunchao Zhang       }
436*71438e86SJunchao Zhang     }
437*71438e86SJunchao Zhang   }
438*71438e86SJunchao Zhang   PetscFunctionReturn(0);
439*71438e86SJunchao Zhang }
440*71438e86SJunchao Zhang 
441*71438e86SJunchao Zhang /* Finish the communication (can be done before Unpack)
442*71438e86SJunchao Zhang    Receiver tells its senders that they are allowed to reuse their send buffer (since receiver has got data from their send buffer)
443*71438e86SJunchao Zhang */
444*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkGetDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
445*71438e86SJunchao Zhang {
446*71438e86SJunchao Zhang   PetscErrorCode    ierr;
447*71438e86SJunchao Zhang   cudaError_t       cerr;
448*71438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
449*71438e86SJunchao Zhang   uint64_t          *srcsig;
450*71438e86SJunchao Zhang   PetscInt          nsrcranks,*srcsigdisp;
451*71438e86SJunchao Zhang   PetscMPIInt       *srcranks;
452*71438e86SJunchao Zhang 
453*71438e86SJunchao Zhang   PetscFunctionBegin;
454*71438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
455*71438e86SJunchao Zhang     nsrcranks   = sf->nRemoteRootRanks;
456*71438e86SJunchao Zhang     srcsig      = link->rootSendSig;     /* I want to set their root signal */
457*71438e86SJunchao Zhang     srcsigdisp  = sf->rootsigdisp_d;     /* offset of each root signal */
458*71438e86SJunchao Zhang     srcranks    = sf->ranks_d;           /* ranks of the n root ranks */
459*71438e86SJunchao Zhang   } else { /* LEAF2ROOT, root ranks are getting data */
460*71438e86SJunchao Zhang     nsrcranks   = bas->nRemoteLeafRanks;
461*71438e86SJunchao Zhang     srcsig      = link->leafSendSig;
462*71438e86SJunchao Zhang     srcsigdisp  = bas->leafsigdisp_d;
463*71438e86SJunchao Zhang     srcranks    = bas->iranks_d;
464*71438e86SJunchao Zhang   }
465*71438e86SJunchao Zhang 
466*71438e86SJunchao Zhang   if (nsrcranks) {
467*71438e86SJunchao Zhang     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Finish the nonblocking get, so that we can unpack afterwards */
468*71438e86SJunchao Zhang     cerr = cudaGetLastError();CHKERRCUDA(cerr);
469*71438e86SJunchao Zhang     NvshmemSendSignals<<<(nsrcranks+511)/512,512,0,link->remoteCommStream>>>(nsrcranks,srcsig,srcsigdisp,srcranks,0); /* set signals to 0 */
470*71438e86SJunchao Zhang     cerr = cudaGetLastError();CHKERRCUDA(cerr);
471*71438e86SJunchao Zhang   }
472*71438e86SJunchao Zhang   ierr = PetscSFLinkBuildDependenceEnd(sf,link,direction);CHKERRQ(ierr);
473*71438e86SJunchao Zhang   PetscFunctionReturn(0);
474*71438e86SJunchao Zhang }
475*71438e86SJunchao Zhang 
476*71438e86SJunchao Zhang /* ===========================================================================================================
477*71438e86SJunchao Zhang 
478*71438e86SJunchao Zhang    A set of routines to support sender initiated communication using the put-based method (the default)
479*71438e86SJunchao Zhang 
480*71438e86SJunchao Zhang     The putting protocol is:
481*71438e86SJunchao Zhang 
482*71438e86SJunchao Zhang     Sender has a send buf (sbuf) and a send signal var (ssig);  Receiver has a stand-alone recv buf (rbuf)
483*71438e86SJunchao Zhang     and a recv signal var (rsig); All signal variables have an initial value 0. rbuf is allocated by SF and
484*71438e86SJunchao Zhang     is in nvshmem space.
485*71438e86SJunchao Zhang 
486*71438e86SJunchao Zhang     Sender:                                 |  Receiver:
487*71438e86SJunchao Zhang                                             |
488*71438e86SJunchao Zhang   1.  Pack data into sbuf                   |
489*71438e86SJunchao Zhang   2.  Wait ssig be 0, then set it to 1      |
490*71438e86SJunchao Zhang   3.  Put data to remote stand-alone rbuf   |
491*71438e86SJunchao Zhang   4.  Fence // make sure 5 happens after 3  |
492*71438e86SJunchao Zhang   5.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
493*71438e86SJunchao Zhang                                             |   2. Unpack data from local rbuf
494*71438e86SJunchao Zhang                                             |   3. Put 0 to sender's ssig
495*71438e86SJunchao Zhang    ===========================================================================================================*/
496*71438e86SJunchao Zhang 
497*71438e86SJunchao Zhang /* n thread blocks. Each takes in charge one remote rank */
498*71438e86SJunchao Zhang __global__ static void WaitAndPutDataToRemotelyAccessible(PetscInt ndstranks,PetscMPIInt *dstranks,char *dst,PetscInt *dstdisp,const char *src,PetscInt *srcdisp,uint64_t *srcsig,PetscInt unitbytes)
499*71438e86SJunchao Zhang {
500*71438e86SJunchao Zhang   int               bid = blockIdx.x;
501*71438e86SJunchao Zhang   PetscMPIInt       pe  = dstranks[bid];
502*71438e86SJunchao Zhang 
503*71438e86SJunchao Zhang   if (!nvshmem_ptr(dst,pe)) {
504*71438e86SJunchao Zhang     PetscInt nelems = (srcdisp[bid+1]-srcdisp[bid])*unitbytes;
505*71438e86SJunchao Zhang     nvshmem_uint64_wait_until(srcsig+bid,NVSHMEM_CMP_EQ,0); /* Wait until the sig = 0 */
506*71438e86SJunchao Zhang     srcsig[bid] = 1;
507*71438e86SJunchao Zhang     nvshmem_putmem_nbi(dst+dstdisp[bid]*unitbytes,src+(srcdisp[bid]-srcdisp[0])*unitbytes,nelems,pe);
508*71438e86SJunchao Zhang   }
509*71438e86SJunchao Zhang }
510*71438e86SJunchao Zhang 
511*71438e86SJunchao Zhang /* one-thread kernel, which takes in charge all locally accesible */
512*71438e86SJunchao Zhang __global__ static void WaitSignalsFromLocallyAccessible(PetscInt ndstranks,PetscMPIInt *dstranks,uint64_t *srcsig,const char *dst)
513*71438e86SJunchao Zhang {
514*71438e86SJunchao Zhang   for (int i=0; i<ndstranks; i++) {
515*71438e86SJunchao Zhang     int pe = dstranks[i];
516*71438e86SJunchao Zhang     if (nvshmem_ptr(dst,pe)) {
517*71438e86SJunchao Zhang       nvshmem_uint64_wait_until(srcsig+i,NVSHMEM_CMP_EQ,0); /* Wait until the sig = 0 */
518*71438e86SJunchao Zhang       srcsig[i] = 1;
519*71438e86SJunchao Zhang     }
520*71438e86SJunchao Zhang   }
521*71438e86SJunchao Zhang }
522*71438e86SJunchao Zhang 
523*71438e86SJunchao Zhang /* Put data in the given direction  */
524*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkPutDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
525*71438e86SJunchao Zhang {
526*71438e86SJunchao Zhang   PetscErrorCode    ierr;
527*71438e86SJunchao Zhang   cudaError_t       cerr;
528*71438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
529*71438e86SJunchao Zhang   PetscInt          ndstranks,nLocallyAccessible = 0;
530*71438e86SJunchao Zhang   char              *src,*dst;
531*71438e86SJunchao Zhang   PetscInt          *srcdisp_h,*dstdisp_h;
532*71438e86SJunchao Zhang   PetscInt          *srcdisp_d,*dstdisp_d;
533*71438e86SJunchao Zhang   PetscMPIInt       *dstranks_h;
534*71438e86SJunchao Zhang   PetscMPIInt       *dstranks_d;
535*71438e86SJunchao Zhang   uint64_t          *srcsig;
536*71438e86SJunchao Zhang 
537*71438e86SJunchao Zhang   PetscFunctionBegin;
538*71438e86SJunchao Zhang   ierr = PetscSFLinkBuildDependenceBegin(sf,link,direction);CHKERRQ(ierr);
539*71438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* put data in rootbuf to leafbuf  */
540*71438e86SJunchao Zhang     ndstranks    = bas->nRemoteLeafRanks; /* number of (remote) leaf ranks */
541*71438e86SJunchao Zhang     src          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* Both src & dst must be symmetric */
542*71438e86SJunchao Zhang     dst          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
543*71438e86SJunchao Zhang 
544*71438e86SJunchao Zhang     srcdisp_h    = bas->ioffset+bas->ndiranks;  /* offsets of rootbuf. srcdisp[0] is not necessarily zero */
545*71438e86SJunchao Zhang     srcdisp_d    = bas->ioffset_d;
546*71438e86SJunchao Zhang     srcsig       = link->rootSendSig;
547*71438e86SJunchao Zhang 
548*71438e86SJunchao Zhang     dstdisp_h    = bas->leafbufdisp;            /* for my i-th remote leaf rank, I will access its leaf buf at offset leafbufdisp[i] */
549*71438e86SJunchao Zhang     dstdisp_d    = bas->leafbufdisp_d;
550*71438e86SJunchao Zhang     dstranks_h   = bas->iranks+bas->ndiranks;   /* remote leaf ranks */
551*71438e86SJunchao Zhang     dstranks_d   = bas->iranks_d;
552*71438e86SJunchao Zhang   } else { /* put data in leafbuf to rootbuf */
553*71438e86SJunchao Zhang     ndstranks    = sf->nRemoteRootRanks;
554*71438e86SJunchao Zhang     src          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
555*71438e86SJunchao Zhang     dst          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
556*71438e86SJunchao Zhang 
557*71438e86SJunchao Zhang     srcdisp_h    = sf->roffset+sf->ndranks; /* offsets of leafbuf */
558*71438e86SJunchao Zhang     srcdisp_d    = sf->roffset_d;
559*71438e86SJunchao Zhang     srcsig       = link->leafSendSig;
560*71438e86SJunchao Zhang 
561*71438e86SJunchao Zhang     dstdisp_h    = sf->rootbufdisp;         /* for my i-th remote root rank, I will access its root buf at offset rootbufdisp[i] */
562*71438e86SJunchao Zhang     dstdisp_d    = sf->rootbufdisp_d;
563*71438e86SJunchao Zhang     dstranks_h   = sf->ranks+sf->ndranks;   /* remote root ranks */
564*71438e86SJunchao Zhang     dstranks_d   = sf->ranks_d;
565*71438e86SJunchao Zhang   }
566*71438e86SJunchao Zhang 
567*71438e86SJunchao Zhang   /* Wait for signals and then put data to dst ranks using non-blocking nvshmem_put, which are finished in PetscSFLinkPutDataEnd_NVSHMEM */
568*71438e86SJunchao Zhang 
569*71438e86SJunchao Zhang   /* Count number of locally accessible neighbors, which should be a small number */
570*71438e86SJunchao Zhang   for (int i=0; i<ndstranks; i++) {if (nvshmem_ptr(dst,dstranks_h[i])) nLocallyAccessible++;}
571*71438e86SJunchao Zhang 
572*71438e86SJunchao Zhang   /* For remotely accessible PEs, send data to them in one kernel call */
573*71438e86SJunchao Zhang   if (nLocallyAccessible < ndstranks) {
574*71438e86SJunchao Zhang     WaitAndPutDataToRemotelyAccessible<<<ndstranks,1,0,link->remoteCommStream>>>(ndstranks,dstranks_d,dst,dstdisp_d,src,srcdisp_d,srcsig,link->unitbytes);
575*71438e86SJunchao Zhang     cerr = cudaGetLastError();CHKERRCUDA(cerr);
576*71438e86SJunchao Zhang   }
577*71438e86SJunchao Zhang 
578*71438e86SJunchao Zhang   /* For locally accessible PEs, use host API, which uses CUDA copy-engines and is much faster than device API */
579*71438e86SJunchao Zhang   if (nLocallyAccessible) {
580*71438e86SJunchao Zhang     WaitSignalsFromLocallyAccessible<<<1,1,0,link->remoteCommStream>>>(ndstranks,dstranks_d,srcsig,dst);
581*71438e86SJunchao Zhang     for (int i=0; i<ndstranks; i++) {
582*71438e86SJunchao Zhang       int pe = dstranks_h[i];
583*71438e86SJunchao Zhang       if (nvshmem_ptr(dst,pe)) { /* If return a non-null pointer, then <pe> is locally accessible */
584*71438e86SJunchao Zhang         size_t nelems = (srcdisp_h[i+1]-srcdisp_h[i])*link->unitbytes;
585*71438e86SJunchao Zhang          /* Initiate the nonblocking communication */
586*71438e86SJunchao Zhang         nvshmemx_putmem_nbi_on_stream(dst+dstdisp_h[i]*link->unitbytes,src+(srcdisp_h[i]-srcdisp_h[0])*link->unitbytes,nelems,pe,link->remoteCommStream);
587*71438e86SJunchao Zhang       }
588*71438e86SJunchao Zhang     }
589*71438e86SJunchao Zhang   }
590*71438e86SJunchao Zhang 
591*71438e86SJunchao Zhang   if (nLocallyAccessible) {
592*71438e86SJunchao Zhang     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Calling nvshmem_fence/quiet() does not fence the above nvshmemx_putmem_nbi_on_stream! */
593*71438e86SJunchao Zhang   }
594*71438e86SJunchao Zhang   PetscFunctionReturn(0);
595*71438e86SJunchao Zhang }
596*71438e86SJunchao Zhang 
597*71438e86SJunchao Zhang /* A one-thread kernel. The thread takes in charge all remote PEs */
598*71438e86SJunchao Zhang __global__ static void PutDataEnd(PetscInt nsrcranks,PetscInt ndstranks,PetscMPIInt *dstranks,uint64_t *dstsig,PetscInt *dstsigdisp)
599*71438e86SJunchao Zhang {
600*71438e86SJunchao Zhang   /* TODO: Shall we finished the non-blocking remote puts? */
601*71438e86SJunchao Zhang 
602*71438e86SJunchao Zhang   /* 1. Send a signal to each dst rank */
603*71438e86SJunchao Zhang 
604*71438e86SJunchao Zhang   /* According to Akhil@NVIDIA, IB is orderred, so no fence is needed for remote PEs.
605*71438e86SJunchao Zhang      For local PEs, we already called nvshmemx_quiet_on_stream(). Therefore, we are good to send signals to all dst ranks now.
606*71438e86SJunchao Zhang   */
607*71438e86SJunchao Zhang   for (int i=0; i<ndstranks; i++) {nvshmemx_uint64_signal(dstsig+dstsigdisp[i],1,dstranks[i]);} /* set sig to 1 */
608*71438e86SJunchao Zhang 
609*71438e86SJunchao Zhang   /* 2. Wait for signals from src ranks (if any) */
610*71438e86SJunchao Zhang   if (nsrcranks) {
611*71438e86SJunchao Zhang     nvshmem_uint64_wait_until_all(dstsig,nsrcranks,NULL/*no mask*/,NVSHMEM_CMP_EQ,1); /* wait sigs to be 1, then set them to 0 */
612*71438e86SJunchao Zhang     for (int i=0; i<nsrcranks; i++) dstsig[i] = 0;
613*71438e86SJunchao Zhang   }
614*71438e86SJunchao Zhang }
615*71438e86SJunchao Zhang 
616*71438e86SJunchao Zhang /* Finish the communication -- A receiver waits until it can access its receive buffer */
617*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkPutDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
618*71438e86SJunchao Zhang {
619*71438e86SJunchao Zhang   PetscErrorCode    ierr;
620*71438e86SJunchao Zhang   cudaError_t       cerr;
621*71438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
622*71438e86SJunchao Zhang   PetscMPIInt       *dstranks;
623*71438e86SJunchao Zhang   uint64_t          *dstsig;
624*71438e86SJunchao Zhang   PetscInt          nsrcranks,ndstranks,*dstsigdisp;
625*71438e86SJunchao Zhang 
626*71438e86SJunchao Zhang   PetscFunctionBegin;
627*71438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* put root data to leaf */
628*71438e86SJunchao Zhang     nsrcranks    = sf->nRemoteRootRanks;
629*71438e86SJunchao Zhang 
630*71438e86SJunchao Zhang     ndstranks    = bas->nRemoteLeafRanks;
631*71438e86SJunchao Zhang     dstranks     = bas->iranks_d;       /* leaf ranks */
632*71438e86SJunchao Zhang     dstsig       = link->leafRecvSig;   /* I will set my leaf ranks's RecvSig */
633*71438e86SJunchao Zhang     dstsigdisp   = bas->leafsigdisp_d;  /* for my i-th remote leaf rank, I will access its signal at offset leafsigdisp[i] */
634*71438e86SJunchao Zhang   } else { /* LEAF2ROOT */
635*71438e86SJunchao Zhang     nsrcranks    = bas->nRemoteLeafRanks;
636*71438e86SJunchao Zhang 
637*71438e86SJunchao Zhang     ndstranks    = sf->nRemoteRootRanks;
638*71438e86SJunchao Zhang     dstranks     = sf->ranks_d;
639*71438e86SJunchao Zhang     dstsig       = link->rootRecvSig;
640*71438e86SJunchao Zhang     dstsigdisp   = sf->rootsigdisp_d;
641*71438e86SJunchao Zhang   }
642*71438e86SJunchao Zhang 
643*71438e86SJunchao Zhang   if (nsrcranks || ndstranks) {
644*71438e86SJunchao Zhang     PutDataEnd<<<1,1,0,link->remoteCommStream>>>(nsrcranks,ndstranks,dstranks,dstsig,dstsigdisp);
645*71438e86SJunchao Zhang     cerr = cudaGetLastError();CHKERRCUDA(cerr);
646*71438e86SJunchao Zhang   }
647*71438e86SJunchao Zhang   ierr = PetscSFLinkBuildDependenceEnd(sf,link,direction);CHKERRQ(ierr);
648*71438e86SJunchao Zhang   PetscFunctionReturn(0);
649*71438e86SJunchao Zhang }
650*71438e86SJunchao Zhang 
651*71438e86SJunchao Zhang /* PostUnpack operation -- A receiver tells its senders that they are allowed to put data to here (it implies recv buf is free to take new data) */
652*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
653*71438e86SJunchao Zhang {
654*71438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
655*71438e86SJunchao Zhang   uint64_t          *srcsig;
656*71438e86SJunchao Zhang   PetscInt          nsrcranks,*srcsigdisp_d;
657*71438e86SJunchao Zhang   PetscMPIInt       *srcranks_d;
658*71438e86SJunchao Zhang 
659*71438e86SJunchao Zhang   PetscFunctionBegin;
660*71438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* I allow my root ranks to put data to me */
661*71438e86SJunchao Zhang     nsrcranks    = sf->nRemoteRootRanks;
662*71438e86SJunchao Zhang     srcsig       = link->rootSendSig;      /* I want to set their send signals */
663*71438e86SJunchao Zhang     srcsigdisp_d = sf->rootsigdisp_d;      /* offset of each root signal */
664*71438e86SJunchao Zhang     srcranks_d   = sf->ranks_d;            /* ranks of the n root ranks */
665*71438e86SJunchao Zhang   } else { /* LEAF2ROOT */
666*71438e86SJunchao Zhang     nsrcranks    = bas->nRemoteLeafRanks;
667*71438e86SJunchao Zhang     srcsig       = link->leafSendSig;
668*71438e86SJunchao Zhang     srcsigdisp_d = bas->leafsigdisp_d;
669*71438e86SJunchao Zhang     srcranks_d   = bas->iranks_d;
670*71438e86SJunchao Zhang   }
671*71438e86SJunchao Zhang 
672*71438e86SJunchao Zhang   if (nsrcranks) {
673*71438e86SJunchao Zhang     NvshmemSendSignals<<<(nsrcranks+255)/256,256,0,link->remoteCommStream>>>(nsrcranks,srcsig,srcsigdisp_d,srcranks_d,0); /* Set remote signals to 0 */
674*71438e86SJunchao Zhang     cudaError_t cerr = cudaGetLastError();CHKERRCUDA(cerr);
675*71438e86SJunchao Zhang   }
676*71438e86SJunchao Zhang   PetscFunctionReturn(0);
677*71438e86SJunchao Zhang }
678*71438e86SJunchao Zhang 
679*71438e86SJunchao Zhang /* Destructor when the link uses nvshmem for communication */
680*71438e86SJunchao Zhang static PetscErrorCode PetscSFLinkDestroy_NVSHMEM(PetscSF sf,PetscSFLink link)
681*71438e86SJunchao Zhang {
682*71438e86SJunchao Zhang   PetscErrorCode    ierr;
683*71438e86SJunchao Zhang   cudaError_t       cerr;
684*71438e86SJunchao Zhang 
685*71438e86SJunchao Zhang   PetscFunctionBegin;
686*71438e86SJunchao Zhang   cerr = cudaEventDestroy(link->dataReady);CHKERRCUDA(cerr);
687*71438e86SJunchao Zhang   cerr = cudaEventDestroy(link->endRemoteComm);CHKERRCUDA(cerr);
688*71438e86SJunchao Zhang   cerr = cudaStreamDestroy(link->remoteCommStream);CHKERRCUDA(cerr);
689*71438e86SJunchao Zhang 
690*71438e86SJunchao Zhang   /* nvshmem does not need buffers on host, which should be NULL */
691*71438e86SJunchao Zhang   ierr = PetscNvshmemFree(link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
692*71438e86SJunchao Zhang   ierr = PetscNvshmemFree(link->leafSendSig);CHKERRQ(ierr);
693*71438e86SJunchao Zhang   ierr = PetscNvshmemFree(link->leafRecvSig);CHKERRQ(ierr);
694*71438e86SJunchao Zhang   ierr = PetscNvshmemFree(link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
695*71438e86SJunchao Zhang   ierr = PetscNvshmemFree(link->rootSendSig);CHKERRQ(ierr);
696*71438e86SJunchao Zhang   ierr = PetscNvshmemFree(link->rootRecvSig);CHKERRQ(ierr);
697*71438e86SJunchao Zhang   PetscFunctionReturn(0);
698*71438e86SJunchao Zhang }
699*71438e86SJunchao Zhang 
700*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkCreate_NVSHMEM(PetscSF sf,MPI_Datatype unit,PetscMemType rootmtype,const void *rootdata,PetscMemType leafmtype,const void *leafdata,MPI_Op op,PetscSFOperation sfop,PetscSFLink *mylink)
701*71438e86SJunchao Zhang {
702*71438e86SJunchao Zhang   PetscErrorCode    ierr;
703*71438e86SJunchao Zhang   cudaError_t       cerr;
704*71438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
705*71438e86SJunchao Zhang   PetscSFLink       *p,link;
706*71438e86SJunchao Zhang   PetscBool         match,rootdirect[2],leafdirect[2];
707*71438e86SJunchao Zhang   int               greatestPriority;
708*71438e86SJunchao Zhang 
709*71438e86SJunchao Zhang   PetscFunctionBegin;
710*71438e86SJunchao Zhang   /* Check to see if we can directly send/recv root/leafdata with the given sf, sfop and op.
711*71438e86SJunchao Zhang      We only care root/leafdirect[PETSCSF_REMOTE], since we never need intermeidate buffers in local communication with NVSHMEM.
712*71438e86SJunchao Zhang   */
713*71438e86SJunchao Zhang   if (sfop == PETSCSF_BCAST) { /* Move data from rootbuf to leafbuf */
714*71438e86SJunchao Zhang     if (sf->use_nvshmem_get) {
715*71438e86SJunchao Zhang       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* send buffer has to be stand-alone (can't be rootdata) */
716*71438e86SJunchao Zhang       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
717*71438e86SJunchao Zhang     } else {
718*71438e86SJunchao Zhang       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
719*71438e86SJunchao Zhang       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;  /* Our put-protocol always needs a nvshmem alloc'ed recv buffer */
720*71438e86SJunchao Zhang     }
721*71438e86SJunchao Zhang   } else if (sfop == PETSCSF_REDUCE) { /* Move data from leafbuf to rootbuf */
722*71438e86SJunchao Zhang     if (sf->use_nvshmem_get) {
723*71438e86SJunchao Zhang       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
724*71438e86SJunchao Zhang       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;
725*71438e86SJunchao Zhang     } else {
726*71438e86SJunchao Zhang       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE;
727*71438e86SJunchao Zhang       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
728*71438e86SJunchao Zhang     }
729*71438e86SJunchao Zhang   } else { /* PETSCSF_FETCH */
730*71438e86SJunchao Zhang     rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* FETCH always need a separate rootbuf */
731*71438e86SJunchao Zhang     leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* We also force allocating a separate leafbuf so that leafdata and leafupdate can share mpi requests */
732*71438e86SJunchao Zhang   }
733*71438e86SJunchao Zhang 
734*71438e86SJunchao Zhang   /* Look for free nvshmem links in cache */
735*71438e86SJunchao Zhang   for (p=&bas->avail; (link=*p); p=&link->next) {
736*71438e86SJunchao Zhang     if (link->use_nvshmem) {
737*71438e86SJunchao Zhang       ierr = MPIPetsc_Type_compare(unit,link->unit,&match);CHKERRQ(ierr);
738*71438e86SJunchao Zhang       if (match) {
739*71438e86SJunchao Zhang         *p = link->next; /* Remove from available list */
740*71438e86SJunchao Zhang         goto found;
741*71438e86SJunchao Zhang       }
742*71438e86SJunchao Zhang     }
743*71438e86SJunchao Zhang   }
744*71438e86SJunchao Zhang   ierr = PetscNew(&link);CHKERRQ(ierr);
745*71438e86SJunchao Zhang   ierr = PetscSFLinkSetUp_Host(sf,link,unit);CHKERRQ(ierr); /* Compute link->unitbytes, dup link->unit etc. */
746*71438e86SJunchao Zhang   if (sf->backend == PETSCSF_BACKEND_CUDA) {ierr = PetscSFLinkSetUp_CUDA(sf,link,unit);CHKERRQ(ierr);} /* Setup pack routines, streams etc */
747*71438e86SJunchao Zhang  #if defined(PETSC_HAVE_KOKKOS)
748*71438e86SJunchao Zhang   else if (sf->backend == PETSCSF_BACKEND_KOKKOS) {ierr = PetscSFLinkSetUp_Kokkos(sf,link,unit);CHKERRQ(ierr);}
749*71438e86SJunchao Zhang  #endif
750*71438e86SJunchao Zhang 
751*71438e86SJunchao Zhang   link->rootdirect[PETSCSF_LOCAL]  = PETSC_TRUE; /* For the local part we directly use root/leafdata */
752*71438e86SJunchao Zhang   link->leafdirect[PETSCSF_LOCAL]  = PETSC_TRUE;
753*71438e86SJunchao Zhang 
754*71438e86SJunchao Zhang   /* Init signals to zero */
755*71438e86SJunchao Zhang   if (!link->rootSendSig) {ierr = PetscNvshmemCalloc(bas->nRemoteLeafRanksMax*sizeof(uint64_t),(void**)&link->rootSendSig);CHKERRQ(ierr);}
756*71438e86SJunchao Zhang   if (!link->rootRecvSig) {ierr = PetscNvshmemCalloc(bas->nRemoteLeafRanksMax*sizeof(uint64_t),(void**)&link->rootRecvSig);CHKERRQ(ierr);}
757*71438e86SJunchao Zhang   if (!link->leafSendSig) {ierr = PetscNvshmemCalloc(sf->nRemoteRootRanksMax*sizeof(uint64_t),(void**)&link->leafSendSig);CHKERRQ(ierr);}
758*71438e86SJunchao Zhang   if (!link->leafRecvSig) {ierr = PetscNvshmemCalloc(sf->nRemoteRootRanksMax*sizeof(uint64_t),(void**)&link->leafRecvSig);CHKERRQ(ierr);}
759*71438e86SJunchao Zhang 
760*71438e86SJunchao Zhang   link->use_nvshmem                = PETSC_TRUE;
761*71438e86SJunchao Zhang   link->rootmtype                  = PETSC_MEMTYPE_DEVICE; /* Only need 0/1-based mtype from now on */
762*71438e86SJunchao Zhang   link->leafmtype                  = PETSC_MEMTYPE_DEVICE;
763*71438e86SJunchao Zhang   /* Overwrite some function pointers set by PetscSFLinkSetUp_CUDA */
764*71438e86SJunchao Zhang   link->Destroy                    = PetscSFLinkDestroy_NVSHMEM;
765*71438e86SJunchao Zhang   if (sf->use_nvshmem_get) { /* get-based protocol */
766*71438e86SJunchao Zhang     link->PrePack                  = PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM;
767*71438e86SJunchao Zhang     link->StartCommunication       = PetscSFLinkGetDataBegin_NVSHMEM;
768*71438e86SJunchao Zhang     link->FinishCommunication      = PetscSFLinkGetDataEnd_NVSHMEM;
769*71438e86SJunchao Zhang   } else { /* put-based protocol */
770*71438e86SJunchao Zhang     link->StartCommunication       = PetscSFLinkPutDataBegin_NVSHMEM;
771*71438e86SJunchao Zhang     link->FinishCommunication      = PetscSFLinkPutDataEnd_NVSHMEM;
772*71438e86SJunchao Zhang     link->PostUnpack               = PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM;
773*71438e86SJunchao Zhang   }
774*71438e86SJunchao Zhang 
775*71438e86SJunchao Zhang   cerr = cudaDeviceGetStreamPriorityRange(NULL,&greatestPriority);CHKERRCUDA(cerr);
776*71438e86SJunchao Zhang   cerr = cudaStreamCreateWithPriority(&link->remoteCommStream,cudaStreamNonBlocking,greatestPriority);CHKERRCUDA(cerr);
777*71438e86SJunchao Zhang 
778*71438e86SJunchao Zhang   cerr = cudaEventCreateWithFlags(&link->dataReady,cudaEventDisableTiming);CHKERRCUDA(cerr);
779*71438e86SJunchao Zhang   cerr = cudaEventCreateWithFlags(&link->endRemoteComm,cudaEventDisableTiming);CHKERRCUDA(cerr);
780*71438e86SJunchao Zhang 
781*71438e86SJunchao Zhang found:
782*71438e86SJunchao Zhang   if (rootdirect[PETSCSF_REMOTE]) {
783*71438e86SJunchao Zhang     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char*)rootdata + bas->rootstart[PETSCSF_REMOTE]*link->unitbytes;
784*71438e86SJunchao Zhang   } else {
785*71438e86SJunchao Zhang     if (!link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) {
786*71438e86SJunchao Zhang       ierr = PetscNvshmemMalloc(bas->rootbuflen_rmax*link->unitbytes,(void**)&link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
787*71438e86SJunchao Zhang     }
788*71438e86SJunchao Zhang     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
789*71438e86SJunchao Zhang   }
790*71438e86SJunchao Zhang 
791*71438e86SJunchao Zhang   if (leafdirect[PETSCSF_REMOTE]) {
792*71438e86SJunchao Zhang     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char*)leafdata + sf->leafstart[PETSCSF_REMOTE]*link->unitbytes;
793*71438e86SJunchao Zhang   } else {
794*71438e86SJunchao Zhang     if (!link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) {
795*71438e86SJunchao Zhang       ierr = PetscNvshmemMalloc(sf->leafbuflen_rmax*link->unitbytes,(void**)&link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
796*71438e86SJunchao Zhang     }
797*71438e86SJunchao Zhang     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
798*71438e86SJunchao Zhang   }
799*71438e86SJunchao Zhang 
800*71438e86SJunchao Zhang   link->rootdirect[PETSCSF_REMOTE] = rootdirect[PETSCSF_REMOTE];
801*71438e86SJunchao Zhang   link->leafdirect[PETSCSF_REMOTE] = leafdirect[PETSCSF_REMOTE];
802*71438e86SJunchao Zhang   link->rootdata                   = rootdata; /* root/leafdata are keys to look up links in PetscSFXxxEnd */
803*71438e86SJunchao Zhang   link->leafdata                   = leafdata;
804*71438e86SJunchao Zhang   link->next                       = bas->inuse;
805*71438e86SJunchao Zhang   bas->inuse                       = link;
806*71438e86SJunchao Zhang   *mylink                          = link;
807*71438e86SJunchao Zhang   PetscFunctionReturn(0);
808*71438e86SJunchao Zhang }
809*71438e86SJunchao Zhang 
810*71438e86SJunchao Zhang #if defined(PETSC_USE_REAL_SINGLE)
811*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemSum(PetscInt count,float *dst,const float *src)
812*71438e86SJunchao Zhang {
813*71438e86SJunchao Zhang   PetscErrorCode    ierr;
814*71438e86SJunchao Zhang   PetscMPIInt       num; /* Assume nvshmem's int is MPI's int */
815*71438e86SJunchao Zhang 
816*71438e86SJunchao Zhang   PetscFunctionBegin;
817*71438e86SJunchao Zhang   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
818*71438e86SJunchao Zhang   nvshmemx_float_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
819*71438e86SJunchao Zhang   PetscFunctionReturn(0);
820*71438e86SJunchao Zhang }
821*71438e86SJunchao Zhang 
822*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemMax(PetscInt count,float *dst,const float *src)
823*71438e86SJunchao Zhang {
824*71438e86SJunchao Zhang   PetscErrorCode    ierr;
825*71438e86SJunchao Zhang   PetscMPIInt       num;
826*71438e86SJunchao Zhang 
827*71438e86SJunchao Zhang   PetscFunctionBegin;
828*71438e86SJunchao Zhang   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
829*71438e86SJunchao Zhang   nvshmemx_float_max_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
830*71438e86SJunchao Zhang   PetscFunctionReturn(0);
831*71438e86SJunchao Zhang }
832*71438e86SJunchao Zhang #elif defined(PETSC_USE_REAL_DOUBLE)
833*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemSum(PetscInt count,double *dst,const double *src)
834*71438e86SJunchao Zhang {
835*71438e86SJunchao Zhang   PetscErrorCode    ierr;
836*71438e86SJunchao Zhang   PetscMPIInt       num;
837*71438e86SJunchao Zhang 
838*71438e86SJunchao Zhang   PetscFunctionBegin;
839*71438e86SJunchao Zhang   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
840*71438e86SJunchao Zhang   nvshmemx_double_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
841*71438e86SJunchao Zhang   PetscFunctionReturn(0);
842*71438e86SJunchao Zhang }
843*71438e86SJunchao Zhang 
844*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemMax(PetscInt count,double *dst,const double *src)
845*71438e86SJunchao Zhang {
846*71438e86SJunchao Zhang   PetscErrorCode    ierr;
847*71438e86SJunchao Zhang   PetscMPIInt       num;
848*71438e86SJunchao Zhang 
849*71438e86SJunchao Zhang   PetscFunctionBegin;
850*71438e86SJunchao Zhang   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
851*71438e86SJunchao Zhang   nvshmemx_double_max_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
852*71438e86SJunchao Zhang   PetscFunctionReturn(0);
853*71438e86SJunchao Zhang }
854*71438e86SJunchao Zhang #endif
855*71438e86SJunchao Zhang 
856