1*62b7942eSJames Wright // Copyright (c) 2017-2023, Lawrence Livermore National Security, LLC and other CEED contributors. 2*62b7942eSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*62b7942eSJames Wright // 4*62b7942eSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5*62b7942eSJames Wright // 6*62b7942eSJames Wright // This file is part of CEED: http://github.com/ceed 7*62b7942eSJames Wright 8*62b7942eSJames Wright #include "../qfunctions/sgs_dd_model.h" 9*62b7942eSJames Wright 10*62b7942eSJames Wright #include <petscdmplex.h> 11*62b7942eSJames Wright 12*62b7942eSJames Wright #include "../navierstokes.h" 13*62b7942eSJames Wright 14*62b7942eSJames Wright // @brief B = A^T, A is NxM, B is MxN 15*62b7942eSJames Wright PetscErrorCode TransposeMatrix(const PetscScalar *A, PetscScalar *B, const PetscInt N, const PetscInt M) { 16*62b7942eSJames Wright PetscFunctionBeginUser; 17*62b7942eSJames Wright for (PetscInt i = 0; i < N; i++) { 18*62b7942eSJames Wright for (PetscInt j = 0; j < M; j++) { 19*62b7942eSJames Wright B[j * N + i] = A[i * M + j]; 20*62b7942eSJames Wright } 21*62b7942eSJames Wright } 22*62b7942eSJames Wright PetscFunctionReturn(0); 23*62b7942eSJames Wright } 24*62b7942eSJames Wright 25*62b7942eSJames Wright // @brief Read neural network coefficients from file and put into context struct 26*62b7942eSJames Wright PetscErrorCode SGS_DD_ModelContextFill(MPI_Comm comm, char data_dir[PETSC_MAX_PATH_LEN], SGS_DDModelContext *psgsdd_ctx) { 27*62b7942eSJames Wright SGS_DDModelContext sgsdd_ctx; 28*62b7942eSJames Wright PetscInt num_inputs = (*psgsdd_ctx)->num_inputs, num_outputs = (*psgsdd_ctx)->num_outputs, num_neurons = (*psgsdd_ctx)->num_neurons; 29*62b7942eSJames Wright char file_path[PETSC_MAX_PATH_LEN]; 30*62b7942eSJames Wright PetscScalar *temp; 31*62b7942eSJames Wright 32*62b7942eSJames Wright PetscFunctionBeginUser; 33*62b7942eSJames Wright { 34*62b7942eSJames Wright SGS_DDModelContext sgsdd_temp; 35*62b7942eSJames Wright PetscCall(PetscNew(&sgsdd_temp)); 36*62b7942eSJames Wright *sgsdd_temp = **psgsdd_ctx; 37*62b7942eSJames Wright sgsdd_temp->offsets.bias1 = 0; 38*62b7942eSJames Wright sgsdd_temp->offsets.bias2 = sgsdd_temp->offsets.bias1 + num_neurons; 39*62b7942eSJames Wright sgsdd_temp->offsets.weight1 = sgsdd_temp->offsets.bias2 + num_neurons; 40*62b7942eSJames Wright sgsdd_temp->offsets.weight2 = sgsdd_temp->offsets.weight1 + num_neurons * num_inputs; 41*62b7942eSJames Wright sgsdd_temp->offsets.out_scaling = sgsdd_temp->offsets.weight2 + num_inputs * num_neurons; 42*62b7942eSJames Wright PetscInt total_num_scalars = sgsdd_temp->offsets.out_scaling + 2 * num_outputs; 43*62b7942eSJames Wright sgsdd_temp->total_bytes = sizeof(*sgsdd_ctx) + total_num_scalars * sizeof(sgsdd_ctx->data[0]); 44*62b7942eSJames Wright PetscCall(PetscMalloc(sgsdd_temp->total_bytes, &sgsdd_ctx)); 45*62b7942eSJames Wright *sgsdd_ctx = *sgsdd_temp; 46*62b7942eSJames Wright PetscCall(PetscFree(sgsdd_temp)); 47*62b7942eSJames Wright } 48*62b7942eSJames Wright 49*62b7942eSJames Wright PetscCall(PetscSNPrintf(file_path, sizeof file_path, "%s/%s", data_dir, "b1.dat")); 50*62b7942eSJames Wright PetscCall(PHASTADatFileReadToArrayReal(comm, file_path, &sgsdd_ctx->data[sgsdd_ctx->offsets.bias1])); 51*62b7942eSJames Wright PetscCall(PetscSNPrintf(file_path, sizeof file_path, "%s/%s", data_dir, "b2.dat")); 52*62b7942eSJames Wright PetscCall(PHASTADatFileReadToArrayReal(comm, file_path, &sgsdd_ctx->data[sgsdd_ctx->offsets.bias2])); 53*62b7942eSJames Wright PetscCall(PetscSNPrintf(file_path, sizeof file_path, "%s/%s", data_dir, "OutScaling.dat")); 54*62b7942eSJames Wright PetscCall(PHASTADatFileReadToArrayReal(comm, file_path, &sgsdd_ctx->data[sgsdd_ctx->offsets.out_scaling])); 55*62b7942eSJames Wright 56*62b7942eSJames Wright { 57*62b7942eSJames Wright PetscCall(PetscMalloc1(num_inputs * num_neurons, &temp)); 58*62b7942eSJames Wright PetscCall(PetscSNPrintf(file_path, sizeof file_path, "%s/%s", data_dir, "w1.dat")); 59*62b7942eSJames Wright PetscCall(PHASTADatFileReadToArrayReal(comm, file_path, temp)); 60*62b7942eSJames Wright PetscCall(TransposeMatrix(temp, &sgsdd_ctx->data[sgsdd_ctx->offsets.weight1], num_inputs, num_neurons)); 61*62b7942eSJames Wright PetscCall(PetscFree(temp)); 62*62b7942eSJames Wright } 63*62b7942eSJames Wright { 64*62b7942eSJames Wright PetscCall(PetscMalloc1(num_outputs * num_neurons, &temp)); 65*62b7942eSJames Wright PetscCall(PetscSNPrintf(file_path, sizeof file_path, "%s/%s", data_dir, "w2.dat")); 66*62b7942eSJames Wright PetscCall(PHASTADatFileReadToArrayReal(comm, file_path, temp)); 67*62b7942eSJames Wright PetscCall(TransposeMatrix(temp, &sgsdd_ctx->data[sgsdd_ctx->offsets.weight2], num_neurons, num_outputs)); 68*62b7942eSJames Wright PetscCall(PetscFree(temp)); 69*62b7942eSJames Wright } 70*62b7942eSJames Wright 71*62b7942eSJames Wright PetscCall(PetscFree(*psgsdd_ctx)); 72*62b7942eSJames Wright *psgsdd_ctx = sgsdd_ctx; 73*62b7942eSJames Wright PetscFunctionReturn(0); 74*62b7942eSJames Wright } 75*62b7942eSJames Wright 76*62b7942eSJames Wright PetscErrorCode SGS_DD_ModelSetup(Ceed ceed, User user, CeedData ceed_data, ProblemData *problem) { 77*62b7942eSJames Wright PetscReal alpha; 78*62b7942eSJames Wright SGS_DDModelContext sgsdd_ctx; 79*62b7942eSJames Wright MPI_Comm comm = user->comm; 80*62b7942eSJames Wright char sgs_dd_dir[PETSC_MAX_PATH_LEN] = "./dd_sgs_data"; 81*62b7942eSJames Wright PetscFunctionBeginUser; 82*62b7942eSJames Wright 83*62b7942eSJames Wright PetscCall(PetscNew(&sgsdd_ctx)); 84*62b7942eSJames Wright 85*62b7942eSJames Wright PetscOptionsBegin(comm, NULL, "SGS Data-Drive Model Options", NULL); 86*62b7942eSJames Wright PetscCall(PetscOptionsReal("-sgs_model_dd_leakyrelu_alpha", "Slope parameter for Leaky ReLU activation function", NULL, alpha, &alpha, NULL)); 87*62b7942eSJames Wright PetscCall(PetscOptionsString("-sgs_model_dd_parameter_dir", "Path to directory with model parameters (weights, biases, etc.)", NULL, sgs_dd_dir, 88*62b7942eSJames Wright sgs_dd_dir, sizeof(sgs_dd_dir), NULL)); 89*62b7942eSJames Wright PetscOptionsEnd(); 90*62b7942eSJames Wright 91*62b7942eSJames Wright sgsdd_ctx->num_layers = 2; 92*62b7942eSJames Wright sgsdd_ctx->num_inputs = 6; 93*62b7942eSJames Wright sgsdd_ctx->num_outputs = 6; 94*62b7942eSJames Wright sgsdd_ctx->num_neurons = 20; 95*62b7942eSJames Wright sgsdd_ctx->alpha = alpha; 96*62b7942eSJames Wright 97*62b7942eSJames Wright // PetscCall(SGS_DD_ModelContextFill(comm, sgs_dd_dir, &sgsdd_ctx)); 98*62b7942eSJames Wright 99*62b7942eSJames Wright PetscFunctionReturn(0); 100*62b7942eSJames Wright } 101