xref: /honee/problems/sgs_dd_model.c (revision 62b7942e37b3edd7bd547f0acf1de11d9ff29152)
1*62b7942eSJames Wright // Copyright (c) 2017-2023, Lawrence Livermore National Security, LLC and other CEED contributors.
2*62b7942eSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*62b7942eSJames Wright //
4*62b7942eSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5*62b7942eSJames Wright //
6*62b7942eSJames Wright // This file is part of CEED:  http://github.com/ceed
7*62b7942eSJames Wright 
8*62b7942eSJames Wright #include "../qfunctions/sgs_dd_model.h"
9*62b7942eSJames Wright 
10*62b7942eSJames Wright #include <petscdmplex.h>
11*62b7942eSJames Wright 
12*62b7942eSJames Wright #include "../navierstokes.h"
13*62b7942eSJames Wright 
14*62b7942eSJames Wright // @brief B = A^T, A is NxM, B is MxN
15*62b7942eSJames Wright PetscErrorCode TransposeMatrix(const PetscScalar *A, PetscScalar *B, const PetscInt N, const PetscInt M) {
16*62b7942eSJames Wright   PetscFunctionBeginUser;
17*62b7942eSJames Wright   for (PetscInt i = 0; i < N; i++) {
18*62b7942eSJames Wright     for (PetscInt j = 0; j < M; j++) {
19*62b7942eSJames Wright       B[j * N + i] = A[i * M + j];
20*62b7942eSJames Wright     }
21*62b7942eSJames Wright   }
22*62b7942eSJames Wright   PetscFunctionReturn(0);
23*62b7942eSJames Wright }
24*62b7942eSJames Wright 
25*62b7942eSJames Wright // @brief Read neural network coefficients from file and put into context struct
26*62b7942eSJames Wright PetscErrorCode SGS_DD_ModelContextFill(MPI_Comm comm, char data_dir[PETSC_MAX_PATH_LEN], SGS_DDModelContext *psgsdd_ctx) {
27*62b7942eSJames Wright   SGS_DDModelContext sgsdd_ctx;
28*62b7942eSJames Wright   PetscInt           num_inputs = (*psgsdd_ctx)->num_inputs, num_outputs = (*psgsdd_ctx)->num_outputs, num_neurons = (*psgsdd_ctx)->num_neurons;
29*62b7942eSJames Wright   char               file_path[PETSC_MAX_PATH_LEN];
30*62b7942eSJames Wright   PetscScalar       *temp;
31*62b7942eSJames Wright 
32*62b7942eSJames Wright   PetscFunctionBeginUser;
33*62b7942eSJames Wright   {
34*62b7942eSJames Wright     SGS_DDModelContext sgsdd_temp;
35*62b7942eSJames Wright     PetscCall(PetscNew(&sgsdd_temp));
36*62b7942eSJames Wright     *sgsdd_temp                     = **psgsdd_ctx;
37*62b7942eSJames Wright     sgsdd_temp->offsets.bias1       = 0;
38*62b7942eSJames Wright     sgsdd_temp->offsets.bias2       = sgsdd_temp->offsets.bias1 + num_neurons;
39*62b7942eSJames Wright     sgsdd_temp->offsets.weight1     = sgsdd_temp->offsets.bias2 + num_neurons;
40*62b7942eSJames Wright     sgsdd_temp->offsets.weight2     = sgsdd_temp->offsets.weight1 + num_neurons * num_inputs;
41*62b7942eSJames Wright     sgsdd_temp->offsets.out_scaling = sgsdd_temp->offsets.weight2 + num_inputs * num_neurons;
42*62b7942eSJames Wright     PetscInt total_num_scalars      = sgsdd_temp->offsets.out_scaling + 2 * num_outputs;
43*62b7942eSJames Wright     sgsdd_temp->total_bytes         = sizeof(*sgsdd_ctx) + total_num_scalars * sizeof(sgsdd_ctx->data[0]);
44*62b7942eSJames Wright     PetscCall(PetscMalloc(sgsdd_temp->total_bytes, &sgsdd_ctx));
45*62b7942eSJames Wright     *sgsdd_ctx = *sgsdd_temp;
46*62b7942eSJames Wright     PetscCall(PetscFree(sgsdd_temp));
47*62b7942eSJames Wright   }
48*62b7942eSJames Wright 
49*62b7942eSJames Wright   PetscCall(PetscSNPrintf(file_path, sizeof file_path, "%s/%s", data_dir, "b1.dat"));
50*62b7942eSJames Wright   PetscCall(PHASTADatFileReadToArrayReal(comm, file_path, &sgsdd_ctx->data[sgsdd_ctx->offsets.bias1]));
51*62b7942eSJames Wright   PetscCall(PetscSNPrintf(file_path, sizeof file_path, "%s/%s", data_dir, "b2.dat"));
52*62b7942eSJames Wright   PetscCall(PHASTADatFileReadToArrayReal(comm, file_path, &sgsdd_ctx->data[sgsdd_ctx->offsets.bias2]));
53*62b7942eSJames Wright   PetscCall(PetscSNPrintf(file_path, sizeof file_path, "%s/%s", data_dir, "OutScaling.dat"));
54*62b7942eSJames Wright   PetscCall(PHASTADatFileReadToArrayReal(comm, file_path, &sgsdd_ctx->data[sgsdd_ctx->offsets.out_scaling]));
55*62b7942eSJames Wright 
56*62b7942eSJames Wright   {
57*62b7942eSJames Wright     PetscCall(PetscMalloc1(num_inputs * num_neurons, &temp));
58*62b7942eSJames Wright     PetscCall(PetscSNPrintf(file_path, sizeof file_path, "%s/%s", data_dir, "w1.dat"));
59*62b7942eSJames Wright     PetscCall(PHASTADatFileReadToArrayReal(comm, file_path, temp));
60*62b7942eSJames Wright     PetscCall(TransposeMatrix(temp, &sgsdd_ctx->data[sgsdd_ctx->offsets.weight1], num_inputs, num_neurons));
61*62b7942eSJames Wright     PetscCall(PetscFree(temp));
62*62b7942eSJames Wright   }
63*62b7942eSJames Wright   {
64*62b7942eSJames Wright     PetscCall(PetscMalloc1(num_outputs * num_neurons, &temp));
65*62b7942eSJames Wright     PetscCall(PetscSNPrintf(file_path, sizeof file_path, "%s/%s", data_dir, "w2.dat"));
66*62b7942eSJames Wright     PetscCall(PHASTADatFileReadToArrayReal(comm, file_path, temp));
67*62b7942eSJames Wright     PetscCall(TransposeMatrix(temp, &sgsdd_ctx->data[sgsdd_ctx->offsets.weight2], num_neurons, num_outputs));
68*62b7942eSJames Wright     PetscCall(PetscFree(temp));
69*62b7942eSJames Wright   }
70*62b7942eSJames Wright 
71*62b7942eSJames Wright   PetscCall(PetscFree(*psgsdd_ctx));
72*62b7942eSJames Wright   *psgsdd_ctx = sgsdd_ctx;
73*62b7942eSJames Wright   PetscFunctionReturn(0);
74*62b7942eSJames Wright }
75*62b7942eSJames Wright 
76*62b7942eSJames Wright PetscErrorCode SGS_DD_ModelSetup(Ceed ceed, User user, CeedData ceed_data, ProblemData *problem) {
77*62b7942eSJames Wright   PetscReal          alpha;
78*62b7942eSJames Wright   SGS_DDModelContext sgsdd_ctx;
79*62b7942eSJames Wright   MPI_Comm           comm                           = user->comm;
80*62b7942eSJames Wright   char               sgs_dd_dir[PETSC_MAX_PATH_LEN] = "./dd_sgs_data";
81*62b7942eSJames Wright   PetscFunctionBeginUser;
82*62b7942eSJames Wright 
83*62b7942eSJames Wright   PetscCall(PetscNew(&sgsdd_ctx));
84*62b7942eSJames Wright 
85*62b7942eSJames Wright   PetscOptionsBegin(comm, NULL, "SGS Data-Drive Model Options", NULL);
86*62b7942eSJames Wright   PetscCall(PetscOptionsReal("-sgs_model_dd_leakyrelu_alpha", "Slope parameter for Leaky ReLU activation function", NULL, alpha, &alpha, NULL));
87*62b7942eSJames Wright   PetscCall(PetscOptionsString("-sgs_model_dd_parameter_dir", "Path to directory with model parameters (weights, biases, etc.)", NULL, sgs_dd_dir,
88*62b7942eSJames Wright                                sgs_dd_dir, sizeof(sgs_dd_dir), NULL));
89*62b7942eSJames Wright   PetscOptionsEnd();
90*62b7942eSJames Wright 
91*62b7942eSJames Wright   sgsdd_ctx->num_layers  = 2;
92*62b7942eSJames Wright   sgsdd_ctx->num_inputs  = 6;
93*62b7942eSJames Wright   sgsdd_ctx->num_outputs = 6;
94*62b7942eSJames Wright   sgsdd_ctx->num_neurons = 20;
95*62b7942eSJames Wright   sgsdd_ctx->alpha       = alpha;
96*62b7942eSJames Wright 
97*62b7942eSJames Wright   // PetscCall(SGS_DD_ModelContextFill(comm, sgs_dd_dir, &sgsdd_ctx));
98*62b7942eSJames Wright 
99*62b7942eSJames Wright   PetscFunctionReturn(0);
100*62b7942eSJames Wright }
101