1*71438e86SJunchao Zhang #include <petsc/private/cudavecimpl.h> 2*71438e86SJunchao Zhang #include <petsccublas.h> 3*71438e86SJunchao Zhang #include <../src/vec/is/sf/impls/basic/sfpack.h> 4*71438e86SJunchao Zhang #include <mpi.h> 5*71438e86SJunchao Zhang #include <nvshmem.h> 6*71438e86SJunchao Zhang #include <nvshmemx.h> 7*71438e86SJunchao Zhang 8*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemInitializeCheck(void) 9*71438e86SJunchao Zhang { 10*71438e86SJunchao Zhang PetscErrorCode ierr; 11*71438e86SJunchao Zhang 12*71438e86SJunchao Zhang PetscFunctionBegin; 13*71438e86SJunchao Zhang if (!PetscNvshmemInitialized) { /* Note NVSHMEM does not provide a routine to check whether it is initialized */ 14*71438e86SJunchao Zhang nvshmemx_init_attr_t attr; 15*71438e86SJunchao Zhang attr.mpi_comm = &PETSC_COMM_WORLD; 16*71438e86SJunchao Zhang ierr = PetscCUDAInitializeCheck();CHKERRQ(ierr); 17*71438e86SJunchao Zhang ierr = nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM,&attr);CHKERRQ(ierr); 18*71438e86SJunchao Zhang PetscNvshmemInitialized = PETSC_TRUE; 19*71438e86SJunchao Zhang PetscBeganNvshmem = PETSC_TRUE; 20*71438e86SJunchao Zhang } 21*71438e86SJunchao Zhang PetscFunctionReturn(0); 22*71438e86SJunchao Zhang } 23*71438e86SJunchao Zhang 24*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemMalloc(size_t size, void** ptr) 25*71438e86SJunchao Zhang { 26*71438e86SJunchao Zhang PetscErrorCode ierr; 27*71438e86SJunchao Zhang 28*71438e86SJunchao Zhang PetscFunctionBegin; 29*71438e86SJunchao Zhang ierr = PetscNvshmemInitializeCheck();CHKERRQ(ierr); 30*71438e86SJunchao Zhang *ptr = nvshmem_malloc(size); 31*71438e86SJunchao Zhang if (!*ptr) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"nvshmem_malloc() failed to allocate %zu bytes",size); 32*71438e86SJunchao Zhang PetscFunctionReturn(0); 33*71438e86SJunchao Zhang } 34*71438e86SJunchao Zhang 35*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemCalloc(size_t size, void**ptr) 36*71438e86SJunchao Zhang { 37*71438e86SJunchao Zhang PetscErrorCode ierr; 38*71438e86SJunchao Zhang 39*71438e86SJunchao Zhang PetscFunctionBegin; 40*71438e86SJunchao Zhang ierr = PetscNvshmemInitializeCheck();CHKERRQ(ierr); 41*71438e86SJunchao Zhang *ptr = nvshmem_calloc(size,1); 42*71438e86SJunchao Zhang if (!*ptr) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"nvshmem_calloc() failed to allocate %zu bytes",size); 43*71438e86SJunchao Zhang PetscFunctionReturn(0); 44*71438e86SJunchao Zhang } 45*71438e86SJunchao Zhang 46*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemFree_Private(void* ptr) 47*71438e86SJunchao Zhang { 48*71438e86SJunchao Zhang PetscFunctionBegin; 49*71438e86SJunchao Zhang nvshmem_free(ptr); 50*71438e86SJunchao Zhang PetscFunctionReturn(0); 51*71438e86SJunchao Zhang } 52*71438e86SJunchao Zhang 53*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemFinalize(void) 54*71438e86SJunchao Zhang { 55*71438e86SJunchao Zhang PetscFunctionBegin; 56*71438e86SJunchao Zhang nvshmem_finalize(); 57*71438e86SJunchao Zhang PetscFunctionReturn(0); 58*71438e86SJunchao Zhang } 59*71438e86SJunchao Zhang 60*71438e86SJunchao Zhang /* Free nvshmem related fields in the SF */ 61*71438e86SJunchao Zhang PetscErrorCode PetscSFReset_Basic_NVSHMEM(PetscSF sf) 62*71438e86SJunchao Zhang { 63*71438e86SJunchao Zhang PetscErrorCode ierr; 64*71438e86SJunchao Zhang PetscSF_Basic *bas = (PetscSF_Basic*)sf->data; 65*71438e86SJunchao Zhang 66*71438e86SJunchao Zhang PetscFunctionBegin; 67*71438e86SJunchao Zhang ierr = PetscFree2(bas->leafsigdisp,bas->leafbufdisp);CHKERRQ(ierr); 68*71438e86SJunchao Zhang ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->leafbufdisp_d);CHKERRQ(ierr); 69*71438e86SJunchao Zhang ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->leafsigdisp_d);CHKERRQ(ierr); 70*71438e86SJunchao Zhang ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->iranks_d);CHKERRQ(ierr); 71*71438e86SJunchao Zhang ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,bas->ioffset_d);CHKERRQ(ierr); 72*71438e86SJunchao Zhang 73*71438e86SJunchao Zhang ierr = PetscFree2(sf->rootsigdisp,sf->rootbufdisp);CHKERRQ(ierr); 74*71438e86SJunchao Zhang ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->rootbufdisp_d);CHKERRQ(ierr); 75*71438e86SJunchao Zhang ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->rootsigdisp_d);CHKERRQ(ierr); 76*71438e86SJunchao Zhang ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->ranks_d);CHKERRQ(ierr); 77*71438e86SJunchao Zhang ierr = PetscSFFree(sf,PETSC_MEMTYPE_CUDA,sf->roffset_d);CHKERRQ(ierr); 78*71438e86SJunchao Zhang PetscFunctionReturn(0); 79*71438e86SJunchao Zhang } 80*71438e86SJunchao Zhang 81*71438e86SJunchao Zhang /* Set up NVSHMEM related fields for an SF of type SFBASIC (only after PetscSFSetup_Basic() already set up dependant fields */ 82*71438e86SJunchao Zhang static PetscErrorCode PetscSFSetUp_Basic_NVSHMEM(PetscSF sf) 83*71438e86SJunchao Zhang { 84*71438e86SJunchao Zhang PetscErrorCode ierr; 85*71438e86SJunchao Zhang cudaError_t cerr; 86*71438e86SJunchao Zhang PetscSF_Basic *bas = (PetscSF_Basic*)sf->data; 87*71438e86SJunchao Zhang PetscInt i,nRemoteRootRanks,nRemoteLeafRanks; 88*71438e86SJunchao Zhang PetscMPIInt tag; 89*71438e86SJunchao Zhang MPI_Comm comm; 90*71438e86SJunchao Zhang MPI_Request *rootreqs,*leafreqs; 91*71438e86SJunchao Zhang PetscInt tmp,stmp[4],rtmp[4]; /* tmps for send/recv buffers */ 92*71438e86SJunchao Zhang 93*71438e86SJunchao Zhang PetscFunctionBegin; 94*71438e86SJunchao Zhang ierr = PetscObjectGetComm((PetscObject)sf,&comm);CHKERRQ(ierr); 95*71438e86SJunchao Zhang ierr = PetscObjectGetNewTag((PetscObject)sf,&tag);CHKERRQ(ierr); 96*71438e86SJunchao Zhang 97*71438e86SJunchao Zhang nRemoteRootRanks = sf->nranks-sf->ndranks; 98*71438e86SJunchao Zhang nRemoteLeafRanks = bas->niranks-bas->ndiranks; 99*71438e86SJunchao Zhang sf->nRemoteRootRanks = nRemoteRootRanks; 100*71438e86SJunchao Zhang bas->nRemoteLeafRanks = nRemoteLeafRanks; 101*71438e86SJunchao Zhang 102*71438e86SJunchao Zhang ierr = PetscMalloc2(nRemoteLeafRanks,&rootreqs,nRemoteRootRanks,&leafreqs);CHKERRQ(ierr); 103*71438e86SJunchao Zhang 104*71438e86SJunchao Zhang stmp[0] = nRemoteRootRanks; 105*71438e86SJunchao Zhang stmp[1] = sf->leafbuflen[PETSCSF_REMOTE]; 106*71438e86SJunchao Zhang stmp[2] = nRemoteLeafRanks; 107*71438e86SJunchao Zhang stmp[3] = bas->rootbuflen[PETSCSF_REMOTE]; 108*71438e86SJunchao Zhang 109*71438e86SJunchao Zhang ierr = MPIU_Allreduce(stmp,rtmp,4,MPIU_INT,MPI_MAX,comm);CHKERRMPI(ierr); 110*71438e86SJunchao Zhang 111*71438e86SJunchao Zhang sf->nRemoteRootRanksMax = rtmp[0]; 112*71438e86SJunchao Zhang sf->leafbuflen_rmax = rtmp[1]; 113*71438e86SJunchao Zhang bas->nRemoteLeafRanksMax = rtmp[2]; 114*71438e86SJunchao Zhang bas->rootbuflen_rmax = rtmp[3]; 115*71438e86SJunchao Zhang 116*71438e86SJunchao Zhang /* Total four rounds of MPI communications to set up the nvshmem fields */ 117*71438e86SJunchao Zhang 118*71438e86SJunchao Zhang /* Root ranks to leaf ranks: send info about rootsigdisp[] and rootbufdisp[] */ 119*71438e86SJunchao Zhang ierr = PetscMalloc2(nRemoteRootRanks,&sf->rootsigdisp,nRemoteRootRanks,&sf->rootbufdisp);CHKERRQ(ierr); 120*71438e86SJunchao Zhang for (i=0; i<nRemoteRootRanks; i++) {ierr = MPI_Irecv(&sf->rootsigdisp[i],1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm,&leafreqs[i]);CHKERRMPI(ierr);} /* Leaves recv */ 121*71438e86SJunchao Zhang for (i=0; i<nRemoteLeafRanks; i++) {ierr = MPI_Send(&i,1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm);CHKERRMPI(ierr);} /* Roots send. Note i changes, so we use MPI_Send. */ 122*71438e86SJunchao Zhang ierr = MPI_Waitall(nRemoteRootRanks,leafreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr); 123*71438e86SJunchao Zhang 124*71438e86SJunchao Zhang for (i=0; i<nRemoteRootRanks; i++) {ierr = MPI_Irecv(&sf->rootbufdisp[i],1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm,&leafreqs[i]);CHKERRMPI(ierr);} /* Leaves recv */ 125*71438e86SJunchao Zhang for (i=0; i<nRemoteLeafRanks; i++) { 126*71438e86SJunchao Zhang tmp = bas->ioffset[i+bas->ndiranks] - bas->ioffset[bas->ndiranks]; 127*71438e86SJunchao Zhang ierr = MPI_Send(&tmp,1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm);CHKERRMPI(ierr); /* Roots send. Note tmp changes, so we use MPI_Send. */ 128*71438e86SJunchao Zhang } 129*71438e86SJunchao Zhang ierr = MPI_Waitall(nRemoteRootRanks,leafreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr); 130*71438e86SJunchao Zhang 131*71438e86SJunchao Zhang cerr = cudaMalloc((void**)&sf->rootbufdisp_d,nRemoteRootRanks*sizeof(PetscInt));CHKERRCUDA(cerr); 132*71438e86SJunchao Zhang cerr = cudaMalloc((void**)&sf->rootsigdisp_d,nRemoteRootRanks*sizeof(PetscInt));CHKERRCUDA(cerr); 133*71438e86SJunchao Zhang cerr = cudaMalloc((void**)&sf->ranks_d,nRemoteRootRanks*sizeof(PetscMPIInt));CHKERRCUDA(cerr); 134*71438e86SJunchao Zhang cerr = cudaMalloc((void**)&sf->roffset_d,(nRemoteRootRanks+1)*sizeof(PetscInt));CHKERRCUDA(cerr); 135*71438e86SJunchao Zhang 136*71438e86SJunchao Zhang cerr = cudaMemcpyAsync(sf->rootbufdisp_d,sf->rootbufdisp,nRemoteRootRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr); 137*71438e86SJunchao Zhang cerr = cudaMemcpyAsync(sf->rootsigdisp_d,sf->rootsigdisp,nRemoteRootRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr); 138*71438e86SJunchao Zhang cerr = cudaMemcpyAsync(sf->ranks_d,sf->ranks+sf->ndranks,nRemoteRootRanks*sizeof(PetscMPIInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr); 139*71438e86SJunchao Zhang cerr = cudaMemcpyAsync(sf->roffset_d,sf->roffset+sf->ndranks,(nRemoteRootRanks+1)*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr); 140*71438e86SJunchao Zhang 141*71438e86SJunchao Zhang /* Leaf ranks to root ranks: send info about leafsigdisp[] and leafbufdisp[] */ 142*71438e86SJunchao Zhang ierr = PetscMalloc2(nRemoteLeafRanks,&bas->leafsigdisp,nRemoteLeafRanks,&bas->leafbufdisp);CHKERRQ(ierr); 143*71438e86SJunchao Zhang for (i=0; i<nRemoteLeafRanks; i++) {ierr = MPI_Irecv(&bas->leafsigdisp[i],1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm,&rootreqs[i]);CHKERRMPI(ierr);} 144*71438e86SJunchao Zhang for (i=0; i<nRemoteRootRanks; i++) {ierr = MPI_Send(&i,1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm);CHKERRMPI(ierr);} 145*71438e86SJunchao Zhang ierr = MPI_Waitall(nRemoteLeafRanks,rootreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr); 146*71438e86SJunchao Zhang 147*71438e86SJunchao Zhang for (i=0; i<nRemoteLeafRanks; i++) {ierr = MPI_Irecv(&bas->leafbufdisp[i],1,MPIU_INT,bas->iranks[i+bas->ndiranks],tag,comm,&rootreqs[i]);CHKERRMPI(ierr);} 148*71438e86SJunchao Zhang for (i=0; i<nRemoteRootRanks; i++) { 149*71438e86SJunchao Zhang tmp = sf->roffset[i+sf->ndranks] - sf->roffset[sf->ndranks]; 150*71438e86SJunchao Zhang ierr = MPI_Send(&tmp,1,MPIU_INT,sf->ranks[i+sf->ndranks],tag,comm);CHKERRMPI(ierr); 151*71438e86SJunchao Zhang } 152*71438e86SJunchao Zhang ierr = MPI_Waitall(nRemoteLeafRanks,rootreqs,MPI_STATUSES_IGNORE);CHKERRMPI(ierr); 153*71438e86SJunchao Zhang 154*71438e86SJunchao Zhang cerr = cudaMalloc((void**)&bas->leafbufdisp_d,nRemoteLeafRanks*sizeof(PetscInt));CHKERRCUDA(cerr); 155*71438e86SJunchao Zhang cerr = cudaMalloc((void**)&bas->leafsigdisp_d,nRemoteLeafRanks*sizeof(PetscInt));CHKERRCUDA(cerr); 156*71438e86SJunchao Zhang cerr = cudaMalloc((void**)&bas->iranks_d,nRemoteLeafRanks*sizeof(PetscMPIInt));CHKERRCUDA(cerr); 157*71438e86SJunchao Zhang cerr = cudaMalloc((void**)&bas->ioffset_d,(nRemoteLeafRanks+1)*sizeof(PetscInt));CHKERRCUDA(cerr); 158*71438e86SJunchao Zhang 159*71438e86SJunchao Zhang cerr = cudaMemcpyAsync(bas->leafbufdisp_d,bas->leafbufdisp,nRemoteLeafRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr); 160*71438e86SJunchao Zhang cerr = cudaMemcpyAsync(bas->leafsigdisp_d,bas->leafsigdisp,nRemoteLeafRanks*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr); 161*71438e86SJunchao Zhang cerr = cudaMemcpyAsync(bas->iranks_d,bas->iranks+bas->ndiranks,nRemoteLeafRanks*sizeof(PetscMPIInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr); 162*71438e86SJunchao Zhang cerr = cudaMemcpyAsync(bas->ioffset_d,bas->ioffset+bas->ndiranks,(nRemoteLeafRanks+1)*sizeof(PetscInt),cudaMemcpyHostToDevice,PetscDefaultCudaStream);CHKERRCUDA(cerr); 163*71438e86SJunchao Zhang 164*71438e86SJunchao Zhang ierr = PetscFree2(rootreqs,leafreqs);CHKERRQ(ierr); 165*71438e86SJunchao Zhang PetscFunctionReturn(0); 166*71438e86SJunchao Zhang } 167*71438e86SJunchao Zhang 168*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkNvshmemCheck(PetscSF sf,PetscMemType rootmtype,const void *rootdata,PetscMemType leafmtype,const void *leafdata,PetscBool *use_nvshmem) 169*71438e86SJunchao Zhang { 170*71438e86SJunchao Zhang PetscErrorCode ierr; 171*71438e86SJunchao Zhang MPI_Comm comm; 172*71438e86SJunchao Zhang PetscBool isBasic; 173*71438e86SJunchao Zhang PetscMPIInt result = MPI_UNEQUAL; 174*71438e86SJunchao Zhang 175*71438e86SJunchao Zhang PetscFunctionBegin; 176*71438e86SJunchao Zhang ierr = PetscObjectGetComm((PetscObject)sf,&comm);CHKERRQ(ierr); 177*71438e86SJunchao Zhang /* Check if the sf is eligible for NVSHMEM, if we have not checked yet. 178*71438e86SJunchao Zhang Note the check result <use_nvshmem> must be the same over comm, since an SFLink must be collectively either NVSHMEM or MPI. 179*71438e86SJunchao Zhang */ 180*71438e86SJunchao Zhang sf->checked_nvshmem_eligibility = PETSC_TRUE; 181*71438e86SJunchao Zhang if (sf->use_nvshmem && !sf->checked_nvshmem_eligibility) { 182*71438e86SJunchao Zhang /* Only use NVSHMEM for SFBASIC on PETSC_COMM_WORLD */ 183*71438e86SJunchao Zhang ierr = PetscObjectTypeCompare((PetscObject)sf,PETSCSFBASIC,&isBasic);CHKERRQ(ierr); 184*71438e86SJunchao Zhang if (isBasic) {ierr = MPI_Comm_compare(PETSC_COMM_WORLD,comm,&result);CHKERRMPI(ierr);} 185*71438e86SJunchao Zhang if (!isBasic || (result != MPI_IDENT && result != MPI_CONGRUENT)) sf->use_nvshmem = PETSC_FALSE; /* If not eligible, clear the flag so that we don't try again */ 186*71438e86SJunchao Zhang 187*71438e86SJunchao Zhang /* Do further check: If on a rank, both rootdata and leafdata are NULL, we might think they are PETSC_MEMTYPE_CUDA (or HOST) 188*71438e86SJunchao Zhang and then use NVSHMEM. But if root/leafmtypes on other ranks are PETSC_MEMTYPE_HOST (or DEVICE), this would lead to 189*71438e86SJunchao Zhang inconsistency on the return value <use_nvshmem>. To be safe, we simply disable nvshmem on these rare SFs. 190*71438e86SJunchao Zhang */ 191*71438e86SJunchao Zhang if (sf->use_nvshmem) { 192*71438e86SJunchao Zhang PetscInt hasNullRank = (!rootdata && !leafdata) ? 1 : 0; 193*71438e86SJunchao Zhang ierr = MPI_Allreduce(MPI_IN_PLACE,&hasNullRank,1,MPIU_INT,MPI_LOR,comm);CHKERRMPI(ierr); 194*71438e86SJunchao Zhang if (hasNullRank) sf->use_nvshmem = PETSC_FALSE; 195*71438e86SJunchao Zhang } 196*71438e86SJunchao Zhang sf->checked_nvshmem_eligibility = PETSC_TRUE; /* If eligible, don't do above check again */ 197*71438e86SJunchao Zhang } 198*71438e86SJunchao Zhang 199*71438e86SJunchao Zhang /* Check if rootmtype and leafmtype collectively are PETSC_MEMTYPE_CUDA */ 200*71438e86SJunchao Zhang if (sf->use_nvshmem) { 201*71438e86SJunchao Zhang PetscInt oneCuda = (!rootdata || PetscMemTypeCUDA(rootmtype)) && (!leafdata || PetscMemTypeCUDA(leafmtype)) ? 1 : 0; /* Do I use cuda for both root&leafmtype? */ 202*71438e86SJunchao Zhang PetscInt allCuda = oneCuda; /* Assume the same for all ranks. But if not, in opt mode, return value <use_nvshmem> won't be collective! */ 203*71438e86SJunchao Zhang #if defined(PETSC_USE_DEBUG) /* Check in debug mode. Note MPI_Allreduce is expensive, so only in debug mode */ 204*71438e86SJunchao Zhang ierr = MPI_Allreduce(&oneCuda,&allCuda,1,MPIU_INT,MPI_LAND,comm);CHKERRMPI(ierr); 205*71438e86SJunchao Zhang if (allCuda != oneCuda) SETERRQ(comm,PETSC_ERR_SUP,"root/leaf mtypes are inconsistent among ranks, which may lead to SF nvshmem failure in opt mode. Add -use_nvshmem 0 to disable it."); 206*71438e86SJunchao Zhang #endif 207*71438e86SJunchao Zhang if (allCuda) { 208*71438e86SJunchao Zhang ierr = PetscNvshmemInitializeCheck();CHKERRQ(ierr); 209*71438e86SJunchao Zhang if (!sf->setup_nvshmem) { /* Set up nvshmem related fields on this SF on-demand */ 210*71438e86SJunchao Zhang ierr = PetscSFSetUp_Basic_NVSHMEM(sf);CHKERRQ(ierr); 211*71438e86SJunchao Zhang sf->setup_nvshmem = PETSC_TRUE; 212*71438e86SJunchao Zhang } 213*71438e86SJunchao Zhang *use_nvshmem = PETSC_TRUE; 214*71438e86SJunchao Zhang } else { 215*71438e86SJunchao Zhang *use_nvshmem = PETSC_FALSE; 216*71438e86SJunchao Zhang } 217*71438e86SJunchao Zhang } else { 218*71438e86SJunchao Zhang *use_nvshmem = PETSC_FALSE; 219*71438e86SJunchao Zhang } 220*71438e86SJunchao Zhang PetscFunctionReturn(0); 221*71438e86SJunchao Zhang } 222*71438e86SJunchao Zhang 223*71438e86SJunchao Zhang /* Build dependence between <stream> and <remoteCommStream> at the entry of NVSHMEM communication */ 224*71438e86SJunchao Zhang static PetscErrorCode PetscSFLinkBuildDependenceBegin(PetscSF sf,PetscSFLink link,PetscSFDirection direction) 225*71438e86SJunchao Zhang { 226*71438e86SJunchao Zhang cudaError_t cerr; 227*71438e86SJunchao Zhang PetscSF_Basic *bas = (PetscSF_Basic *)sf->data; 228*71438e86SJunchao Zhang PetscInt buflen = (direction == PETSCSF_ROOT2LEAF)? bas->rootbuflen[PETSCSF_REMOTE] : sf->leafbuflen[PETSCSF_REMOTE]; 229*71438e86SJunchao Zhang 230*71438e86SJunchao Zhang PetscFunctionBegin; 231*71438e86SJunchao Zhang if (buflen) { 232*71438e86SJunchao Zhang cerr = cudaEventRecord(link->dataReady,link->stream);CHKERRCUDA(cerr); 233*71438e86SJunchao Zhang cerr = cudaStreamWaitEvent(link->remoteCommStream,link->dataReady,0);CHKERRCUDA(cerr); 234*71438e86SJunchao Zhang } 235*71438e86SJunchao Zhang PetscFunctionReturn(0); 236*71438e86SJunchao Zhang } 237*71438e86SJunchao Zhang 238*71438e86SJunchao Zhang /* Build dependence between <stream> and <remoteCommStream> at the exit of NVSHMEM communication */ 239*71438e86SJunchao Zhang static PetscErrorCode PetscSFLinkBuildDependenceEnd(PetscSF sf,PetscSFLink link,PetscSFDirection direction) 240*71438e86SJunchao Zhang { 241*71438e86SJunchao Zhang cudaError_t cerr; 242*71438e86SJunchao Zhang PetscSF_Basic *bas = (PetscSF_Basic *)sf->data; 243*71438e86SJunchao Zhang PetscInt buflen = (direction == PETSCSF_ROOT2LEAF)? sf->leafbuflen[PETSCSF_REMOTE] : bas->rootbuflen[PETSCSF_REMOTE]; 244*71438e86SJunchao Zhang 245*71438e86SJunchao Zhang PetscFunctionBegin; 246*71438e86SJunchao Zhang /* If unpack to non-null device buffer, build the endRemoteComm dependance */ 247*71438e86SJunchao Zhang if (buflen) { 248*71438e86SJunchao Zhang cerr = cudaEventRecord(link->endRemoteComm,link->remoteCommStream);CHKERRCUDA(cerr); 249*71438e86SJunchao Zhang cerr = cudaStreamWaitEvent(link->stream,link->endRemoteComm,0);CHKERRCUDA(cerr); 250*71438e86SJunchao Zhang } 251*71438e86SJunchao Zhang PetscFunctionReturn(0); 252*71438e86SJunchao Zhang } 253*71438e86SJunchao Zhang 254*71438e86SJunchao Zhang /* Send/Put signals to remote ranks 255*71438e86SJunchao Zhang 256*71438e86SJunchao Zhang Input parameters: 257*71438e86SJunchao Zhang + n - Number of remote ranks 258*71438e86SJunchao Zhang . sig - Signal address in symmetric heap 259*71438e86SJunchao Zhang . sigdisp - To i-th rank, use its signal at offset sigdisp[i] 260*71438e86SJunchao Zhang . ranks - remote ranks 261*71438e86SJunchao Zhang - newval - Set signals to this value 262*71438e86SJunchao Zhang */ 263*71438e86SJunchao Zhang __global__ static void NvshmemSendSignals(PetscInt n,uint64_t *sig,PetscInt *sigdisp,PetscMPIInt *ranks,uint64_t newval) 264*71438e86SJunchao Zhang { 265*71438e86SJunchao Zhang int i = blockIdx.x*blockDim.x + threadIdx.x; 266*71438e86SJunchao Zhang 267*71438e86SJunchao Zhang /* Each thread puts one remote signal */ 268*71438e86SJunchao Zhang if (i < n) nvshmemx_uint64_signal(sig+sigdisp[i],newval,ranks[i]); 269*71438e86SJunchao Zhang } 270*71438e86SJunchao Zhang 271*71438e86SJunchao Zhang /* Wait until local signals equal to the expected value and then set them to a new value 272*71438e86SJunchao Zhang 273*71438e86SJunchao Zhang Input parameters: 274*71438e86SJunchao Zhang + n - Number of signals 275*71438e86SJunchao Zhang . sig - Local signal address 276*71438e86SJunchao Zhang . expval - expected value 277*71438e86SJunchao Zhang - newval - Set signals to this new value 278*71438e86SJunchao Zhang */ 279*71438e86SJunchao Zhang __global__ static void NvshmemWaitSignals(PetscInt n,uint64_t *sig,uint64_t expval,uint64_t newval) 280*71438e86SJunchao Zhang { 281*71438e86SJunchao Zhang #if 0 282*71438e86SJunchao Zhang /* Akhil Langer@NVIDIA said using 1 thread and nvshmem_uint64_wait_until_all is better */ 283*71438e86SJunchao Zhang int i = blockIdx.x*blockDim.x + threadIdx.x; 284*71438e86SJunchao Zhang if (i < n) { 285*71438e86SJunchao Zhang nvshmem_signal_wait_until(sig+i,NVSHMEM_CMP_EQ,expval); 286*71438e86SJunchao Zhang sig[i] = newval; 287*71438e86SJunchao Zhang } 288*71438e86SJunchao Zhang #else 289*71438e86SJunchao Zhang nvshmem_uint64_wait_until_all(sig,n,NULL/*no mask*/,NVSHMEM_CMP_EQ,expval); 290*71438e86SJunchao Zhang for (int i=0; i<n; i++) sig[i] = newval; 291*71438e86SJunchao Zhang #endif 292*71438e86SJunchao Zhang } 293*71438e86SJunchao Zhang 294*71438e86SJunchao Zhang /* =========================================================================================================== 295*71438e86SJunchao Zhang 296*71438e86SJunchao Zhang A set of routines to support receiver initiated communication using the get method 297*71438e86SJunchao Zhang 298*71438e86SJunchao Zhang The getting protocol is: 299*71438e86SJunchao Zhang 300*71438e86SJunchao Zhang Sender has a send buf (sbuf) and a signal variable (ssig); Receiver has a recv buf (rbuf) and a signal variable (rsig); 301*71438e86SJunchao Zhang All signal variables have an initial value 0. 302*71438e86SJunchao Zhang 303*71438e86SJunchao Zhang Sender: | Receiver: 304*71438e86SJunchao Zhang 1. Wait ssig be 0, then set it to 1 305*71438e86SJunchao Zhang 2. Pack data into stand alone sbuf | 306*71438e86SJunchao Zhang 3. Put 1 to receiver's rsig | 1. Wait rsig to be 1, then set it 0 307*71438e86SJunchao Zhang | 2. Get data from remote sbuf to local rbuf 308*71438e86SJunchao Zhang | 3. Put 1 to sender's ssig 309*71438e86SJunchao Zhang | 4. Unpack data from local rbuf 310*71438e86SJunchao Zhang ===========================================================================================================*/ 311*71438e86SJunchao Zhang /* PrePack operation -- since sender will overwrite the send buffer which the receiver might be getting data from. 312*71438e86SJunchao Zhang Sender waits for signals (from receivers) indicating receivers have finished getting data 313*71438e86SJunchao Zhang */ 314*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction) 315*71438e86SJunchao Zhang { 316*71438e86SJunchao Zhang PetscSF_Basic *bas = (PetscSF_Basic*)sf->data; 317*71438e86SJunchao Zhang uint64_t *sig; 318*71438e86SJunchao Zhang PetscInt n; 319*71438e86SJunchao Zhang 320*71438e86SJunchao Zhang PetscFunctionBegin; 321*71438e86SJunchao Zhang if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */ 322*71438e86SJunchao Zhang sig = link->rootSendSig; /* leaf ranks set my rootSendsig */ 323*71438e86SJunchao Zhang n = bas->nRemoteLeafRanks; 324*71438e86SJunchao Zhang } else { /* LEAF2ROOT */ 325*71438e86SJunchao Zhang sig = link->leafSendSig; 326*71438e86SJunchao Zhang n = sf->nRemoteRootRanks; 327*71438e86SJunchao Zhang } 328*71438e86SJunchao Zhang 329*71438e86SJunchao Zhang if (n) { 330*71438e86SJunchao Zhang NvshmemWaitSignals<<<1,1,0,link->remoteCommStream>>>(n,sig,0,1); /* wait the signals to be 0, then set them to 1 */ 331*71438e86SJunchao Zhang cudaError_t cerr = cudaGetLastError();CHKERRCUDA(cerr); 332*71438e86SJunchao Zhang } 333*71438e86SJunchao Zhang PetscFunctionReturn(0); 334*71438e86SJunchao Zhang } 335*71438e86SJunchao Zhang 336*71438e86SJunchao Zhang /* n thread blocks. Each takes in charge one remote rank */ 337*71438e86SJunchao Zhang __global__ static void GetDataFromRemotelyAccessible(PetscInt nsrcranks,PetscMPIInt *srcranks,const char *src,PetscInt *srcdisp,char *dst,PetscInt *dstdisp,PetscInt unitbytes) 338*71438e86SJunchao Zhang { 339*71438e86SJunchao Zhang int bid = blockIdx.x; 340*71438e86SJunchao Zhang PetscMPIInt pe = srcranks[bid]; 341*71438e86SJunchao Zhang 342*71438e86SJunchao Zhang if (!nvshmem_ptr(src,pe)) { 343*71438e86SJunchao Zhang PetscInt nelems = (dstdisp[bid+1]-dstdisp[bid])*unitbytes; 344*71438e86SJunchao Zhang nvshmem_getmem_nbi(dst+(dstdisp[bid]-dstdisp[0])*unitbytes,src+srcdisp[bid]*unitbytes,nelems,pe); 345*71438e86SJunchao Zhang } 346*71438e86SJunchao Zhang } 347*71438e86SJunchao Zhang 348*71438e86SJunchao Zhang /* Start communication -- Get data in the given direction */ 349*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkGetDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction) 350*71438e86SJunchao Zhang { 351*71438e86SJunchao Zhang PetscErrorCode ierr; 352*71438e86SJunchao Zhang cudaError_t cerr; 353*71438e86SJunchao Zhang PetscSF_Basic *bas = (PetscSF_Basic*)sf->data; 354*71438e86SJunchao Zhang 355*71438e86SJunchao Zhang PetscInt nsrcranks,ndstranks,nLocallyAccessible = 0; 356*71438e86SJunchao Zhang 357*71438e86SJunchao Zhang char *src,*dst; 358*71438e86SJunchao Zhang PetscInt *srcdisp_h,*dstdisp_h; 359*71438e86SJunchao Zhang PetscInt *srcdisp_d,*dstdisp_d; 360*71438e86SJunchao Zhang PetscMPIInt *srcranks_h; 361*71438e86SJunchao Zhang PetscMPIInt *srcranks_d,*dstranks_d; 362*71438e86SJunchao Zhang uint64_t *dstsig; 363*71438e86SJunchao Zhang PetscInt *dstsigdisp_d; 364*71438e86SJunchao Zhang 365*71438e86SJunchao Zhang PetscFunctionBegin; 366*71438e86SJunchao Zhang ierr = PetscSFLinkBuildDependenceBegin(sf,link,direction);CHKERRQ(ierr); 367*71438e86SJunchao Zhang if (direction == PETSCSF_ROOT2LEAF) { /* src is root, dst is leaf; we will move data from src to dst */ 368*71438e86SJunchao Zhang nsrcranks = sf->nRemoteRootRanks; 369*71438e86SJunchao Zhang src = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* root buf is the send buf; it is in symmetric heap */ 370*71438e86SJunchao Zhang 371*71438e86SJunchao Zhang srcdisp_h = sf->rootbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */ 372*71438e86SJunchao Zhang srcdisp_d = sf->rootbufdisp_d; 373*71438e86SJunchao Zhang srcranks_h = sf->ranks+sf->ndranks; /* my (remote) root ranks */ 374*71438e86SJunchao Zhang srcranks_d = sf->ranks_d; 375*71438e86SJunchao Zhang 376*71438e86SJunchao Zhang ndstranks = bas->nRemoteLeafRanks; 377*71438e86SJunchao Zhang dst = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* recv buf is the local leaf buf, also in symmetric heap */ 378*71438e86SJunchao Zhang 379*71438e86SJunchao Zhang dstdisp_h = sf->roffset+sf->ndranks; /* offsets of the local leaf buf. Note dstdisp[0] is not necessarily 0 */ 380*71438e86SJunchao Zhang dstdisp_d = sf->roffset_d; 381*71438e86SJunchao Zhang dstranks_d = bas->iranks_d; /* my (remote) leaf ranks */ 382*71438e86SJunchao Zhang 383*71438e86SJunchao Zhang dstsig = link->leafRecvSig; 384*71438e86SJunchao Zhang dstsigdisp_d = bas->leafsigdisp_d; 385*71438e86SJunchao Zhang } else { /* src is leaf, dst is root; we will move data from src to dst */ 386*71438e86SJunchao Zhang nsrcranks = bas->nRemoteLeafRanks; 387*71438e86SJunchao Zhang src = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* leaf buf is the send buf */ 388*71438e86SJunchao Zhang 389*71438e86SJunchao Zhang srcdisp_h = bas->leafbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */ 390*71438e86SJunchao Zhang srcdisp_d = bas->leafbufdisp_d; 391*71438e86SJunchao Zhang srcranks_h = bas->iranks+bas->ndiranks; /* my (remote) root ranks */ 392*71438e86SJunchao Zhang srcranks_d = bas->iranks_d; 393*71438e86SJunchao Zhang 394*71438e86SJunchao Zhang ndstranks = sf->nRemoteRootRanks; 395*71438e86SJunchao Zhang dst = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* the local root buf is the recv buf */ 396*71438e86SJunchao Zhang 397*71438e86SJunchao Zhang dstdisp_h = bas->ioffset+bas->ndiranks; /* offsets of the local root buf. Note dstdisp[0] is not necessarily 0 */ 398*71438e86SJunchao Zhang dstdisp_d = bas->ioffset_d; 399*71438e86SJunchao Zhang dstranks_d = sf->ranks_d; /* my (remote) root ranks */ 400*71438e86SJunchao Zhang 401*71438e86SJunchao Zhang dstsig = link->rootRecvSig; 402*71438e86SJunchao Zhang dstsigdisp_d = sf->rootsigdisp_d; 403*71438e86SJunchao Zhang } 404*71438e86SJunchao Zhang 405*71438e86SJunchao Zhang /* After Pack operation -- src tells dst ranks that they are allowed to get data */ 406*71438e86SJunchao Zhang if (ndstranks) { 407*71438e86SJunchao Zhang NvshmemSendSignals<<<(ndstranks+255)/256,256,0,link->remoteCommStream>>>(ndstranks,dstsig,dstsigdisp_d,dstranks_d,1); /* set signals to 1 */ 408*71438e86SJunchao Zhang cerr = cudaGetLastError();CHKERRCUDA(cerr); 409*71438e86SJunchao Zhang } 410*71438e86SJunchao Zhang 411*71438e86SJunchao Zhang /* dst waits for signals (permissions) from src ranks to start getting data */ 412*71438e86SJunchao Zhang if (nsrcranks) { 413*71438e86SJunchao Zhang NvshmemWaitSignals<<<1,1,0,link->remoteCommStream>>>(nsrcranks,dstsig,1,0); /* wait the signals to be 1, then set them to 0 */ 414*71438e86SJunchao Zhang cerr = cudaGetLastError();CHKERRCUDA(cerr); 415*71438e86SJunchao Zhang } 416*71438e86SJunchao Zhang 417*71438e86SJunchao Zhang /* dst gets data from src ranks using non-blocking nvshmem_gets, which are finished in PetscSFLinkGetDataEnd_NVSHMEM() */ 418*71438e86SJunchao Zhang 419*71438e86SJunchao Zhang /* Count number of locally accessible src ranks, which should be a small number */ 420*71438e86SJunchao Zhang for (int i=0; i<nsrcranks; i++) {if (nvshmem_ptr(src,srcranks_h[i])) nLocallyAccessible++;} 421*71438e86SJunchao Zhang 422*71438e86SJunchao Zhang /* Get data from remotely accessible PEs */ 423*71438e86SJunchao Zhang if (nLocallyAccessible < nsrcranks) { 424*71438e86SJunchao Zhang GetDataFromRemotelyAccessible<<<nsrcranks,1,0,link->remoteCommStream>>>(nsrcranks,srcranks_d,src,srcdisp_d,dst,dstdisp_d,link->unitbytes); 425*71438e86SJunchao Zhang cerr = cudaGetLastError();CHKERRCUDA(cerr); 426*71438e86SJunchao Zhang } 427*71438e86SJunchao Zhang 428*71438e86SJunchao Zhang /* Get data from locally accessible PEs */ 429*71438e86SJunchao Zhang if (nLocallyAccessible) { 430*71438e86SJunchao Zhang for (int i=0; i<nsrcranks; i++) { 431*71438e86SJunchao Zhang int pe = srcranks_h[i]; 432*71438e86SJunchao Zhang if (nvshmem_ptr(src,pe)) { 433*71438e86SJunchao Zhang size_t nelems = (dstdisp_h[i+1]-dstdisp_h[i])*link->unitbytes; 434*71438e86SJunchao Zhang nvshmemx_getmem_nbi_on_stream(dst+(dstdisp_h[i]-dstdisp_h[0])*link->unitbytes,src+srcdisp_h[i]*link->unitbytes,nelems,pe,link->remoteCommStream); 435*71438e86SJunchao Zhang } 436*71438e86SJunchao Zhang } 437*71438e86SJunchao Zhang } 438*71438e86SJunchao Zhang PetscFunctionReturn(0); 439*71438e86SJunchao Zhang } 440*71438e86SJunchao Zhang 441*71438e86SJunchao Zhang /* Finish the communication (can be done before Unpack) 442*71438e86SJunchao Zhang Receiver tells its senders that they are allowed to reuse their send buffer (since receiver has got data from their send buffer) 443*71438e86SJunchao Zhang */ 444*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkGetDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction) 445*71438e86SJunchao Zhang { 446*71438e86SJunchao Zhang PetscErrorCode ierr; 447*71438e86SJunchao Zhang cudaError_t cerr; 448*71438e86SJunchao Zhang PetscSF_Basic *bas = (PetscSF_Basic*)sf->data; 449*71438e86SJunchao Zhang uint64_t *srcsig; 450*71438e86SJunchao Zhang PetscInt nsrcranks,*srcsigdisp; 451*71438e86SJunchao Zhang PetscMPIInt *srcranks; 452*71438e86SJunchao Zhang 453*71438e86SJunchao Zhang PetscFunctionBegin; 454*71438e86SJunchao Zhang if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */ 455*71438e86SJunchao Zhang nsrcranks = sf->nRemoteRootRanks; 456*71438e86SJunchao Zhang srcsig = link->rootSendSig; /* I want to set their root signal */ 457*71438e86SJunchao Zhang srcsigdisp = sf->rootsigdisp_d; /* offset of each root signal */ 458*71438e86SJunchao Zhang srcranks = sf->ranks_d; /* ranks of the n root ranks */ 459*71438e86SJunchao Zhang } else { /* LEAF2ROOT, root ranks are getting data */ 460*71438e86SJunchao Zhang nsrcranks = bas->nRemoteLeafRanks; 461*71438e86SJunchao Zhang srcsig = link->leafSendSig; 462*71438e86SJunchao Zhang srcsigdisp = bas->leafsigdisp_d; 463*71438e86SJunchao Zhang srcranks = bas->iranks_d; 464*71438e86SJunchao Zhang } 465*71438e86SJunchao Zhang 466*71438e86SJunchao Zhang if (nsrcranks) { 467*71438e86SJunchao Zhang nvshmemx_quiet_on_stream(link->remoteCommStream); /* Finish the nonblocking get, so that we can unpack afterwards */ 468*71438e86SJunchao Zhang cerr = cudaGetLastError();CHKERRCUDA(cerr); 469*71438e86SJunchao Zhang NvshmemSendSignals<<<(nsrcranks+511)/512,512,0,link->remoteCommStream>>>(nsrcranks,srcsig,srcsigdisp,srcranks,0); /* set signals to 0 */ 470*71438e86SJunchao Zhang cerr = cudaGetLastError();CHKERRCUDA(cerr); 471*71438e86SJunchao Zhang } 472*71438e86SJunchao Zhang ierr = PetscSFLinkBuildDependenceEnd(sf,link,direction);CHKERRQ(ierr); 473*71438e86SJunchao Zhang PetscFunctionReturn(0); 474*71438e86SJunchao Zhang } 475*71438e86SJunchao Zhang 476*71438e86SJunchao Zhang /* =========================================================================================================== 477*71438e86SJunchao Zhang 478*71438e86SJunchao Zhang A set of routines to support sender initiated communication using the put-based method (the default) 479*71438e86SJunchao Zhang 480*71438e86SJunchao Zhang The putting protocol is: 481*71438e86SJunchao Zhang 482*71438e86SJunchao Zhang Sender has a send buf (sbuf) and a send signal var (ssig); Receiver has a stand-alone recv buf (rbuf) 483*71438e86SJunchao Zhang and a recv signal var (rsig); All signal variables have an initial value 0. rbuf is allocated by SF and 484*71438e86SJunchao Zhang is in nvshmem space. 485*71438e86SJunchao Zhang 486*71438e86SJunchao Zhang Sender: | Receiver: 487*71438e86SJunchao Zhang | 488*71438e86SJunchao Zhang 1. Pack data into sbuf | 489*71438e86SJunchao Zhang 2. Wait ssig be 0, then set it to 1 | 490*71438e86SJunchao Zhang 3. Put data to remote stand-alone rbuf | 491*71438e86SJunchao Zhang 4. Fence // make sure 5 happens after 3 | 492*71438e86SJunchao Zhang 5. Put 1 to receiver's rsig | 1. Wait rsig to be 1, then set it 0 493*71438e86SJunchao Zhang | 2. Unpack data from local rbuf 494*71438e86SJunchao Zhang | 3. Put 0 to sender's ssig 495*71438e86SJunchao Zhang ===========================================================================================================*/ 496*71438e86SJunchao Zhang 497*71438e86SJunchao Zhang /* n thread blocks. Each takes in charge one remote rank */ 498*71438e86SJunchao Zhang __global__ static void WaitAndPutDataToRemotelyAccessible(PetscInt ndstranks,PetscMPIInt *dstranks,char *dst,PetscInt *dstdisp,const char *src,PetscInt *srcdisp,uint64_t *srcsig,PetscInt unitbytes) 499*71438e86SJunchao Zhang { 500*71438e86SJunchao Zhang int bid = blockIdx.x; 501*71438e86SJunchao Zhang PetscMPIInt pe = dstranks[bid]; 502*71438e86SJunchao Zhang 503*71438e86SJunchao Zhang if (!nvshmem_ptr(dst,pe)) { 504*71438e86SJunchao Zhang PetscInt nelems = (srcdisp[bid+1]-srcdisp[bid])*unitbytes; 505*71438e86SJunchao Zhang nvshmem_uint64_wait_until(srcsig+bid,NVSHMEM_CMP_EQ,0); /* Wait until the sig = 0 */ 506*71438e86SJunchao Zhang srcsig[bid] = 1; 507*71438e86SJunchao Zhang nvshmem_putmem_nbi(dst+dstdisp[bid]*unitbytes,src+(srcdisp[bid]-srcdisp[0])*unitbytes,nelems,pe); 508*71438e86SJunchao Zhang } 509*71438e86SJunchao Zhang } 510*71438e86SJunchao Zhang 511*71438e86SJunchao Zhang /* one-thread kernel, which takes in charge all locally accesible */ 512*71438e86SJunchao Zhang __global__ static void WaitSignalsFromLocallyAccessible(PetscInt ndstranks,PetscMPIInt *dstranks,uint64_t *srcsig,const char *dst) 513*71438e86SJunchao Zhang { 514*71438e86SJunchao Zhang for (int i=0; i<ndstranks; i++) { 515*71438e86SJunchao Zhang int pe = dstranks[i]; 516*71438e86SJunchao Zhang if (nvshmem_ptr(dst,pe)) { 517*71438e86SJunchao Zhang nvshmem_uint64_wait_until(srcsig+i,NVSHMEM_CMP_EQ,0); /* Wait until the sig = 0 */ 518*71438e86SJunchao Zhang srcsig[i] = 1; 519*71438e86SJunchao Zhang } 520*71438e86SJunchao Zhang } 521*71438e86SJunchao Zhang } 522*71438e86SJunchao Zhang 523*71438e86SJunchao Zhang /* Put data in the given direction */ 524*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkPutDataBegin_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction) 525*71438e86SJunchao Zhang { 526*71438e86SJunchao Zhang PetscErrorCode ierr; 527*71438e86SJunchao Zhang cudaError_t cerr; 528*71438e86SJunchao Zhang PetscSF_Basic *bas = (PetscSF_Basic*)sf->data; 529*71438e86SJunchao Zhang PetscInt ndstranks,nLocallyAccessible = 0; 530*71438e86SJunchao Zhang char *src,*dst; 531*71438e86SJunchao Zhang PetscInt *srcdisp_h,*dstdisp_h; 532*71438e86SJunchao Zhang PetscInt *srcdisp_d,*dstdisp_d; 533*71438e86SJunchao Zhang PetscMPIInt *dstranks_h; 534*71438e86SJunchao Zhang PetscMPIInt *dstranks_d; 535*71438e86SJunchao Zhang uint64_t *srcsig; 536*71438e86SJunchao Zhang 537*71438e86SJunchao Zhang PetscFunctionBegin; 538*71438e86SJunchao Zhang ierr = PetscSFLinkBuildDependenceBegin(sf,link,direction);CHKERRQ(ierr); 539*71438e86SJunchao Zhang if (direction == PETSCSF_ROOT2LEAF) { /* put data in rootbuf to leafbuf */ 540*71438e86SJunchao Zhang ndstranks = bas->nRemoteLeafRanks; /* number of (remote) leaf ranks */ 541*71438e86SJunchao Zhang src = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* Both src & dst must be symmetric */ 542*71438e86SJunchao Zhang dst = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; 543*71438e86SJunchao Zhang 544*71438e86SJunchao Zhang srcdisp_h = bas->ioffset+bas->ndiranks; /* offsets of rootbuf. srcdisp[0] is not necessarily zero */ 545*71438e86SJunchao Zhang srcdisp_d = bas->ioffset_d; 546*71438e86SJunchao Zhang srcsig = link->rootSendSig; 547*71438e86SJunchao Zhang 548*71438e86SJunchao Zhang dstdisp_h = bas->leafbufdisp; /* for my i-th remote leaf rank, I will access its leaf buf at offset leafbufdisp[i] */ 549*71438e86SJunchao Zhang dstdisp_d = bas->leafbufdisp_d; 550*71438e86SJunchao Zhang dstranks_h = bas->iranks+bas->ndiranks; /* remote leaf ranks */ 551*71438e86SJunchao Zhang dstranks_d = bas->iranks_d; 552*71438e86SJunchao Zhang } else { /* put data in leafbuf to rootbuf */ 553*71438e86SJunchao Zhang ndstranks = sf->nRemoteRootRanks; 554*71438e86SJunchao Zhang src = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; 555*71438e86SJunchao Zhang dst = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; 556*71438e86SJunchao Zhang 557*71438e86SJunchao Zhang srcdisp_h = sf->roffset+sf->ndranks; /* offsets of leafbuf */ 558*71438e86SJunchao Zhang srcdisp_d = sf->roffset_d; 559*71438e86SJunchao Zhang srcsig = link->leafSendSig; 560*71438e86SJunchao Zhang 561*71438e86SJunchao Zhang dstdisp_h = sf->rootbufdisp; /* for my i-th remote root rank, I will access its root buf at offset rootbufdisp[i] */ 562*71438e86SJunchao Zhang dstdisp_d = sf->rootbufdisp_d; 563*71438e86SJunchao Zhang dstranks_h = sf->ranks+sf->ndranks; /* remote root ranks */ 564*71438e86SJunchao Zhang dstranks_d = sf->ranks_d; 565*71438e86SJunchao Zhang } 566*71438e86SJunchao Zhang 567*71438e86SJunchao Zhang /* Wait for signals and then put data to dst ranks using non-blocking nvshmem_put, which are finished in PetscSFLinkPutDataEnd_NVSHMEM */ 568*71438e86SJunchao Zhang 569*71438e86SJunchao Zhang /* Count number of locally accessible neighbors, which should be a small number */ 570*71438e86SJunchao Zhang for (int i=0; i<ndstranks; i++) {if (nvshmem_ptr(dst,dstranks_h[i])) nLocallyAccessible++;} 571*71438e86SJunchao Zhang 572*71438e86SJunchao Zhang /* For remotely accessible PEs, send data to them in one kernel call */ 573*71438e86SJunchao Zhang if (nLocallyAccessible < ndstranks) { 574*71438e86SJunchao Zhang WaitAndPutDataToRemotelyAccessible<<<ndstranks,1,0,link->remoteCommStream>>>(ndstranks,dstranks_d,dst,dstdisp_d,src,srcdisp_d,srcsig,link->unitbytes); 575*71438e86SJunchao Zhang cerr = cudaGetLastError();CHKERRCUDA(cerr); 576*71438e86SJunchao Zhang } 577*71438e86SJunchao Zhang 578*71438e86SJunchao Zhang /* For locally accessible PEs, use host API, which uses CUDA copy-engines and is much faster than device API */ 579*71438e86SJunchao Zhang if (nLocallyAccessible) { 580*71438e86SJunchao Zhang WaitSignalsFromLocallyAccessible<<<1,1,0,link->remoteCommStream>>>(ndstranks,dstranks_d,srcsig,dst); 581*71438e86SJunchao Zhang for (int i=0; i<ndstranks; i++) { 582*71438e86SJunchao Zhang int pe = dstranks_h[i]; 583*71438e86SJunchao Zhang if (nvshmem_ptr(dst,pe)) { /* If return a non-null pointer, then <pe> is locally accessible */ 584*71438e86SJunchao Zhang size_t nelems = (srcdisp_h[i+1]-srcdisp_h[i])*link->unitbytes; 585*71438e86SJunchao Zhang /* Initiate the nonblocking communication */ 586*71438e86SJunchao Zhang nvshmemx_putmem_nbi_on_stream(dst+dstdisp_h[i]*link->unitbytes,src+(srcdisp_h[i]-srcdisp_h[0])*link->unitbytes,nelems,pe,link->remoteCommStream); 587*71438e86SJunchao Zhang } 588*71438e86SJunchao Zhang } 589*71438e86SJunchao Zhang } 590*71438e86SJunchao Zhang 591*71438e86SJunchao Zhang if (nLocallyAccessible) { 592*71438e86SJunchao Zhang nvshmemx_quiet_on_stream(link->remoteCommStream); /* Calling nvshmem_fence/quiet() does not fence the above nvshmemx_putmem_nbi_on_stream! */ 593*71438e86SJunchao Zhang } 594*71438e86SJunchao Zhang PetscFunctionReturn(0); 595*71438e86SJunchao Zhang } 596*71438e86SJunchao Zhang 597*71438e86SJunchao Zhang /* A one-thread kernel. The thread takes in charge all remote PEs */ 598*71438e86SJunchao Zhang __global__ static void PutDataEnd(PetscInt nsrcranks,PetscInt ndstranks,PetscMPIInt *dstranks,uint64_t *dstsig,PetscInt *dstsigdisp) 599*71438e86SJunchao Zhang { 600*71438e86SJunchao Zhang /* TODO: Shall we finished the non-blocking remote puts? */ 601*71438e86SJunchao Zhang 602*71438e86SJunchao Zhang /* 1. Send a signal to each dst rank */ 603*71438e86SJunchao Zhang 604*71438e86SJunchao Zhang /* According to Akhil@NVIDIA, IB is orderred, so no fence is needed for remote PEs. 605*71438e86SJunchao Zhang For local PEs, we already called nvshmemx_quiet_on_stream(). Therefore, we are good to send signals to all dst ranks now. 606*71438e86SJunchao Zhang */ 607*71438e86SJunchao Zhang for (int i=0; i<ndstranks; i++) {nvshmemx_uint64_signal(dstsig+dstsigdisp[i],1,dstranks[i]);} /* set sig to 1 */ 608*71438e86SJunchao Zhang 609*71438e86SJunchao Zhang /* 2. Wait for signals from src ranks (if any) */ 610*71438e86SJunchao Zhang if (nsrcranks) { 611*71438e86SJunchao Zhang nvshmem_uint64_wait_until_all(dstsig,nsrcranks,NULL/*no mask*/,NVSHMEM_CMP_EQ,1); /* wait sigs to be 1, then set them to 0 */ 612*71438e86SJunchao Zhang for (int i=0; i<nsrcranks; i++) dstsig[i] = 0; 613*71438e86SJunchao Zhang } 614*71438e86SJunchao Zhang } 615*71438e86SJunchao Zhang 616*71438e86SJunchao Zhang /* Finish the communication -- A receiver waits until it can access its receive buffer */ 617*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkPutDataEnd_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction) 618*71438e86SJunchao Zhang { 619*71438e86SJunchao Zhang PetscErrorCode ierr; 620*71438e86SJunchao Zhang cudaError_t cerr; 621*71438e86SJunchao Zhang PetscSF_Basic *bas = (PetscSF_Basic*)sf->data; 622*71438e86SJunchao Zhang PetscMPIInt *dstranks; 623*71438e86SJunchao Zhang uint64_t *dstsig; 624*71438e86SJunchao Zhang PetscInt nsrcranks,ndstranks,*dstsigdisp; 625*71438e86SJunchao Zhang 626*71438e86SJunchao Zhang PetscFunctionBegin; 627*71438e86SJunchao Zhang if (direction == PETSCSF_ROOT2LEAF) { /* put root data to leaf */ 628*71438e86SJunchao Zhang nsrcranks = sf->nRemoteRootRanks; 629*71438e86SJunchao Zhang 630*71438e86SJunchao Zhang ndstranks = bas->nRemoteLeafRanks; 631*71438e86SJunchao Zhang dstranks = bas->iranks_d; /* leaf ranks */ 632*71438e86SJunchao Zhang dstsig = link->leafRecvSig; /* I will set my leaf ranks's RecvSig */ 633*71438e86SJunchao Zhang dstsigdisp = bas->leafsigdisp_d; /* for my i-th remote leaf rank, I will access its signal at offset leafsigdisp[i] */ 634*71438e86SJunchao Zhang } else { /* LEAF2ROOT */ 635*71438e86SJunchao Zhang nsrcranks = bas->nRemoteLeafRanks; 636*71438e86SJunchao Zhang 637*71438e86SJunchao Zhang ndstranks = sf->nRemoteRootRanks; 638*71438e86SJunchao Zhang dstranks = sf->ranks_d; 639*71438e86SJunchao Zhang dstsig = link->rootRecvSig; 640*71438e86SJunchao Zhang dstsigdisp = sf->rootsigdisp_d; 641*71438e86SJunchao Zhang } 642*71438e86SJunchao Zhang 643*71438e86SJunchao Zhang if (nsrcranks || ndstranks) { 644*71438e86SJunchao Zhang PutDataEnd<<<1,1,0,link->remoteCommStream>>>(nsrcranks,ndstranks,dstranks,dstsig,dstsigdisp); 645*71438e86SJunchao Zhang cerr = cudaGetLastError();CHKERRCUDA(cerr); 646*71438e86SJunchao Zhang } 647*71438e86SJunchao Zhang ierr = PetscSFLinkBuildDependenceEnd(sf,link,direction);CHKERRQ(ierr); 648*71438e86SJunchao Zhang PetscFunctionReturn(0); 649*71438e86SJunchao Zhang } 650*71438e86SJunchao Zhang 651*71438e86SJunchao Zhang /* PostUnpack operation -- A receiver tells its senders that they are allowed to put data to here (it implies recv buf is free to take new data) */ 652*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM(PetscSF sf,PetscSFLink link,PetscSFDirection direction) 653*71438e86SJunchao Zhang { 654*71438e86SJunchao Zhang PetscSF_Basic *bas = (PetscSF_Basic*)sf->data; 655*71438e86SJunchao Zhang uint64_t *srcsig; 656*71438e86SJunchao Zhang PetscInt nsrcranks,*srcsigdisp_d; 657*71438e86SJunchao Zhang PetscMPIInt *srcranks_d; 658*71438e86SJunchao Zhang 659*71438e86SJunchao Zhang PetscFunctionBegin; 660*71438e86SJunchao Zhang if (direction == PETSCSF_ROOT2LEAF) { /* I allow my root ranks to put data to me */ 661*71438e86SJunchao Zhang nsrcranks = sf->nRemoteRootRanks; 662*71438e86SJunchao Zhang srcsig = link->rootSendSig; /* I want to set their send signals */ 663*71438e86SJunchao Zhang srcsigdisp_d = sf->rootsigdisp_d; /* offset of each root signal */ 664*71438e86SJunchao Zhang srcranks_d = sf->ranks_d; /* ranks of the n root ranks */ 665*71438e86SJunchao Zhang } else { /* LEAF2ROOT */ 666*71438e86SJunchao Zhang nsrcranks = bas->nRemoteLeafRanks; 667*71438e86SJunchao Zhang srcsig = link->leafSendSig; 668*71438e86SJunchao Zhang srcsigdisp_d = bas->leafsigdisp_d; 669*71438e86SJunchao Zhang srcranks_d = bas->iranks_d; 670*71438e86SJunchao Zhang } 671*71438e86SJunchao Zhang 672*71438e86SJunchao Zhang if (nsrcranks) { 673*71438e86SJunchao Zhang NvshmemSendSignals<<<(nsrcranks+255)/256,256,0,link->remoteCommStream>>>(nsrcranks,srcsig,srcsigdisp_d,srcranks_d,0); /* Set remote signals to 0 */ 674*71438e86SJunchao Zhang cudaError_t cerr = cudaGetLastError();CHKERRCUDA(cerr); 675*71438e86SJunchao Zhang } 676*71438e86SJunchao Zhang PetscFunctionReturn(0); 677*71438e86SJunchao Zhang } 678*71438e86SJunchao Zhang 679*71438e86SJunchao Zhang /* Destructor when the link uses nvshmem for communication */ 680*71438e86SJunchao Zhang static PetscErrorCode PetscSFLinkDestroy_NVSHMEM(PetscSF sf,PetscSFLink link) 681*71438e86SJunchao Zhang { 682*71438e86SJunchao Zhang PetscErrorCode ierr; 683*71438e86SJunchao Zhang cudaError_t cerr; 684*71438e86SJunchao Zhang 685*71438e86SJunchao Zhang PetscFunctionBegin; 686*71438e86SJunchao Zhang cerr = cudaEventDestroy(link->dataReady);CHKERRCUDA(cerr); 687*71438e86SJunchao Zhang cerr = cudaEventDestroy(link->endRemoteComm);CHKERRCUDA(cerr); 688*71438e86SJunchao Zhang cerr = cudaStreamDestroy(link->remoteCommStream);CHKERRCUDA(cerr); 689*71438e86SJunchao Zhang 690*71438e86SJunchao Zhang /* nvshmem does not need buffers on host, which should be NULL */ 691*71438e86SJunchao Zhang ierr = PetscNvshmemFree(link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr); 692*71438e86SJunchao Zhang ierr = PetscNvshmemFree(link->leafSendSig);CHKERRQ(ierr); 693*71438e86SJunchao Zhang ierr = PetscNvshmemFree(link->leafRecvSig);CHKERRQ(ierr); 694*71438e86SJunchao Zhang ierr = PetscNvshmemFree(link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr); 695*71438e86SJunchao Zhang ierr = PetscNvshmemFree(link->rootSendSig);CHKERRQ(ierr); 696*71438e86SJunchao Zhang ierr = PetscNvshmemFree(link->rootRecvSig);CHKERRQ(ierr); 697*71438e86SJunchao Zhang PetscFunctionReturn(0); 698*71438e86SJunchao Zhang } 699*71438e86SJunchao Zhang 700*71438e86SJunchao Zhang PetscErrorCode PetscSFLinkCreate_NVSHMEM(PetscSF sf,MPI_Datatype unit,PetscMemType rootmtype,const void *rootdata,PetscMemType leafmtype,const void *leafdata,MPI_Op op,PetscSFOperation sfop,PetscSFLink *mylink) 701*71438e86SJunchao Zhang { 702*71438e86SJunchao Zhang PetscErrorCode ierr; 703*71438e86SJunchao Zhang cudaError_t cerr; 704*71438e86SJunchao Zhang PetscSF_Basic *bas = (PetscSF_Basic*)sf->data; 705*71438e86SJunchao Zhang PetscSFLink *p,link; 706*71438e86SJunchao Zhang PetscBool match,rootdirect[2],leafdirect[2]; 707*71438e86SJunchao Zhang int greatestPriority; 708*71438e86SJunchao Zhang 709*71438e86SJunchao Zhang PetscFunctionBegin; 710*71438e86SJunchao Zhang /* Check to see if we can directly send/recv root/leafdata with the given sf, sfop and op. 711*71438e86SJunchao Zhang We only care root/leafdirect[PETSCSF_REMOTE], since we never need intermeidate buffers in local communication with NVSHMEM. 712*71438e86SJunchao Zhang */ 713*71438e86SJunchao Zhang if (sfop == PETSCSF_BCAST) { /* Move data from rootbuf to leafbuf */ 714*71438e86SJunchao Zhang if (sf->use_nvshmem_get) { 715*71438e86SJunchao Zhang rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* send buffer has to be stand-alone (can't be rootdata) */ 716*71438e86SJunchao Zhang leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE; 717*71438e86SJunchao Zhang } else { 718*71438e86SJunchao Zhang rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE; 719*71438e86SJunchao Zhang leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* Our put-protocol always needs a nvshmem alloc'ed recv buffer */ 720*71438e86SJunchao Zhang } 721*71438e86SJunchao Zhang } else if (sfop == PETSCSF_REDUCE) { /* Move data from leafbuf to rootbuf */ 722*71438e86SJunchao Zhang if (sf->use_nvshmem_get) { 723*71438e86SJunchao Zhang rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE; 724*71438e86SJunchao Zhang leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; 725*71438e86SJunchao Zhang } else { 726*71438e86SJunchao Zhang rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; 727*71438e86SJunchao Zhang leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE; 728*71438e86SJunchao Zhang } 729*71438e86SJunchao Zhang } else { /* PETSCSF_FETCH */ 730*71438e86SJunchao Zhang rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* FETCH always need a separate rootbuf */ 731*71438e86SJunchao Zhang leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* We also force allocating a separate leafbuf so that leafdata and leafupdate can share mpi requests */ 732*71438e86SJunchao Zhang } 733*71438e86SJunchao Zhang 734*71438e86SJunchao Zhang /* Look for free nvshmem links in cache */ 735*71438e86SJunchao Zhang for (p=&bas->avail; (link=*p); p=&link->next) { 736*71438e86SJunchao Zhang if (link->use_nvshmem) { 737*71438e86SJunchao Zhang ierr = MPIPetsc_Type_compare(unit,link->unit,&match);CHKERRQ(ierr); 738*71438e86SJunchao Zhang if (match) { 739*71438e86SJunchao Zhang *p = link->next; /* Remove from available list */ 740*71438e86SJunchao Zhang goto found; 741*71438e86SJunchao Zhang } 742*71438e86SJunchao Zhang } 743*71438e86SJunchao Zhang } 744*71438e86SJunchao Zhang ierr = PetscNew(&link);CHKERRQ(ierr); 745*71438e86SJunchao Zhang ierr = PetscSFLinkSetUp_Host(sf,link,unit);CHKERRQ(ierr); /* Compute link->unitbytes, dup link->unit etc. */ 746*71438e86SJunchao Zhang if (sf->backend == PETSCSF_BACKEND_CUDA) {ierr = PetscSFLinkSetUp_CUDA(sf,link,unit);CHKERRQ(ierr);} /* Setup pack routines, streams etc */ 747*71438e86SJunchao Zhang #if defined(PETSC_HAVE_KOKKOS) 748*71438e86SJunchao Zhang else if (sf->backend == PETSCSF_BACKEND_KOKKOS) {ierr = PetscSFLinkSetUp_Kokkos(sf,link,unit);CHKERRQ(ierr);} 749*71438e86SJunchao Zhang #endif 750*71438e86SJunchao Zhang 751*71438e86SJunchao Zhang link->rootdirect[PETSCSF_LOCAL] = PETSC_TRUE; /* For the local part we directly use root/leafdata */ 752*71438e86SJunchao Zhang link->leafdirect[PETSCSF_LOCAL] = PETSC_TRUE; 753*71438e86SJunchao Zhang 754*71438e86SJunchao Zhang /* Init signals to zero */ 755*71438e86SJunchao Zhang if (!link->rootSendSig) {ierr = PetscNvshmemCalloc(bas->nRemoteLeafRanksMax*sizeof(uint64_t),(void**)&link->rootSendSig);CHKERRQ(ierr);} 756*71438e86SJunchao Zhang if (!link->rootRecvSig) {ierr = PetscNvshmemCalloc(bas->nRemoteLeafRanksMax*sizeof(uint64_t),(void**)&link->rootRecvSig);CHKERRQ(ierr);} 757*71438e86SJunchao Zhang if (!link->leafSendSig) {ierr = PetscNvshmemCalloc(sf->nRemoteRootRanksMax*sizeof(uint64_t),(void**)&link->leafSendSig);CHKERRQ(ierr);} 758*71438e86SJunchao Zhang if (!link->leafRecvSig) {ierr = PetscNvshmemCalloc(sf->nRemoteRootRanksMax*sizeof(uint64_t),(void**)&link->leafRecvSig);CHKERRQ(ierr);} 759*71438e86SJunchao Zhang 760*71438e86SJunchao Zhang link->use_nvshmem = PETSC_TRUE; 761*71438e86SJunchao Zhang link->rootmtype = PETSC_MEMTYPE_DEVICE; /* Only need 0/1-based mtype from now on */ 762*71438e86SJunchao Zhang link->leafmtype = PETSC_MEMTYPE_DEVICE; 763*71438e86SJunchao Zhang /* Overwrite some function pointers set by PetscSFLinkSetUp_CUDA */ 764*71438e86SJunchao Zhang link->Destroy = PetscSFLinkDestroy_NVSHMEM; 765*71438e86SJunchao Zhang if (sf->use_nvshmem_get) { /* get-based protocol */ 766*71438e86SJunchao Zhang link->PrePack = PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM; 767*71438e86SJunchao Zhang link->StartCommunication = PetscSFLinkGetDataBegin_NVSHMEM; 768*71438e86SJunchao Zhang link->FinishCommunication = PetscSFLinkGetDataEnd_NVSHMEM; 769*71438e86SJunchao Zhang } else { /* put-based protocol */ 770*71438e86SJunchao Zhang link->StartCommunication = PetscSFLinkPutDataBegin_NVSHMEM; 771*71438e86SJunchao Zhang link->FinishCommunication = PetscSFLinkPutDataEnd_NVSHMEM; 772*71438e86SJunchao Zhang link->PostUnpack = PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM; 773*71438e86SJunchao Zhang } 774*71438e86SJunchao Zhang 775*71438e86SJunchao Zhang cerr = cudaDeviceGetStreamPriorityRange(NULL,&greatestPriority);CHKERRCUDA(cerr); 776*71438e86SJunchao Zhang cerr = cudaStreamCreateWithPriority(&link->remoteCommStream,cudaStreamNonBlocking,greatestPriority);CHKERRCUDA(cerr); 777*71438e86SJunchao Zhang 778*71438e86SJunchao Zhang cerr = cudaEventCreateWithFlags(&link->dataReady,cudaEventDisableTiming);CHKERRCUDA(cerr); 779*71438e86SJunchao Zhang cerr = cudaEventCreateWithFlags(&link->endRemoteComm,cudaEventDisableTiming);CHKERRCUDA(cerr); 780*71438e86SJunchao Zhang 781*71438e86SJunchao Zhang found: 782*71438e86SJunchao Zhang if (rootdirect[PETSCSF_REMOTE]) { 783*71438e86SJunchao Zhang link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char*)rootdata + bas->rootstart[PETSCSF_REMOTE]*link->unitbytes; 784*71438e86SJunchao Zhang } else { 785*71438e86SJunchao Zhang if (!link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) { 786*71438e86SJunchao Zhang ierr = PetscNvshmemMalloc(bas->rootbuflen_rmax*link->unitbytes,(void**)&link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr); 787*71438e86SJunchao Zhang } 788*71438e86SJunchao Zhang link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; 789*71438e86SJunchao Zhang } 790*71438e86SJunchao Zhang 791*71438e86SJunchao Zhang if (leafdirect[PETSCSF_REMOTE]) { 792*71438e86SJunchao Zhang link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char*)leafdata + sf->leafstart[PETSCSF_REMOTE]*link->unitbytes; 793*71438e86SJunchao Zhang } else { 794*71438e86SJunchao Zhang if (!link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) { 795*71438e86SJunchao Zhang ierr = PetscNvshmemMalloc(sf->leafbuflen_rmax*link->unitbytes,(void**)&link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]);CHKERRQ(ierr); 796*71438e86SJunchao Zhang } 797*71438e86SJunchao Zhang link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; 798*71438e86SJunchao Zhang } 799*71438e86SJunchao Zhang 800*71438e86SJunchao Zhang link->rootdirect[PETSCSF_REMOTE] = rootdirect[PETSCSF_REMOTE]; 801*71438e86SJunchao Zhang link->leafdirect[PETSCSF_REMOTE] = leafdirect[PETSCSF_REMOTE]; 802*71438e86SJunchao Zhang link->rootdata = rootdata; /* root/leafdata are keys to look up links in PetscSFXxxEnd */ 803*71438e86SJunchao Zhang link->leafdata = leafdata; 804*71438e86SJunchao Zhang link->next = bas->inuse; 805*71438e86SJunchao Zhang bas->inuse = link; 806*71438e86SJunchao Zhang *mylink = link; 807*71438e86SJunchao Zhang PetscFunctionReturn(0); 808*71438e86SJunchao Zhang } 809*71438e86SJunchao Zhang 810*71438e86SJunchao Zhang #if defined(PETSC_USE_REAL_SINGLE) 811*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemSum(PetscInt count,float *dst,const float *src) 812*71438e86SJunchao Zhang { 813*71438e86SJunchao Zhang PetscErrorCode ierr; 814*71438e86SJunchao Zhang PetscMPIInt num; /* Assume nvshmem's int is MPI's int */ 815*71438e86SJunchao Zhang 816*71438e86SJunchao Zhang PetscFunctionBegin; 817*71438e86SJunchao Zhang ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr); 818*71438e86SJunchao Zhang nvshmemx_float_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream); 819*71438e86SJunchao Zhang PetscFunctionReturn(0); 820*71438e86SJunchao Zhang } 821*71438e86SJunchao Zhang 822*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemMax(PetscInt count,float *dst,const float *src) 823*71438e86SJunchao Zhang { 824*71438e86SJunchao Zhang PetscErrorCode ierr; 825*71438e86SJunchao Zhang PetscMPIInt num; 826*71438e86SJunchao Zhang 827*71438e86SJunchao Zhang PetscFunctionBegin; 828*71438e86SJunchao Zhang ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr); 829*71438e86SJunchao Zhang nvshmemx_float_max_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream); 830*71438e86SJunchao Zhang PetscFunctionReturn(0); 831*71438e86SJunchao Zhang } 832*71438e86SJunchao Zhang #elif defined(PETSC_USE_REAL_DOUBLE) 833*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemSum(PetscInt count,double *dst,const double *src) 834*71438e86SJunchao Zhang { 835*71438e86SJunchao Zhang PetscErrorCode ierr; 836*71438e86SJunchao Zhang PetscMPIInt num; 837*71438e86SJunchao Zhang 838*71438e86SJunchao Zhang PetscFunctionBegin; 839*71438e86SJunchao Zhang ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr); 840*71438e86SJunchao Zhang nvshmemx_double_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream); 841*71438e86SJunchao Zhang PetscFunctionReturn(0); 842*71438e86SJunchao Zhang } 843*71438e86SJunchao Zhang 844*71438e86SJunchao Zhang PetscErrorCode PetscNvshmemMax(PetscInt count,double *dst,const double *src) 845*71438e86SJunchao Zhang { 846*71438e86SJunchao Zhang PetscErrorCode ierr; 847*71438e86SJunchao Zhang PetscMPIInt num; 848*71438e86SJunchao Zhang 849*71438e86SJunchao Zhang PetscFunctionBegin; 850*71438e86SJunchao Zhang ierr = PetscMPIIntCast(count,&num);CHKERRQ(ierr); 851*71438e86SJunchao Zhang nvshmemx_double_max_reduce_on_stream(NVSHMEM_TEAM_WORLD,dst,src,num,PetscDefaultCudaStream); 852*71438e86SJunchao Zhang PetscFunctionReturn(0); 853*71438e86SJunchao Zhang } 854*71438e86SJunchao Zhang #endif 855*71438e86SJunchao Zhang 856