#include "wash.h"

/* Subroutines for Pipe                                  */
/* -------------------------------------------------------*/

/*
   PipeCreate - Create Pipe object.

   Input Parameters:
   comm - MPI communicator

   Output Parameter:
.  pipe - location to put the PIPE context
*/
PetscErrorCode PipeCreate(MPI_Comm comm,Pipe *pipe)
{
  PetscFunctionBegin;
  PetscCall(PetscNew(pipe));
  PetscFunctionReturn(0);
}

/*
   PipeDestroy - Destroy Pipe object.

   Input Parameters:
   pipe - Reference to pipe intended to be destroyed.
*/
PetscErrorCode PipeDestroy(Pipe *pipe)
{
  PetscFunctionBegin;
  if (!*pipe) PetscFunctionReturn(0);

  PetscCall(PipeDestroyJacobian(*pipe));
  PetscCall(VecDestroy(&(*pipe)->x));
  PetscCall(DMDestroy(&(*pipe)->da));
  PetscFunctionReturn(0);
}

/*
   PipeSetParameters - Set parameters for Pipe context

   Input Parameter:
+  pipe - PIPE object
.  length -
.  nnodes -
.  D -
.  a -
-  fric -
*/
PetscErrorCode PipeSetParameters(Pipe pipe,PetscReal length,PetscReal D,PetscReal a,PetscReal fric)
{
  PetscFunctionBegin;
  pipe->length = length;
  pipe->D      = D;
  pipe->a      = a;
  pipe->fric   = fric;
  PetscFunctionReturn(0);
}

/*
    PipeSetUp - Set up pipe based on set parameters.
*/
PetscErrorCode PipeSetUp(Pipe pipe)
{
  DMDALocalInfo  info;

  PetscFunctionBegin;
  PetscCall(DMDACreate1d(PETSC_COMM_SELF, DM_BOUNDARY_GHOSTED, pipe->nnodes, 2, 1, NULL, &pipe->da));
  PetscCall(DMSetFromOptions(pipe->da));
  PetscCall(DMSetUp(pipe->da));
  PetscCall(DMDASetFieldName(pipe->da, 0, "Q"));
  PetscCall(DMDASetFieldName(pipe->da, 1, "H"));
  PetscCall(DMDASetUniformCoordinates(pipe->da, 0, pipe->length, 0, 0, 0, 0));
  PetscCall(DMCreateGlobalVector(pipe->da, &(pipe->x)));

  PetscCall(DMDAGetLocalInfo(pipe->da, &info));

  pipe->rad = pipe->D / 2;
  pipe->A   = PETSC_PI*pipe->rad*pipe->rad;
  pipe->R   = pipe->fric / (2*pipe->D*pipe->A);
  PetscFunctionReturn(0);
}

/*
    PipeCreateJacobian - Create Jacobian matrix structures for a Pipe.

    Collective on Pipe

    Input Parameter:
+   pipe - the Pipe object
-   Jin - array of three constructed Jacobian matrices to be reused. Set NULL if it is not available

    Output Parameter:
.   J  - array of three empty Jacobian matrices

    Level: beginner
*/
PetscErrorCode PipeCreateJacobian(Pipe pipe,Mat *Jin,Mat *J[])
{
  Mat            *Jpipe;
  PetscInt       M,rows[2],cols[2],*nz;
  PetscScalar    *aa;

  PetscFunctionBegin;
  if (Jin) {
    *J = Jin;
    pipe->jacobian = Jin;
    PetscCall(PetscObjectReference((PetscObject)(Jin[0])));
    PetscFunctionReturn(0);
  }
  PetscCall(PetscMalloc1(3,&Jpipe));

  /* Jacobian for this pipe */
  PetscCall(DMSetMatrixStructureOnly(pipe->da,PETSC_TRUE));
  PetscCall(DMCreateMatrix(pipe->da,&Jpipe[0]));
  PetscCall(DMSetMatrixStructureOnly(pipe->da,PETSC_FALSE));

  /* Jacobian for upstream vertex */
  PetscCall(MatGetSize(Jpipe[0],&M,NULL));
  PetscCall(PetscCalloc2(M,&nz,4,&aa));

  PetscCall(MatCreate(PETSC_COMM_SELF,&Jpipe[1]));
  PetscCall(MatSetSizes(Jpipe[1],PETSC_DECIDE,PETSC_DECIDE,M,2));
  PetscCall(MatSetFromOptions(Jpipe[1]));
  PetscCall(MatSetOption(Jpipe[1],MAT_STRUCTURE_ONLY,PETSC_TRUE));
  nz[0] = 2; nz[1] = 2;
  rows[0] = 0; rows[1] = 1;
  cols[0] = 0; cols[1] = 1;
  PetscCall(MatSeqAIJSetPreallocation(Jpipe[1],0,nz));
  PetscCall(MatSetValues(Jpipe[1],2,rows,2,cols,aa,INSERT_VALUES));
  PetscCall(MatAssemblyBegin(Jpipe[1],MAT_FINAL_ASSEMBLY));
  PetscCall(MatAssemblyEnd(Jpipe[1],MAT_FINAL_ASSEMBLY));

  /* Jacobian for downstream vertex */
  PetscCall(MatCreate(PETSC_COMM_SELF,&Jpipe[2]));
  PetscCall(MatSetSizes(Jpipe[2],PETSC_DECIDE,PETSC_DECIDE,M,2));
  PetscCall(MatSetFromOptions(Jpipe[2]));
  PetscCall(MatSetOption(Jpipe[2],MAT_STRUCTURE_ONLY,PETSC_TRUE));
  nz[0] = 0; nz[1] = 0; nz[M-2] = 2; nz[M-1] = 2;
  rows[0] = M - 2; rows[1] = M - 1;
  PetscCall(MatSeqAIJSetPreallocation(Jpipe[2],0,nz));
  PetscCall(MatSetValues(Jpipe[2],2,rows,2,cols,aa,INSERT_VALUES));
  PetscCall(MatAssemblyBegin(Jpipe[2],MAT_FINAL_ASSEMBLY));
  PetscCall(MatAssemblyEnd(Jpipe[2],MAT_FINAL_ASSEMBLY));

  PetscCall(PetscFree2(nz,aa));

  *J = Jpipe;
  pipe->jacobian = Jpipe;
  PetscFunctionReturn(0);
}

PetscErrorCode PipeDestroyJacobian(Pipe pipe)
{
  Mat            *Jpipe = pipe->jacobian;
  PetscInt       i;

  PetscFunctionBegin;
  if (Jpipe) {
    for (i=0; i<3; i++) {
      PetscCall(MatDestroy(&Jpipe[i]));
    }
  }
  PetscCall(PetscFree(Jpipe));
  PetscFunctionReturn(0);
}

/*
    JunctionCreateJacobian - Create Jacobian matrices for a vertex.

    Collective on Pipe

    Input Parameter:
+   dm - the DMNetwork object
.   v - vertex point
-   Jin - Jacobian patterns created by JunctionCreateJacobianSample() for reuse

    Output Parameter:
.   J  - array of Jacobian matrices (see dmnetworkimpl.h)

    Level: beginner
*/
PetscErrorCode JunctionCreateJacobian(DM dm,PetscInt v,Mat *Jin,Mat *J[])
{
  Mat            *Jv;
  PetscInt       nedges,e,i,M,N,*rows,*cols;
  PetscBool      isSelf;
  const PetscInt *edges,*cone;
  PetscScalar    *zeros;

  PetscFunctionBegin;
  /* Get array size of Jv */
  PetscCall(DMNetworkGetSupportingEdges(dm,v,&nedges,&edges));
  PetscCheck(nedges > 0,PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"%d vertex, nedges %d",v,nedges);

  /* two Jacobians for each connected edge: J(v,e) and J(v,vc); adding J(v,v), total 2*nedges+1 Jacobians */
  PetscCall(PetscCalloc1(2*nedges+1,&Jv));

  /* Create dense zero block for this vertex: J[0] = Jacobian(v,v) */
  PetscCall(DMNetworkGetComponent(dm,v,-1,NULL,NULL,&M));
  PetscCheck(M ==2,PETSC_COMM_SELF,PETSC_ERR_ARG_SIZ,"M != 2",M);
  PetscCall(PetscMalloc3(M,&rows,M,&cols,M*M,&zeros));
  PetscCall(PetscArrayzero(zeros,M*M));
  for (i=0; i<M; i++) rows[i] = i;

  for (e=0; e<nedges; e++) {
    /* create Jv[2*e+1] = Jacobian(v,e), e: supporting edge */
    PetscCall(DMNetworkGetConnectedVertices(dm,edges[e],&cone));
    isSelf = (v == cone[0]) ? PETSC_TRUE:PETSC_FALSE;

    if (Jin) {
      if (isSelf) {
        Jv[2*e+1] = Jin[0];
      } else {
        Jv[2*e+1] = Jin[1];
      }
      Jv[2*e+2] = Jin[2];
      PetscCall(PetscObjectReference((PetscObject)(Jv[2*e+1])));
      PetscCall(PetscObjectReference((PetscObject)(Jv[2*e+2])));
    } else {
      /* create J(v,e) */
      PetscCall(MatCreate(PETSC_COMM_SELF,&Jv[2*e+1]));
      PetscCall(DMNetworkGetComponent(dm,edges[e],-1,NULL,NULL,&N));
      PetscCall(MatSetSizes(Jv[2*e+1],PETSC_DECIDE,PETSC_DECIDE,M,N));
      PetscCall(MatSetFromOptions(Jv[2*e+1]));
      PetscCall(MatSetOption(Jv[2*e+1],MAT_STRUCTURE_ONLY,PETSC_TRUE));
      PetscCall(MatSeqAIJSetPreallocation(Jv[2*e+1],2,NULL));
      if (N) {
        if (isSelf) { /* coupling at upstream */
          for (i=0; i<2; i++) cols[i] = i;
        } else { /* coupling at downstream */
          cols[0] = N-2; cols[1] = N-1;
        }
        PetscCall(MatSetValues(Jv[2*e+1],2,rows,2,cols,zeros,INSERT_VALUES));
      }
      PetscCall(MatAssemblyBegin(Jv[2*e+1],MAT_FINAL_ASSEMBLY));
      PetscCall(MatAssemblyEnd(Jv[2*e+1],MAT_FINAL_ASSEMBLY));

      /* create Jv[2*e+2] = Jacobian(v,vc), vc: connected vertex.
       In WashNetwork, v and vc are not connected, thus Jacobian(v,vc) is empty */
      PetscCall(MatCreate(PETSC_COMM_SELF,&Jv[2*e+2]));
      PetscCall(MatSetSizes(Jv[2*e+2],PETSC_DECIDE,PETSC_DECIDE,M,M)); /* empty matrix, sizes can be arbitrary */
      PetscCall(MatSetFromOptions(Jv[2*e+2]));
      PetscCall(MatSetOption(Jv[2*e+2],MAT_STRUCTURE_ONLY,PETSC_TRUE));
      PetscCall(MatSeqAIJSetPreallocation(Jv[2*e+2],1,NULL));
      PetscCall(MatAssemblyBegin(Jv[2*e+2],MAT_FINAL_ASSEMBLY));
      PetscCall(MatAssemblyEnd(Jv[2*e+2],MAT_FINAL_ASSEMBLY));
    }
  }
  PetscCall(PetscFree3(rows,cols,zeros));

  *J = Jv;
  PetscFunctionReturn(0);
}

PetscErrorCode JunctionDestroyJacobian(DM dm,PetscInt v,Junction junc)
{
  Mat            *Jv=junc->jacobian;
  const PetscInt *edges;
  PetscInt       nedges,e;

  PetscFunctionBegin;
  if (!Jv) PetscFunctionReturn(0);

  PetscCall(DMNetworkGetSupportingEdges(dm,v,&nedges,&edges));
  for (e=0; e<nedges; e++) {
    PetscCall(MatDestroy(&Jv[2*e+1]));
    PetscCall(MatDestroy(&Jv[2*e+2]));
  }
  PetscCall(PetscFree(Jv));
  PetscFunctionReturn(0);
}
