xref: /honee/include/sgs_model_torch.h (revision 4c07ec2294887c4a114ef13a7c2da0ab5f5dc208)
1*4c07ec22SJames Wright // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2*4c07ec22SJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*4c07ec22SJames Wright //
4*4c07ec22SJames Wright // SPDX-License-Identifier: BSD-2-Clause
5*4c07ec22SJames Wright //
6*4c07ec22SJames Wright // This file is part of CEED:  http://github.com/ceed
7*4c07ec22SJames Wright 
8*4c07ec22SJames Wright #include <petsc.h>
9*4c07ec22SJames Wright 
10*4c07ec22SJames Wright #ifdef __cplusplus
11*4c07ec22SJames Wright extern "C" {
12*4c07ec22SJames Wright #endif
13*4c07ec22SJames Wright 
14*4c07ec22SJames Wright typedef enum {
15*4c07ec22SJames Wright   TORCH_DEVICE_CPU,
16*4c07ec22SJames Wright   TORCH_DEVICE_CUDA,
17*4c07ec22SJames Wright   TORCH_DEVICE_HIP,
18*4c07ec22SJames Wright   TORCH_DEVICE_XPU,
19*4c07ec22SJames Wright } TorchDeviceType;
20*4c07ec22SJames Wright static const char *const TorchDeviceTypes[] = {"cpu", "cuda", "hip", "xpu", "TorchDeviceType", "TORCH_DEVICE_", NULL};
21*4c07ec22SJames Wright 
22*4c07ec22SJames Wright PetscErrorCode LoadModel_Torch(const char *model_path, TorchDeviceType device_enum);
23*4c07ec22SJames Wright PetscErrorCode ModelInference_Torch(Vec DD_Inputs_loc, Vec DD_Outputs_loc);
24*4c07ec22SJames Wright 
25*4c07ec22SJames Wright #ifdef __cplusplus
26*4c07ec22SJames Wright }
27*4c07ec22SJames Wright #endif
28