xref: /honee/src/smartsim/smartsim.c (revision 7ebeccb998f2e038b4992256a1377876ce929ed3)
1ae2b091fSJames Wright // SPDX-FileCopyrightText: Copyright (c) 2017-2024, HONEE contributors.
2ae2b091fSJames Wright // SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause
37cd70835SJames Wright // Based on the instructions from https://www.craylabs.org/docs/sr_integration.html and PHASTA implementation
47cd70835SJames Wright 
5149fb536SJames Wright #include <smartsim.h>
67cd70835SJames Wright 
7149fb536SJames Wright #include <navierstokes.h>
87cd70835SJames Wright 
9*7ebeccb9SJames Wright #define SMARTSIM_KEY "SmartSimData"
10797f7eedSJames Wright 
11*7ebeccb9SJames Wright static PetscErrorCode SmartSimDataDestroy(SmartSimData *smartsim) {
12*7ebeccb9SJames Wright   SmartSimData smartsim_ = *smartsim;
13*7ebeccb9SJames Wright   PetscFunctionBeginUser;
14*7ebeccb9SJames Wright   if (!smartsim_) PetscFunctionReturn(PETSC_SUCCESS);
15*7ebeccb9SJames Wright 
16*7ebeccb9SJames Wright   PetscCallSmartRedis(DeleteCClient(&smartsim_->client));
17*7ebeccb9SJames Wright   PetscCall(PetscFree(smartsim_));
18*7ebeccb9SJames Wright   *smartsim = NULL;
197cd70835SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
207cd70835SJames Wright }
217cd70835SJames Wright 
22797f7eedSJames Wright static PetscErrorCode SmartSimTrainingSetup(Honee honee) {
23*7ebeccb9SJames Wright   SmartSimData smartsim;
247cd70835SJames Wright   PetscMPIInt  rank;
257cd70835SJames Wright   PetscReal    checkrun[2] = {1};
267cd70835SJames Wright   size_t       dim_2[1]    = {2};
277cd70835SJames Wright 
287cd70835SJames Wright   PetscFunctionBeginUser;
29*7ebeccb9SJames Wright   PetscCall(HoneeGetSmartSimData(honee, &smartsim));
300c373b74SJames Wright   PetscCallMPI(MPI_Comm_rank(honee->comm, &rank));
317cd70835SJames Wright 
327cd70835SJames Wright   if (rank % smartsim->collocated_database_num_ranks == 0) {
337cd70835SJames Wright     // -- Send array that communicates when ML is done training
34ea615d4cSJames Wright     PetscCall(PetscLogEventBegin(HONEE_SmartRedis_Meta, 0, 0, 0, 0));
3543e9749fSJames Wright     PetscCallSmartRedis(put_tensor(smartsim->client, "check-run", 9, checkrun, dim_2, 1, SRTensorTypeDouble, SRMemLayoutContiguous));
367cd70835SJames Wright     PetscCall(SmartRedisVerifyPutTensor(smartsim->client, "check-run", 9));
37ea615d4cSJames Wright     PetscCall(PetscLogEventEnd(HONEE_SmartRedis_Meta, 0, 0, 0, 0));
387cd70835SJames Wright   }
39aa0b7f76SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
407cd70835SJames Wright }
417cd70835SJames Wright 
42*7ebeccb9SJames Wright static PetscErrorCode SmartSimSetup(Honee honee) {
437cd70835SJames Wright   PetscMPIInt  rank;
447cd70835SJames Wright   PetscInt     num_orchestrator_nodes = 1;
45*7ebeccb9SJames Wright   SmartSimData smartsim;
467cd70835SJames Wright 
477cd70835SJames Wright   PetscFunctionBeginUser;
48*7ebeccb9SJames Wright   PetscCall(PetscNew(&smartsim));
497cd70835SJames Wright 
507cd70835SJames Wright   smartsim->collocated_database_num_ranks = 1;
510c373b74SJames Wright   PetscOptionsBegin(honee->comm, NULL, "Options for SmartSim integration", NULL);
527cd70835SJames Wright   PetscCall(PetscOptionsInt("-smartsim_collocated_database_num_ranks", "Number of ranks per collocated database instance", NULL,
537cd70835SJames Wright                             smartsim->collocated_database_num_ranks, &smartsim->collocated_database_num_ranks, NULL));
547cd70835SJames Wright   PetscOptionsEnd();
557cd70835SJames Wright 
567cd70835SJames Wright   // Create prefix to be put on tensor names
570c373b74SJames Wright   PetscCallMPI(MPI_Comm_rank(honee->comm, &rank));
584fa1625aSJames Wright   PetscCall(PetscSNPrintf(smartsim->rank_id_name, sizeof(smartsim->rank_id_name), "y.%d", rank));
597cd70835SJames Wright 
60ea615d4cSJames Wright   PetscCall(PetscLogEventBegin(HONEE_SmartRedis_Init, 0, 0, 0, 0));
6143e9749fSJames Wright   PetscCallSmartRedis(SmartRedisCClient(num_orchestrator_nodes != 1, smartsim->rank_id_name, strlen(smartsim->rank_id_name), &smartsim->client));
62ea615d4cSJames Wright   PetscCall(PetscLogEventEnd(HONEE_SmartRedis_Init, 0, 0, 0, 0));
637cd70835SJames Wright 
64*7ebeccb9SJames Wright   PetscCall(PetscObjectContainerCompose((PetscObject)honee, SMARTSIM_KEY, smartsim, (PetscCtxDestroyFn *)SmartSimDataDestroy));
65*7ebeccb9SJames Wright 
660c373b74SJames Wright   PetscCall(SmartSimTrainingSetup(honee));
677cd70835SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
687cd70835SJames Wright }
69ec6e4151SJames Wright 
70*7ebeccb9SJames Wright PetscErrorCode HoneeGetSmartSimData(Honee honee, SmartSimData *smartsim) {
71*7ebeccb9SJames Wright   PetscFunctionBeginUser;
72*7ebeccb9SJames Wright   PetscCall(PetscObjectContainerQuery((PetscObject)honee, SMARTSIM_KEY, smartsim));
73*7ebeccb9SJames Wright   if (*smartsim == NULL) {
74*7ebeccb9SJames Wright     PetscCall(SmartSimSetup(honee));
75*7ebeccb9SJames Wright     PetscCall(PetscObjectContainerQuery((PetscObject)honee, SMARTSIM_KEY, smartsim));
76*7ebeccb9SJames Wright     PetscCheck(*smartsim, honee->comm, PETSC_ERR_ARG_WRONGSTATE, "SmartSimData struct is not in Honee after setup.");
77*7ebeccb9SJames Wright   }
78*7ebeccb9SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
79*7ebeccb9SJames Wright }
80*7ebeccb9SJames Wright 
81797f7eedSJames Wright PetscErrorCode SmartRedisVerifyPutTensor(void *c_client, const char *name, const size_t name_length) {
82797f7eedSJames Wright   bool does_exist = true;
83ec6e4151SJames Wright 
84797f7eedSJames Wright   PetscFunctionBeginUser;
85797f7eedSJames Wright   PetscCall(PetscLogEventBegin(HONEE_SmartRedis_Meta, 0, 0, 0, 0));
86797f7eedSJames Wright   PetscCallSmartRedis(tensor_exists(c_client, name, name_length, &does_exist));
87797f7eedSJames Wright   PetscCheck(does_exist, PETSC_COMM_SELF, -1, "Tensor of name '%s' was not written to the database successfully", name);
88797f7eedSJames Wright   PetscCall(PetscLogEventEnd(HONEE_SmartRedis_Meta, 0, 0, 0, 0));
89ec6e4151SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
90ec6e4151SJames Wright }
91