xref: /honee/src/smartsim/smartsim.c (revision 7cd70835138b7dacb5693d4e4e0578aebf9b5d9c)
1*7cd70835SJames Wright // Copyright (c) 2017-2023, Lawrence Livermore National Security, LLC and other CEED contributors.
2*7cd70835SJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*7cd70835SJames Wright //
4*7cd70835SJames Wright // SPDX-License-Identifier: BSD-2-Clause
5*7cd70835SJames Wright //
6*7cd70835SJames Wright // This file is part of CEED:  http://github.com/ceed
7*7cd70835SJames Wright // Based on the instructions from https://www.craylabs.org/docs/sr_integration.html and PHASTA implementation
8*7cd70835SJames Wright 
9*7cd70835SJames Wright #include "../../include/smartsim.h"
10*7cd70835SJames Wright 
11*7cd70835SJames Wright #include "../../navierstokes.h"
12*7cd70835SJames Wright 
13*7cd70835SJames Wright PetscErrorCode SmartRedisVerifyPutTensor(void *c_client, const char *name, const size_t name_length) {
14*7cd70835SJames Wright   bool does_exist = true;
15*7cd70835SJames Wright 
16*7cd70835SJames Wright   PetscFunctionBeginUser;
17*7cd70835SJames Wright   SmartRedisCall(tensor_exists(c_client, name, name_length, &does_exist));
18*7cd70835SJames Wright   PetscCheck(does_exist, PETSC_COMM_SELF, -1, "Tensor of name '%s' was not written to the database successfully", name);
19*7cd70835SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
20*7cd70835SJames Wright }
21*7cd70835SJames Wright 
22*7cd70835SJames Wright PetscErrorCode SmartSimTrainingSetup(User user) {
23*7cd70835SJames Wright   SmartSimData smartsim = user->smartsim;
24*7cd70835SJames Wright   PetscMPIInt  rank;
25*7cd70835SJames Wright   PetscReal    checkrun[2] = {1};
26*7cd70835SJames Wright   size_t       dim_2[1]    = {2};
27*7cd70835SJames Wright   PetscInt     num_ranks;
28*7cd70835SJames Wright 
29*7cd70835SJames Wright   PetscFunctionBeginUser;
30*7cd70835SJames Wright   PetscCallMPI(MPI_Comm_rank(user->comm, &rank));
31*7cd70835SJames Wright   PetscCallMPI(MPI_Comm_size(user->comm, &num_ranks));
32*7cd70835SJames Wright 
33*7cd70835SJames Wright   if (rank % smartsim->collocated_database_num_ranks == 0) {
34*7cd70835SJames Wright     // -- Send array that communicates when ML is done training
35*7cd70835SJames Wright     SmartRedisCall(put_tensor(smartsim->client, "check-run", 9, checkrun, dim_2, 1, SRTensorTypeDouble, SRMemLayoutContiguous));
36*7cd70835SJames Wright     PetscCall(SmartRedisVerifyPutTensor(smartsim->client, "check-run", 9));
37*7cd70835SJames Wright   }
38*7cd70835SJames Wright 
39*7cd70835SJames Wright   {  // -- Get minimum per-rank global vec size
40*7cd70835SJames Wright     PetscInt GlobalVecSize;
41*7cd70835SJames Wright     PetscCall(DMGetGlobalVectorInfo(user->dm, &GlobalVecSize, NULL, NULL));
42*7cd70835SJames Wright     PetscCallMPI(MPI_Allreduce(&GlobalVecSize, &smartsim->num_tensor_nodes, 1, MPIU_INT, MPI_MIN, user->comm));
43*7cd70835SJames Wright     smartsim->num_nodes_to_remove = GlobalVecSize - smartsim->num_tensor_nodes;
44*7cd70835SJames Wright   }
45*7cd70835SJames Wright 
46*7cd70835SJames Wright   // Determine the size of the training data arrays... somehow
47*7cd70835SJames Wright   if (rank % smartsim->collocated_database_num_ranks == 0) {
48*7cd70835SJames Wright     size_t   array_dims[2] = {smartsim->num_tensor_nodes, 6}, array_info_dim = 6;
49*7cd70835SJames Wright     PetscInt array_info[6] = {0}, num_features = 6;
50*7cd70835SJames Wright 
51*7cd70835SJames Wright     array_info[0] = array_dims[0];
52*7cd70835SJames Wright     array_info[1] = array_dims[1];
53*7cd70835SJames Wright     array_info[2] = num_features;
54*7cd70835SJames Wright     array_info[3] = num_ranks;
55*7cd70835SJames Wright     array_info[4] = smartsim->collocated_database_num_ranks;
56*7cd70835SJames Wright     array_info[5] = rank;
57*7cd70835SJames Wright 
58*7cd70835SJames Wright     SmartRedisCall(put_tensor(smartsim->client, "array_info", 10, array_info, &array_info_dim, 1, SRTensorTypeInt32, SRMemLayoutContiguous));
59*7cd70835SJames Wright     PetscCall(SmartRedisVerifyPutTensor(smartsim->client, "array_info", 10));
60*7cd70835SJames Wright   }
61*7cd70835SJames Wright   PetscFunctionReturn(0);
62*7cd70835SJames Wright }
63*7cd70835SJames Wright 
64*7cd70835SJames Wright PetscErrorCode SmartSimSetup(User user) {
65*7cd70835SJames Wright   PetscMPIInt rank;
66*7cd70835SJames Wright   size_t      rank_id_name_len;
67*7cd70835SJames Wright   PetscInt    num_orchestrator_nodes = 1;
68*7cd70835SJames Wright 
69*7cd70835SJames Wright   PetscFunctionBeginUser;
70*7cd70835SJames Wright   PetscCall(PetscNew(&user->smartsim));
71*7cd70835SJames Wright   SmartSimData smartsim = user->smartsim;
72*7cd70835SJames Wright 
73*7cd70835SJames Wright   smartsim->collocated_database_num_ranks = 1;
74*7cd70835SJames Wright   PetscOptionsBegin(user->comm, NULL, "Options for SmartSim integration", NULL);
75*7cd70835SJames Wright   PetscCall(PetscOptionsInt("-smartsim_collocated_database_num_ranks", "Number of ranks per collocated database instance", NULL,
76*7cd70835SJames Wright                             smartsim->collocated_database_num_ranks, &smartsim->collocated_database_num_ranks, NULL));
77*7cd70835SJames Wright   PetscOptionsEnd();
78*7cd70835SJames Wright 
79*7cd70835SJames Wright   PetscCall(PetscStrlen(smartsim->rank_id_name, &rank_id_name_len));
80*7cd70835SJames Wright   // Create prefix to be put on tensor names
81*7cd70835SJames Wright   PetscCallMPI(MPI_Comm_rank(user->comm, &rank));
82*7cd70835SJames Wright   PetscCall(PetscSNPrintf(smartsim->rank_id_name, sizeof smartsim->rank_id_name, "y.%d", rank));
83*7cd70835SJames Wright 
84*7cd70835SJames Wright   SmartRedisCall(SmartRedisCClient(num_orchestrator_nodes != 1, smartsim->rank_id_name, rank_id_name_len, &smartsim->client));
85*7cd70835SJames Wright 
86*7cd70835SJames Wright   PetscCall(SmartSimTrainingSetup(user));
87*7cd70835SJames Wright   PetscFunctionReturn(PETSC_SUCCESS);
88*7cd70835SJames Wright }
89