#include <petsc/private/petscfeimpl.h> /*I "petscfe.h" I*/

static PetscErrorCode PetscSpaceSetFromOptions_Polynomial(PetscSpace sp, PetscOptionItems PetscOptionsObject)
{
  PetscSpace_Poly *poly = (PetscSpace_Poly *)sp->data;

  PetscFunctionBegin;
  PetscOptionsHeadBegin(PetscOptionsObject, "PetscSpace polynomial options");
  PetscCall(PetscOptionsBool("-petscspace_poly_tensor", "Use the tensor product polynomials", "PetscSpacePolynomialSetTensor", poly->tensor, &poly->tensor, NULL));
  PetscOptionsHeadEnd();
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode PetscSpacePolynomialView_Ascii(PetscSpace sp, PetscViewer v)
{
  PetscSpace_Poly *poly = (PetscSpace_Poly *)sp->data;

  PetscFunctionBegin;
  PetscCall(PetscViewerASCIIPrintf(v, "%s space of degree %" PetscInt_FMT "\n", poly->tensor ? "Tensor polynomial" : "Polynomial", sp->degree));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode PetscSpaceView_Polynomial(PetscSpace sp, PetscViewer viewer)
{
  PetscBool isascii;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(sp, PETSCSPACE_CLASSID, 1);
  PetscValidHeaderSpecific(viewer, PETSC_VIEWER_CLASSID, 2);
  PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
  if (isascii) PetscCall(PetscSpacePolynomialView_Ascii(sp, viewer));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode PetscSpaceDestroy_Polynomial(PetscSpace sp)
{
  PetscSpace_Poly *poly = (PetscSpace_Poly *)sp->data;

  PetscFunctionBegin;
  PetscCall(PetscObjectComposeFunction((PetscObject)sp, "PetscSpacePolynomialGetTensor_C", NULL));
  PetscCall(PetscObjectComposeFunction((PetscObject)sp, "PetscSpacePolynomialSetTensor_C", NULL));
  if (poly->subspaces) {
    PetscInt d;

    for (d = 0; d < sp->Nv; ++d) PetscCall(PetscSpaceDestroy(&poly->subspaces[d]));
  }
  PetscCall(PetscFree(poly->subspaces));
  PetscCall(PetscFree(poly));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode PetscSpaceSetUp_Polynomial(PetscSpace sp)
{
  PetscSpace_Poly *poly = (PetscSpace_Poly *)sp->data;

  PetscFunctionBegin;
  if (poly->setupcalled) PetscFunctionReturn(PETSC_SUCCESS);
  if (sp->Nv <= 1) poly->tensor = PETSC_FALSE;
  if (sp->Nc != 1) {
    PetscInt    Nc     = sp->Nc;
    PetscBool   tensor = poly->tensor;
    PetscInt    Nv     = sp->Nv;
    PetscInt    degree = sp->degree;
    const char *prefix;
    const char *name;
    char        subname[PETSC_MAX_PATH_LEN];
    PetscSpace  subsp;

    PetscCall(PetscSpaceSetType(sp, PETSCSPACESUM));
    PetscCall(PetscSpaceSumSetNumSubspaces(sp, Nc));
    PetscCall(PetscSpaceSumSetInterleave(sp, PETSC_TRUE, PETSC_FALSE));
    PetscCall(PetscSpaceCreate(PetscObjectComm((PetscObject)sp), &subsp));
    PetscCall(PetscObjectGetOptionsPrefix((PetscObject)sp, &prefix));
    PetscCall(PetscObjectSetOptionsPrefix((PetscObject)subsp, prefix));
    PetscCall(PetscObjectAppendOptionsPrefix((PetscObject)subsp, "sumcomp_"));
    if (((PetscObject)sp)->name) {
      PetscCall(PetscObjectGetName((PetscObject)sp, &name));
      PetscCall(PetscSNPrintf(subname, PETSC_MAX_PATH_LEN - 1, "%s sum component", name));
      PetscCall(PetscObjectSetName((PetscObject)subsp, subname));
    } else PetscCall(PetscObjectSetName((PetscObject)subsp, "sum component"));
    PetscCall(PetscSpaceSetType(subsp, PETSCSPACEPOLYNOMIAL));
    PetscCall(PetscSpaceSetDegree(subsp, degree, PETSC_DETERMINE));
    PetscCall(PetscSpaceSetNumComponents(subsp, 1));
    PetscCall(PetscSpaceSetNumVariables(subsp, Nv));
    PetscCall(PetscSpacePolynomialSetTensor(subsp, tensor));
    PetscCall(PetscSpaceSetUp(subsp));
    for (PetscInt i = 0; i < Nc; i++) PetscCall(PetscSpaceSumSetSubspace(sp, i, subsp));
    PetscCall(PetscSpaceDestroy(&subsp));
    PetscCall(PetscSpaceSetUp(sp));
    PetscFunctionReturn(PETSC_SUCCESS);
  }
  if (poly->tensor) {
    sp->maxDegree = PETSC_DETERMINE;
    PetscCall(PetscSpaceSetType(sp, PETSCSPACETENSOR));
    PetscCall(PetscSpaceSetUp(sp));
    PetscFunctionReturn(PETSC_SUCCESS);
  }
  PetscCheck(sp->degree >= 0, PetscObjectComm((PetscObject)sp), PETSC_ERR_ARG_OUTOFRANGE, "Negative degree %" PetscInt_FMT " invalid", sp->degree);
  sp->maxDegree     = sp->degree;
  poly->setupcalled = PETSC_TRUE;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode PetscSpaceGetDimension_Polynomial(PetscSpace sp, PetscInt *dim)
{
  PetscInt deg = sp->degree;
  PetscInt n   = sp->Nv;

  PetscFunctionBegin;
  PetscCall(PetscDTBinomialInt(n + deg, n, dim));
  *dim *= sp->Nc;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode CoordinateBasis(PetscInt dim, PetscInt npoints, const PetscReal points[], PetscInt jet, PetscInt Njet, PetscReal pScalar[])
{
  PetscFunctionBegin;
  PetscCall(PetscArrayzero(pScalar, (1 + dim) * Njet * npoints));
  for (PetscInt b = 0; b < 1 + dim; b++) {
    for (PetscInt j = 0; j < PetscMin(1 + dim, Njet); j++) {
      if (j == 0) {
        if (b == 0) {
          for (PetscInt pt = 0; pt < npoints; pt++) pScalar[b * Njet * npoints + j * npoints + pt] = 1.;
        } else {
          for (PetscInt pt = 0; pt < npoints; pt++) pScalar[b * Njet * npoints + j * npoints + pt] = points[pt * dim + (b - 1)];
        }
      } else if (j == b) {
        for (PetscInt pt = 0; pt < npoints; pt++) pScalar[b * Njet * npoints + j * npoints + pt] = 1.;
      }
    }
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode PetscSpaceEvaluate_Polynomial(PetscSpace sp, PetscInt npoints, const PetscReal points[], PetscReal B[], PetscReal D[], PetscReal H[])
{
  PetscSpace_Poly *poly = (PetscSpace_Poly *)sp->data;
  DM               dm   = sp->dm;
  PetscInt         dim  = sp->Nv;
  PetscInt         Nb, jet, Njet;
  PetscReal       *pScalar;

  PetscFunctionBegin;
  if (!poly->setupcalled) {
    PetscCall(PetscSpaceSetUp(sp));
    PetscCall(PetscSpaceEvaluate(sp, npoints, points, B, D, H));
    PetscFunctionReturn(PETSC_SUCCESS);
  }
  PetscCheck(!poly->tensor && sp->Nc == 1, PETSC_COMM_SELF, PETSC_ERR_PLIB, "tensor and multicomponent spaces should have been converted");
  PetscCall(PetscDTBinomialInt(dim + sp->degree, dim, &Nb));
  if (H) {
    jet = 2;
  } else if (D) {
    jet = 1;
  } else {
    jet = 0;
  }
  PetscCall(PetscDTBinomialInt(dim + jet, dim, &Njet));
  PetscCall(DMGetWorkArray(dm, Nb * Njet * npoints, MPIU_REAL, &pScalar));
  // Why are we handling the case degree == 1 specially?  Because we don't want numerical noise when we evaluate hat
  // functions at the vertices of a simplex, which happens when we invert the Vandermonde matrix of the PKD basis.
  // We don't make any promise about which basis is used.
  if (sp->degree == 1) {
    PetscCall(CoordinateBasis(dim, npoints, points, jet, Njet, pScalar));
  } else {
    PetscCall(PetscDTPKDEvalJet(dim, npoints, points, sp->degree, jet, pScalar));
  }
  if (B) {
    PetscInt p_strl = Nb;
    PetscInt b_strl = 1;

    PetscInt b_strr = Njet * npoints;
    PetscInt p_strr = 1;

    PetscCall(PetscArrayzero(B, npoints * Nb));
    for (PetscInt b = 0; b < Nb; b++) {
      for (PetscInt p = 0; p < npoints; p++) B[p * p_strl + b * b_strl] = pScalar[b * b_strr + p * p_strr];
    }
  }
  if (D) {
    PetscInt p_strl = dim * Nb;
    PetscInt b_strl = dim;
    PetscInt d_strl = 1;

    PetscInt b_strr = Njet * npoints;
    PetscInt d_strr = npoints;
    PetscInt p_strr = 1;

    PetscCall(PetscArrayzero(D, npoints * Nb * dim));
    for (PetscInt d = 0; d < dim; d++) {
      for (PetscInt b = 0; b < Nb; b++) {
        for (PetscInt p = 0; p < npoints; p++) D[p * p_strl + b * b_strl + d * d_strl] = pScalar[b * b_strr + (1 + d) * d_strr + p * p_strr];
      }
    }
  }
  if (H) {
    PetscInt p_strl  = dim * dim * Nb;
    PetscInt b_strl  = dim * dim;
    PetscInt d1_strl = dim;
    PetscInt d2_strl = 1;

    PetscInt b_strr = Njet * npoints;
    PetscInt j_strr = npoints;
    PetscInt p_strr = 1;

    PetscInt *derivs;
    PetscCall(PetscCalloc1(dim, &derivs));
    PetscCall(PetscArrayzero(H, npoints * Nb * dim * dim));
    for (PetscInt d1 = 0; d1 < dim; d1++) {
      for (PetscInt d2 = 0; d2 < dim; d2++) {
        PetscInt j;
        derivs[d1]++;
        derivs[d2]++;
        PetscCall(PetscDTGradedOrderToIndex(dim, derivs, &j));
        derivs[d1]--;
        derivs[d2]--;
        for (PetscInt b = 0; b < Nb; b++) {
          for (PetscInt p = 0; p < npoints; p++) H[p * p_strl + b * b_strl + d1 * d1_strl + d2 * d2_strl] = pScalar[b * b_strr + j * j_strr + p * p_strr];
        }
      }
    }
    PetscCall(PetscFree(derivs));
  }
  PetscCall(DMRestoreWorkArray(dm, Nb * Njet * npoints, MPIU_REAL, &pScalar));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  PetscSpacePolynomialSetTensor - Set whether a function space is a space of tensor polynomials.

  Input Parameters:
+ sp     - the function space object
- tensor - `PETSC_TRUE` for a tensor polynomial space, `PETSC_FALSE` for a polynomial space

  Options Database Key:
. -petscspace_poly_tensor <bool> - Whether to use tensor product polynomials in higher dimension

  Level: intermediate

  Notes:
  It is a tensor space if it is spanned by polynomials whose degree in each variable is
  bounded by the given order, as opposed to the space spanned by polynomials
  whose total degree---summing over all variables---is bounded by the given order.

.seealso: `PetscSpace`, `PetscSpacePolynomialGetTensor()`, `PetscSpaceSetDegree()`, `PetscSpaceSetNumVariables()`
@*/
PetscErrorCode PetscSpacePolynomialSetTensor(PetscSpace sp, PetscBool tensor)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(sp, PETSCSPACE_CLASSID, 1);
  PetscTryMethod(sp, "PetscSpacePolynomialSetTensor_C", (PetscSpace, PetscBool), (sp, tensor));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*@
  PetscSpacePolynomialGetTensor - Get whether a function space is a space of tensor
  polynomials.

  Input Parameter:
. sp - the function space object

  Output Parameter:
. tensor - `PETSC_TRUE` for a tensor polynomial space, `PETSC_FALSE` for a polynomial space

  Level: intermediate

  Notes:
  The space is a tensor space if it is spanned by polynomials whose degree in each variable is
  bounded by the given order, as opposed to the space spanned by polynomials
  whose total degree---summing over all variables---is bounded by the given order.

.seealso: `PetscSpace`, `PetscSpacePolynomialSetTensor()`, `PetscSpaceSetDegree()`, `PetscSpaceSetNumVariables()`
@*/
PetscErrorCode PetscSpacePolynomialGetTensor(PetscSpace sp, PetscBool *tensor)
{
  PetscFunctionBegin;
  PetscValidHeaderSpecific(sp, PETSCSPACE_CLASSID, 1);
  PetscAssertPointer(tensor, 2);
  PetscTryMethod(sp, "PetscSpacePolynomialGetTensor_C", (PetscSpace, PetscBool *), (sp, tensor));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode PetscSpacePolynomialSetTensor_Polynomial(PetscSpace sp, PetscBool tensor)
{
  PetscSpace_Poly *poly = (PetscSpace_Poly *)sp->data;

  PetscFunctionBegin;
  poly->tensor = tensor;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode PetscSpacePolynomialGetTensor_Polynomial(PetscSpace sp, PetscBool *tensor)
{
  PetscSpace_Poly *poly = (PetscSpace_Poly *)sp->data;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(sp, PETSCSPACE_CLASSID, 1);
  PetscAssertPointer(tensor, 2);
  *tensor = poly->tensor;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode PetscSpaceGetHeightSubspace_Polynomial(PetscSpace sp, PetscInt height, PetscSpace *subsp)
{
  PetscSpace_Poly *poly = (PetscSpace_Poly *)sp->data;
  PetscInt         Nc, dim, order;
  PetscBool        tensor;

  PetscFunctionBegin;
  PetscCall(PetscSpaceGetNumComponents(sp, &Nc));
  PetscCall(PetscSpaceGetNumVariables(sp, &dim));
  PetscCall(PetscSpaceGetDegree(sp, &order, NULL));
  PetscCall(PetscSpacePolynomialGetTensor(sp, &tensor));
  PetscCheck(height <= dim && height >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Asked for space at height %" PetscInt_FMT " for dimension %" PetscInt_FMT " space", height, dim);
  if (!poly->subspaces) PetscCall(PetscCalloc1(dim, &poly->subspaces));
  if (height <= dim) {
    if (!poly->subspaces[height - 1]) {
      PetscSpace  sub;
      const char *name;

      PetscCall(PetscSpaceCreate(PetscObjectComm((PetscObject)sp), &sub));
      PetscCall(PetscObjectGetName((PetscObject)sp, &name));
      PetscCall(PetscObjectSetName((PetscObject)sub, name));
      PetscCall(PetscSpaceSetType(sub, PETSCSPACEPOLYNOMIAL));
      PetscCall(PetscSpaceSetNumComponents(sub, Nc));
      PetscCall(PetscSpaceSetDegree(sub, order, PETSC_DETERMINE));
      PetscCall(PetscSpaceSetNumVariables(sub, dim - height));
      PetscCall(PetscSpacePolynomialSetTensor(sub, tensor));
      PetscCall(PetscSpaceSetUp(sub));
      poly->subspaces[height - 1] = sub;
    }
    *subsp = poly->subspaces[height - 1];
  } else {
    *subsp = NULL;
  }
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode PetscSpaceInitialize_Polynomial(PetscSpace sp)
{
  PetscFunctionBegin;
  sp->ops->setfromoptions    = PetscSpaceSetFromOptions_Polynomial;
  sp->ops->setup             = PetscSpaceSetUp_Polynomial;
  sp->ops->view              = PetscSpaceView_Polynomial;
  sp->ops->destroy           = PetscSpaceDestroy_Polynomial;
  sp->ops->getdimension      = PetscSpaceGetDimension_Polynomial;
  sp->ops->evaluate          = PetscSpaceEvaluate_Polynomial;
  sp->ops->getheightsubspace = PetscSpaceGetHeightSubspace_Polynomial;
  PetscCall(PetscObjectComposeFunction((PetscObject)sp, "PetscSpacePolynomialGetTensor_C", PetscSpacePolynomialGetTensor_Polynomial));
  PetscCall(PetscObjectComposeFunction((PetscObject)sp, "PetscSpacePolynomialSetTensor_C", PetscSpacePolynomialSetTensor_Polynomial));
  PetscFunctionReturn(PETSC_SUCCESS);
}

/*MC
  PETSCSPACEPOLYNOMIAL = "poly" - A `PetscSpace` object that encapsulates a polynomial space, e.g. P1 is the space of
  linear polynomials. The space is replicated for each component.

  Level: intermediate

.seealso: `PetscSpace`, `PetscSpaceType`, `PetscSpaceCreate()`, `PetscSpaceSetType()`
M*/

PETSC_EXTERN PetscErrorCode PetscSpaceCreate_Polynomial(PetscSpace sp)
{
  PetscSpace_Poly *poly;

  PetscFunctionBegin;
  PetscValidHeaderSpecific(sp, PETSCSPACE_CLASSID, 1);
  PetscCall(PetscNew(&poly));
  sp->data = poly;

  poly->tensor    = PETSC_FALSE;
  poly->subspaces = NULL;

  PetscCall(PetscSpaceInitialize_Polynomial(sp));
  PetscFunctionReturn(PETSC_SUCCESS);
}
