xref: /petsc/src/vec/is/sf/impls/basic/nvshmem/sfnvshmem.cu (revision a4af0ceea8a251db97ee0dc5c0d52d4adf50264a)
171438e86SJunchao Zhang #include <petsc/private/cudavecimpl.h>
271438e86SJunchao Zhang #include <../src/vec/is/sf/impls/basic/sfpack.h>
371438e86SJunchao Zhang #include <mpi.h>
471438e86SJunchao Zhang #include <nvshmem.h>
571438e86SJunchao Zhang #include <nvshmemx.h>
671438e86SJunchao Zhang 
771438e86SJunchao Zhang PetscErrorCode PetscNvshmemInitializeCheck(void)
871438e86SJunchao Zhang {
971438e86SJunchao Zhang   PetscErrorCode   ierr;
1071438e86SJunchao Zhang 
1171438e86SJunchao Zhang   PetscFunctionBegin;
1271438e86SJunchao Zhang   if (!PetscNvshmemInitialized) { /* Note NVSHMEM does not provide a routine to check whether it is initialized */
1371438e86SJunchao Zhang     nvshmemx_init_attr_t attr;
1471438e86SJunchao Zhang     attr.mpi_comm = &PETSC_COMM_WORLD;
15*a4af0ceeSJacob Faibussowitsch     ierr = PetscDeviceInitialize(PETSC_DEVICE_CUDA);CHKERRQ(ierr);
1671438e86SJunchao Zhang     ierr = nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM,&attr);CHKERRQ(ierr);
1771438e86SJunchao Zhang     PetscNvshmemInitialized = PETSC_TRUE;
1871438e86SJunchao Zhang     PetscBeganNvshmem       = PETSC_TRUE;
1971438e86SJunchao Zhang   }
2071438e86SJunchao Zhang   PetscFunctionReturn(0);
2171438e86SJunchao Zhang }
2271438e86SJunchao Zhang 
2371438e86SJunchao Zhang PetscErrorCode PetscNvshmemMalloc(size_t size, void** ptr)
2471438e86SJunchao Zhang {
2571438e86SJunchao Zhang   PetscErrorCode ierr;
2671438e86SJunchao Zhang 
2771438e86SJunchao Zhang   PetscFunctionBegin;
2871438e86SJunchao Zhang   ierr = PetscNvshmemInitializeCheck();CHKERRQ(ierr);
2971438e86SJunchao Zhang   *ptr = nvshmem_malloc(size);
3071438e86SJunchao Zhang   if (!*ptr) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"nvshmem_malloc() failed to allocate %zu bytes",size);
3171438e86SJunchao Zhang   PetscFunctionReturn(0);
3271438e86SJunchao Zhang }
3371438e86SJunchao Zhang 
3471438e86SJunchao Zhang PetscErrorCode PetscNvshmemCalloc(size_t size, void**ptr)
3571438e86SJunchao Zhang {
3671438e86SJunchao Zhang   PetscErrorCode ierr;
3771438e86SJunchao Zhang 
3871438e86SJunchao Zhang   PetscFunctionBegin;
3971438e86SJunchao Zhang   ierr = PetscNvshmemInitializeCheck();CHKERRQ(ierr);
4071438e86SJunchao Zhang   *ptr = nvshmem_calloc(size,1);
4171438e86SJunchao Zhang   if (!*ptr) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"nvshmem_calloc() failed to allocate %zu bytes",size);
4271438e86SJunchao Zhang   PetscFunctionReturn(0);
4371438e86SJunchao Zhang }
4471438e86SJunchao Zhang 
4571438e86SJunchao Zhang PetscErrorCode PetscNvshmemFree_Private(void* ptr)
4671438e86SJunchao Zhang {
4771438e86SJunchao Zhang   PetscFunctionBegin;
4871438e86SJunchao Zhang   nvshmem_free(ptr);
4971438e86SJunchao Zhang   PetscFunctionReturn(0);
5071438e86SJunchao Zhang }
5171438e86SJunchao Zhang 
5271438e86SJunchao Zhang PetscErrorCode PetscNvshmemFinalize(void)
5371438e86SJunchao Zhang {
5471438e86SJunchao Zhang   PetscFunctionBegin;
5571438e86SJunchao Zhang   nvshmem_finalize();
5671438e86SJunchao Zhang   PetscFunctionReturn(0);
5771438e86SJunchao Zhang }
5871438e86SJunchao Zhang 
5971438e86SJunchao Zhang /* Free nvshmem related fields in the SF */
6071438e86SJunchao Zhang PetscErrorCode PetscSFReset_Basic_NVSHMEM(PetscSF sf)
6171438e86SJunchao Zhang {
6271438e86SJunchao Zhang   PetscErrorCode    ierr;
6371438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
6471438e86SJunchao Zhang 
6571438e86SJunchao Zhang   PetscFunctionBegin;
6671438e86SJunchao Zhang   ierr = PetscFree2(bas->leafsigdisp,bas->leafbufdisp);CHKERRQ(ierr);
6771438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->leafbufdisp_d);CHKERRQ(ierr);
6871438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->leafsigdisp_d);CHKERRQ(ierr);
6971438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->iranks_d);CHKERRQ(ierr);
7071438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->ioffset_d);CHKERRQ(ierr);
7171438e86SJunchao Zhang 
7271438e86SJunchao Zhang   ierr = PetscFree2(sf->rootsigdisp,sf->rootbufdisp);CHKERRQ(ierr);
7371438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->rootbufdisp_d);CHKERRQ(ierr);
7471438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->rootsigdisp_d);CHKERRQ(ierr);
7571438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->ranks_d);CHKERRQ(ierr);
7671438e86SJunchao Zhang   ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->roffset_d);CHKERRQ(ierr);
7771438e86SJunchao Zhang   PetscFunctionReturn(0);
7871438e86SJunchao Zhang }
7971438e86SJunchao Zhang 
8071438e86SJunchao Zhang /* Set up NVSHMEM related fields for an SF of type SFBASIC (only after PetscSFSetup_Basic() already set up dependant fields */
8171438e86SJunchao Zhang static PetscErrorCode PetscSFSetUp_Basic_NVSHMEM(PetscSF sf)
8271438e86SJunchao Zhang {
8371438e86SJunchao Zhang   PetscErrorCode ierr;
8471438e86SJunchao Zhang   cudaError_t    cerr;
8571438e86SJunchao Zhang   PetscSF_Basic  *bas = (PetscSF_Basic*)sf->data;
8671438e86SJunchao Zhang   PetscInt       i,nRemoteRootRanks,nRemoteLeafRanks;
8771438e86SJunchao Zhang   PetscMPIInt    tag;
8871438e86SJunchao Zhang   MPI_Comm       comm;
8971438e86SJunchao Zhang   MPI_Request    *rootreqs,*leafreqs;
9071438e86SJunchao Zhang   PetscInt       tmp,stmp[4],rtmp[4]; /* tmps for send/recv buffers */
9171438e86SJunchao Zhang 
9271438e86SJunchao Zhang   PetscFunctionBegin;
9371438e86SJunchao Zhang   ierr = PetscObjectGetComm((PetscObject)sf,&comm);CHKERRQ(ierr);
9471438e86SJunchao Zhang   ierr = PetscObjectGetNewTag((PetscObject)sf,&tag);CHKERRQ(ierr);
9571438e86SJunchao Zhang 
9671438e86SJunchao Zhang   nRemoteRootRanks      = sf->nranks-sf->ndranks;
9771438e86SJunchao Zhang   nRemoteLeafRanks      = bas->niranks-bas->ndiranks;
9871438e86SJunchao Zhang   sf->nRemoteRootRanks  = nRemoteRootRanks;
9971438e86SJunchao Zhang   bas->nRemoteLeafRanks = nRemoteLeafRanks;
10071438e86SJunchao Zhang 
10171438e86SJunchao Zhang   ierr = PetscMalloc2(nRemoteLeafRanks,&rootreqs,nRemoteRootRanks,&leafreqs);CHKERRQ(ierr);
10271438e86SJunchao Zhang 
10371438e86SJunchao Zhang   stmp[0] = nRemoteRootRanks;
10471438e86SJunchao Zhang   stmp[1] = sf->leafbuflen[PETSCSF_REMOTE];
10571438e86SJunchao Zhang   stmp[2] = nRemoteLeafRanks;
10671438e86SJunchao Zhang   stmp[3] = bas->rootbuflen[PETSCSF_REMOTE];
10771438e86SJunchao Zhang 
10871438e86SJunchao Zhang   ierr = MPIU_Allreduce(stmp,rtmp,4,MPIU_INT,MPI_MAX,comm);CHKERRMPI(ierr);
10971438e86SJunchao Zhang 
11071438e86SJunchao Zhang   sf->nRemoteRootRanksMax   = rtmp[0];
11171438e86SJunchao Zhang   sf->leafbuflen_rmax       = rtmp[1];
11271438e86SJunchao Zhang   bas->nRemoteLeafRanksMax  = rtmp[2];
11371438e86SJunchao Zhang   bas->rootbuflen_rmax      = rtmp[3];
11471438e86SJunchao Zhang 
11571438e86SJunchao Zhang   /* Total four rounds of MPI communications to set up the nvshmem fields */
11671438e86SJunchao Zhang 
11771438e86SJunchao Zhang   /* Root ranks to leaf ranks: send info about rootsigdisp[] and rootbufdisp[] */
11871438e86SJunchao Zhang   ierr = PetscMalloc2(nRemoteRootRanks,&sf->rootsigdisp,nRemoteRootRanks,&sf->rootbufdisp);CHKERRQ(ierr);
11971438e86SJunchao Zhang   for (i=0; i<nRemoteRootRanks; i++) {ierr = MPI_Irecv(&sf->rootsigdisp[i],1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm,&leafreqs[i]);CHKERRMPI(ierr);} /* Leaves recv */
12071438e86SJunchao Zhang   for (i=0; i<nRemoteLeafRanks; i++) {ierr = MPI_Send(&i,1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm);CHKERRMPI(ierr);} /* Roots send. Note i changes, so we use MPI_Send. */
12171438e86SJunchao Zhang   ierr = MPI_Waitall(nRemoteRootRanks,leafreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
12271438e86SJunchao Zhang 
12371438e86SJunchao Zhang   for (i=0; i<nRemoteRootRanks; i++) {ierr = MPI_Irecv(&sf->rootbufdisp[i],1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm,&leafreqs[i]);CHKERRMPI(ierr);} /* Leaves recv */
12471438e86SJunchao Zhang   for (i=0; i<nRemoteLeafRanks; i++) {
12571438e86SJunchao Zhang     tmp  = bas->ioffset[i+bas->ndiranks] - bas->ioffset[bas->ndiranks];
12671438e86SJunchao Zhang     ierr = MPI_Send(&tmp,1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm);CHKERRMPI(ierr);  /* Roots send. Note tmp changes, so we use MPI_Send. */
12771438e86SJunchao Zhang   }
12871438e86SJunchao Zhang   ierr = MPI_Waitall(nRemoteRootRanks,leafreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
12971438e86SJunchao Zhang 
13071438e86SJunchao Zhang   cerr = cudaMalloc((void**)&sf->rootbufdisp_d,nRemoteRootRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
13171438e86SJunchao Zhang   cerr = cudaMalloc((void**)&sf->rootsigdisp_d,nRemoteRootRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
13271438e86SJunchao Zhang   cerr = cudaMalloc((void**)&sf->ranks_d,nRemoteRootRanks*sizeof(PetscMPIInt));CHKERRCUDA(cerr);
13371438e86SJunchao Zhang   cerr = cudaMalloc((void**)&sf->roffset_d,(nRemoteRootRanks+1)*sizeof(PetscInt));CHKERRCUDA(cerr);
13471438e86SJunchao Zhang 
13571438e86SJunchao Zhang   cerr = cudaMemcpyAsync(sf->rootbufdisp_d,sf->rootbufdisp,nRemoteRootRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
13671438e86SJunchao Zhang   cerr = cudaMemcpyAsync(sf->rootsigdisp_d,sf->rootsigdisp,nRemoteRootRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
13771438e86SJunchao Zhang   cerr = cudaMemcpyAsync(sf->ranks_d,sf->ranks+sf->ndranks,nRemoteRootRanks*sizeof(PetscMPIInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
13871438e86SJunchao Zhang   cerr = cudaMemcpyAsync(sf->roffset_d,sf->roffset+sf->ndranks,(nRemoteRootRanks+1)*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
13971438e86SJunchao Zhang 
14071438e86SJunchao Zhang   /* Leaf ranks to root ranks: send info about leafsigdisp[] and leafbufdisp[] */
14171438e86SJunchao Zhang   ierr = PetscMalloc2(nRemoteLeafRanks,&bas->leafsigdisp,nRemoteLeafRanks,&bas->leafbufdisp);CHKERRQ(ierr);
14271438e86SJunchao Zhang   for (i=0; i<nRemoteLeafRanks; i++) {ierr = MPI_Irecv(&bas->leafsigdisp[i],1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm,&rootreqs[i]);CHKERRMPI(ierr);}
14371438e86SJunchao Zhang   for (i=0; i<nRemoteRootRanks; i++) {ierr = MPI_Send(&i,1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm);CHKERRMPI(ierr);}
14471438e86SJunchao Zhang   ierr = MPI_Waitall(nRemoteLeafRanks,rootreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
14571438e86SJunchao Zhang 
14671438e86SJunchao Zhang   for (i=0; i<nRemoteLeafRanks; i++) {ierr = MPI_Irecv(&bas->leafbufdisp[i],1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm,&rootreqs[i]);CHKERRMPI(ierr);}
14771438e86SJunchao Zhang   for (i=0; i<nRemoteRootRanks; i++) {
14871438e86SJunchao Zhang     tmp  = sf->roffset[i+sf->ndranks] - sf->roffset[sf->ndranks];
14971438e86SJunchao Zhang     ierr = MPI_Send(&tmp,1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm);CHKERRMPI(ierr);
15071438e86SJunchao Zhang   }
15171438e86SJunchao Zhang   ierr = MPI_Waitall(nRemoteLeafRanks,rootreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr);
15271438e86SJunchao Zhang 
15371438e86SJunchao Zhang   cerr = cudaMalloc((void**)&bas->leafbufdisp_d,nRemoteLeafRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
15471438e86SJunchao Zhang   cerr = cudaMalloc((void**)&bas->leafsigdisp_d,nRemoteLeafRanks*sizeof(PetscInt));CHKERRCUDA(cerr);
15571438e86SJunchao Zhang   cerr = cudaMalloc((void**)&bas->iranks_d,nRemoteLeafRanks*sizeof(PetscMPIInt));CHKERRCUDA(cerr);
15671438e86SJunchao Zhang   cerr = cudaMalloc((void**)&bas->ioffset_d,(nRemoteLeafRanks+1)*sizeof(PetscInt));CHKERRCUDA(cerr);
15771438e86SJunchao Zhang 
15871438e86SJunchao Zhang   cerr = cudaMemcpyAsync(bas->leafbufdisp_d,bas->leafbufdisp,nRemoteLeafRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
15971438e86SJunchao Zhang   cerr = cudaMemcpyAsync(bas->leafsigdisp_d,bas->leafsigdisp,nRemoteLeafRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
16071438e86SJunchao Zhang   cerr = cudaMemcpyAsync(bas->iranks_d,bas->iranks+bas->ndiranks,nRemoteLeafRanks*sizeof(PetscMPIInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
16171438e86SJunchao Zhang   cerr = cudaMemcpyAsync(bas->ioffset_d,bas->ioffset+bas->ndiranks,(nRemoteLeafRanks+1)*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr);
16271438e86SJunchao Zhang 
16371438e86SJunchao Zhang   ierr = PetscFree2(rootreqs,leafreqs);CHKERRQ(ierr);
16471438e86SJunchao Zhang   PetscFunctionReturn(0);
16571438e86SJunchao Zhang }
16671438e86SJunchao Zhang 
16771438e86SJunchao Zhang PetscErrorCode PetscSFLinkNvshmemCheck(PetscSF sf,PetscMemType rootmtype,const void *rootdata,PetscMemType leafmtype,const void *leafdata,PetscBool *use_nvshmem)
16871438e86SJunchao Zhang {
16971438e86SJunchao Zhang   PetscErrorCode   ierr;
17071438e86SJunchao Zhang   MPI_Comm         comm;
17171438e86SJunchao Zhang   PetscBool        isBasic;
17271438e86SJunchao Zhang   PetscMPIInt      result = MPI_UNEQUAL;
17371438e86SJunchao Zhang 
17471438e86SJunchao Zhang   PetscFunctionBegin;
17571438e86SJunchao Zhang   ierr = PetscObjectGetComm((PetscObject)sf,&comm);CHKERRQ(ierr);
17671438e86SJunchao Zhang   /* Check if the sf is eligible for NVSHMEM, if we have not checked yet.
17771438e86SJunchao Zhang      Note the check result <use_nvshmem> must be the same over comm, since an SFLink must be collectively either NVSHMEM or MPI.
17871438e86SJunchao Zhang   */
17971438e86SJunchao Zhang   sf->checked_nvshmem_eligibility = PETSC_TRUE;
18071438e86SJunchao Zhang   if (sf->use_nvshmem && !sf->checked_nvshmem_eligibility) {
18171438e86SJunchao Zhang     /* Only use NVSHMEM for SFBASIC on PETSC_COMM_WORLD  */
18271438e86SJunchao Zhang     ierr = PetscObjectTypeCompare((PetscObject)sf,PETSCSFBASIC,&isBasic);CHKERRQ(ierr);
18371438e86SJunchao Zhang     if (isBasic) {ierr = MPI_Comm_compare(PETSC_COMM_WORLD,comm,&result);CHKERRMPI(ierr);}
18471438e86SJunchao Zhang     if (!isBasic || (result != MPI_IDENT && result != MPI_CONGRUENT)) sf->use_nvshmem = PETSC_FALSE; /* If not eligible, clear the flag so that we don't try again */
18571438e86SJunchao Zhang 
18671438e86SJunchao Zhang     /* Do further check: If on a rank, both rootdata and leafdata are NULL, we might think they are PETSC_MEMTYPE_CUDA (or HOST)
18771438e86SJunchao Zhang        and then use NVSHMEM. But if root/leafmtypes on other ranks are PETSC_MEMTYPE_HOST (or DEVICE), this would lead to
18871438e86SJunchao Zhang        inconsistency on the return value <use_nvshmem>. To be safe, we simply disable nvshmem on these rare SFs.
18971438e86SJunchao Zhang     */
19071438e86SJunchao Zhang     if (sf->use_nvshmem) {
19171438e86SJunchao Zhang       PetscInt hasNullRank = (!rootdata && !leafdata) ? 1 : 0;
19271438e86SJunchao Zhang       ierr = MPI_Allreduce(MPI_IN_PLACE,&hasNullRank,1,MPIU_INT,MPI_LOR,comm);CHKERRMPI(ierr);
19371438e86SJunchao Zhang       if (hasNullRank) sf->use_nvshmem = PETSC_FALSE;
19471438e86SJunchao Zhang     }
19571438e86SJunchao Zhang     sf->checked_nvshmem_eligibility = PETSC_TRUE; /* If eligible, don't do above check again */
19671438e86SJunchao Zhang   }
19771438e86SJunchao Zhang 
19871438e86SJunchao Zhang   /* Check if rootmtype and leafmtype collectively are PETSC_MEMTYPE_CUDA */
19971438e86SJunchao Zhang   if (sf->use_nvshmem) {
20071438e86SJunchao Zhang     PetscInt oneCuda = (!rootdata || PetscMemTypeCUDA(rootmtype)) && (!leafdata || PetscMemTypeCUDA(leafmtype)) ? 1 : 0; /* Do I use cuda for both root&leafmtype? */
20171438e86SJunchao Zhang     PetscInt allCuda = oneCuda; /* Assume the same for all ranks. But if not, in opt mode, return value <use_nvshmem> won't be collective! */
20271438e86SJunchao Zhang    #if defined(PETSC_USE_DEBUG)  /* Check in debug mode. Note MPI_Allreduce is expensive, so only in debug mode */
20371438e86SJunchao Zhang     ierr = MPI_Allreduce(&oneCuda,&allCuda,1,MPIU_INT,MPI_LAND,comm);CHKERRMPI(ierr);
20471438e86SJunchao Zhang     if (allCuda != oneCuda) SETERRQ(comm,PETSC_ERR_SUP,"root/leaf mtypes are inconsistent among ranks, which may lead to SF nvshmem failure in opt mode. Add -use_nvshmem 0 to disable it.");
20571438e86SJunchao Zhang    #endif
20671438e86SJunchao Zhang     if (allCuda) {
20771438e86SJunchao Zhang       ierr = PetscNvshmemInitializeCheck();CHKERRQ(ierr);
20871438e86SJunchao Zhang       if (!sf->setup_nvshmem) { /* Set up nvshmem related fields on this SF on-demand */
20971438e86SJunchao Zhang         ierr = PetscSFSetUp_Basic_NVSHMEM(sf);CHKERRQ(ierr);
21071438e86SJunchao Zhang         sf->setup_nvshmem = PETSC_TRUE;
21171438e86SJunchao Zhang       }
21271438e86SJunchao Zhang       *use_nvshmem = PETSC_TRUE;
21371438e86SJunchao Zhang     } else {
21471438e86SJunchao Zhang       *use_nvshmem = PETSC_FALSE;
21571438e86SJunchao Zhang     }
21671438e86SJunchao Zhang   } else {
21771438e86SJunchao Zhang     *use_nvshmem = PETSC_FALSE;
21871438e86SJunchao Zhang   }
21971438e86SJunchao Zhang   PetscFunctionReturn(0);
22071438e86SJunchao Zhang }
22171438e86SJunchao Zhang 
22271438e86SJunchao Zhang /* Build dependence between <stream> and <remoteCommStream> at the entry of NVSHMEM communication */
22371438e86SJunchao Zhang static PetscErrorCode PetscSFLinkBuildDependenceBegin(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
22471438e86SJunchao Zhang {
22571438e86SJunchao Zhang   cudaError_t    cerr;
22671438e86SJunchao Zhang   PetscSF_Basic  *bas = (PetscSF_Basic *)sf->data;
22771438e86SJunchao Zhang   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF)? bas->rootbuflen[PETSCSF_REMOTE] : sf->leafbuflen[PETSCSF_REMOTE];
22871438e86SJunchao Zhang 
22971438e86SJunchao Zhang   PetscFunctionBegin;
23071438e86SJunchao Zhang   if (buflen) {
23171438e86SJunchao Zhang     cerr = cudaEventRecord(link->dataReady,link->stream);CHKERRCUDA(cerr);
23271438e86SJunchao Zhang     cerr = cudaStreamWaitEvent(link->remoteCommStream,link->dataReady,0);CHKERRCUDA(cerr);
23371438e86SJunchao Zhang   }
23471438e86SJunchao Zhang   PetscFunctionReturn(0);
23571438e86SJunchao Zhang }
23671438e86SJunchao Zhang 
23771438e86SJunchao Zhang /* Build dependence between <stream> and <remoteCommStream> at the exit of NVSHMEM communication */
23871438e86SJunchao Zhang static PetscErrorCode PetscSFLinkBuildDependenceEnd(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
23971438e86SJunchao Zhang {
24071438e86SJunchao Zhang   cudaError_t    cerr;
24171438e86SJunchao Zhang   PetscSF_Basic  *bas = (PetscSF_Basic *)sf->data;
24271438e86SJunchao Zhang   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF)? sf->leafbuflen[PETSCSF_REMOTE] : bas->rootbuflen[PETSCSF_REMOTE];
24371438e86SJunchao Zhang 
24471438e86SJunchao Zhang   PetscFunctionBegin;
24571438e86SJunchao Zhang   /* If unpack to non-null device buffer, build the endRemoteComm dependance */
24671438e86SJunchao Zhang   if (buflen) {
24771438e86SJunchao Zhang     cerr = cudaEventRecord(link->endRemoteComm,link->remoteCommStream);CHKERRCUDA(cerr);
24871438e86SJunchao Zhang     cerr = cudaStreamWaitEvent(link->stream,link->endRemoteComm,0);CHKERRCUDA(cerr);
24971438e86SJunchao Zhang   }
25071438e86SJunchao Zhang   PetscFunctionReturn(0);
25171438e86SJunchao Zhang }
25271438e86SJunchao Zhang 
25371438e86SJunchao Zhang /* Send/Put signals to remote ranks
25471438e86SJunchao Zhang 
25571438e86SJunchao Zhang  Input parameters:
25671438e86SJunchao Zhang   + n        - Number of remote ranks
25771438e86SJunchao Zhang   . sig      - Signal address in symmetric heap
25871438e86SJunchao Zhang   . sigdisp  - To i-th rank, use its signal at offset sigdisp[i]
25971438e86SJunchao Zhang   . ranks    - remote ranks
26071438e86SJunchao Zhang   - newval   - Set signals to this value
26171438e86SJunchao Zhang */
26271438e86SJunchao Zhang __global__ static void NvshmemSendSignals(PetscInt n,uint64_t *sig,PetscInt *sigdisp,PetscMPIInt *ranks,uint64_t newval)
26371438e86SJunchao Zhang {
26471438e86SJunchao Zhang   int i = blockIdx.x*blockDim.x + threadIdx.x;
26571438e86SJunchao Zhang 
26671438e86SJunchao Zhang   /* Each thread puts one remote signal */
26771438e86SJunchao Zhang   if (i < n) nvshmemx_uint64_signal(sig+sigdisp[i],newval,ranks[i]);
26871438e86SJunchao Zhang }
26971438e86SJunchao Zhang 
27071438e86SJunchao Zhang /* Wait until local signals equal to the expected value and then set them to a new value
27171438e86SJunchao Zhang 
27271438e86SJunchao Zhang  Input parameters:
27371438e86SJunchao Zhang   + n        - Number of signals
27471438e86SJunchao Zhang   . sig      - Local signal address
27571438e86SJunchao Zhang   . expval   - expected value
27671438e86SJunchao Zhang   - newval   - Set signals to this new value
27771438e86SJunchao Zhang */
27871438e86SJunchao Zhang __global__ static void NvshmemWaitSignals(PetscInt n,uint64_t *sig,uint64_t expval,uint64_t newval)
27971438e86SJunchao Zhang {
28071438e86SJunchao Zhang #if 0
28171438e86SJunchao Zhang   /* Akhil Langer@NVIDIA said using 1 thread and nvshmem_uint64_wait_until_all is better */
28271438e86SJunchao Zhang   int i = blockIdx.x*blockDim.x + threadIdx.x;
28371438e86SJunchao Zhang   if (i < n) {
28471438e86SJunchao Zhang     nvshmem_signal_wait_until(sig+i,NVSHMEM_CMP_EQ,expval);
28571438e86SJunchao Zhang     sig[i] = newval;
28671438e86SJunchao Zhang   }
28771438e86SJunchao Zhang #else
28871438e86SJunchao Zhang   nvshmem_uint64_wait_until_all(sig,n,NULL/*no mask*/,NVSHMEM_CMP_EQ,expval);
28971438e86SJunchao Zhang   for (int i=0; i<n; i++) sig[i] = newval;
29071438e86SJunchao Zhang #endif
29171438e86SJunchao Zhang }
29271438e86SJunchao Zhang 
29371438e86SJunchao Zhang /* ===========================================================================================================
29471438e86SJunchao Zhang 
29571438e86SJunchao Zhang    A set of routines to support receiver initiated communication using the get method
29671438e86SJunchao Zhang 
29771438e86SJunchao Zhang     The getting protocol is:
29871438e86SJunchao Zhang 
29971438e86SJunchao Zhang     Sender has a send buf (sbuf) and a signal variable (ssig);  Receiver has a recv buf (rbuf) and a signal variable (rsig);
30071438e86SJunchao Zhang     All signal variables have an initial value 0.
30171438e86SJunchao Zhang 
30271438e86SJunchao Zhang     Sender:                                 |  Receiver:
30371438e86SJunchao Zhang   1.  Wait ssig be 0, then set it to 1
30471438e86SJunchao Zhang   2.  Pack data into stand alone sbuf       |
30571438e86SJunchao Zhang   3.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
30671438e86SJunchao Zhang                                             |   2. Get data from remote sbuf to local rbuf
30771438e86SJunchao Zhang                                             |   3. Put 1 to sender's ssig
30871438e86SJunchao Zhang                                             |   4. Unpack data from local rbuf
30971438e86SJunchao Zhang    ===========================================================================================================*/
31071438e86SJunchao Zhang /* PrePack operation -- since sender will overwrite the send buffer which the receiver might be getting data from.
31171438e86SJunchao Zhang    Sender waits for signals (from receivers) indicating receivers have finished getting data
31271438e86SJunchao Zhang */
31371438e86SJunchao Zhang PetscErrorCode PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
31471438e86SJunchao Zhang {
31571438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
31671438e86SJunchao Zhang   uint64_t          *sig;
31771438e86SJunchao Zhang   PetscInt          n;
31871438e86SJunchao Zhang 
31971438e86SJunchao Zhang   PetscFunctionBegin;
32071438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
32171438e86SJunchao Zhang     sig = link->rootSendSig;            /* leaf ranks set my rootSendsig */
32271438e86SJunchao Zhang     n   = bas->nRemoteLeafRanks;
32371438e86SJunchao Zhang   } else { /* LEAF2ROOT */
32471438e86SJunchao Zhang     sig = link->leafSendSig;
32571438e86SJunchao Zhang     n   = sf->nRemoteRootRanks;
32671438e86SJunchao Zhang   }
32771438e86SJunchao Zhang 
32871438e86SJunchao Zhang   if (n) {
32971438e86SJunchao Zhang     NvshmemWaitSignals<<<1,1,0,link->remoteCommStream>>>(n,sig,0,1); /* wait the signals to be 0, then set them to 1 */
33071438e86SJunchao Zhang     cudaError_t cerr = cudaGetLastError();CHKERRCUDA(cerr);
33171438e86SJunchao Zhang   }
33271438e86SJunchao Zhang   PetscFunctionReturn(0);
33371438e86SJunchao Zhang }
33471438e86SJunchao Zhang 
33571438e86SJunchao Zhang /* n thread blocks. Each takes in charge one remote rank */
33671438e86SJunchao Zhang __global__ static void GetDataFromRemotelyAccessible(PetscInt nsrcranks,PetscMPIInt *srcranks,const char *src,PetscInt *srcdisp,char *dst,PetscInt *dstdisp,PetscInt unitbytes)
33771438e86SJunchao Zhang {
33871438e86SJunchao Zhang   int               bid = blockIdx.x;
33971438e86SJunchao Zhang   PetscMPIInt       pe  = srcranks[bid];
34071438e86SJunchao Zhang 
34171438e86SJunchao Zhang   if (!nvshmem_ptr(src,pe)) {
34271438e86SJunchao Zhang     PetscInt nelems = (dstdisp[bid+1]-dstdisp[bid])*unitbytes;
34371438e86SJunchao Zhang     nvshmem_getmem_nbi(dst+(dstdisp[bid]-dstdisp[0])*unitbytes,src+srcdisp[bid]*unitbytes,nelems,pe);
34471438e86SJunchao Zhang   }
34571438e86SJunchao Zhang }
34671438e86SJunchao Zhang 
34771438e86SJunchao Zhang /* Start communication -- Get data in the given direction */
34871438e86SJunchao Zhang PetscErrorCode PetscSFLinkGetDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
34971438e86SJunchao Zhang {
35071438e86SJunchao Zhang   PetscErrorCode    ierr;
35171438e86SJunchao Zhang   cudaError_t       cerr;
35271438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
35371438e86SJunchao Zhang 
35471438e86SJunchao Zhang   PetscInt          nsrcranks,ndstranks,nLocallyAccessible = 0;
35571438e86SJunchao Zhang 
35671438e86SJunchao Zhang   char              *src,*dst;
35771438e86SJunchao Zhang   PetscInt          *srcdisp_h,*dstdisp_h;
35871438e86SJunchao Zhang   PetscInt          *srcdisp_d,*dstdisp_d;
35971438e86SJunchao Zhang   PetscMPIInt       *srcranks_h;
36071438e86SJunchao Zhang   PetscMPIInt       *srcranks_d,*dstranks_d;
36171438e86SJunchao Zhang   uint64_t          *dstsig;
36271438e86SJunchao Zhang   PetscInt          *dstsigdisp_d;
36371438e86SJunchao Zhang 
36471438e86SJunchao Zhang   PetscFunctionBegin;
36571438e86SJunchao Zhang   ierr = PetscSFLinkBuildDependenceBegin(sf,link,direction);CHKERRQ(ierr);
36671438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* src is root, dst is leaf; we will move data from src to dst */
36771438e86SJunchao Zhang     nsrcranks    = sf->nRemoteRootRanks;
36871438e86SJunchao Zhang     src          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* root buf is the send buf; it is in symmetric heap */
36971438e86SJunchao Zhang 
37071438e86SJunchao Zhang     srcdisp_h    = sf->rootbufdisp;       /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
37171438e86SJunchao Zhang     srcdisp_d    = sf->rootbufdisp_d;
37271438e86SJunchao Zhang     srcranks_h   = sf->ranks+sf->ndranks; /* my (remote) root ranks */
37371438e86SJunchao Zhang     srcranks_d   = sf->ranks_d;
37471438e86SJunchao Zhang 
37571438e86SJunchao Zhang     ndstranks    = bas->nRemoteLeafRanks;
37671438e86SJunchao Zhang     dst          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* recv buf is the local leaf buf, also in symmetric heap */
37771438e86SJunchao Zhang 
37871438e86SJunchao Zhang     dstdisp_h    = sf->roffset+sf->ndranks; /* offsets of the local leaf buf. Note dstdisp[0] is not necessarily 0 */
37971438e86SJunchao Zhang     dstdisp_d    = sf->roffset_d;
38071438e86SJunchao Zhang     dstranks_d   = bas->iranks_d; /* my (remote) leaf ranks */
38171438e86SJunchao Zhang 
38271438e86SJunchao Zhang     dstsig       = link->leafRecvSig;
38371438e86SJunchao Zhang     dstsigdisp_d = bas->leafsigdisp_d;
38471438e86SJunchao Zhang   } else { /* src is leaf, dst is root; we will move data from src to dst */
38571438e86SJunchao Zhang     nsrcranks    = bas->nRemoteLeafRanks;
38671438e86SJunchao Zhang     src          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* leaf buf is the send buf */
38771438e86SJunchao Zhang 
38871438e86SJunchao Zhang     srcdisp_h    = bas->leafbufdisp;       /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
38971438e86SJunchao Zhang     srcdisp_d    = bas->leafbufdisp_d;
39071438e86SJunchao Zhang     srcranks_h   = bas->iranks+bas->ndiranks; /* my (remote) root ranks */
39171438e86SJunchao Zhang     srcranks_d   = bas->iranks_d;
39271438e86SJunchao Zhang 
39371438e86SJunchao Zhang     ndstranks    = sf->nRemoteRootRanks;
39471438e86SJunchao Zhang     dst          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* the local root buf is the recv buf */
39571438e86SJunchao Zhang 
39671438e86SJunchao Zhang     dstdisp_h    = bas->ioffset+bas->ndiranks; /* offsets of the local root buf. Note dstdisp[0] is not necessarily 0 */
39771438e86SJunchao Zhang     dstdisp_d    = bas->ioffset_d;
39871438e86SJunchao Zhang     dstranks_d   = sf->ranks_d; /* my (remote) root ranks */
39971438e86SJunchao Zhang 
40071438e86SJunchao Zhang     dstsig       = link->rootRecvSig;
40171438e86SJunchao Zhang     dstsigdisp_d = sf->rootsigdisp_d;
40271438e86SJunchao Zhang   }
40371438e86SJunchao Zhang 
40471438e86SJunchao Zhang   /* After Pack operation -- src tells dst ranks that they are allowed to get data */
40571438e86SJunchao Zhang   if (ndstranks) {
40671438e86SJunchao Zhang     NvshmemSendSignals<<<(ndstranks+255)/256,256,0,link->remoteCommStream>>>(ndstranks,dstsig,dstsigdisp_d,dstranks_d,1); /* set signals to 1 */
40771438e86SJunchao Zhang     cerr = cudaGetLastError();CHKERRCUDA(cerr);
40871438e86SJunchao Zhang   }
40971438e86SJunchao Zhang 
41071438e86SJunchao Zhang   /* dst waits for signals (permissions) from src ranks to start getting data */
41171438e86SJunchao Zhang   if (nsrcranks) {
41271438e86SJunchao Zhang     NvshmemWaitSignals<<<1,1,0,link->remoteCommStream>>>(nsrcranks,dstsig,1,0); /* wait the signals to be 1, then set them to 0 */
41371438e86SJunchao Zhang     cerr = cudaGetLastError();CHKERRCUDA(cerr);
41471438e86SJunchao Zhang   }
41571438e86SJunchao Zhang 
41671438e86SJunchao Zhang   /* dst gets data from src ranks using non-blocking nvshmem_gets, which are finished in PetscSFLinkGetDataEnd_NVSHMEM() */
41771438e86SJunchao Zhang 
41871438e86SJunchao Zhang   /* Count number of locally accessible src ranks, which should be a small number */
41971438e86SJunchao Zhang   for (int i=0; i<nsrcranks; i++) {if (nvshmem_ptr(src,srcranks_h[i])) nLocallyAccessible++;}
42071438e86SJunchao Zhang 
42171438e86SJunchao Zhang   /* Get data from remotely accessible PEs */
42271438e86SJunchao Zhang   if (nLocallyAccessible < nsrcranks) {
42371438e86SJunchao Zhang     GetDataFromRemotelyAccessible<<<nsrcranks,1,0,link->remoteCommStream>>>(nsrcranks,srcranks_d,src,srcdisp_d,dst,dstdisp_d,link->unitbytes);
42471438e86SJunchao Zhang     cerr = cudaGetLastError();CHKERRCUDA(cerr);
42571438e86SJunchao Zhang   }
42671438e86SJunchao Zhang 
42771438e86SJunchao Zhang   /* Get data from locally accessible PEs */
42871438e86SJunchao Zhang   if (nLocallyAccessible) {
42971438e86SJunchao Zhang     for (int i=0; i<nsrcranks; i++) {
43071438e86SJunchao Zhang       int pe = srcranks_h[i];
43171438e86SJunchao Zhang       if (nvshmem_ptr(src,pe)) {
43271438e86SJunchao Zhang         size_t nelems = (dstdisp_h[i+1]-dstdisp_h[i])*link->unitbytes;
43371438e86SJunchao Zhang         nvshmemx_getmem_nbi_on_stream(dst+(dstdisp_h[i]-dstdisp_h[0])*link->unitbytes,src+srcdisp_h[i]*link->unitbytes,nelems,pe,link->remoteCommStream);
43471438e86SJunchao Zhang       }
43571438e86SJunchao Zhang     }
43671438e86SJunchao Zhang   }
43771438e86SJunchao Zhang   PetscFunctionReturn(0);
43871438e86SJunchao Zhang }
43971438e86SJunchao Zhang 
44071438e86SJunchao Zhang /* Finish the communication (can be done before Unpack)
44171438e86SJunchao Zhang    Receiver tells its senders that they are allowed to reuse their send buffer (since receiver has got data from their send buffer)
44271438e86SJunchao Zhang */
44371438e86SJunchao Zhang PetscErrorCode PetscSFLinkGetDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
44471438e86SJunchao Zhang {
44571438e86SJunchao Zhang   PetscErrorCode    ierr;
44671438e86SJunchao Zhang   cudaError_t       cerr;
44771438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
44871438e86SJunchao Zhang   uint64_t          *srcsig;
44971438e86SJunchao Zhang   PetscInt          nsrcranks,*srcsigdisp;
45071438e86SJunchao Zhang   PetscMPIInt       *srcranks;
45171438e86SJunchao Zhang 
45271438e86SJunchao Zhang   PetscFunctionBegin;
45371438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
45471438e86SJunchao Zhang     nsrcranks   = sf->nRemoteRootRanks;
45571438e86SJunchao Zhang     srcsig      = link->rootSendSig;     /* I want to set their root signal */
45671438e86SJunchao Zhang     srcsigdisp  = sf->rootsigdisp_d;     /* offset of each root signal */
45771438e86SJunchao Zhang     srcranks    = sf->ranks_d;           /* ranks of the n root ranks */
45871438e86SJunchao Zhang   } else { /* LEAF2ROOT, root ranks are getting data */
45971438e86SJunchao Zhang     nsrcranks   = bas->nRemoteLeafRanks;
46071438e86SJunchao Zhang     srcsig      = link->leafSendSig;
46171438e86SJunchao Zhang     srcsigdisp  = bas->leafsigdisp_d;
46271438e86SJunchao Zhang     srcranks    = bas->iranks_d;
46371438e86SJunchao Zhang   }
46471438e86SJunchao Zhang 
46571438e86SJunchao Zhang   if (nsrcranks) {
46671438e86SJunchao Zhang     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Finish the nonblocking get, so that we can unpack afterwards */
46771438e86SJunchao Zhang     cerr = cudaGetLastError();CHKERRCUDA(cerr);
46871438e86SJunchao Zhang     NvshmemSendSignals<<<(nsrcranks+511)/512,512,0,link->remoteCommStream>>>(nsrcranks,srcsig,srcsigdisp,srcranks,0); /* set signals to 0 */
46971438e86SJunchao Zhang     cerr = cudaGetLastError();CHKERRCUDA(cerr);
47071438e86SJunchao Zhang   }
47171438e86SJunchao Zhang   ierr = PetscSFLinkBuildDependenceEnd(sf,link,direction);CHKERRQ(ierr);
47271438e86SJunchao Zhang   PetscFunctionReturn(0);
47371438e86SJunchao Zhang }
47471438e86SJunchao Zhang 
47571438e86SJunchao Zhang /* ===========================================================================================================
47671438e86SJunchao Zhang 
47771438e86SJunchao Zhang    A set of routines to support sender initiated communication using the put-based method (the default)
47871438e86SJunchao Zhang 
47971438e86SJunchao Zhang     The putting protocol is:
48071438e86SJunchao Zhang 
48171438e86SJunchao Zhang     Sender has a send buf (sbuf) and a send signal var (ssig);  Receiver has a stand-alone recv buf (rbuf)
48271438e86SJunchao Zhang     and a recv signal var (rsig); All signal variables have an initial value 0. rbuf is allocated by SF and
48371438e86SJunchao Zhang     is in nvshmem space.
48471438e86SJunchao Zhang 
48571438e86SJunchao Zhang     Sender:                                 |  Receiver:
48671438e86SJunchao Zhang                                             |
48771438e86SJunchao Zhang   1.  Pack data into sbuf                   |
48871438e86SJunchao Zhang   2.  Wait ssig be 0, then set it to 1      |
48971438e86SJunchao Zhang   3.  Put data to remote stand-alone rbuf   |
49071438e86SJunchao Zhang   4.  Fence // make sure 5 happens after 3  |
49171438e86SJunchao Zhang   5.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
49271438e86SJunchao Zhang                                             |   2. Unpack data from local rbuf
49371438e86SJunchao Zhang                                             |   3. Put 0 to sender's ssig
49471438e86SJunchao Zhang    ===========================================================================================================*/
49571438e86SJunchao Zhang 
49671438e86SJunchao Zhang /* n thread blocks. Each takes in charge one remote rank */
49771438e86SJunchao Zhang __global__ static void WaitAndPutDataToRemotelyAccessible(PetscInt ndstranks,PetscMPIInt *dstranks,char *dst,PetscInt *dstdisp,const char *src,PetscInt *srcdisp,uint64_t *srcsig,PetscInt unitbytes)
49871438e86SJunchao Zhang {
49971438e86SJunchao Zhang   int               bid = blockIdx.x;
50071438e86SJunchao Zhang   PetscMPIInt       pe  = dstranks[bid];
50171438e86SJunchao Zhang 
50271438e86SJunchao Zhang   if (!nvshmem_ptr(dst,pe)) {
50371438e86SJunchao Zhang     PetscInt nelems = (srcdisp[bid+1]-srcdisp[bid])*unitbytes;
50471438e86SJunchao Zhang     nvshmem_uint64_wait_until(srcsig+bid,NVSHMEM_CMP_EQ,0); /* Wait until the sig = 0 */
50571438e86SJunchao Zhang     srcsig[bid] = 1;
50671438e86SJunchao Zhang     nvshmem_putmem_nbi(dst+dstdisp[bid]*unitbytes,src+(srcdisp[bid]-srcdisp[0])*unitbytes,nelems,pe);
50771438e86SJunchao Zhang   }
50871438e86SJunchao Zhang }
50971438e86SJunchao Zhang 
51071438e86SJunchao Zhang /* one-thread kernel, which takes in charge all locally accesible */
51171438e86SJunchao Zhang __global__ static void WaitSignalsFromLocallyAccessible(PetscInt ndstranks,PetscMPIInt *dstranks,uint64_t *srcsig,const char *dst)
51271438e86SJunchao Zhang {
51371438e86SJunchao Zhang   for (int i=0; i<ndstranks; i++) {
51471438e86SJunchao Zhang     int pe = dstranks[i];
51571438e86SJunchao Zhang     if (nvshmem_ptr(dst,pe)) {
51671438e86SJunchao Zhang       nvshmem_uint64_wait_until(srcsig+i,NVSHMEM_CMP_EQ,0); /* Wait until the sig = 0 */
51771438e86SJunchao Zhang       srcsig[i] = 1;
51871438e86SJunchao Zhang     }
51971438e86SJunchao Zhang   }
52071438e86SJunchao Zhang }
52171438e86SJunchao Zhang 
52271438e86SJunchao Zhang /* Put data in the given direction  */
52371438e86SJunchao Zhang PetscErrorCode PetscSFLinkPutDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
52471438e86SJunchao Zhang {
52571438e86SJunchao Zhang   PetscErrorCode    ierr;
52671438e86SJunchao Zhang   cudaError_t       cerr;
52771438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
52871438e86SJunchao Zhang   PetscInt          ndstranks,nLocallyAccessible = 0;
52971438e86SJunchao Zhang   char              *src,*dst;
53071438e86SJunchao Zhang   PetscInt          *srcdisp_h,*dstdisp_h;
53171438e86SJunchao Zhang   PetscInt          *srcdisp_d,*dstdisp_d;
53271438e86SJunchao Zhang   PetscMPIInt       *dstranks_h;
53371438e86SJunchao Zhang   PetscMPIInt       *dstranks_d;
53471438e86SJunchao Zhang   uint64_t          *srcsig;
53571438e86SJunchao Zhang 
53671438e86SJunchao Zhang   PetscFunctionBegin;
53771438e86SJunchao Zhang   ierr = PetscSFLinkBuildDependenceBegin(sf,link,direction);CHKERRQ(ierr);
53871438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* put data in rootbuf to leafbuf  */
53971438e86SJunchao Zhang     ndstranks    = bas->nRemoteLeafRanks; /* number of (remote) leaf ranks */
54071438e86SJunchao Zhang     src          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* Both src & dst must be symmetric */
54171438e86SJunchao Zhang     dst          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
54271438e86SJunchao Zhang 
54371438e86SJunchao Zhang     srcdisp_h    = bas->ioffset+bas->ndiranks;  /* offsets of rootbuf. srcdisp[0] is not necessarily zero */
54471438e86SJunchao Zhang     srcdisp_d    = bas->ioffset_d;
54571438e86SJunchao Zhang     srcsig       = link->rootSendSig;
54671438e86SJunchao Zhang 
54771438e86SJunchao Zhang     dstdisp_h    = bas->leafbufdisp;            /* for my i-th remote leaf rank, I will access its leaf buf at offset leafbufdisp[i] */
54871438e86SJunchao Zhang     dstdisp_d    = bas->leafbufdisp_d;
54971438e86SJunchao Zhang     dstranks_h   = bas->iranks+bas->ndiranks;   /* remote leaf ranks */
55071438e86SJunchao Zhang     dstranks_d   = bas->iranks_d;
55171438e86SJunchao Zhang   } else { /* put data in leafbuf to rootbuf */
55271438e86SJunchao Zhang     ndstranks    = sf->nRemoteRootRanks;
55371438e86SJunchao Zhang     src          = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
55471438e86SJunchao Zhang     dst          = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
55571438e86SJunchao Zhang 
55671438e86SJunchao Zhang     srcdisp_h    = sf->roffset+sf->ndranks; /* offsets of leafbuf */
55771438e86SJunchao Zhang     srcdisp_d    = sf->roffset_d;
55871438e86SJunchao Zhang     srcsig       = link->leafSendSig;
55971438e86SJunchao Zhang 
56071438e86SJunchao Zhang     dstdisp_h    = sf->rootbufdisp;         /* for my i-th remote root rank, I will access its root buf at offset rootbufdisp[i] */
56171438e86SJunchao Zhang     dstdisp_d    = sf->rootbufdisp_d;
56271438e86SJunchao Zhang     dstranks_h   = sf->ranks+sf->ndranks;   /* remote root ranks */
56371438e86SJunchao Zhang     dstranks_d   = sf->ranks_d;
56471438e86SJunchao Zhang   }
56571438e86SJunchao Zhang 
56671438e86SJunchao Zhang   /* Wait for signals and then put data to dst ranks using non-blocking nvshmem_put, which are finished in PetscSFLinkPutDataEnd_NVSHMEM */
56771438e86SJunchao Zhang 
56871438e86SJunchao Zhang   /* Count number of locally accessible neighbors, which should be a small number */
56971438e86SJunchao Zhang   for (int i=0; i<ndstranks; i++) {if (nvshmem_ptr(dst,dstranks_h[i])) nLocallyAccessible++;}
57071438e86SJunchao Zhang 
57171438e86SJunchao Zhang   /* For remotely accessible PEs, send data to them in one kernel call */
57271438e86SJunchao Zhang   if (nLocallyAccessible < ndstranks) {
57371438e86SJunchao Zhang     WaitAndPutDataToRemotelyAccessible<<<ndstranks,1,0,link->remoteCommStream>>>(ndstranks,dstranks_d,dst,dstdisp_d,src,srcdisp_d,srcsig,link->unitbytes);
57471438e86SJunchao Zhang     cerr = cudaGetLastError();CHKERRCUDA(cerr);
57571438e86SJunchao Zhang   }
57671438e86SJunchao Zhang 
57771438e86SJunchao Zhang   /* For locally accessible PEs, use host API, which uses CUDA copy-engines and is much faster than device API */
57871438e86SJunchao Zhang   if (nLocallyAccessible) {
57971438e86SJunchao Zhang     WaitSignalsFromLocallyAccessible<<<1,1,0,link->remoteCommStream>>>(ndstranks,dstranks_d,srcsig,dst);
58071438e86SJunchao Zhang     for (int i=0; i<ndstranks; i++) {
58171438e86SJunchao Zhang       int pe = dstranks_h[i];
58271438e86SJunchao Zhang       if (nvshmem_ptr(dst,pe)) { /* If return a non-null pointer, then <pe> is locally accessible */
58371438e86SJunchao Zhang         size_t nelems = (srcdisp_h[i+1]-srcdisp_h[i])*link->unitbytes;
58471438e86SJunchao Zhang          /* Initiate the nonblocking communication */
58571438e86SJunchao Zhang         nvshmemx_putmem_nbi_on_stream(dst+dstdisp_h[i]*link->unitbytes,src+(srcdisp_h[i]-srcdisp_h[0])*link->unitbytes,nelems,pe,link->remoteCommStream);
58671438e86SJunchao Zhang       }
58771438e86SJunchao Zhang     }
58871438e86SJunchao Zhang   }
58971438e86SJunchao Zhang 
59071438e86SJunchao Zhang   if (nLocallyAccessible) {
59171438e86SJunchao Zhang     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Calling nvshmem_fence/quiet() does not fence the above nvshmemx_putmem_nbi_on_stream! */
59271438e86SJunchao Zhang   }
59371438e86SJunchao Zhang   PetscFunctionReturn(0);
59471438e86SJunchao Zhang }
59571438e86SJunchao Zhang 
59671438e86SJunchao Zhang /* A one-thread kernel. The thread takes in charge all remote PEs */
59771438e86SJunchao Zhang __global__ static void PutDataEnd(PetscInt nsrcranks,PetscInt ndstranks,PetscMPIInt *dstranks,uint64_t *dstsig,PetscInt *dstsigdisp)
59871438e86SJunchao Zhang {
59971438e86SJunchao Zhang   /* TODO: Shall we finished the non-blocking remote puts? */
60071438e86SJunchao Zhang 
60171438e86SJunchao Zhang   /* 1. Send a signal to each dst rank */
60271438e86SJunchao Zhang 
60371438e86SJunchao Zhang   /* According to Akhil@NVIDIA, IB is orderred, so no fence is needed for remote PEs.
60471438e86SJunchao Zhang      For local PEs, we already called nvshmemx_quiet_on_stream(). Therefore, we are good to send signals to all dst ranks now.
60571438e86SJunchao Zhang   */
60671438e86SJunchao Zhang   for (int i=0; i<ndstranks; i++) {nvshmemx_uint64_signal(dstsig+dstsigdisp[i],1,dstranks[i]);} /* set sig to 1 */
60771438e86SJunchao Zhang 
60871438e86SJunchao Zhang   /* 2. Wait for signals from src ranks (if any) */
60971438e86SJunchao Zhang   if (nsrcranks) {
61071438e86SJunchao Zhang     nvshmem_uint64_wait_until_all(dstsig,nsrcranks,NULL/*no mask*/,NVSHMEM_CMP_EQ,1); /* wait sigs to be 1, then set them to 0 */
61171438e86SJunchao Zhang     for (int i=0; i<nsrcranks; i++) dstsig[i] = 0;
61271438e86SJunchao Zhang   }
61371438e86SJunchao Zhang }
61471438e86SJunchao Zhang 
61571438e86SJunchao Zhang /* Finish the communication -- A receiver waits until it can access its receive buffer */
61671438e86SJunchao Zhang PetscErrorCode PetscSFLinkPutDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
61771438e86SJunchao Zhang {
61871438e86SJunchao Zhang   PetscErrorCode    ierr;
61971438e86SJunchao Zhang   cudaError_t       cerr;
62071438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
62171438e86SJunchao Zhang   PetscMPIInt       *dstranks;
62271438e86SJunchao Zhang   uint64_t          *dstsig;
62371438e86SJunchao Zhang   PetscInt          nsrcranks,ndstranks,*dstsigdisp;
62471438e86SJunchao Zhang 
62571438e86SJunchao Zhang   PetscFunctionBegin;
62671438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* put root data to leaf */
62771438e86SJunchao Zhang     nsrcranks    = sf->nRemoteRootRanks;
62871438e86SJunchao Zhang 
62971438e86SJunchao Zhang     ndstranks    = bas->nRemoteLeafRanks;
63071438e86SJunchao Zhang     dstranks     = bas->iranks_d;       /* leaf ranks */
63171438e86SJunchao Zhang     dstsig       = link->leafRecvSig;   /* I will set my leaf ranks's RecvSig */
63271438e86SJunchao Zhang     dstsigdisp   = bas->leafsigdisp_d;  /* for my i-th remote leaf rank, I will access its signal at offset leafsigdisp[i] */
63371438e86SJunchao Zhang   } else { /* LEAF2ROOT */
63471438e86SJunchao Zhang     nsrcranks    = bas->nRemoteLeafRanks;
63571438e86SJunchao Zhang 
63671438e86SJunchao Zhang     ndstranks    = sf->nRemoteRootRanks;
63771438e86SJunchao Zhang     dstranks     = sf->ranks_d;
63871438e86SJunchao Zhang     dstsig       = link->rootRecvSig;
63971438e86SJunchao Zhang     dstsigdisp   = sf->rootsigdisp_d;
64071438e86SJunchao Zhang   }
64171438e86SJunchao Zhang 
64271438e86SJunchao Zhang   if (nsrcranks || ndstranks) {
64371438e86SJunchao Zhang     PutDataEnd<<<1,1,0,link->remoteCommStream>>>(nsrcranks,ndstranks,dstranks,dstsig,dstsigdisp);
64471438e86SJunchao Zhang     cerr = cudaGetLastError();CHKERRCUDA(cerr);
64571438e86SJunchao Zhang   }
64671438e86SJunchao Zhang   ierr = PetscSFLinkBuildDependenceEnd(sf,link,direction);CHKERRQ(ierr);
64771438e86SJunchao Zhang   PetscFunctionReturn(0);
64871438e86SJunchao Zhang }
64971438e86SJunchao Zhang 
65071438e86SJunchao Zhang /* PostUnpack operation -- A receiver tells its senders that they are allowed to put data to here (it implies recv buf is free to take new data) */
65171438e86SJunchao Zhang PetscErrorCode PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction)
65271438e86SJunchao Zhang {
65371438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
65471438e86SJunchao Zhang   uint64_t          *srcsig;
65571438e86SJunchao Zhang   PetscInt          nsrcranks,*srcsigdisp_d;
65671438e86SJunchao Zhang   PetscMPIInt       *srcranks_d;
65771438e86SJunchao Zhang 
65871438e86SJunchao Zhang   PetscFunctionBegin;
65971438e86SJunchao Zhang   if (direction == PETSCSF_ROOT2LEAF) { /* I allow my root ranks to put data to me */
66071438e86SJunchao Zhang     nsrcranks    = sf->nRemoteRootRanks;
66171438e86SJunchao Zhang     srcsig       = link->rootSendSig;      /* I want to set their send signals */
66271438e86SJunchao Zhang     srcsigdisp_d = sf->rootsigdisp_d;      /* offset of each root signal */
66371438e86SJunchao Zhang     srcranks_d   = sf->ranks_d;            /* ranks of the n root ranks */
66471438e86SJunchao Zhang   } else { /* LEAF2ROOT */
66571438e86SJunchao Zhang     nsrcranks    = bas->nRemoteLeafRanks;
66671438e86SJunchao Zhang     srcsig       = link->leafSendSig;
66771438e86SJunchao Zhang     srcsigdisp_d = bas->leafsigdisp_d;
66871438e86SJunchao Zhang     srcranks_d   = bas->iranks_d;
66971438e86SJunchao Zhang   }
67071438e86SJunchao Zhang 
67171438e86SJunchao Zhang   if (nsrcranks) {
67271438e86SJunchao Zhang     NvshmemSendSignals<<<(nsrcranks+255)/256,256,0,link->remoteCommStream>>>(nsrcranks,srcsig,srcsigdisp_d,srcranks_d,0); /* Set remote signals to 0 */
67371438e86SJunchao Zhang     cudaError_t cerr = cudaGetLastError();CHKERRCUDA(cerr);
67471438e86SJunchao Zhang   }
67571438e86SJunchao Zhang   PetscFunctionReturn(0);
67671438e86SJunchao Zhang }
67771438e86SJunchao Zhang 
67871438e86SJunchao Zhang /* Destructor when the link uses nvshmem for communication */
67971438e86SJunchao Zhang static PetscErrorCode PetscSFLinkDestroy_NVSHMEM(PetscSF sf,PetscSFLink link)
68071438e86SJunchao Zhang {
68171438e86SJunchao Zhang   PetscErrorCode    ierr;
68271438e86SJunchao Zhang   cudaError_t       cerr;
68371438e86SJunchao Zhang 
68471438e86SJunchao Zhang   PetscFunctionBegin;
68571438e86SJunchao Zhang   cerr = cudaEventDestroy(link->dataReady);CHKERRCUDA(cerr);
68671438e86SJunchao Zhang   cerr = cudaEventDestroy(link->endRemoteComm);CHKERRCUDA(cerr);
68771438e86SJunchao Zhang   cerr = cudaStreamDestroy(link->remoteCommStream);CHKERRCUDA(cerr);
68871438e86SJunchao Zhang 
68971438e86SJunchao Zhang   /* nvshmem does not need buffers on host, which should be NULL */
69071438e86SJunchao Zhang   ierr = PetscNvshmemFree(link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
69171438e86SJunchao Zhang   ierr = PetscNvshmemFree(link->leafSendSig);CHKERRQ(ierr);
69271438e86SJunchao Zhang   ierr = PetscNvshmemFree(link->leafRecvSig);CHKERRQ(ierr);
69371438e86SJunchao Zhang   ierr = PetscNvshmemFree(link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
69471438e86SJunchao Zhang   ierr = PetscNvshmemFree(link->rootSendSig);CHKERRQ(ierr);
69571438e86SJunchao Zhang   ierr = PetscNvshmemFree(link->rootRecvSig);CHKERRQ(ierr);
69671438e86SJunchao Zhang   PetscFunctionReturn(0);
69771438e86SJunchao Zhang }
69871438e86SJunchao Zhang 
69971438e86SJunchao Zhang PetscErrorCode PetscSFLinkCreate_NVSHMEM(PetscSF sf,MPI_Datatype unit,PetscMemType rootmtype,const void *rootdata,PetscMemType leafmtype,const void *leafdata,MPI_Op op,PetscSFOperation sfop,PetscSFLink *mylink)
70071438e86SJunchao Zhang {
70171438e86SJunchao Zhang   PetscErrorCode    ierr;
70271438e86SJunchao Zhang   cudaError_t       cerr;
70371438e86SJunchao Zhang   PetscSF_Basic     *bas = (PetscSF_Basic*)sf->data;
70471438e86SJunchao Zhang   PetscSFLink       *p,link;
70571438e86SJunchao Zhang   PetscBool         match,rootdirect[2],leafdirect[2];
70671438e86SJunchao Zhang   int               greatestPriority;
70771438e86SJunchao Zhang 
70871438e86SJunchao Zhang   PetscFunctionBegin;
70971438e86SJunchao Zhang   /* Check to see if we can directly send/recv root/leafdata with the given sf, sfop and op.
71071438e86SJunchao Zhang      We only care root/leafdirect[PETSCSF_REMOTE], since we never need intermeidate buffers in local communication with NVSHMEM.
71171438e86SJunchao Zhang   */
71271438e86SJunchao Zhang   if (sfop == PETSCSF_BCAST) { /* Move data from rootbuf to leafbuf */
71371438e86SJunchao Zhang     if (sf->use_nvshmem_get) {
71471438e86SJunchao Zhang       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* send buffer has to be stand-alone (can't be rootdata) */
71571438e86SJunchao Zhang       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
71671438e86SJunchao Zhang     } else {
71771438e86SJunchao Zhang       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
71871438e86SJunchao Zhang       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;  /* Our put-protocol always needs a nvshmem alloc'ed recv buffer */
71971438e86SJunchao Zhang     }
72071438e86SJunchao Zhang   } else if (sfop == PETSCSF_REDUCE) { /* Move data from leafbuf to rootbuf */
72171438e86SJunchao Zhang     if (sf->use_nvshmem_get) {
72271438e86SJunchao Zhang       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
72371438e86SJunchao Zhang       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;
72471438e86SJunchao Zhang     } else {
72571438e86SJunchao Zhang       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE;
72671438e86SJunchao Zhang       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
72771438e86SJunchao Zhang     }
72871438e86SJunchao Zhang   } else { /* PETSCSF_FETCH */
72971438e86SJunchao Zhang     rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* FETCH always need a separate rootbuf */
73071438e86SJunchao Zhang     leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* We also force allocating a separate leafbuf so that leafdata and leafupdate can share mpi requests */
73171438e86SJunchao Zhang   }
73271438e86SJunchao Zhang 
73371438e86SJunchao Zhang   /* Look for free nvshmem links in cache */
73471438e86SJunchao Zhang   for (p=&bas->avail; (link=*p); p=&link->next) {
73571438e86SJunchao Zhang     if (link->use_nvshmem) {
73671438e86SJunchao Zhang       ierr = MPIPetsc_Type_compare(unit,link->unit,&match);CHKERRQ(ierr);
73771438e86SJunchao Zhang       if (match) {
73871438e86SJunchao Zhang         *p = link->next; /* Remove from available list */
73971438e86SJunchao Zhang         goto found;
74071438e86SJunchao Zhang       }
74171438e86SJunchao Zhang     }
74271438e86SJunchao Zhang   }
74371438e86SJunchao Zhang   ierr = PetscNew(&link);CHKERRQ(ierr);
74471438e86SJunchao Zhang   ierr = PetscSFLinkSetUp_Host(sf,link,unit);CHKERRQ(ierr); /* Compute link->unitbytes, dup link->unit etc. */
74571438e86SJunchao Zhang   if (sf->backend == PETSCSF_BACKEND_CUDA) {ierr = PetscSFLinkSetUp_CUDA(sf,link,unit);CHKERRQ(ierr);} /* Setup pack routines, streams etc */
74671438e86SJunchao Zhang  #if defined(PETSC_HAVE_KOKKOS)
74771438e86SJunchao Zhang   else if (sf->backend == PETSCSF_BACKEND_KOKKOS) {ierr = PetscSFLinkSetUp_Kokkos(sf,link,unit);CHKERRQ(ierr);}
74871438e86SJunchao Zhang  #endif
74971438e86SJunchao Zhang 
75071438e86SJunchao Zhang   link->rootdirect[PETSCSF_LOCAL]  = PETSC_TRUE; /* For the local part we directly use root/leafdata */
75171438e86SJunchao Zhang   link->leafdirect[PETSCSF_LOCAL]  = PETSC_TRUE;
75271438e86SJunchao Zhang 
75371438e86SJunchao Zhang   /* Init signals to zero */
75471438e86SJunchao Zhang   if (!link->rootSendSig) {ierr = PetscNvshmemCalloc(bas->nRemoteLeafRanksMax*sizeof(uint64_t),(void**)&link->rootSendSig);CHKERRQ(ierr);}
75571438e86SJunchao Zhang   if (!link->rootRecvSig) {ierr = PetscNvshmemCalloc(bas->nRemoteLeafRanksMax*sizeof(uint64_t),(void**)&link->rootRecvSig);CHKERRQ(ierr);}
75671438e86SJunchao Zhang   if (!link->leafSendSig) {ierr = PetscNvshmemCalloc(sf->nRemoteRootRanksMax*sizeof(uint64_t),(void**)&link->leafSendSig);CHKERRQ(ierr);}
75771438e86SJunchao Zhang   if (!link->leafRecvSig) {ierr = PetscNvshmemCalloc(sf->nRemoteRootRanksMax*sizeof(uint64_t),(void**)&link->leafRecvSig);CHKERRQ(ierr);}
75871438e86SJunchao Zhang 
75971438e86SJunchao Zhang   link->use_nvshmem                = PETSC_TRUE;
76071438e86SJunchao Zhang   link->rootmtype                  = PETSC_MEMTYPE_DEVICE; /* Only need 0/1-based mtype from now on */
76171438e86SJunchao Zhang   link->leafmtype                  = PETSC_MEMTYPE_DEVICE;
76271438e86SJunchao Zhang   /* Overwrite some function pointers set by PetscSFLinkSetUp_CUDA */
76371438e86SJunchao Zhang   link->Destroy                    = PetscSFLinkDestroy_NVSHMEM;
76471438e86SJunchao Zhang   if (sf->use_nvshmem_get) { /* get-based protocol */
76571438e86SJunchao Zhang     link->PrePack                  = PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM;
76671438e86SJunchao Zhang     link->StartCommunication       = PetscSFLinkGetDataBegin_NVSHMEM;
76771438e86SJunchao Zhang     link->FinishCommunication      = PetscSFLinkGetDataEnd_NVSHMEM;
76871438e86SJunchao Zhang   } else { /* put-based protocol */
76971438e86SJunchao Zhang     link->StartCommunication       = PetscSFLinkPutDataBegin_NVSHMEM;
77071438e86SJunchao Zhang     link->FinishCommunication      = PetscSFLinkPutDataEnd_NVSHMEM;
77171438e86SJunchao Zhang     link->PostUnpack               = PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM;
77271438e86SJunchao Zhang   }
77371438e86SJunchao Zhang 
77471438e86SJunchao Zhang   cerr = cudaDeviceGetStreamPriorityRange(NULL,&greatestPriority);CHKERRCUDA(cerr);
77571438e86SJunchao Zhang   cerr = cudaStreamCreateWithPriority(&link->remoteCommStream,cudaStreamNonBlocking,greatestPriority);CHKERRCUDA(cerr);
77671438e86SJunchao Zhang 
77771438e86SJunchao Zhang   cerr = cudaEventCreateWithFlags(&link->dataReady,cudaEventDisableTiming);CHKERRCUDA(cerr);
77871438e86SJunchao Zhang   cerr = cudaEventCreateWithFlags(&link->endRemoteComm,cudaEventDisableTiming);CHKERRCUDA(cerr);
77971438e86SJunchao Zhang 
78071438e86SJunchao Zhang found:
78171438e86SJunchao Zhang   if (rootdirect[PETSCSF_REMOTE]) {
78271438e86SJunchao Zhang     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char*)rootdata + bas->rootstart[PETSCSF_REMOTE]*link->unitbytes;
78371438e86SJunchao Zhang   } else {
78471438e86SJunchao Zhang     if (!link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) {
78571438e86SJunchao Zhang       ierr = PetscNvshmemMalloc(bas->rootbuflen_rmax*link->unitbytes,(void**)&link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
78671438e86SJunchao Zhang     }
78771438e86SJunchao Zhang     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
78871438e86SJunchao Zhang   }
78971438e86SJunchao Zhang 
79071438e86SJunchao Zhang   if (leafdirect[PETSCSF_REMOTE]) {
79171438e86SJunchao Zhang     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char*)leafdata + sf->leafstart[PETSCSF_REMOTE]*link->unitbytes;
79271438e86SJunchao Zhang   } else {
79371438e86SJunchao Zhang     if (!link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) {
79471438e86SJunchao Zhang       ierr = PetscNvshmemMalloc(sf->leafbuflen_rmax*link->unitbytes,(void**)&link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr);
79571438e86SJunchao Zhang     }
79671438e86SJunchao Zhang     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
79771438e86SJunchao Zhang   }
79871438e86SJunchao Zhang 
79971438e86SJunchao Zhang   link->rootdirect[PETSCSF_REMOTE] = rootdirect[PETSCSF_REMOTE];
80071438e86SJunchao Zhang   link->leafdirect[PETSCSF_REMOTE] = leafdirect[PETSCSF_REMOTE];
80171438e86SJunchao Zhang   link->rootdata                   = rootdata; /* root/leafdata are keys to look up links in PetscSFXxxEnd */
80271438e86SJunchao Zhang   link->leafdata                   = leafdata;
80371438e86SJunchao Zhang   link->next                       = bas->inuse;
80471438e86SJunchao Zhang   bas->inuse                       = link;
80571438e86SJunchao Zhang   *mylink                          = link;
80671438e86SJunchao Zhang   PetscFunctionReturn(0);
80771438e86SJunchao Zhang }
80871438e86SJunchao Zhang 
80971438e86SJunchao Zhang #if defined(PETSC_USE_REAL_SINGLE)
81071438e86SJunchao Zhang PetscErrorCode PetscNvshmemSum(PetscInt count,float *dst,const float *src)
81171438e86SJunchao Zhang {
81271438e86SJunchao Zhang   PetscErrorCode    ierr;
81371438e86SJunchao Zhang   PetscMPIInt       num; /* Assume nvshmem's int is MPI's int */
81471438e86SJunchao Zhang 
81571438e86SJunchao Zhang   PetscFunctionBegin;
81671438e86SJunchao Zhang   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
81771438e86SJunchao Zhang   nvshmemx_float_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
81871438e86SJunchao Zhang   PetscFunctionReturn(0);
81971438e86SJunchao Zhang }
82071438e86SJunchao Zhang 
82171438e86SJunchao Zhang PetscErrorCode PetscNvshmemMax(PetscInt count,float *dst,const float *src)
82271438e86SJunchao Zhang {
82371438e86SJunchao Zhang   PetscErrorCode    ierr;
82471438e86SJunchao Zhang   PetscMPIInt       num;
82571438e86SJunchao Zhang 
82671438e86SJunchao Zhang   PetscFunctionBegin;
82771438e86SJunchao Zhang   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
82871438e86SJunchao Zhang   nvshmemx_float_max_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
82971438e86SJunchao Zhang   PetscFunctionReturn(0);
83071438e86SJunchao Zhang }
83171438e86SJunchao Zhang #elif defined(PETSC_USE_REAL_DOUBLE)
83271438e86SJunchao Zhang PetscErrorCode PetscNvshmemSum(PetscInt count,double *dst,const double *src)
83371438e86SJunchao Zhang {
83471438e86SJunchao Zhang   PetscErrorCode    ierr;
83571438e86SJunchao Zhang   PetscMPIInt       num;
83671438e86SJunchao Zhang 
83771438e86SJunchao Zhang   PetscFunctionBegin;
83871438e86SJunchao Zhang   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
83971438e86SJunchao Zhang   nvshmemx_double_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
84071438e86SJunchao Zhang   PetscFunctionReturn(0);
84171438e86SJunchao Zhang }
84271438e86SJunchao Zhang 
84371438e86SJunchao Zhang PetscErrorCode PetscNvshmemMax(PetscInt count,double *dst,const double *src)
84471438e86SJunchao Zhang {
84571438e86SJunchao Zhang   PetscErrorCode    ierr;
84671438e86SJunchao Zhang   PetscMPIInt       num;
84771438e86SJunchao Zhang 
84871438e86SJunchao Zhang   PetscFunctionBegin;
84971438e86SJunchao Zhang   ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr);
85071438e86SJunchao Zhang   nvshmemx_double_max_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream);
85171438e86SJunchao Zhang   PetscFunctionReturn(0);
85271438e86SJunchao Zhang }
85371438e86SJunchao Zhang #endif
85471438e86SJunchao Zhang 
855