1*5aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 24548da4eSSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 34548da4eSSebastian Grimberg // 44548da4eSSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 54548da4eSSebastian Grimberg // 64548da4eSSebastian Grimberg // This file is part of CEED: http://github.com/ceed 74548da4eSSebastian Grimberg 84548da4eSSebastian Grimberg #include <ceed.h> 94548da4eSSebastian Grimberg #include <ceed/backend.h> 104548da4eSSebastian Grimberg #include <libxsmm.h> 114548da4eSSebastian Grimberg 124548da4eSSebastian Grimberg #include "ceed-xsmm.h" 134548da4eSSebastian Grimberg 144548da4eSSebastian Grimberg //------------------------------------------------------------------------------ 154548da4eSSebastian Grimberg // Tensor Contract Apply 164548da4eSSebastian Grimberg //------------------------------------------------------------------------------ 174548da4eSSebastian Grimberg static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 184548da4eSSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 194548da4eSSebastian Grimberg Ceed ceed; 20ad70ee2cSJeremy L Thompson 214548da4eSSebastian Grimberg CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed)); 224548da4eSSebastian Grimberg 234548da4eSSebastian Grimberg if (C == 1) { 244548da4eSSebastian Grimberg // Build or query the required kernel 254548da4eSSebastian Grimberg const int flags_t = LIBXSMM_GEMM_FLAGS(!t_mode ? 'T' : 'N', 'N'); 264548da4eSSebastian Grimberg const int flags_ab = (!add) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE; 274548da4eSSebastian Grimberg const int flags = (flags_t | flags_ab); 284548da4eSSebastian Grimberg const libxsmm_gemm_shape gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64) 294548da4eSSebastian Grimberg ? libxsmm_create_gemm_shape(J, A, B, !t_mode ? B : J, B, J, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, 304548da4eSSebastian Grimberg LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64) 314548da4eSSebastian Grimberg : libxsmm_create_gemm_shape(J, A, B, !t_mode ? B : J, B, J, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, 324548da4eSSebastian Grimberg LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32); 334548da4eSSebastian Grimberg const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm_v2(gemm_shape, (libxsmm_bitfield)(flags), (libxsmm_bitfield)LIBXSMM_GEMM_PREFETCH_NONE); 34ad70ee2cSJeremy L Thompson libxsmm_gemm_param gemm_param; 35ad70ee2cSJeremy L Thompson 364548da4eSSebastian Grimberg CeedCheck(kernel, ceed, CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build."); 374548da4eSSebastian Grimberg 384548da4eSSebastian Grimberg // Run kernel 394548da4eSSebastian Grimberg gemm_param.a.primary = (CeedScalar *)&t[0]; 404548da4eSSebastian Grimberg gemm_param.b.primary = (CeedScalar *)&u[0]; 414548da4eSSebastian Grimberg gemm_param.c.primary = (CeedScalar *)&v[0]; 424548da4eSSebastian Grimberg kernel(&gemm_param); 434548da4eSSebastian Grimberg } else { 444548da4eSSebastian Grimberg // Build or query the required kernel 454548da4eSSebastian Grimberg const int flags_t = LIBXSMM_GEMM_FLAGS('N', t_mode ? 'T' : 'N'); 464548da4eSSebastian Grimberg const int flags_ab = (!add) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE; 474548da4eSSebastian Grimberg const int flags = (flags_t | flags_ab); 484548da4eSSebastian Grimberg const libxsmm_gemm_shape gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64) 494548da4eSSebastian Grimberg ? libxsmm_create_gemm_shape(C, J, B, C, !t_mode ? B : J, C, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, 504548da4eSSebastian Grimberg LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64) 514548da4eSSebastian Grimberg : libxsmm_create_gemm_shape(C, J, B, C, !t_mode ? B : J, C, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, 524548da4eSSebastian Grimberg LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32); 534548da4eSSebastian Grimberg const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm_v2(gemm_shape, (libxsmm_bitfield)(flags), (libxsmm_bitfield)LIBXSMM_GEMM_PREFETCH_NONE); 54ad70ee2cSJeremy L Thompson libxsmm_gemm_param gemm_param; 55ad70ee2cSJeremy L Thompson 564548da4eSSebastian Grimberg CeedCheck(kernel, ceed, CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build."); 574548da4eSSebastian Grimberg 584548da4eSSebastian Grimberg // Run kernel 594548da4eSSebastian Grimberg gemm_param.b.primary = (CeedScalar *)&t[0]; 604548da4eSSebastian Grimberg for (CeedInt a = 0; a < A; a++) { 614548da4eSSebastian Grimberg gemm_param.a.primary = (CeedScalar *)&u[a * B * C]; 624548da4eSSebastian Grimberg gemm_param.c.primary = (CeedScalar *)&v[a * J * C]; 634548da4eSSebastian Grimberg kernel(&gemm_param); 644548da4eSSebastian Grimberg } 654548da4eSSebastian Grimberg } 664548da4eSSebastian Grimberg return CEED_ERROR_SUCCESS; 674548da4eSSebastian Grimberg } 684548da4eSSebastian Grimberg 694548da4eSSebastian Grimberg //------------------------------------------------------------------------------ 704548da4eSSebastian Grimberg // Tensor Contract Create 714548da4eSSebastian Grimberg //------------------------------------------------------------------------------ 72a71faab1SSebastian Grimberg int CeedTensorContractCreate_Xsmm(CeedTensorContract contract) { 736e536b99SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(CeedTensorContractReturnCeed(contract), "TensorContract", contract, "Apply", CeedTensorContractApply_Xsmm)); 744548da4eSSebastian Grimberg return CEED_ERROR_SUCCESS; 754548da4eSSebastian Grimberg } 764548da4eSSebastian Grimberg 774548da4eSSebastian Grimberg //------------------------------------------------------------------------------ 78