xref: /libCEED/rust/libceed-sys/c-src/backends/sycl-ref/ceed-sycl-ref.hpp (revision bd882c8a454763a096666645dc9a6229d5263694)
1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other
2*bd882c8aSJames Wright // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE
3*bd882c8aSJames Wright // files for details.
4*bd882c8aSJames Wright //
5*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
6*bd882c8aSJames Wright //
7*bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
8*bd882c8aSJames Wright 
9*bd882c8aSJames Wright #ifndef _ceed_sycl_hpp
10*bd882c8aSJames Wright #define _ceed_sycl_hpp
11*bd882c8aSJames Wright 
12*bd882c8aSJames Wright #include <ceed/backend.h>
13*bd882c8aSJames Wright #include <ceed/ceed.h>
14*bd882c8aSJames Wright 
15*bd882c8aSJames Wright #include <sycl/sycl.hpp>
16*bd882c8aSJames Wright 
17*bd882c8aSJames Wright #include "../sycl/ceed-sycl-common.hpp"
18*bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp"
19*bd882c8aSJames Wright 
20*bd882c8aSJames Wright typedef struct {
21*bd882c8aSJames Wright   CeedScalar *h_array;
22*bd882c8aSJames Wright   CeedScalar *h_array_borrowed;
23*bd882c8aSJames Wright   CeedScalar *h_array_owned;
24*bd882c8aSJames Wright   CeedScalar *d_array;
25*bd882c8aSJames Wright   CeedScalar *d_array_borrowed;
26*bd882c8aSJames Wright   CeedScalar *d_array_owned;
27*bd882c8aSJames Wright   CeedScalar *reduction_norm;
28*bd882c8aSJames Wright } CeedVector_Sycl;
29*bd882c8aSJames Wright 
30*bd882c8aSJames Wright typedef struct {
31*bd882c8aSJames Wright   CeedInt  num_nodes;
32*bd882c8aSJames Wright   CeedInt  num_elem;
33*bd882c8aSJames Wright   CeedInt  num_comp;
34*bd882c8aSJames Wright   CeedInt  elem_size;
35*bd882c8aSJames Wright   CeedInt  comp_stride;
36*bd882c8aSJames Wright   CeedInt  strides[3];
37*bd882c8aSJames Wright   CeedInt *h_ind;
38*bd882c8aSJames Wright   CeedInt *h_ind_allocated;
39*bd882c8aSJames Wright   CeedInt *d_ind;
40*bd882c8aSJames Wright   CeedInt *d_ind_allocated;
41*bd882c8aSJames Wright   CeedInt *d_t_offsets;
42*bd882c8aSJames Wright   CeedInt *d_t_indices;
43*bd882c8aSJames Wright   CeedInt *d_l_vec_indices;
44*bd882c8aSJames Wright } CeedElemRestriction_Sycl;
45*bd882c8aSJames Wright 
46*bd882c8aSJames Wright typedef struct {
47*bd882c8aSJames Wright   CeedInt       dim;
48*bd882c8aSJames Wright   CeedInt       P_1d;
49*bd882c8aSJames Wright   CeedInt       Q_1d;
50*bd882c8aSJames Wright   CeedInt       num_comp;
51*bd882c8aSJames Wright   CeedInt       num_nodes;
52*bd882c8aSJames Wright   CeedInt       num_qpts;
53*bd882c8aSJames Wright   CeedInt       buf_len;
54*bd882c8aSJames Wright   CeedInt       op_len;
55*bd882c8aSJames Wright   SyclModule_t *sycl_module;
56*bd882c8aSJames Wright   CeedScalar   *d_interp_1d;
57*bd882c8aSJames Wright   CeedScalar   *d_grad_1d;
58*bd882c8aSJames Wright   CeedScalar   *d_q_weight_1d;
59*bd882c8aSJames Wright } CeedBasis_Sycl;
60*bd882c8aSJames Wright 
61*bd882c8aSJames Wright typedef struct {
62*bd882c8aSJames Wright   CeedInt     dim;
63*bd882c8aSJames Wright   CeedInt     num_comp;
64*bd882c8aSJames Wright   CeedInt     num_nodes;
65*bd882c8aSJames Wright   CeedInt     num_qpts;
66*bd882c8aSJames Wright   CeedScalar *d_interp;
67*bd882c8aSJames Wright   CeedScalar *d_grad;
68*bd882c8aSJames Wright   CeedScalar *d_q_weight;
69*bd882c8aSJames Wright } CeedBasisNonTensor_Sycl;
70*bd882c8aSJames Wright 
71*bd882c8aSJames Wright typedef struct {
72*bd882c8aSJames Wright   SyclModule_t *sycl_module;
73*bd882c8aSJames Wright   sycl::kernel *QFunction;
74*bd882c8aSJames Wright } CeedQFunction_Sycl;
75*bd882c8aSJames Wright 
76*bd882c8aSJames Wright typedef struct {
77*bd882c8aSJames Wright   void *h_data;
78*bd882c8aSJames Wright   void *h_data_borrowed;
79*bd882c8aSJames Wright   void *h_data_owned;
80*bd882c8aSJames Wright   void *d_data;
81*bd882c8aSJames Wright   void *d_data_borrowed;
82*bd882c8aSJames Wright   void *d_data_owned;
83*bd882c8aSJames Wright } CeedQFunctionContext_Sycl;
84*bd882c8aSJames Wright 
85*bd882c8aSJames Wright typedef struct {
86*bd882c8aSJames Wright   CeedBasis           basisin, basisout;
87*bd882c8aSJames Wright   CeedElemRestriction diagrstr, pbdiagrstr;
88*bd882c8aSJames Wright   CeedVector          elemdiag, pbelemdiag;
89*bd882c8aSJames Wright   CeedInt             numemodein, numemodeout, nnodes;
90*bd882c8aSJames Wright   CeedInt             nqpts, ncomp;  // Kernel parameters
91*bd882c8aSJames Wright   CeedEvalMode       *h_emodein, *h_emodeout;
92*bd882c8aSJames Wright   CeedEvalMode       *d_emodein, *d_emodeout;
93*bd882c8aSJames Wright   CeedScalar         *d_identity, *d_interpin, *d_interpout, *d_gradin, *d_gradout;
94*bd882c8aSJames Wright } CeedOperatorDiag_Sycl;
95*bd882c8aSJames Wright 
96*bd882c8aSJames Wright typedef struct {
97*bd882c8aSJames Wright   CeedInt     nelem, block_size_x, block_size_y, elemsPerBlock;
98*bd882c8aSJames Wright   CeedInt     numemodein, numemodeout, nqpts, nnodes, block_size, ncomp;  // Kernel parameters
99*bd882c8aSJames Wright   bool        fallback;
100*bd882c8aSJames Wright   CeedScalar *d_B_in, *d_B_out;
101*bd882c8aSJames Wright } CeedOperatorAssemble_Sycl;
102*bd882c8aSJames Wright 
103*bd882c8aSJames Wright typedef struct {
104*bd882c8aSJames Wright   CeedVector                *evecs;     // E-vectors, inputs followed by outputs
105*bd882c8aSJames Wright   CeedVector                *qvecsin;   // Input Q-vectors needed to apply operator
106*bd882c8aSJames Wright   CeedVector                *qvecsout;  // Output Q-vectors needed to apply operator
107*bd882c8aSJames Wright   CeedInt                    numein;
108*bd882c8aSJames Wright   CeedInt                    numeout;
109*bd882c8aSJames Wright   CeedInt                    qfnumactivein, qfnumactiveout;
110*bd882c8aSJames Wright   CeedVector                *qfactivein;
111*bd882c8aSJames Wright   CeedOperatorDiag_Sycl     *diag;
112*bd882c8aSJames Wright   CeedOperatorAssemble_Sycl *asmb;
113*bd882c8aSJames Wright } CeedOperator_Sycl;
114*bd882c8aSJames Wright 
115*bd882c8aSJames Wright CEED_INTERN int CeedVectorCreate_Sycl(CeedSize n, CeedVector vec);
116*bd882c8aSJames Wright 
117*bd882c8aSJames Wright CEED_INTERN int CeedElemRestrictionCreate_Sycl(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, CeedElemRestriction r);
118*bd882c8aSJames Wright 
119*bd882c8aSJames Wright CEED_INTERN int CeedBasisApplyElems_Sycl(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode,
120*bd882c8aSJames Wright                                          const CeedVector u, CeedVector v);
121*bd882c8aSJames Wright 
122*bd882c8aSJames Wright CEED_INTERN int CeedQFunctionApplyElems_Sycl(CeedQFunction qf, const CeedInt Q, const CeedVector *const u, const CeedVector *v);
123*bd882c8aSJames Wright 
124*bd882c8aSJames Wright CEED_INTERN int CeedBasisCreateTensorH1_Sycl(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d,
125*bd882c8aSJames Wright                                              const CeedScalar *qref_1d, const CeedScalar *qweight_1d, CeedBasis basis);
126*bd882c8aSJames Wright 
127*bd882c8aSJames Wright CEED_INTERN int CeedBasisCreateH1_Sycl(CeedElemTopology, CeedInt, CeedInt, CeedInt, const CeedScalar *, const CeedScalar *, const CeedScalar *,
128*bd882c8aSJames Wright                                        const CeedScalar *, CeedBasis);
129*bd882c8aSJames Wright 
130*bd882c8aSJames Wright CEED_INTERN int CeedQFunctionCreate_Sycl(CeedQFunction qf);
131*bd882c8aSJames Wright 
132*bd882c8aSJames Wright CEED_INTERN int CeedQFunctionContextCreate_Sycl(CeedQFunctionContext ctx);
133*bd882c8aSJames Wright 
134*bd882c8aSJames Wright CEED_INTERN int CeedOperatorCreate_Sycl(CeedOperator op);
135*bd882c8aSJames Wright 
136*bd882c8aSJames Wright #endif
137