xref: /honee/src/smartsim/smartsim.c (revision ad2e713ea4896e3cbe7eaaaac46996b5b5bf5c52)
17cd70835SJames Wright // Copyright (c) 2017-2023, Lawrence Livermore National Security, LLC and other CEED contributors.
27cd70835SJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
37cd70835SJames Wright //
47cd70835SJames Wright // SPDX-License-Identifier: BSD-2-Clause
57cd70835SJames Wright //
67cd70835SJames Wright // This file is part of CEED:  http://github.com/ceed
77cd70835SJames Wright // Based on the instructions from https://www.craylabs.org/docs/sr_integration.html and PHASTA implementation
87cd70835SJames Wright 
97cd70835SJames Wright #include "../../include/smartsim.h"
107cd70835SJames Wright 
117cd70835SJames Wright #include "../../navierstokes.h"
127cd70835SJames Wright 
137cd70835SJames Wright PetscErrorCode SmartRedisVerifyPutTensor(void *c_client, const char *name, const size_t name_length) {
147cd70835SJames Wright   bool does_exist = true;
157cd70835SJames Wright 
167cd70835SJames Wright   PetscFunctionBeginUser;
17ff6b888aSJames Wright   PetscSmartRedisCall(tensor_exists(c_client, name, name_length, &does_exist));
187cd70835SJames Wright   PetscCheck(does_exist, PETSC_COMM_SELF, -1, "Tensor of name '%s' was not written to the database successfully", name);
197cd70835SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
207cd70835SJames Wright }
217cd70835SJames Wright 
227cd70835SJames Wright PetscErrorCode SmartSimTrainingSetup(User user) {
237cd70835SJames Wright   SmartSimData smartsim = user->smartsim;
247cd70835SJames Wright   PetscMPIInt  rank;
257cd70835SJames Wright   PetscReal    checkrun[2] = {1};
267cd70835SJames Wright   size_t       dim_2[1]    = {2};
277cd70835SJames Wright   PetscInt     num_ranks;
287cd70835SJames Wright 
297cd70835SJames Wright   PetscFunctionBeginUser;
307cd70835SJames Wright   PetscCallMPI(MPI_Comm_rank(user->comm, &rank));
317cd70835SJames Wright   PetscCallMPI(MPI_Comm_size(user->comm, &num_ranks));
327cd70835SJames Wright 
337cd70835SJames Wright   if (rank % smartsim->collocated_database_num_ranks == 0) {
347cd70835SJames Wright     // -- Send array that communicates when ML is done training
35*ad2e713eSRiccardo Balin     PetscCall(PetscLogEventBegin(FLUIDS_SmartRedis_Meta, 0, 0, 0, 0));
36ff6b888aSJames Wright     PetscSmartRedisCall(put_tensor(smartsim->client, "check-run", 9, checkrun, dim_2, 1, SRTensorTypeDouble, SRMemLayoutContiguous));
377cd70835SJames Wright     PetscCall(SmartRedisVerifyPutTensor(smartsim->client, "check-run", 9));
38*ad2e713eSRiccardo Balin     PetscCall(PetscLogEventEnd(FLUIDS_SmartRedis_Meta, 0, 0, 0, 0));
397cd70835SJames Wright   }
40aa0b7f76SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
417cd70835SJames Wright }
427cd70835SJames Wright 
437cd70835SJames Wright PetscErrorCode SmartSimSetup(User user) {
447cd70835SJames Wright   PetscMPIInt rank;
457cd70835SJames Wright   size_t      rank_id_name_len;
467cd70835SJames Wright   PetscInt    num_orchestrator_nodes = 1;
477cd70835SJames Wright 
487cd70835SJames Wright   PetscFunctionBeginUser;
497cd70835SJames Wright   PetscCall(PetscNew(&user->smartsim));
507cd70835SJames Wright   SmartSimData smartsim = user->smartsim;
517cd70835SJames Wright 
527cd70835SJames Wright   smartsim->collocated_database_num_ranks = 1;
537cd70835SJames Wright   PetscOptionsBegin(user->comm, NULL, "Options for SmartSim integration", NULL);
547cd70835SJames Wright   PetscCall(PetscOptionsInt("-smartsim_collocated_database_num_ranks", "Number of ranks per collocated database instance", NULL,
557cd70835SJames Wright                             smartsim->collocated_database_num_ranks, &smartsim->collocated_database_num_ranks, NULL));
567cd70835SJames Wright   PetscOptionsEnd();
577cd70835SJames Wright 
587cd70835SJames Wright   PetscCall(PetscStrlen(smartsim->rank_id_name, &rank_id_name_len));
597cd70835SJames Wright   // Create prefix to be put on tensor names
607cd70835SJames Wright   PetscCallMPI(MPI_Comm_rank(user->comm, &rank));
617cd70835SJames Wright   PetscCall(PetscSNPrintf(smartsim->rank_id_name, sizeof smartsim->rank_id_name, "y.%d", rank));
627cd70835SJames Wright 
63*ad2e713eSRiccardo Balin   PetscCall(PetscLogEventBegin(FLUIDS_SmartRedis_Init, 0, 0, 0, 0));
64ff6b888aSJames Wright   PetscSmartRedisCall(SmartRedisCClient(num_orchestrator_nodes != 1, smartsim->rank_id_name, rank_id_name_len, &smartsim->client));
65*ad2e713eSRiccardo Balin   PetscCall(PetscLogEventEnd(FLUIDS_SmartRedis_Init, 0, 0, 0, 0));
667cd70835SJames Wright 
677cd70835SJames Wright   PetscCall(SmartSimTrainingSetup(user));
687cd70835SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
697cd70835SJames Wright }
70ec6e4151SJames Wright 
71ec6e4151SJames Wright PetscErrorCode SmartSimDataDestroy(SmartSimData smartsim) {
72ec6e4151SJames Wright   PetscFunctionBeginUser;
73ec6e4151SJames Wright   if (!smartsim) PetscFunctionReturn(PETSC_SUCCESS);
74ec6e4151SJames Wright 
75ff6b888aSJames Wright   PetscSmartRedisCall(DeleteCClient(&smartsim->client));
76ec6e4151SJames Wright   PetscCall(PetscFree(smartsim));
77ec6e4151SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
78ec6e4151SJames Wright }
79