1*4c07ec22SJames Wright // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2*4c07ec22SJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*4c07ec22SJames Wright // 4*4c07ec22SJames Wright // SPDX-License-Identifier: BSD-2-Clause 5*4c07ec22SJames Wright // 6*4c07ec22SJames Wright // This file is part of CEED: http://github.com/ceed 7*4c07ec22SJames Wright 8*4c07ec22SJames Wright #include <petsc.h> 9*4c07ec22SJames Wright 10*4c07ec22SJames Wright #ifdef __cplusplus 11*4c07ec22SJames Wright extern "C" { 12*4c07ec22SJames Wright #endif 13*4c07ec22SJames Wright 14*4c07ec22SJames Wright typedef enum { 15*4c07ec22SJames Wright TORCH_DEVICE_CPU, 16*4c07ec22SJames Wright TORCH_DEVICE_CUDA, 17*4c07ec22SJames Wright TORCH_DEVICE_HIP, 18*4c07ec22SJames Wright TORCH_DEVICE_XPU, 19*4c07ec22SJames Wright } TorchDeviceType; 20*4c07ec22SJames Wright static const char *const TorchDeviceTypes[] = {"cpu", "cuda", "hip", "xpu", "TorchDeviceType", "TORCH_DEVICE_", NULL}; 21*4c07ec22SJames Wright 22*4c07ec22SJames Wright PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum); 23*4c07ec22SJames Wright PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc); 24*4c07ec22SJames Wright 25*4c07ec22SJames Wright #ifdef __cplusplus 26*4c07ec22SJames Wright } 27*4c07ec22SJames Wright #endif 28