xref: /libCEED/rust/libceed/src/basis.rs (revision c68be7a2e45631197b626561fad53d5b146fcd59)
1 // Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at
2 // the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights
3 // reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative
16 
17 //! A Ceed Basis defines the discrete finite element basis and associated
18 //! quadrature rule.
19 
20 use crate::prelude::*;
21 
22 // -----------------------------------------------------------------------------
23 // CeedBasis option
24 // -----------------------------------------------------------------------------
25 #[derive(Debug)]
26 pub enum BasisOpt<'a> {
27     Some(&'a Basis<'a>),
28     Collocated,
29 }
30 /// Construct a BasisOpt reference from a Basis reference
31 impl<'a> From<&'a Basis<'_>> for BasisOpt<'a> {
32     fn from(basis: &'a Basis) -> Self {
33         debug_assert!(basis.ptr != unsafe { bind_ceed::CEED_BASIS_COLLOCATED });
34         Self::Some(basis)
35     }
36 }
37 impl<'a> BasisOpt<'a> {
38     /// Transform a Rust libCEED BasisOpt into C libCEED CeedBasis
39     pub(crate) fn to_raw(self) -> bind_ceed::CeedBasis {
40         match self {
41             Self::Some(basis) => basis.ptr,
42             Self::Collocated => unsafe { bind_ceed::CEED_BASIS_COLLOCATED },
43         }
44     }
45 }
46 
47 // -----------------------------------------------------------------------------
48 // CeedBasis context wrapper
49 // -----------------------------------------------------------------------------
50 #[derive(Debug)]
51 pub struct Basis<'a> {
52     ceed: &'a crate::Ceed,
53     pub(crate) ptr: bind_ceed::CeedBasis,
54 }
55 
56 // -----------------------------------------------------------------------------
57 // Destructor
58 // -----------------------------------------------------------------------------
59 impl<'a> Drop for Basis<'a> {
60     fn drop(&mut self) {
61         unsafe {
62             if self.ptr != bind_ceed::CEED_BASIS_COLLOCATED {
63                 bind_ceed::CeedBasisDestroy(&mut self.ptr);
64             }
65         }
66     }
67 }
68 
69 // -----------------------------------------------------------------------------
70 // Display
71 // -----------------------------------------------------------------------------
72 impl<'a> fmt::Display for Basis<'a> {
73     /// View a Basis
74     ///
75     /// ```
76     /// # use libceed::prelude::*;
77     /// # fn main() -> Result<(), libceed::CeedError> {
78     /// # let ceed = libceed::Ceed::default_init();
79     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
80     /// println!("{}", b);
81     /// # Ok(())
82     /// # }
83     /// ```
84     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85         let mut ptr = std::ptr::null_mut();
86         let mut sizeloc = crate::MAX_BUFFER_LENGTH;
87         let cstring = unsafe {
88             let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
89             bind_ceed::CeedBasisView(self.ptr, file);
90             bind_ceed::fclose(file);
91             CString::from_raw(ptr)
92         };
93         cstring.to_string_lossy().fmt(f)
94     }
95 }
96 
97 // -----------------------------------------------------------------------------
98 // Implementations
99 // -----------------------------------------------------------------------------
100 impl<'a> Basis<'a> {
101     // Constructors
102     pub fn create_tensor_H1(
103         ceed: &'a crate::Ceed,
104         dim: usize,
105         ncomp: usize,
106         P1d: usize,
107         Q1d: usize,
108         interp1d: &[crate::Scalar],
109         grad1d: &[crate::Scalar],
110         qref1d: &[crate::Scalar],
111         qweight1d: &[crate::Scalar],
112     ) -> crate::Result<Self> {
113         let mut ptr = std::ptr::null_mut();
114         let (dim, ncomp, P1d, Q1d) = (
115             i32::try_from(dim).unwrap(),
116             i32::try_from(ncomp).unwrap(),
117             i32::try_from(P1d).unwrap(),
118             i32::try_from(Q1d).unwrap(),
119         );
120         let ierr = unsafe {
121             bind_ceed::CeedBasisCreateTensorH1(
122                 ceed.ptr,
123                 dim,
124                 ncomp,
125                 P1d,
126                 Q1d,
127                 interp1d.as_ptr(),
128                 grad1d.as_ptr(),
129                 qref1d.as_ptr(),
130                 qweight1d.as_ptr(),
131                 &mut ptr,
132             )
133         };
134         ceed.check_error(ierr)?;
135         Ok(Self { ceed, ptr })
136     }
137 
138     pub fn create_tensor_H1_Lagrange(
139         ceed: &'a crate::Ceed,
140         dim: usize,
141         ncomp: usize,
142         P: usize,
143         Q: usize,
144         qmode: crate::QuadMode,
145     ) -> crate::Result<Self> {
146         let mut ptr = std::ptr::null_mut();
147         let (dim, ncomp, P, Q, qmode) = (
148             i32::try_from(dim).unwrap(),
149             i32::try_from(ncomp).unwrap(),
150             i32::try_from(P).unwrap(),
151             i32::try_from(Q).unwrap(),
152             qmode as bind_ceed::CeedQuadMode,
153         );
154         let ierr = unsafe {
155             bind_ceed::CeedBasisCreateTensorH1Lagrange(ceed.ptr, dim, ncomp, P, Q, qmode, &mut ptr)
156         };
157         ceed.check_error(ierr)?;
158         Ok(Self { ceed, ptr })
159     }
160 
161     pub fn create_H1(
162         ceed: &'a crate::Ceed,
163         topo: crate::ElemTopology,
164         ncomp: usize,
165         nnodes: usize,
166         nqpts: usize,
167         interp: &[crate::Scalar],
168         grad: &[crate::Scalar],
169         qref: &[crate::Scalar],
170         qweight: &[crate::Scalar],
171     ) -> crate::Result<Self> {
172         let mut ptr = std::ptr::null_mut();
173         let (topo, ncomp, nnodes, nqpts) = (
174             topo as bind_ceed::CeedElemTopology,
175             i32::try_from(ncomp).unwrap(),
176             i32::try_from(nnodes).unwrap(),
177             i32::try_from(nqpts).unwrap(),
178         );
179         let ierr = unsafe {
180             bind_ceed::CeedBasisCreateH1(
181                 ceed.ptr,
182                 topo,
183                 ncomp,
184                 nnodes,
185                 nqpts,
186                 interp.as_ptr(),
187                 grad.as_ptr(),
188                 qref.as_ptr(),
189                 qweight.as_ptr(),
190                 &mut ptr,
191             )
192         };
193         ceed.check_error(ierr)?;
194         Ok(Self { ceed, ptr })
195     }
196 
197     /// Apply basis evaluation from nodes to quadrature points or vice versa
198     ///
199     /// * `nelem` - The number of elements to apply the basis evaluation to
200     /// * `tmode` - `TrasposeMode::NoTranspose` to evaluate from nodes to
201     ///               quadrature points, `TransposeMode::Transpose` to apply the
202     ///               transpose, mapping from quadrature points to nodes
203     /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp`
204     ///               to use interpolated values, `EvalMode::Grad` to use
205     ///               gradients, `EvalMode::Weight` to use quadrature weights
206     /// * `u`     - Input Vector
207     /// * `v`     - Output Vector
208     ///
209     /// ```
210     /// # use libceed::prelude::*;
211     /// # fn main() -> Result<(), libceed::CeedError> {
212     /// # let ceed = libceed::Ceed::default_init();
213     /// const Q: usize = 6;
214     /// let bu = ceed.basis_tensor_H1_Lagrange(1, 1, Q, Q, QuadMode::GaussLobatto)?;
215     /// let bx = ceed.basis_tensor_H1_Lagrange(1, 1, 2, Q, QuadMode::Gauss)?;
216     ///
217     /// let x_corners = ceed.vector_from_slice(&[-1., 1.])?;
218     /// let mut x_qpts = ceed.vector(Q)?;
219     /// let mut x_nodes = ceed.vector(Q)?;
220     /// bx.apply(
221     ///     1,
222     ///     TransposeMode::NoTranspose,
223     ///     EvalMode::Interp,
224     ///     &x_corners,
225     ///     &mut x_nodes,
226     /// );
227     /// bu.apply(
228     ///     1,
229     ///     TransposeMode::NoTranspose,
230     ///     EvalMode::Interp,
231     ///     &x_nodes,
232     ///     &mut x_qpts,
233     /// );
234     ///
235     /// // Create function x^3 + 1 on Gauss Lobatto points
236     /// let mut u_arr = [0.; Q];
237     /// u_arr
238     ///     .iter_mut()
239     ///     .zip(x_nodes.view().iter())
240     ///     .for_each(|(u, x)| *u = x * x * x + 1.);
241     /// let u = ceed.vector_from_slice(&u_arr)?;
242     ///
243     /// // Map function to Gauss points
244     /// let mut v = ceed.vector(Q)?;
245     /// v.set_value(0.);
246     /// bu.apply(1, TransposeMode::NoTranspose, EvalMode::Interp, &u, &mut v)?;
247     ///
248     /// // Verify results
249     /// v.view()
250     ///     .iter()
251     ///     .zip(x_qpts.view().iter())
252     ///     .for_each(|(v, x)| {
253     ///         let true_value = x * x * x + 1.;
254     ///         assert!(
255     ///             (*v - true_value).abs() < 10.0 * libceed::EPSILON,
256     ///             "Incorrect basis application"
257     ///         );
258     ///     });
259     /// # Ok(())
260     /// # }
261     /// ```
262     pub fn apply(
263         &self,
264         nelem: usize,
265         tmode: TransposeMode,
266         emode: EvalMode,
267         u: &Vector,
268         v: &mut Vector,
269     ) -> crate::Result<i32> {
270         let (nelem, tmode, emode) = (
271             i32::try_from(nelem).unwrap(),
272             tmode as bind_ceed::CeedTransposeMode,
273             emode as bind_ceed::CeedEvalMode,
274         );
275         let ierr =
276             unsafe { bind_ceed::CeedBasisApply(self.ptr, nelem, tmode, emode, u.ptr, v.ptr) };
277         self.ceed.check_error(ierr)
278     }
279 
280     /// Returns the dimension for given CeedBasis
281     ///
282     /// ```
283     /// # use libceed::prelude::*;
284     /// # fn main() -> Result<(), libceed::CeedError> {
285     /// # let ceed = libceed::Ceed::default_init();
286     /// let dim = 2;
287     /// let b = ceed.basis_tensor_H1_Lagrange(dim, 1, 3, 4, QuadMode::Gauss)?;
288     ///
289     /// let d = b.dimension();
290     /// assert_eq!(d, dim, "Incorrect dimension");
291     /// # Ok(())
292     /// # }
293     /// ```
294     pub fn dimension(&self) -> usize {
295         let mut dim = 0;
296         unsafe { bind_ceed::CeedBasisGetDimension(self.ptr, &mut dim) };
297         usize::try_from(dim).unwrap()
298     }
299 
300     /// Returns number of components for given CeedBasis
301     ///
302     /// ```
303     /// # use libceed::prelude::*;
304     /// # fn main() -> Result<(), libceed::CeedError> {
305     /// # let ceed = libceed::Ceed::default_init();
306     /// let ncomp = 2;
307     /// let b = ceed.basis_tensor_H1_Lagrange(1, ncomp, 3, 4, QuadMode::Gauss)?;
308     ///
309     /// let n = b.num_components();
310     /// assert_eq!(n, ncomp, "Incorrect number of components");
311     /// # Ok(())
312     /// # }
313     /// ```
314     pub fn num_components(&self) -> usize {
315         let mut ncomp = 0;
316         unsafe { bind_ceed::CeedBasisGetNumComponents(self.ptr, &mut ncomp) };
317         usize::try_from(ncomp).unwrap()
318     }
319 
320     /// Returns total number of nodes (in dim dimensions) of a CeedBasis
321     ///
322     /// ```
323     /// # use libceed::prelude::*;
324     /// # fn main() -> Result<(), libceed::CeedError> {
325     /// # let ceed = libceed::Ceed::default_init();
326     /// let p = 3;
327     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, p, 4, QuadMode::Gauss)?;
328     ///
329     /// let nnodes = b.num_nodes();
330     /// assert_eq!(nnodes, p * p, "Incorrect number of nodes");
331     /// # Ok(())
332     /// # }
333     /// ```
334     pub fn num_nodes(&self) -> usize {
335         let mut nnodes = 0;
336         unsafe { bind_ceed::CeedBasisGetNumNodes(self.ptr, &mut nnodes) };
337         usize::try_from(nnodes).unwrap()
338     }
339 
340     /// Returns total number of quadrature points (in dim dimensions) of a
341     /// CeedBasis
342     ///
343     /// ```
344     /// # use libceed::prelude::*;
345     /// # fn main() -> Result<(), libceed::CeedError> {
346     /// # let ceed = libceed::Ceed::default_init();
347     /// let q = 4;
348     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, 3, q, QuadMode::Gauss)?;
349     ///
350     /// let nqpts = b.num_quadrature_points();
351     /// assert_eq!(nqpts, q * q, "Incorrect number of quadrature points");
352     /// # Ok(())
353     /// # }
354     /// ```
355     pub fn num_quadrature_points(&self) -> usize {
356         let mut Q = 0;
357         unsafe {
358             bind_ceed::CeedBasisGetNumQuadraturePoints(self.ptr, &mut Q);
359         }
360         usize::try_from(Q).unwrap()
361     }
362 }
363 
364 // -----------------------------------------------------------------------------
365