// Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at
// the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights
// reserved. See files LICENSE and NOTICE for details.
//
// This file is part of CEED, a collection of benchmarks, miniapps, software
// libraries and APIs for efficient high-order finite element and spectral
// element discretizations for exascale applications. For more information and
// source code availability see http://github.com/ceed.
//
// The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
// a collaborative effort of two U.S. Department of Energy organizations (Office
// of Science and the National Nuclear Security Administration) responsible for
// the planning and preparation of a capable exascale ecosystem, including
// software, applications, hardware, advanced system engineering and early
// testbed platforms, in support of the nation's exascale computing imperative.

#include <ceed-impl.h>
#include <string.h>

/// @file
/// Implementation of public CeedOperator interfaces
///
/// @addtogroup CeedOperator
///   @{

/**
  @brief Create an operator from element restriction, basis, and QFunction

  @param ceed    A Ceed object where the CeedOperator will be created
  @param qf      QFunction defining the action of the operator at quadrature points
  @param dqf     QFunction defining the action of the Jacobian of @a qf (or NULL)
  @param dqfT    QFunction defining the action of the transpose of the Jacobian
                   of @a qf (or NULL)
  @param[out] op Address of the variable where the newly created
                     CeedOperator will be stored

  @return An error code: 0 - success, otherwise - failure

  @ref Basic
 */
int CeedOperatorCreate(Ceed ceed, CeedQFunction qf, CeedQFunction dqf,
                       CeedQFunction dqfT, CeedOperator *op) {
  int ierr;

  if (!ceed->OperatorCreate) return CeedError(ceed, 1,
                                      "Backend does not support OperatorCreate");
  ierr = CeedCalloc(1,op); CeedChk(ierr);
  (*op)->ceed = ceed;
  ceed->refcount++;
  (*op)->refcount = 1;
  (*op)->qf = qf;
  qf->refcount++;
  (*op)->dqf = dqf;
  if (dqf) dqf->refcount++;
  (*op)->dqfT = dqfT;
  if (dqfT) dqfT->refcount++;
  ierr = ceed->OperatorCreate(*op); CeedChk(ierr);
  return 0;
}

/**
  @brief Provide a field to a CeedOperator for use by its CeedQFunction

  This function is used to specify both active and passive fields to a
  CeedOperator.  For passive fields, a vector @arg v must be provided.  Passive
  fields can inputs or outputs (updated in-place when operator is applied).

  Active fields must be specified using this function, but their data (in a
  CeedVector) is passed in CeedOperatorApply().  There can be at most one active
  input and at most one active output.

  @param op         Ceedoperator on which to provide the field
  @param fieldname  Name of the field (to be matched with the name used by CeedQFunction)
  @param r          CeedElemRestriction
  @param b          CeedBasis in which the field resides or CEED_BASIS_COLOCATED
                      if collocated with quadrature points
  @param v          CeedVector to be used by CeedOperator or CEED_VECTOR_ACTIVE
                      if field is active or CEED_VECTOR_NONE if using
                      CEED_EVAL_WEIGHT in the qfunction

  @return An error code: 0 - success, otherwise - failure

  @ref Basic
**/
int CeedOperatorSetField(CeedOperator op, const char *fieldname,
                         CeedElemRestriction r, CeedBasis b,
                         CeedVector v) {
  int ierr;
  CeedInt numelements;
  ierr = CeedElemRestrictionGetNumElements(r, &numelements); CeedChk(ierr);
  if (op->numelements && op->numelements != numelements)
    return CeedError(op->ceed, 1,
                     "ElemRestriction with %d elements incompatible with prior %d elements",
                     numelements, op->numelements);
  op->numelements = numelements;

  if (b != CEED_BASIS_COLOCATED) {
    CeedInt numqpoints;
    ierr = CeedBasisGetNumQuadraturePoints(b, &numqpoints); CeedChk(ierr);
    if (op->numqpoints && op->numqpoints != numqpoints)
      return CeedError(op->ceed, 1,
                       "Basis with %d quadrature points incompatible with prior %d points",
                       numqpoints, op->numqpoints);
    op->numqpoints = numqpoints;
  }
  struct CeedOperatorField *ofield;
  for (CeedInt i=0; i<op->qf->numinputfields; i++) {
    if (!strcmp(fieldname, op->qf->inputfields[i].fieldname)) {
      ofield = &op->inputfields[i];
      goto found;
    }
  }
  for (CeedInt i=0; i<op->qf->numoutputfields; i++) {
    if (!strcmp(fieldname, op->qf->outputfields[i].fieldname)) {
      ofield = &op->outputfields[i];
      goto found;
    }
  }
  return CeedError(op->ceed, 1, "QFunction has no knowledge of field '%s'",
                   fieldname);
found:
  ofield->Erestrict = r;
  ofield->basis = b;
  ofield->vec = v;
  op->nfields += 1;
  return 0;
}

/**
  @brief Apply CeedOperator to a vector

  This computes the action of the operator on the specified (active) input,
  yielding its (active) output.  All inputs and outputs must be specified using
  CeedOperatorSetField().

  @param op        CeedOperator to apply
  @param[in] in    CeedVector containing input state or NULL if there are no
                     active inputs
  @param[out] out  CeedVector to store result of applying operator (must be
                     distinct from @a in) or NULL if there are no active outputs
  @param request   Address of CeedRequest for non-blocking completion, else
                     CEED_REQUEST_IMMEDIATE

  @return An error code: 0 - success, otherwise - failure

  @ref Basic
**/
int CeedOperatorApply(CeedOperator op, CeedVector in,
                      CeedVector out, CeedRequest *request) {
  int ierr;
  Ceed ceed = op->ceed;
  CeedQFunction qf = op->qf;

  if (op->nfields == 0) return CeedError(ceed, 1, "No operator fields set");
  if (op->nfields < qf->numinputfields + qf->numoutputfields) return CeedError(
          ceed, 1, "Not all operator fields set");
  if (op->numelements == 0) return CeedError(ceed, 1,
                                     "At least one restriction required");
  if (op->numqpoints == 0) return CeedError(ceed, 1,
                                    "At least one non-colocated basis required");
  ierr = op->Apply(op, in, out, request); CeedChk(ierr);
  return 0;
}

/**
  @brief Destroy a CeedOperator

  @param op CeedOperator to destroy

  @return An error code: 0 - success, otherwise - failure

  @ref Basic
**/
int CeedOperatorDestroy(CeedOperator *op) {
  int ierr;

  if (!*op || --(*op)->refcount > 0) return 0;
  if ((*op)->Destroy) {
    ierr = (*op)->Destroy(*op); CeedChk(ierr);
  }
  ierr = CeedQFunctionDestroy(&(*op)->qf); CeedChk(ierr);
  ierr = CeedQFunctionDestroy(&(*op)->dqf); CeedChk(ierr);
  ierr = CeedQFunctionDestroy(&(*op)->dqfT); CeedChk(ierr);
  ierr = CeedDestroy(&(*op)->ceed); CeedChk(ierr);
  ierr = CeedFree(op); CeedChk(ierr);
  return 0;
}

/// @}
