1*7cd70835SJames Wright // Copyright (c) 2017-2023, Lawrence Livermore National Security, LLC and other CEED contributors. 2*7cd70835SJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*7cd70835SJames Wright // 4*7cd70835SJames Wright // SPDX-License-Identifier: BSD-2-Clause 5*7cd70835SJames Wright // 6*7cd70835SJames Wright // This file is part of CEED: http://github.com/ceed 7*7cd70835SJames Wright // Based on the instructions from https://www.craylabs.org/docs/sr_integration.html and PHASTA implementation 8*7cd70835SJames Wright 9*7cd70835SJames Wright #include "../../include/smartsim.h" 10*7cd70835SJames Wright 11*7cd70835SJames Wright #include "../../navierstokes.h" 12*7cd70835SJames Wright 13*7cd70835SJames Wright PetscErrorCode SmartRedisVerifyPutTensor(void *c_client, const char *name, const size_t name_length) { 14*7cd70835SJames Wright bool does_exist = true; 15*7cd70835SJames Wright 16*7cd70835SJames Wright PetscFunctionBeginUser; 17*7cd70835SJames Wright SmartRedisCall(tensor_exists(c_client, name, name_length, &does_exist)); 18*7cd70835SJames Wright PetscCheck(does_exist, PETSC_COMM_SELF, -1, "Tensor of name '%s' was not written to the database successfully", name); 19*7cd70835SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 20*7cd70835SJames Wright } 21*7cd70835SJames Wright 22*7cd70835SJames Wright PetscErrorCode SmartSimTrainingSetup(User user) { 23*7cd70835SJames Wright SmartSimData smartsim = user->smartsim; 24*7cd70835SJames Wright PetscMPIInt rank; 25*7cd70835SJames Wright PetscReal checkrun[2] = {1}; 26*7cd70835SJames Wright size_t dim_2[1] = {2}; 27*7cd70835SJames Wright PetscInt num_ranks; 28*7cd70835SJames Wright 29*7cd70835SJames Wright PetscFunctionBeginUser; 30*7cd70835SJames Wright PetscCallMPI(MPI_Comm_rank(user->comm, &rank)); 31*7cd70835SJames Wright PetscCallMPI(MPI_Comm_size(user->comm, &num_ranks)); 32*7cd70835SJames Wright 33*7cd70835SJames Wright if (rank % smartsim->collocated_database_num_ranks == 0) { 34*7cd70835SJames Wright // -- Send array that communicates when ML is done training 35*7cd70835SJames Wright SmartRedisCall(put_tensor(smartsim->client, "check-run", 9, checkrun, dim_2, 1, SRTensorTypeDouble, SRMemLayoutContiguous)); 36*7cd70835SJames Wright PetscCall(SmartRedisVerifyPutTensor(smartsim->client, "check-run", 9)); 37*7cd70835SJames Wright } 38*7cd70835SJames Wright 39*7cd70835SJames Wright { // -- Get minimum per-rank global vec size 40*7cd70835SJames Wright PetscInt GlobalVecSize; 41*7cd70835SJames Wright PetscCall(DMGetGlobalVectorInfo(user->dm, &GlobalVecSize, NULL, NULL)); 42*7cd70835SJames Wright PetscCallMPI(MPI_Allreduce(&GlobalVecSize, &smartsim->num_tensor_nodes, 1, MPIU_INT, MPI_MIN, user->comm)); 43*7cd70835SJames Wright smartsim->num_nodes_to_remove = GlobalVecSize - smartsim->num_tensor_nodes; 44*7cd70835SJames Wright } 45*7cd70835SJames Wright 46*7cd70835SJames Wright // Determine the size of the training data arrays... somehow 47*7cd70835SJames Wright if (rank % smartsim->collocated_database_num_ranks == 0) { 48*7cd70835SJames Wright size_t array_dims[2] = {smartsim->num_tensor_nodes, 6}, array_info_dim = 6; 49*7cd70835SJames Wright PetscInt array_info[6] = {0}, num_features = 6; 50*7cd70835SJames Wright 51*7cd70835SJames Wright array_info[0] = array_dims[0]; 52*7cd70835SJames Wright array_info[1] = array_dims[1]; 53*7cd70835SJames Wright array_info[2] = num_features; 54*7cd70835SJames Wright array_info[3] = num_ranks; 55*7cd70835SJames Wright array_info[4] = smartsim->collocated_database_num_ranks; 56*7cd70835SJames Wright array_info[5] = rank; 57*7cd70835SJames Wright 58*7cd70835SJames Wright SmartRedisCall(put_tensor(smartsim->client, "array_info", 10, array_info, &array_info_dim, 1, SRTensorTypeInt32, SRMemLayoutContiguous)); 59*7cd70835SJames Wright PetscCall(SmartRedisVerifyPutTensor(smartsim->client, "array_info", 10)); 60*7cd70835SJames Wright } 61*7cd70835SJames Wright PetscFunctionReturn(0); 62*7cd70835SJames Wright } 63*7cd70835SJames Wright 64*7cd70835SJames Wright PetscErrorCode SmartSimSetup(User user) { 65*7cd70835SJames Wright PetscMPIInt rank; 66*7cd70835SJames Wright size_t rank_id_name_len; 67*7cd70835SJames Wright PetscInt num_orchestrator_nodes = 1; 68*7cd70835SJames Wright 69*7cd70835SJames Wright PetscFunctionBeginUser; 70*7cd70835SJames Wright PetscCall(PetscNew(&user->smartsim)); 71*7cd70835SJames Wright SmartSimData smartsim = user->smartsim; 72*7cd70835SJames Wright 73*7cd70835SJames Wright smartsim->collocated_database_num_ranks = 1; 74*7cd70835SJames Wright PetscOptionsBegin(user->comm, NULL, "Options for SmartSim integration", NULL); 75*7cd70835SJames Wright PetscCall(PetscOptionsInt("-smartsim_collocated_database_num_ranks", "Number of ranks per collocated database instance", NULL, 76*7cd70835SJames Wright smartsim->collocated_database_num_ranks, &smartsim->collocated_database_num_ranks, NULL)); 77*7cd70835SJames Wright PetscOptionsEnd(); 78*7cd70835SJames Wright 79*7cd70835SJames Wright PetscCall(PetscStrlen(smartsim->rank_id_name, &rank_id_name_len)); 80*7cd70835SJames Wright // Create prefix to be put on tensor names 81*7cd70835SJames Wright PetscCallMPI(MPI_Comm_rank(user->comm, &rank)); 82*7cd70835SJames Wright PetscCall(PetscSNPrintf(smartsim->rank_id_name, sizeof smartsim->rank_id_name, "y.%d", rank)); 83*7cd70835SJames Wright 84*7cd70835SJames Wright SmartRedisCall(SmartRedisCClient(num_orchestrator_nodes != 1, smartsim->rank_id_name, rank_id_name_len, &smartsim->client)); 85*7cd70835SJames Wright 86*7cd70835SJames Wright PetscCall(SmartSimTrainingSetup(user)); 87*7cd70835SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 88*7cd70835SJames Wright } 89