14c07ec22SJames Wright // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 24c07ec22SJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 34c07ec22SJames Wright // 44c07ec22SJames Wright // SPDX-License-Identifier: BSD-2-Clause 54c07ec22SJames Wright // 64c07ec22SJames Wright // This file is part of CEED: http://github.com/ceed 74c07ec22SJames Wright 84c07ec22SJames Wright #include <petsc.h> 94c07ec22SJames Wright 104c07ec22SJames Wright #ifdef __cplusplus 114c07ec22SJames Wright extern "C" { 124c07ec22SJames Wright #endif 134c07ec22SJames Wright 144c07ec22SJames Wright typedef enum { 154c07ec22SJames Wright TORCH_DEVICE_CPU, 164c07ec22SJames Wright TORCH_DEVICE_CUDA, 174c07ec22SJames Wright TORCH_DEVICE_HIP, 184c07ec22SJames Wright TORCH_DEVICE_XPU, 194c07ec22SJames Wright } TorchDeviceType; 20*6dfcbb05SJames Wright static const char *const TorchDeviceTypes[] = {"CPU", "CUDA", "HIP", "XPU", "TorchDeviceType", "TORCH_DEVICE_", NULL}; 214c07ec22SJames Wright 224c07ec22SJames Wright PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum); 234c07ec22SJames Wright PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc); 244c07ec22SJames Wright 254c07ec22SJames Wright #ifdef __cplusplus 264c07ec22SJames Wright } 274c07ec22SJames Wright #endif 28