17cd70835SJames Wright // Copyright (c) 2017-2023, Lawrence Livermore National Security, LLC and other CEED contributors. 27cd70835SJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 37cd70835SJames Wright // 47cd70835SJames Wright // SPDX-License-Identifier: BSD-2-Clause 57cd70835SJames Wright // 67cd70835SJames Wright // This file is part of CEED: http://github.com/ceed 77cd70835SJames Wright // Based on the instructions from https://www.craylabs.org/docs/sr_integration.html and PHASTA implementation 87cd70835SJames Wright 97cd70835SJames Wright #include "../../include/smartsim.h" 107cd70835SJames Wright 117cd70835SJames Wright #include "../../navierstokes.h" 127cd70835SJames Wright 137cd70835SJames Wright PetscErrorCode SmartRedisVerifyPutTensor(void *c_client, const char *name, const size_t name_length) { 147cd70835SJames Wright bool does_exist = true; 157cd70835SJames Wright 167cd70835SJames Wright PetscFunctionBeginUser; 177cd70835SJames Wright SmartRedisCall(tensor_exists(c_client, name, name_length, &does_exist)); 187cd70835SJames Wright PetscCheck(does_exist, PETSC_COMM_SELF, -1, "Tensor of name '%s' was not written to the database successfully", name); 197cd70835SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 207cd70835SJames Wright } 217cd70835SJames Wright 227cd70835SJames Wright PetscErrorCode SmartSimTrainingSetup(User user) { 237cd70835SJames Wright SmartSimData smartsim = user->smartsim; 247cd70835SJames Wright PetscMPIInt rank; 257cd70835SJames Wright PetscReal checkrun[2] = {1}; 267cd70835SJames Wright size_t dim_2[1] = {2}; 277cd70835SJames Wright PetscInt num_ranks; 287cd70835SJames Wright 297cd70835SJames Wright PetscFunctionBeginUser; 307cd70835SJames Wright PetscCallMPI(MPI_Comm_rank(user->comm, &rank)); 317cd70835SJames Wright PetscCallMPI(MPI_Comm_size(user->comm, &num_ranks)); 327cd70835SJames Wright 337cd70835SJames Wright if (rank % smartsim->collocated_database_num_ranks == 0) { 347cd70835SJames Wright // -- Send array that communicates when ML is done training 357cd70835SJames Wright SmartRedisCall(put_tensor(smartsim->client, "check-run", 9, checkrun, dim_2, 1, SRTensorTypeDouble, SRMemLayoutContiguous)); 367cd70835SJames Wright PetscCall(SmartRedisVerifyPutTensor(smartsim->client, "check-run", 9)); 377cd70835SJames Wright } 38*aa0b7f76SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 397cd70835SJames Wright } 407cd70835SJames Wright 417cd70835SJames Wright PetscErrorCode SmartSimSetup(User user) { 427cd70835SJames Wright PetscMPIInt rank; 437cd70835SJames Wright size_t rank_id_name_len; 447cd70835SJames Wright PetscInt num_orchestrator_nodes = 1; 457cd70835SJames Wright 467cd70835SJames Wright PetscFunctionBeginUser; 477cd70835SJames Wright PetscCall(PetscNew(&user->smartsim)); 487cd70835SJames Wright SmartSimData smartsim = user->smartsim; 497cd70835SJames Wright 507cd70835SJames Wright smartsim->collocated_database_num_ranks = 1; 517cd70835SJames Wright PetscOptionsBegin(user->comm, NULL, "Options for SmartSim integration", NULL); 527cd70835SJames Wright PetscCall(PetscOptionsInt("-smartsim_collocated_database_num_ranks", "Number of ranks per collocated database instance", NULL, 537cd70835SJames Wright smartsim->collocated_database_num_ranks, &smartsim->collocated_database_num_ranks, NULL)); 547cd70835SJames Wright PetscOptionsEnd(); 557cd70835SJames Wright 567cd70835SJames Wright PetscCall(PetscStrlen(smartsim->rank_id_name, &rank_id_name_len)); 577cd70835SJames Wright // Create prefix to be put on tensor names 587cd70835SJames Wright PetscCallMPI(MPI_Comm_rank(user->comm, &rank)); 597cd70835SJames Wright PetscCall(PetscSNPrintf(smartsim->rank_id_name, sizeof smartsim->rank_id_name, "y.%d", rank)); 607cd70835SJames Wright 617cd70835SJames Wright SmartRedisCall(SmartRedisCClient(num_orchestrator_nodes != 1, smartsim->rank_id_name, rank_id_name_len, &smartsim->client)); 627cd70835SJames Wright 637cd70835SJames Wright PetscCall(SmartSimTrainingSetup(user)); 647cd70835SJames Wright PetscFunctionReturn(PETSC_SUCCESS); 657cd70835SJames Wright } 66