// SPDX-FileCopyrightText: Copyright (c) 2017-2024, HONEE contributors.
// SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause

/// @file
/// Utility functions for setting up Blasius Boundary Layer

#include "../qfunctions/blasius.h"

#include <ceed.h>
#include <petscdm.h>
#include <petscdt.h>

#include <differential_filter.h>
#include <navierstokes.h>
#include "stg_shur14.h"

PetscErrorCode CompressibleBlasiusResidual(SNES snes, Vec X, Vec R, void *ctx) {
  const BlasiusContext  blasius = (BlasiusContext)ctx;
  const PetscScalar    *Tf, *Th;  // Chebyshev coefficients
  PetscScalar          *r, f[4], h[4];
  PetscInt              N       = blasius->n_cheb;
  State                 S_infty = blasius->S_infty;
  CeedScalar            U_infty = Norm3(S_infty.Y.velocity);
  NewtonianIGProperties gas     = blasius->newt_ctx.gas;

  PetscFunctionBeginUser;
  PetscScalar Ma = Mach(gas, S_infty.Y.temperature, U_infty), Pr = Prandtl(gas), gamma = HeatCapacityRatio(gas);

  PetscCall(VecGetArrayRead(X, &Tf));
  Th = Tf + N;
  PetscCall(VecGetArray(R, &r));

  // Left boundary conditions f = f' = 0
  ChebyshevEval(N, Tf, -1., blasius->eta_max, f);
  r[0] = f[0];
  r[1] = f[1];

  // f - right end boundary condition
  ChebyshevEval(N, Tf, 1., blasius->eta_max, f);
  r[2] = f[1] - 1.;

  for (int i = 0; i < N - 3; i++) {
    ChebyshevEval(N, Tf, blasius->X[i], blasius->eta_max, f);
    ChebyshevEval(N - 1, Th, blasius->X[i], blasius->eta_max, h);
    // mu and rho generally depend on h.
    // We naively assume constant mu.
    // For an ideal gas at constant pressure, density is inversely proportional to enthalpy.
    // The *_tilde values are *relative* to their freestream values, and we proved first derivatives here.
    const PetscScalar mu_tilde[2]     = {1, 0};
    const PetscScalar rho_tilde[2]    = {1 / h[0], -h[1] / PetscSqr(h[0])};
    const PetscScalar mu_rho_tilde[2] = {
        mu_tilde[0] * rho_tilde[0],
        mu_tilde[1] * rho_tilde[0] + mu_tilde[0] * rho_tilde[1],
    };
    r[3 + i]     = 2 * (mu_rho_tilde[0] * f[3] + mu_rho_tilde[1] * f[2]) + f[2] * f[0];
    r[N + 2 + i] = (mu_rho_tilde[0] * h[2] + mu_rho_tilde[1] * h[1]) + Pr * f[0] * h[1] + Pr * (gamma - 1) * mu_rho_tilde[0] * PetscSqr(Ma * f[2]);
  }

  // h - left end boundary condition
  ChebyshevEval(N - 1, Th, -1., blasius->eta_max, h);
  r[N] = h[0] - blasius->T_wall / S_infty.Y.temperature;

  // h - right end boundary condition
  ChebyshevEval(N - 1, Th, 1., blasius->eta_max, h);
  r[N + 1] = h[0] - 1.;

  // Restore vectors
  PetscCall(VecRestoreArrayRead(X, &Tf));
  PetscCall(VecRestoreArray(R, &r));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode ComputeChebyshevCoefficients(BlasiusContext blasius) {
  SNES                snes;
  Vec                 sol, res;
  PetscReal          *w;
  PetscInt            N = blasius->n_cheb;
  SNESConvergedReason reason;
  const PetscScalar  *cheb_coefs;

  PetscFunctionBeginUser;
  // Allocate memory
  PetscCall(PetscMalloc2(N - 3, &blasius->X, N - 3, &w));
  PetscCall(PetscDTGaussQuadrature(N - 3, -1., 1., blasius->X, w));

  // Snes solve
  PetscCall(SNESCreate(PETSC_COMM_SELF, &snes));
  PetscCall(VecCreate(PETSC_COMM_SELF, &sol));
  PetscCall(VecSetSizes(sol, PETSC_DECIDE, 2 * N - 1));
  PetscCall(VecSetFromOptions(sol));
  // Constant relative enthalpy 1 as initial guess
  PetscCall(VecSetValue(sol, N, 1., INSERT_VALUES));
  PetscCall(VecDuplicate(sol, &res));
  PetscCall(SNESSetFunction(snes, res, CompressibleBlasiusResidual, blasius));
  PetscCall(SNESSetOptionsPrefix(snes, "chebyshev_"));
  PetscCall(SNESSetFromOptions(snes));
  PetscCall(SNESSolve(snes, NULL, sol));
  PetscCall(SNESGetConvergedReason(snes, &reason));
  PetscCheck(reason >= 0, PETSC_COMM_WORLD, PETSC_ERR_CONV_FAILED, "The Chebyshev solve failed.");

  // Assign Chebyshev coefficients
  PetscCall(VecGetArrayRead(sol, &cheb_coefs));
  for (int i = 0; i < N; i++) blasius->Tf_cheb[i] = cheb_coefs[i];
  for (int i = 0; i < N - 1; i++) blasius->Th_cheb[i] = cheb_coefs[i + N];

  // Destroy objects
  PetscCall(PetscFree2(blasius->X, w));
  PetscCall(VecDestroy(&sol));
  PetscCall(VecDestroy(&res));
  PetscCall(SNESDestroy(&snes));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode GetYNodeLocs(const MPI_Comm comm, const char path[PETSC_MAX_PATH_LEN], PetscReal **pynodes, PetscInt *nynodes) {
  int            ndims, dims[2] = {0.};
  FILE          *fp;
  const PetscInt char_array_len = 512;
  char           line[char_array_len];
  PetscReal     *node_locs;

  PetscFunctionBeginUser;
  PetscCall(PetscFOpen(comm, path, "r", &fp));
  PetscCall(PetscSynchronizedFGets(comm, fp, char_array_len, line));
  {
    char **array;

    PetscCall(PetscStrToArray(line, ' ', &ndims, &array));
    for (PetscInt i = 0; i < ndims; i++) dims[i] = atoi(array[i]);
    PetscCall(PetscStrToArrayDestroy(ndims, array));
  }
  if (ndims < 2) dims[1] = 1;  // Assume 1 column of data is not otherwise specified
  *nynodes = dims[0];
  PetscCall(PetscMalloc1(*nynodes, &node_locs));

  for (PetscInt i = 0; i < dims[0]; i++) {
    char **array;

    PetscCall(PetscSynchronizedFGets(comm, fp, char_array_len, line));
    PetscCall(PetscStrToArray(line, ' ', &ndims, &array));
    PetscCheck(ndims == dims[1], comm, PETSC_ERR_FILE_UNEXPECTED,
               "Line %" PetscInt_FMT " of %s does not contain correct number of columns (%d instead of %d)", i, path, ndims, dims[1]);

    node_locs[i] = (PetscReal)atof(array[0]);
    PetscCall(PetscStrToArrayDestroy(ndims, array));
  }
  PetscCall(PetscFClose(comm, fp));
  *pynodes = node_locs;
  PetscFunctionReturn(PETSC_SUCCESS);
}

/* \brief Modify the domain and mesh for blasius
 *
 * Modifies mesh such that `N` elements are within `refine_height` with a geometric growth ratio of `growth`. Excess elements are then distributed
 * linearly in logspace to the top surface.
 *
 * The top surface is also angled downwards, so that it may be used as an outflow.
 * It's angle is controlled by `top_angle` (in units of degrees).
 *
 * If `node_locs` is not NULL, then the nodes will be placed at `node_locs` locations.
 * If it is NULL, then the modified coordinate values will be set in the array, along with `num_node_locs`.
 */
static PetscErrorCode ModifyMesh(MPI_Comm comm, DM dm, PetscReal growth, PetscInt N, PetscReal refine_height, PetscReal top_angle,
                                 PetscReal *node_locs[], PetscInt *num_node_locs) {
  PetscInt     narr, ncoords, dim;
  PetscReal    domain_min[3], domain_max[3], domain_size[3];
  PetscScalar *arr_coords;
  Vec          vec_coords;

  PetscFunctionBeginUser;
  PetscCall(DMGetDimension(dm, &dim));
  PetscReal angle_coeff = tan(top_angle * (M_PI / 180));
  // Get domain boundary information
  PetscCall(DMGetBoundingBox(dm, domain_min, domain_max));
  for (PetscInt i = 0; i < 3; i++) domain_size[i] = domain_max[i] - domain_min[i];

  // Get coords array from DM
  PetscCall(DMGetCoordinatesLocal(dm, &vec_coords));
  PetscCall(VecGetLocalSize(vec_coords, &narr));
  PetscCall(VecGetArray(vec_coords, &arr_coords));

  PetscScalar(*coords)[dim] = (PetscScalar(*)[dim])arr_coords;
  ncoords                   = narr / dim;

  // Get mesh information
  PetscInt nmax = 3, faces[3];
  PetscCall(PetscOptionsGetIntArray(NULL, NULL, "-dm_plex_box_faces", faces, &nmax, NULL));
  // Get element size of the box mesh, for indexing each node
  const PetscReal dybox = domain_size[1] / faces[1];

  if (!*node_locs) {
    // Calculate the first element height
    PetscReal dy1 = refine_height * (growth - 1) / (pow(growth, N) - 1);

    // Calculate log of sizing outside BL
    PetscReal logdy = (log(domain_max[1]) - log(refine_height)) / (faces[1] - N);

    *num_node_locs = faces[1] + 1;
    PetscReal *temp_node_locs;
    PetscCall(PetscMalloc1(*num_node_locs, &temp_node_locs));

    for (PetscInt i = 0; i < ncoords; i++) {
      PetscInt y_box_index = round(coords[i][1] / dybox);
      if (y_box_index <= N) {
        coords[i][1] =
            (1 - (coords[i][0] - domain_min[0]) * angle_coeff / domain_max[1]) * dy1 * (pow(growth, coords[i][1] / dybox) - 1) / (growth - 1);
      } else {
        PetscInt j   = y_box_index - N;
        coords[i][1] = (1 - (coords[i][0] - domain_min[0]) * angle_coeff / domain_max[1]) * exp(log(refine_height) + logdy * j);
      }
      if (coords[i][0] == domain_min[0] && coords[i][2] == domain_min[2]) temp_node_locs[y_box_index] = coords[i][1];
    }

    *node_locs = temp_node_locs;
  } else {
    PetscCheck(*num_node_locs >= faces[1] + 1, comm, PETSC_ERR_FILE_UNEXPECTED,
               "The y_node_locs_path has too few locations; There are %" PetscInt_FMT " + 1 nodes, but only %" PetscInt_FMT " locations given",
               faces[1] + 1, *num_node_locs);
    if (*num_node_locs > faces[1] + 1) {
      PetscCall(PetscPrintf(comm,
                            "WARNING: y_node_locs_path has more locations (%" PetscInt_FMT ") "
                            "than the mesh has nodes (%" PetscInt_FMT "). This maybe unintended.\n",
                            *num_node_locs, faces[1] + 1));
    }
    PetscScalar max_y = (*node_locs)[faces[1]];

    for (PetscInt i = 0; i < ncoords; i++) {
      // Determine which y-node we're at
      PetscInt y_box_index = round(coords[i][1] / dybox);
      coords[i][1]         = (1 - (coords[i][0] - domain_min[0]) * angle_coeff / max_y) * (*node_locs)[y_box_index];
    }
  }

  PetscCall(VecRestoreArray(vec_coords, &arr_coords));
  PetscCall(DMSetCoordinatesLocal(dm, vec_coords));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode BlasiusInflowBCSetup_CreateIFunctionQF(BCDefinition bc_def, CeedQFunction *qf) {
  HoneeBCStruct honee_bc;

  PetscFunctionBeginUser;
  PetscCall(BCDefinitionGetContext(bc_def, &honee_bc));
  PetscCall(HoneeBCCreateIFunctionQF(bc_def, Blasius_Inflow, Blasius_Inflow_loc, honee_bc->qfctx, qf));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode BlasiusInflowBCSetup_CreateIJacobianQF(BCDefinition bc_def, CeedQFunction *qf) {
  HoneeBCStruct honee_bc;

  PetscFunctionBeginUser;
  PetscCall(BCDefinitionGetContext(bc_def, &honee_bc));
  PetscCall(HoneeBCCreateIJacobianQF(bc_def, Blasius_Inflow_Jacobian, Blasius_Inflow_Jacobian_loc, honee_bc->qfctx, qf));
  PetscFunctionReturn(PETSC_SUCCESS);
}

PetscErrorCode NS_BLASIUS(ProblemData problem, DM dm, void *ctx) {
  Honee                    honee   = *(Honee *)ctx;
  MPI_Comm                 comm    = honee->comm;
  Ceed                     ceed    = honee->ceed;
  PetscBool                use_stg = PETSC_FALSE;
  BlasiusContext           blasius_ctx;
  NewtonianIdealGasContext newtonian_ig_ctx;
  CeedQFunctionContext     blasius_qfctx;

  PetscFunctionBeginUser;
  PetscCall(NS_NEWTONIAN_IG(problem, dm, ctx));
  PetscCall(PetscNew(&blasius_ctx));

  // ------------------------------------------------------
  //               SET UP Blasius
  // ------------------------------------------------------
  problem->ics = (HoneeQFSpec){.qf_func_ptr = ICsBlasius, .qf_loc = ICsBlasius_loc, .qfctx = problem->ics.qfctx};

  CeedScalar U_inf                                = 40;           // m/s
  CeedScalar T_inf                                = 288.;         // K
  CeedScalar T_wall                               = 288.;         // K
  CeedScalar delta0                               = 4.2e-3;       // m
  CeedScalar P_inf                                = 1.01e5;       // Pa
  PetscInt   N                                    = 20;           // Number of Chebyshev terms
  PetscBool  weakT                                = PETSC_FALSE;  // weak density or temperature
  PetscReal  mesh_refine_height                   = 5.9e-4;       // m
  PetscReal  mesh_growth                          = 1.08;         // [-]
  PetscInt   mesh_Ndelta                          = 45;           // [-]
  PetscReal  mesh_top_angle                       = 5;            // degrees
  char       mesh_ynodes_path[PETSC_MAX_PATH_LEN] = "";
  PetscBool  P0_set;

  PetscOptionsBegin(comm, NULL, "Options for BLASIUS problem", NULL);
  PetscCall(PetscOptionsBool("-weakT", "Change from rho weak to T weak at inflow", NULL, weakT, &weakT, NULL));
  PetscCall(PetscOptionsScalar("-velocity_infinity", "Velocity at boundary layer edge", NULL, U_inf, &U_inf, NULL));
  PetscCall(PetscOptionsScalar("-temperature_infinity", "Temperature at boundary layer edge", NULL, T_inf, &T_inf, NULL));
  PetscCall(PetscOptionsHasName(NULL, NULL, "-P0", &P0_set));  // For maintaining behavior of -P0 flag (which is deprecated)
  PetscCall(
      PetscOptionsDeprecated("-P0", "-pressure_infinity", "libCEED 0.12.0",
                             "Use -pressure_infinity to set pressure at boundary layer edge and -idl_pressure to set the IDL reference pressure"));
  PetscCall(PetscOptionsScalar("-pressure_infinity", "Pressure at boundary layer edge", NULL, P_inf, &P_inf, NULL));
  PetscCall(PetscOptionsScalar("-temperature_wall", "Temperature at wall", NULL, T_wall, &T_wall, NULL));
  PetscCall(PetscOptionsScalar("-delta0", "Boundary layer height at inflow", NULL, delta0, &delta0, NULL));
  PetscCall(PetscOptionsInt("-n_chebyshev", "Number of Chebyshev terms", NULL, N, &N, NULL));
  PetscCheck(3 <= N && N <= BLASIUS_MAX_N_CHEBYSHEV, comm, PETSC_ERR_ARG_OUTOFRANGE, "-n_chebyshev %" PetscInt_FMT " must be in range [3, %d]", N,
             BLASIUS_MAX_N_CHEBYSHEV);
  if (honee->app_ctx->mesh_transform_type == MESH_TRANSFORM_PLATEMESH) {
    PetscCall(PetscOptionsBoundedInt("-platemesh_Ndelta", "Velocity at boundary layer edge", NULL, mesh_Ndelta, &mesh_Ndelta, NULL, 1));
    PetscCall(PetscOptionsScalar("-platemesh_refine_height", "Height of boundary layer mesh refinement", NULL, mesh_refine_height,
                                 &mesh_refine_height, NULL));
    PetscCall(PetscOptionsScalar("-platemesh_growth", "Geometric growth rate of boundary layer mesh", NULL, mesh_growth, &mesh_growth, NULL));
    PetscCall(
        PetscOptionsScalar("-platemesh_top_angle", "Geometric top_angle rate of boundary layer mesh", NULL, mesh_top_angle, &mesh_top_angle, NULL));
    PetscCall(PetscOptionsString("-platemesh_y_node_locs_path",
                                 "Path to file with y node locations. "
                                 "If empty, will use the algorithmic mesh warping.",
                                 NULL, mesh_ynodes_path, mesh_ynodes_path, sizeof(mesh_ynodes_path), NULL));
  }
  PetscCall(PetscOptionsBool("-stg_use", "Use STG inflow boundary condition", NULL, use_stg, &use_stg, NULL));
  PetscOptionsEnd();

  Units units = honee->units;

  T_inf *= units->Kelvin;
  T_wall *= units->Kelvin;
  P_inf *= units->Pascal;
  U_inf *= units->meter / units->second;
  delta0 *= units->meter;

  if (honee->app_ctx->mesh_transform_type == MESH_TRANSFORM_PLATEMESH) {
    PetscReal *mesh_ynodes  = NULL;
    PetscInt   mesh_nynodes = 0;
    if (strcmp(mesh_ynodes_path, "")) PetscCall(GetYNodeLocs(comm, mesh_ynodes_path, &mesh_ynodes, &mesh_nynodes));
    PetscCall(ModifyMesh(comm, dm, mesh_growth, mesh_Ndelta, mesh_refine_height, mesh_top_angle, &mesh_ynodes, &mesh_nynodes));
    PetscCall(PetscFree(mesh_ynodes));
  }

  // Some properties depend on parameters from NewtonianIdealGas
  PetscCallCeed(ceed, CeedQFunctionContextGetData(problem->apply_vol_rhs.qfctx, CEED_MEM_HOST, &newtonian_ig_ctx));

  StatePrimitive Y_inf = {
      .pressure = P_inf, .velocity = {U_inf, 0, 0},
           .temperature = T_inf
  };
  State S_infty = StateFromPrimitive(newtonian_ig_ctx->gas, Y_inf);

  blasius_ctx->weakT    = weakT;
  blasius_ctx->T_wall   = T_wall;
  blasius_ctx->delta0   = delta0;
  blasius_ctx->S_infty  = S_infty;
  blasius_ctx->n_cheb   = N;
  blasius_ctx->implicit = honee->phys->implicit;
  if (P0_set) newtonian_ig_ctx->idl_pressure = P_inf;  // For maintaining behavior of -P0 flag (which is deprecated)
  blasius_ctx->newt_ctx = *newtonian_ig_ctx;

  {
    PetscReal domain_min[3], domain_max[3];
    PetscCall(DMGetBoundingBox(dm, domain_min, domain_max));
    blasius_ctx->x_inflow = domain_min[0];
    blasius_ctx->eta_max  = 5 * domain_max[1] / blasius_ctx->delta0;
  }
  PetscBool diff_filter_mms = PETSC_FALSE;
  PetscCall(PetscOptionsGetBool(NULL, NULL, "-diff_filter_mms", &diff_filter_mms, NULL));
  if (!use_stg && !diff_filter_mms) PetscCall(ComputeChebyshevCoefficients(blasius_ctx));

  PetscCallCeed(ceed, CeedQFunctionContextRestoreData(problem->apply_vol_rhs.qfctx, &newtonian_ig_ctx));

  PetscCallCeed(ceed, CeedQFunctionContextCreate(honee->ceed, &blasius_qfctx));
  PetscCallCeed(ceed, CeedQFunctionContextSetData(blasius_qfctx, CEED_MEM_HOST, CEED_USE_POINTER, sizeof(*blasius_ctx), blasius_ctx));
  PetscCallCeed(ceed, CeedQFunctionContextSetDataDestroy(blasius_qfctx, CEED_MEM_HOST, FreeContextPetsc));

  PetscCallCeed(ceed, CeedQFunctionContextDestroy(&problem->ics.qfctx));
  problem->ics.qfctx = blasius_qfctx;
  if (use_stg) {
    PetscCall(SetupStg(comm, dm, problem, honee, weakT, S_infty.Y.temperature, S_infty.Y.pressure));
  } else if (diff_filter_mms) {
    PetscCall(DifferentialFilterMmsICSetup(honee));
  } else {
    PetscCheck((honee->phys->state_var == STATEVAR_CONSERVATIVE) || (honee->app_ctx->test_type == TESTTYPE_DIFF_FILTER), honee->comm,
               PETSC_ERR_ARG_INCOMP, "Can only use conservative variables with Blasius and weak inflow");
    for (PetscCount b = 0; b < problem->num_bc_defs; b++) {
      BCDefinition bc_def = problem->bc_defs[b];
      const char  *name;

      PetscCall(BCDefinitionGetInfo(bc_def, &name, NULL, NULL));
      if (!strcmp(name, "inflow")) {
        HoneeBCStruct honee_bc;

        PetscCall(PetscNew(&honee_bc));
        PetscCallCeed(ceed, CeedQFunctionContextReferenceCopy(blasius_qfctx, &honee_bc->qfctx));
        honee_bc->honee              = honee;
        honee_bc->num_comps_jac_data = 0;
        PetscCall(BCDefinitionSetContext(bc_def, HoneeBCDestroy, honee_bc));

        PetscCall(BCDefinitionSetIFunction(bc_def, BlasiusInflowBCSetup_CreateIFunctionQF, HoneeBCAddIFunctionOp));
        PetscCall(BCDefinitionSetIJacobian(bc_def, BlasiusInflowBCSetup_CreateIJacobianQF, HoneeBCAddIJacobianOp));
      }
    }
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}
