/*
  Defines matrix-matrix product routines for pairs of MPIAIJ matrices
          C = A^T * B
  The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense().
*/
#include <../src/mat/impls/aij/seq/aij.h> /*I "petscmat.h" I*/
#include <../src/mat/impls/aij/mpi/mpiaij.h>
#include <../src/mat/impls/dense/mpi/mpidense.h>

static PetscErrorCode MatProductCtxDestroy_MPIDense_MatTransMatMult(PetscCtxRt data)
{
  MatProductCtx_MatTransMatMult *atb = *(MatProductCtx_MatTransMatMult **)data;

  PetscFunctionBegin;
  PetscCall(MatDestroy(&atb->mA));
  PetscCall(VecDestroy(&atb->bt));
  PetscCall(VecDestroy(&atb->ct));
  PetscCall(PetscFree(atb));
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat);

PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C)
{
  MatProductCtx_MatTransMatMult *atb;
  PetscBool                      cisdense;

  PetscFunctionBegin;
  MatCheckProduct(C, 4);
  PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty");

  /* create output dense matrix C = A^T*B */
  PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N));
  PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, ""));
  if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name));
  PetscCall(MatSetUp(C));

  /* create additional data structure for the product */
  PetscCall(PetscNew(&atb));
  if (B->cmap->N) {
    PetscCall(MatCreateMAIJ(A, B->cmap->N, &atb->mA));
    if (!atb->mA->assembled) {
      PetscCall(MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY));
      PetscCall(MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY));
    }
    PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt));
  }
  C->product->data    = atb;
  C->product->destroy = MatProductCtxDestroy_MPIDense_MatTransMatMult;

  C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense;
  PetscFunctionReturn(PETSC_SUCCESS);
}

static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C)
{
  const PetscScalar             *Barray, *ctarray;
  PetscScalar                   *Carray, *btarray;
  PetscInt                       i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc;
  MatProductCtx_MatTransMatMult *atb;
  Vec                            bt, ct;

  PetscFunctionBegin;
  MatCheckProduct(C, 3);
  atb = (MatProductCtx_MatTransMatMult *)C->product->data;
  PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct");
  if (!BN) {
    PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
    PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
    PetscFunctionReturn(PETSC_SUCCESS);
  }
  bt = atb->bt;
  ct = atb->ct;

  /* transpose local array of B, then copy it to vector bt */
  PetscCall(MatDenseGetArrayRead(B, &Barray));
  PetscCall(MatDenseGetLDA(B, &ldb));
  PetscCall(VecGetArray(bt, &btarray));
  for (j = 0; j < BN; j++)
    for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i];
  PetscCall(VecRestoreArray(bt, &btarray));
  PetscCall(MatDenseRestoreArrayRead(B, &Barray));

  /* compute ct = mA^T * cb */
  PetscCall(MatMultTranspose(atb->mA, bt, ct));

  /* transpose local array of ct to matrix C */
  PetscCall(MatDenseGetArray(C, &Carray));
  PetscCall(MatDenseGetLDA(C, &ldc));
  PetscCall(VecGetArrayRead(ct, &ctarray));
  for (j = 0; j < BN; j++)
    for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j];
  PetscCall(VecRestoreArrayRead(ct, &ctarray));
  PetscCall(MatDenseRestoreArray(C, &Carray));
  PetscCall(MatSetOption(C, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
  PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
  PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
  PetscFunctionReturn(PETSC_SUCCESS);
}
