xref: /honee/include/sgs_model_torch.h (revision 6dfcbb0586fd7920baad6b612d8e992adb46e8d1)
14c07ec22SJames Wright // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
24c07ec22SJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
34c07ec22SJames Wright //
44c07ec22SJames Wright // SPDX-License-Identifier: BSD-2-Clause
54c07ec22SJames Wright //
64c07ec22SJames Wright // This file is part of CEED:  http://github.com/ceed
74c07ec22SJames Wright 
84c07ec22SJames Wright #include <petsc.h>
94c07ec22SJames Wright 
104c07ec22SJames Wright #ifdef __cplusplus
114c07ec22SJames Wright extern "C" {
124c07ec22SJames Wright #endif
134c07ec22SJames Wright 
144c07ec22SJames Wright typedef enum {
154c07ec22SJames Wright   TORCH_DEVICE_CPU,
164c07ec22SJames Wright   TORCH_DEVICE_CUDA,
174c07ec22SJames Wright   TORCH_DEVICE_HIP,
184c07ec22SJames Wright   TORCH_DEVICE_XPU,
194c07ec22SJames Wright } TorchDeviceType;
20*6dfcbb05SJames Wright static const char *const TorchDeviceTypes[] = {"CPU", "CUDA", "HIP", "XPU", "TorchDeviceType", "TORCH_DEVICE_", NULL};
214c07ec22SJames Wright 
224c07ec22SJames Wright PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum);
234c07ec22SJames Wright PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc);
244c07ec22SJames Wright 
254c07ec22SJames Wright #ifdef __cplusplus
264c07ec22SJames Wright }
274c07ec22SJames Wright #endif
28