/* Defines the basic SNES object */
#include <../src/snes/impls/ngmres/snesngmres.h>
#include <petscblaslapack.h>


/*MC
  SNESNGMRES - The Nonlinear Generalized Minimum Residual (NGMRES) method of Oosterlee and Washio.

   Level: beginner

   "Krylov Subspace Acceleration of Nonlinear Multigrid with Application to Recirculating Flows", C. W. Oosterlee and T. Washio,
   SIAM Journal on Scientific Computing, 21(5), 2000.

.seealso: SNESCreate(), SNES, SNESSetType(), SNESType (for list of available types)
M*/

#undef __FUNCT__
#define __FUNCT__ "SNESReset_NGMRES"
PetscErrorCode SNESReset_NGMRES(SNES snes)
{
  SNES_NGMRES   *ngmres = (SNES_NGMRES*) snes->data;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  ierr = VecDestroyVecs(ngmres->msize, &ngmres->rdot);CHKERRQ(ierr);
  ierr = VecDestroyVecs(ngmres->msize, &ngmres->xdot);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

#undef __FUNCT__
#define __FUNCT__ "SNESDestroy_NGMRES"
PetscErrorCode SNESDestroy_NGMRES(SNES snes)
{
  PetscErrorCode ierr;

  PetscFunctionBegin;
  ierr = SNESReset_NGMRES(snes);CHKERRQ(ierr);
  if (snes->work) {ierr = VecDestroyVecs(snes->nwork, &snes->work);CHKERRQ(ierr);}
  if (snes->data) {
    SNES_NGMRES * ngmres = (SNES_NGMRES *)snes->data;
    ierr = PetscFree5(ngmres->h, ngmres->beta, ngmres->xi, ngmres->r_norms, ngmres->q);CHKERRQ(ierr);
    ierr = PetscFree(ngmres->s);CHKERRQ(ierr);
#if PETSC_USE_COMPLEX
    ierr = PetscFree(ngmres->rwork);
#endif
    ierr = PetscFree(ngmres->work);
  }
  ierr = PetscFree(snes->data);
  PetscFunctionReturn(0);
}

#undef __FUNCT__
#define __FUNCT__ "SNESSetUp_NGMRES"
PetscErrorCode SNESSetUp_NGMRES(SNES snes)
{
  SNES_NGMRES   *ngmres = (SNES_NGMRES *) snes->data;
  PetscInt msize,hsize;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  msize         = ngmres->msize;  /* restart size */
  hsize         = msize * msize;


  /* explicit least squares minimization solve */
  ierr = PetscMalloc5(hsize,PetscScalar,&ngmres->h,
		      msize,PetscScalar,&ngmres->beta,
		      msize,PetscScalar,&ngmres->xi,
		      msize,PetscReal,&ngmres->r_norms,
		      hsize,PetscScalar,&ngmres->q);CHKERRQ(ierr);
  ngmres->nrhs = 1;
  ngmres->lda = msize;
  ngmres->ldb = msize;
  ierr = PetscMalloc(msize*sizeof(PetscScalar),&ngmres->s);CHKERRQ(ierr);
  
  ierr = PetscMemzero(ngmres->h,hsize*sizeof(PetscScalar));CHKERRQ(ierr);
  ierr = PetscMemzero(ngmres->q,hsize*sizeof(PetscScalar));CHKERRQ(ierr);
  ierr = PetscMemzero(ngmres->xi,msize*sizeof(PetscScalar));CHKERRQ(ierr);
  ierr = PetscMemzero(ngmres->beta,msize*sizeof(PetscScalar));CHKERRQ(ierr);
  
  ngmres->lwork = 12*msize;
#if PETSC_USE_COMPLEX
  ierr = PetscMalloc(sizeof(PetscReal)*ngmres->lwork,&ngmres->rwork);
#endif
  ierr = PetscMalloc(sizeof(PetscScalar)*ngmres->lwork,&ngmres->work);

  ierr = VecDuplicateVecs(snes->vec_sol, ngmres->msize, &ngmres->xdot);CHKERRQ(ierr);
  ierr = VecDuplicateVecs(snes->vec_sol, ngmres->msize, &ngmres->rdot);CHKERRQ(ierr);
  ierr = SNESDefaultGetWork(snes, 2);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

#undef __FUNCT__
#define __FUNCT__ "SNESSetFromOptions_NGMRES"
PetscErrorCode SNESSetFromOptions_NGMRES(SNES snes)
{
  SNES_NGMRES   *ngmres = (SNES_NGMRES *) snes->data;
  PetscErrorCode ierr;
  
  PetscFunctionBegin;
  ierr = PetscOptionsHead("SNES NGMRES options");CHKERRQ(ierr);
  ierr = PetscOptionsInt("-snes_ngmres_m", "Number of directions", "SNES", ngmres->msize, &ngmres->msize, PETSC_NULL);CHKERRQ(ierr);
  ierr = PetscOptionsInt("-snes_ngmres_restart", "Maximum iterations before restart.", "SNES", ngmres->k_rmax, &ngmres->k_rmax, PETSC_NULL);CHKERRQ(ierr);
  ierr = PetscOptionsBool("-snes_ngmres_debug", "Debugging output for NGMRES", "SNES", ngmres->debug, &ngmres->debug, PETSC_NULL);CHKERRQ(ierr); 
  ierr = PetscOptionsReal("-snes_ngmres_gammaA", "Residual selection constant", "SNES", ngmres->gammaA, &ngmres->gammaA, PETSC_NULL);CHKERRQ(ierr);
  ierr = PetscOptionsReal("-snes_ngmres_gammaC", "Residual restart constant", "SNES", ngmres->gammaC, &ngmres->gammaC, PETSC_NULL);CHKERRQ(ierr);
  ierr = PetscOptionsReal("-snes_ngmres_epsilonB", "Difference selection constant", "SNES", ngmres->epsilonB, &ngmres->epsilonB, PETSC_NULL);CHKERRQ(ierr);
  ierr = PetscOptionsReal("-snes_ngmres_deltaB", "Difference residual selection constant", "SNES", ngmres->deltaB, &ngmres->deltaB, PETSC_NULL);CHKERRQ(ierr);
  ierr = PetscOptionsTail();CHKERRQ(ierr);
  if ((ngmres->gammaA > ngmres->gammaC) && (ngmres->gammaC > 2.)) ngmres->gammaC = ngmres->gammaA;
  PetscFunctionReturn(0);
}

#undef __FUNCT__
#define __FUNCT__ "SNESView_NGMRES"
PetscErrorCode SNESView_NGMRES(SNES snes, PetscViewer viewer)
{
  SNES_NGMRES   *ngmres = (SNES_NGMRES *) snes->data;
  PetscBool      iascii;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  ierr = PetscTypeCompare((PetscObject) viewer, PETSCVIEWERASCII, &iascii);CHKERRQ(ierr);
  if (iascii) {
    ierr = PetscViewerASCIIPrintf(viewer, "  Size of space %d\n", ngmres->msize);CHKERRQ(ierr);
    ierr = PetscViewerASCIIPrintf(viewer, "  Maximum iterations before restart %d\n", ngmres->k_rmax);CHKERRQ(ierr);
  }
  PetscFunctionReturn(0);
}


#undef __FUNCT__
#define __FUNCT__ "SNESSolve_NGMRES"

PetscErrorCode SNESSolve_NGMRES(SNES snes)
{
  SNES           pc;
  SNES_NGMRES   *ngmres = (SNES_NGMRES *) snes->data;

  
  
  /* present solution, residual, and preconditioned residual */
  Vec            x, r, b, d;
  Vec            x_A, r_A;

  /* previous iterations to construct the subspace */
  Vec            *rdot = ngmres->rdot;
  Vec            *xdot = ngmres->xdot;

  /* coefficients and RHS to the minimization problem */
  PetscScalar    *beta = ngmres->beta;
  PetscScalar    *xi = ngmres->xi;
  PetscReal      r_norm, r_A_norm;
  PetscReal      nu;
  PetscScalar    alph_total = 0.;
  PetscScalar    qentry;
  PetscInt       i, j, k, k_restart, l, ivec;

  /* solution selection data */
  PetscBool selectA, selectRestart;
  PetscReal d_norm, d_min_norm, d_cur_norm;
  PetscReal r_min_norm;

  PetscErrorCode ierr;

  PetscFunctionBegin;
  /* variable initialization */
  snes->reason  = SNES_CONVERGED_ITERATING;
  x             = snes->vec_sol;
  r             = snes->vec_func;
  b             = snes->vec_rhs;
  x_A           = snes->vec_sol_update;
  r_A           = snes->work[0];
  d             = snes->work[1];
  ierr = VecDuplicate(x,&r);CHKERRQ(ierr);

  ierr = SNESGetPC(snes, &pc);CHKERRQ(ierr);
  ierr = PetscObjectTakeAccess(snes);CHKERRQ(ierr);
  snes->iter = 0;
  snes->norm = 0.;
  ierr = PetscObjectGrantAccess(snes);CHKERRQ(ierr);

  /* initialization */

  /* r = F(x) */
  ierr = SNESComputeFunction(snes, x, r);CHKERRQ(ierr);
  if (snes->domainerror) {
    snes->reason = SNES_DIVERGED_FUNCTION_DOMAIN;
    PetscFunctionReturn(0);
  }

  /* nu = (r, r) */
  ierr = VecNorm(r, NORM_2, &r_norm);CHKERRQ(ierr);
  r_min_norm = r_norm;
  nu = r_norm*r_norm;
  if (PetscIsInfOrNanReal(r_norm)) SETERRQ(PETSC_COMM_SELF, PETSC_ERR_FP, "Infinite or not-a-number generated in norm");

  /* q_{00} = nu  */
  Q(0,0) = nu;
  ngmres->r_norms[0] = r_norm;
  /* rdot[0] = r */
  ierr = VecCopy(x, xdot[0]);CHKERRQ(ierr);
  ierr = VecCopy(r, rdot[0]);CHKERRQ(ierr);

  ierr = PetscObjectTakeAccess(snes);CHKERRQ(ierr);
  snes->norm = r_norm;
  ierr = PetscObjectGrantAccess(snes);CHKERRQ(ierr);
  SNESLogConvHistory(snes, r_norm, 0);
  ierr = SNESMonitor(snes, 0, r_norm);CHKERRQ(ierr);
  ierr = (*snes->ops->converged)(snes,0,0.0,0.0,r_norm,&snes->reason,snes->cnvP);CHKERRQ(ierr);
  if (snes->reason) PetscFunctionReturn(0);

  k_restart = 1;
  l = 1;
  for (k=1; k<snes->max_its; k++) {

    /* select which vector of the stored subspace will be updated */
    ivec = k_restart % ngmres->msize; /* replace the last used part of the subspace */


    /* Computation of x^M */
    ierr = SNESSolve(pc, b, x);CHKERRQ(ierr);
    /* r = F(x) */
    ierr = SNESComputeFunction(snes, x, r);CHKERRQ(ierr);
    ierr = VecNorm(r, NORM_2, &r_norm);CHKERRQ(ierr);
    /* nu = (r, r) */
    ngmres->r_norms[ivec] = r_norm;
    nu = r_norm*r_norm;    
    if (r_min_norm > r_norm) r_min_norm = r_norm;  /* the minimum norm is now of r^M */

    /* construct the right hand side and xi factors */
    for (i = 0; i < l; i++) {
      VecDot(rdot[i], r, &xi[i]);
      beta[i] = nu - xi[i]; 
    }

    /* construct h */
    for (j = 0; j < l; j++) {
      for (i = 0; i < l; i++) {
	H(i, j) = Q(i, j) - xi[i] - xi[j] + nu;
      }
    }
#ifdef PETSC_MISSING_LAPACK_GELSS
    SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "NGMRES with LS requires the LAPACK GELSS routine."); 
#else
    ngmres->m = PetscBLASIntCast(l);
    ngmres->n = PetscBLASIntCast(l);
    ngmres->info = PetscBLASIntCast(0);
    ngmres->rcond = -1.;
    ngmres->nrhs = 1;
#ifdef PETSC_USE_COMPLEX
    LAPACKgelss_(&ngmres->m,
		 &ngmres->n,
		 &ngmres->nrhs,
		 ngmres->h,
		 &ngmres->lda,
		 ngmres->beta,
		 &ngmres->ldb,
		 ngmres->s,
		 &ngmres->rcond,
		 &ngmres->rank,
		 ngmres->work,
		 &ngmres->lwork,
		 ngmres->rwork,
		 &ngmres->info);
#else
    LAPACKgelss_(&ngmres->m,
		 &ngmres->n,
		 &ngmres->nrhs,
		 ngmres->h,
		 &ngmres->lda,
		 ngmres->beta,
		 &ngmres->ldb,
		 ngmres->s,
		 &ngmres->rcond,
		 &ngmres->rank,
		 ngmres->work,
		 &ngmres->lwork,
		 &ngmres->info);
#endif
    if (ngmres->info < 0) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_LIB,"Bad argument to GELSS");
    if (ngmres->info > 0) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_LIB,"SVD failed to converge");
#endif

    alph_total = 0.;
    for (i = 0; i < l; i++) {
      alph_total += beta[i];
    }
    ierr = VecCopy(x, x_A);CHKERRQ(ierr);
    ierr = VecScale(x_A, 1. - alph_total);CHKERRQ(ierr);

    for(i=0;i<l;i++){
      ierr= VecAXPY(x_A, beta[i], xdot[i]);CHKERRQ(ierr);
    }
    ierr = SNESComputeFunction(snes, x_A, r_A);CHKERRQ(ierr);
    ierr = VecNorm(r_A, NORM_2, &r_A_norm);CHKERRQ(ierr);

    selectA = PETSC_TRUE;
    /* Conditions for choosing the accelerated answer */

    /* Criterion A -- the norm of the function isn't increased above the minimum by too much */
    if (r_A_norm >= ngmres->gammaA*r_min_norm) {
      selectA = PETSC_FALSE;
    }
    
    /* Criterion B -- the choice of x^A isn't too close to some other choice */
    ierr=VecCopy(x_A,d);CHKERRQ(ierr);   
    ierr=VecAXPY(d,-1,x);CHKERRQ(ierr);   
    ierr=VecNorm(d,NORM_2,&d_norm);CHKERRQ(ierr);     
    d_min_norm = -1.0;
    for(i=0;i<l;i++) {
      ierr=VecCopy(x_A,d);CHKERRQ(ierr);   
      ierr=VecAXPY(d,-1,xdot[i]);CHKERRQ(ierr);   
      ierr=VecNorm(d,NORM_2,&d_cur_norm);CHKERRQ(ierr);        
      if((d_cur_norm < d_min_norm) || (d_min_norm < 0.0)) d_min_norm = d_cur_norm;
    }
    if (ngmres->epsilonB*d_norm<d_min_norm || sqrt(r_norm)<ngmres->deltaB*sqrt(r_min_norm)) {
    } else {
      selectA=PETSC_FALSE;
    }


    if (selectA) {
      if (ngmres->debug) 
	PetscPrintf(PETSC_COMM_WORLD, "picked r_A, ||r_A||_2 = %e, ||r_M||_2 = %e\n", r_A_norm, r_norm);
      /* copy it over */
      r_norm = r_A_norm;
      nu = r_norm*r_norm;
      ierr = VecCopy(r_A, r);CHKERRQ(ierr);
      ierr = VecCopy(x_A, x);CHKERRQ(ierr);
    } else {
      if(ngmres->debug)
	PetscPrintf(PETSC_COMM_WORLD, "picked r_M, ||r_A||_2 = %e, ||r_M||_2 = %e\n", r_A_norm, r_norm);
    }

    selectRestart = PETSC_FALSE;
    
    /* maximum iteration criterion */
    if (k_restart > ngmres->k_rmax) {
      selectRestart = PETSC_TRUE;
    }

    /* difference stagnation restart */
    if 	((ngmres->epsilonB*d_norm > d_min_norm) && 
	 (sqrt(r_A_norm) > ngmres->deltaB*sqrt(r_min_norm))) {
      if (ngmres->debug)
	PetscPrintf(PETSC_COMM_WORLD, "difference restart: %e > %e\n", ngmres->epsilonB*d_norm, d_min_norm);
      selectRestart = PETSC_TRUE;
    }
    
    /* residual stagnation restart */
    if (sqrt(r_A_norm) > ngmres->gammaC*sqrt(r_min_norm)) {
      if (ngmres->debug)
	PetscPrintf(PETSC_COMM_WORLD, "residual restart: %e > %e\n", sqrt(r_A_norm), ngmres->gammaC*sqrt(r_min_norm));
      selectRestart = PETSC_TRUE;
    }

    if (selectRestart) {
      if (ngmres->debug)
	PetscPrintf(PETSC_COMM_WORLD, "Restarted at iteration %d\n", k_restart);
      k_restart = 1;
      l = 1;
      /* q_{00} = nu */
      ngmres->r_norms[0] = r_norm;
      nu = r_norm*r_norm;
      Q(0,0) = nu;
      /* rdot[0] = r */
      ierr = VecCopy(x, xdot[0]);CHKERRQ(ierr);
      ierr = VecCopy(r, rdot[0]);CHKERRQ(ierr);
    } else {
      /* select the current size of the subspace */
      if (l < ngmres->msize) {
	l++;
      }
      k_restart++;
      /* place the current entry in the list of previous entries */
      ierr = VecCopy(r, rdot[ivec]);CHKERRQ(ierr);
      ierr = VecCopy(x, xdot[ivec]);CHKERRQ(ierr);
      ngmres->r_norms[ivec] = r_norm;
      if (r_min_norm > r_norm) r_min_norm = r_norm;  /* the minimum norm is now of r^A */
      for (i = 0; i < l; i++) {
	VecDot(r, rdot[i], &qentry);
	Q(i, ivec) = qentry;
	Q(ivec, i) = qentry;
      }
    }
    
    SNESLogConvHistory(snes, r_norm, k);
    ierr = SNESMonitor(snes, k, r_norm);CHKERRQ(ierr);

    snes->iter =k;
    ierr = (*snes->ops->converged)(snes,snes->iter,0.0,0.0,r_norm,&snes->reason,snes->cnvP);CHKERRQ(ierr);
    if (snes->reason) PetscFunctionReturn(0);
  }
  snes->reason = SNES_DIVERGED_MAX_IT;
  PetscFunctionReturn(0);
}



EXTERN_C_BEGIN
#undef __FUNCT__
#define __FUNCT__ "SNESCreate_NGMRES"
PetscErrorCode SNESCreate_NGMRES(SNES snes)
{
  SNES_NGMRES   *ngmres;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  snes->ops->destroy        = SNESDestroy_NGMRES;
  snes->ops->setup          = SNESSetUp_NGMRES;
  snes->ops->setfromoptions = SNESSetFromOptions_NGMRES;
  snes->ops->view           = SNESView_NGMRES;
  snes->ops->solve          = SNESSolve_NGMRES;
  snes->ops->reset          = SNESReset_NGMRES;

  ierr = PetscNewLog(snes, SNES_NGMRES, &ngmres);CHKERRQ(ierr);
  snes->data = (void*) ngmres;
  ngmres->msize = 10;
  ngmres->debug = PETSC_FALSE;

  ngmres->gammaA = 2.;
  ngmres->gammaC = 2.;
  ngmres->deltaB = 0.9;
  ngmres->epsilonB = 0.1;
  ngmres->k_rmax = 200;

  ierr = SNESGetPC(snes, &snes->pc);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}
EXTERN_C_END
