// Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
// All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
//
// SPDX-License-Identifier: BSD-2-Clause
//
// This file is part of CEED:  http://github.com/ceed

#include "ceed-magma-gemm-nontensor.h"
#include "ceed-magma-gemm-selector.h"

#ifdef CEED_MAGMA_USE_HIP
#define devblasDgemmStridedBatched hipblasDgemmStridedBatched
#define devblasSgemmStridedBatched hipblasSgemmStridedBatched
#define magma_queue_get_devblas_handle magma_queue_get_hipblas_handle
#define devblas_trans_const hipblas_trans_const
#else
#define devblasDgemmStridedBatched cublasDgemmStridedBatched
#define devblasSgemmStridedBatched cublasSgemmStridedBatched
#define magma_queue_get_devblas_handle magma_queue_get_cublas_handle
#define devblas_trans_const cublas_trans_const
#endif

////////////////////////////////////////////////////////////////////////////////
static inline int magmablas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha,
                                 const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C,
                                 magma_int_t lddc, magma_queue_t queue) {
  if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
    magmablas_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc,
                    queue);
  } else {
    magmablas_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc,
                    queue);
  }
  return 0;
}

////////////////////////////////////////////////////////////////////////////////
static inline int magmablas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k,
                                                 CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA,
                                                 const CeedScalar *d_B, magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C,
                                                 magma_int_t lddc, magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) {
  if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
    magmablas_sgemm_batched_strided(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, strideA, (const float *)d_B, lddb, strideB,
                                    (float)beta, (float *)d_C, lddc, strideC, batchCount, queue);
  } else {
    magmablas_dgemm_batched_strided(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, strideA, (const double *)d_B, lddb, strideB,
                                    (double)beta, (double *)d_C, lddc, strideC, batchCount, queue);
  }
  return 0;
}

////////////////////////////////////////////////////////////////////////////////
static inline int devblas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha,
                               const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C,
                               magma_int_t lddc, magma_queue_t queue) {
  if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
    magma_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc, queue);
  } else {
    magma_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc,
                queue);
  }
  return 0;
}

////////////////////////////////////////////////////////////////////////////////
static inline int devblas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k,
                                               CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA, const CeedScalar *d_B,
                                               magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C, magma_int_t lddc,
                                               magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) {
  if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
    devblasSgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n,
                               (int)k, (const float *)&alpha, (const float *)d_A, (int)ldda, strideA, (const float *)d_B, (int)lddb, strideB,
                               (const float *)&beta, (float *)d_C, (int)lddc, strideC, (int)batchCount);
  } else {
    devblasDgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n,
                               (int)k, (const double *)&alpha, (const double *)d_A, (int)ldda, strideA, (const double *)d_B, (int)lddb, strideB,
                               (const double *)&beta, (double *)d_C, (int)lddc, strideC, (int)batchCount);
  }
  return 0;
}

////////////////////////////////////////////////////////////////////////////////
int magma_gemm_nontensor(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha,
                         const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C,
                         magma_int_t lddc, magma_queue_t queue) {
  magma_int_t nbatch, use_magmablas;
  magma_int_t arch = magma_getdevice_arch();

  // check for specific transpositions (NN and TN only)
  bool NN = trans_A == MagmaNoTrans && trans_B == MagmaNoTrans;
  bool TN = trans_A == MagmaTrans && trans_B == MagmaNoTrans;
  if (!(NN || TN)) {
    // default case -- no specific tuning
    devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue);
    return 0;
  }

  // get tuning decision
  char trans     = (trans_A == MagmaNoTrans) ? 'n' : 't';
  char precision = (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) ? 's' : 'd';
  gemm_selector(arch, precision, trans, m, n, k, &nbatch, &use_magmablas);

  // perform the gemm operation
  if (nbatch == n) {
    // no batching
    if (use_magmablas) {
      magmablas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue);
    } else {
      devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue);
    }
  } else {
    // use batch kernels
    magma_int_t batchCount = n / nbatch;
    magma_int_t n2         = n - (batchCount * nbatch);
    magma_int_t strideA    = 0;
    magma_int_t strideB    = lddb * nbatch;
    magma_int_t strideC    = lddc * nbatch;

    if (use_magmablas) {
      if (batchCount > 0) {
        magmablas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC,
                                       batchCount, queue);
      }

      // cleanup
      if (n2 > 0) {
        devblas_gemm(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, d_B + batchCount * strideB, lddb, beta, d_C + batchCount * strideC, lddc, queue);
      }
    } else {
      if (batchCount > 0) {
        devblas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC,
                                     batchCount, queue);
      }

      // cleanup
      if (n2 > 0) {
        devblas_gemm_batched_strided(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, strideA, d_B + batchCount * strideB, lddb, strideB, beta,
                                     d_C + batchCount * strideC, lddc, strideC, 1, queue);
      }
    }
  }

  // wait for the operation to complete
  ceed_magma_queue_sync(queue);

  return 0;
}
