xref: /libCEED/rust/libceed/src/basis.rs (revision 3e551a327d6c97f9de071b988b42ffdb7bed19a7)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 //! A Ceed Basis defines the discrete finite element basis and associated
9 //! quadrature rule.
10 
11 use crate::prelude::*;
12 
13 // -----------------------------------------------------------------------------
14 // Basis option
15 // -----------------------------------------------------------------------------
16 #[derive(Debug)]
17 pub enum BasisOpt<'a> {
18     Some(&'a Basis<'a>),
19     None,
20 }
21 /// Construct a BasisOpt reference from a Basis reference
22 impl<'a> From<&'a Basis<'_>> for BasisOpt<'a> {
23     fn from(basis: &'a Basis) -> Self {
24         debug_assert!(basis.ptr != unsafe { bind_ceed::CEED_BASIS_NONE });
25         Self::Some(basis)
26     }
27 }
28 impl<'a> BasisOpt<'a> {
29     /// Transform a Rust libCEED BasisOpt into C libCEED CeedBasis
30     pub(crate) fn to_raw(self) -> bind_ceed::CeedBasis {
31         match self {
32             Self::Some(basis) => basis.ptr,
33             Self::None => unsafe { bind_ceed::CEED_BASIS_NONE },
34         }
35     }
36 
37     /// Check if a BasisOpt is Some
38     ///
39     /// ```
40     /// # use libceed::prelude::*;
41     /// # fn main() -> libceed::Result<()> {
42     /// # let ceed = libceed::Ceed::default_init();
43     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
44     /// let b_opt = BasisOpt::from(&b);
45     /// assert!(b_opt.is_some(), "Incorrect BasisOpt");
46     ///
47     /// let b_opt = BasisOpt::None;
48     /// assert!(!b_opt.is_some(), "Incorrect BasisOpt");
49     /// # Ok(())
50     /// # }
51     /// ```
52     pub fn is_some(&self) -> bool {
53         match self {
54             Self::Some(_) => true,
55             Self::None => false,
56         }
57     }
58 
59     /// Check if a BasisOpt is None
60     ///
61     /// ```
62     /// # use libceed::prelude::*;
63     /// # fn main() -> libceed::Result<()> {
64     /// # let ceed = libceed::Ceed::default_init();
65     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
66     /// let b_opt = BasisOpt::from(&b);
67     /// assert!(!b_opt.is_none(), "Incorrect BasisOpt");
68     ///
69     /// let b_opt = BasisOpt::None;
70     /// assert!(b_opt.is_none(), "Incorrect BasisOpt");
71     /// # Ok(())
72     /// # }
73     /// ```
74     pub fn is_none(&self) -> bool {
75         match self {
76             Self::Some(_) => false,
77             Self::None => true,
78         }
79     }
80 }
81 
82 // -----------------------------------------------------------------------------
83 // Basis context wrapper
84 // -----------------------------------------------------------------------------
85 #[derive(Debug)]
86 pub struct Basis<'a> {
87     pub(crate) ptr: bind_ceed::CeedBasis,
88     _lifeline: PhantomData<&'a ()>,
89 }
90 
91 // -----------------------------------------------------------------------------
92 // Destructor
93 // -----------------------------------------------------------------------------
94 impl<'a> Drop for Basis<'a> {
95     fn drop(&mut self) {
96         unsafe {
97             if self.ptr != bind_ceed::CEED_BASIS_NONE {
98                 bind_ceed::CeedBasisDestroy(&mut self.ptr);
99             }
100         }
101     }
102 }
103 
104 // -----------------------------------------------------------------------------
105 // Display
106 // -----------------------------------------------------------------------------
107 impl<'a> fmt::Display for Basis<'a> {
108     /// View a Basis
109     ///
110     /// ```
111     /// # use libceed::prelude::*;
112     /// # fn main() -> libceed::Result<()> {
113     /// # let ceed = libceed::Ceed::default_init();
114     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
115     /// println!("{}", b);
116     /// # Ok(())
117     /// # }
118     /// ```
119     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
120         let mut ptr = std::ptr::null_mut();
121         let mut sizeloc = crate::MAX_BUFFER_LENGTH;
122         let cstring = unsafe {
123             let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
124             bind_ceed::CeedBasisView(self.ptr, file);
125             bind_ceed::fclose(file);
126             CString::from_raw(ptr)
127         };
128         cstring.to_string_lossy().fmt(f)
129     }
130 }
131 
132 // -----------------------------------------------------------------------------
133 // Implementations
134 // -----------------------------------------------------------------------------
135 impl<'a> Basis<'a> {
136     // Constructors
137     pub fn create_tensor_H1(
138         ceed: &crate::Ceed,
139         dim: usize,
140         ncomp: usize,
141         P1d: usize,
142         Q1d: usize,
143         interp1d: &[crate::Scalar],
144         grad1d: &[crate::Scalar],
145         qref1d: &[crate::Scalar],
146         qweight1d: &[crate::Scalar],
147     ) -> crate::Result<Self> {
148         let mut ptr = std::ptr::null_mut();
149         let (dim, ncomp, P1d, Q1d) = (
150             i32::try_from(dim).unwrap(),
151             i32::try_from(ncomp).unwrap(),
152             i32::try_from(P1d).unwrap(),
153             i32::try_from(Q1d).unwrap(),
154         );
155         let ierr = unsafe {
156             bind_ceed::CeedBasisCreateTensorH1(
157                 ceed.ptr,
158                 dim,
159                 ncomp,
160                 P1d,
161                 Q1d,
162                 interp1d.as_ptr(),
163                 grad1d.as_ptr(),
164                 qref1d.as_ptr(),
165                 qweight1d.as_ptr(),
166                 &mut ptr,
167             )
168         };
169         ceed.check_error(ierr)?;
170         Ok(Self {
171             ptr,
172             _lifeline: PhantomData,
173         })
174     }
175 
176     pub(crate) fn from_raw(ptr: bind_ceed::CeedBasis) -> crate::Result<Self> {
177         Ok(Self {
178             ptr,
179             _lifeline: PhantomData,
180         })
181     }
182 
183     pub fn create_tensor_H1_Lagrange(
184         ceed: &crate::Ceed,
185         dim: usize,
186         ncomp: usize,
187         P: usize,
188         Q: usize,
189         qmode: crate::QuadMode,
190     ) -> crate::Result<Self> {
191         let mut ptr = std::ptr::null_mut();
192         let (dim, ncomp, P, Q, qmode) = (
193             i32::try_from(dim).unwrap(),
194             i32::try_from(ncomp).unwrap(),
195             i32::try_from(P).unwrap(),
196             i32::try_from(Q).unwrap(),
197             qmode as bind_ceed::CeedQuadMode,
198         );
199         let ierr = unsafe {
200             bind_ceed::CeedBasisCreateTensorH1Lagrange(ceed.ptr, dim, ncomp, P, Q, qmode, &mut ptr)
201         };
202         ceed.check_error(ierr)?;
203         Ok(Self {
204             ptr,
205             _lifeline: PhantomData,
206         })
207     }
208 
209     pub fn create_H1(
210         ceed: &crate::Ceed,
211         topo: crate::ElemTopology,
212         ncomp: usize,
213         nnodes: usize,
214         nqpts: usize,
215         interp: &[crate::Scalar],
216         grad: &[crate::Scalar],
217         qref: &[crate::Scalar],
218         qweight: &[crate::Scalar],
219     ) -> crate::Result<Self> {
220         let mut ptr = std::ptr::null_mut();
221         let (topo, ncomp, nnodes, nqpts) = (
222             topo as bind_ceed::CeedElemTopology,
223             i32::try_from(ncomp).unwrap(),
224             i32::try_from(nnodes).unwrap(),
225             i32::try_from(nqpts).unwrap(),
226         );
227         let ierr = unsafe {
228             bind_ceed::CeedBasisCreateH1(
229                 ceed.ptr,
230                 topo,
231                 ncomp,
232                 nnodes,
233                 nqpts,
234                 interp.as_ptr(),
235                 grad.as_ptr(),
236                 qref.as_ptr(),
237                 qweight.as_ptr(),
238                 &mut ptr,
239             )
240         };
241         ceed.check_error(ierr)?;
242         Ok(Self {
243             ptr,
244             _lifeline: PhantomData,
245         })
246     }
247 
248     pub fn create_Hdiv(
249         ceed: &crate::Ceed,
250         topo: crate::ElemTopology,
251         ncomp: usize,
252         nnodes: usize,
253         nqpts: usize,
254         interp: &[crate::Scalar],
255         div: &[crate::Scalar],
256         qref: &[crate::Scalar],
257         qweight: &[crate::Scalar],
258     ) -> crate::Result<Self> {
259         let mut ptr = std::ptr::null_mut();
260         let (topo, ncomp, nnodes, nqpts) = (
261             topo as bind_ceed::CeedElemTopology,
262             i32::try_from(ncomp).unwrap(),
263             i32::try_from(nnodes).unwrap(),
264             i32::try_from(nqpts).unwrap(),
265         );
266         let ierr = unsafe {
267             bind_ceed::CeedBasisCreateHdiv(
268                 ceed.ptr,
269                 topo,
270                 ncomp,
271                 nnodes,
272                 nqpts,
273                 interp.as_ptr(),
274                 div.as_ptr(),
275                 qref.as_ptr(),
276                 qweight.as_ptr(),
277                 &mut ptr,
278             )
279         };
280         ceed.check_error(ierr)?;
281         Ok(Self {
282             ptr,
283             _lifeline: PhantomData,
284         })
285     }
286 
287     pub fn create_Hcurl(
288         ceed: &crate::Ceed,
289         topo: crate::ElemTopology,
290         ncomp: usize,
291         nnodes: usize,
292         nqpts: usize,
293         interp: &[crate::Scalar],
294         curl: &[crate::Scalar],
295         qref: &[crate::Scalar],
296         qweight: &[crate::Scalar],
297     ) -> crate::Result<Self> {
298         let mut ptr = std::ptr::null_mut();
299         let (topo, ncomp, nnodes, nqpts) = (
300             topo as bind_ceed::CeedElemTopology,
301             i32::try_from(ncomp).unwrap(),
302             i32::try_from(nnodes).unwrap(),
303             i32::try_from(nqpts).unwrap(),
304         );
305         let ierr = unsafe {
306             bind_ceed::CeedBasisCreateHcurl(
307                 ceed.ptr,
308                 topo,
309                 ncomp,
310                 nnodes,
311                 nqpts,
312                 interp.as_ptr(),
313                 curl.as_ptr(),
314                 qref.as_ptr(),
315                 qweight.as_ptr(),
316                 &mut ptr,
317             )
318         };
319         ceed.check_error(ierr)?;
320         Ok(Self {
321             ptr,
322             _lifeline: PhantomData,
323         })
324     }
325 
326     // Error handling
327     #[doc(hidden)]
328     fn check_error(&self, ierr: i32) -> crate::Result<i32> {
329         unsafe { crate::check_error(bind_ceed::CeedBasisReturnCeed(self.ptr), ierr) }
330     }
331 
332     /// Apply basis evaluation from nodes to quadrature points or vice versa
333     ///
334     /// * `nelem` - The number of elements to apply the basis evaluation to
335     /// * `tmode` - `TrasposeMode::NoTranspose` to evaluate from nodes to
336     ///               quadrature points, `TransposeMode::Transpose` to apply the
337     ///               transpose, mapping from quadrature points to nodes
338     /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp`
339     ///               to use interpolated values, `EvalMode::Grad` to use
340     ///               gradients, `EvalMode::Weight` to use quadrature weights
341     /// * `u`     - Input Vector
342     /// * `v`     - Output Vector
343     ///
344     /// ```
345     /// # use libceed::prelude::*;
346     /// # fn main() -> libceed::Result<()> {
347     /// # let ceed = libceed::Ceed::default_init();
348     /// const Q: usize = 6;
349     /// let bu = ceed.basis_tensor_H1_Lagrange(1, 1, Q, Q, QuadMode::GaussLobatto)?;
350     /// let bx = ceed.basis_tensor_H1_Lagrange(1, 1, 2, Q, QuadMode::Gauss)?;
351     ///
352     /// let x_corners = ceed.vector_from_slice(&[-1., 1.])?;
353     /// let mut x_qpts = ceed.vector(Q)?;
354     /// let mut x_nodes = ceed.vector(Q)?;
355     /// bx.apply(
356     ///     1,
357     ///     TransposeMode::NoTranspose,
358     ///     EvalMode::Interp,
359     ///     &x_corners,
360     ///     &mut x_nodes,
361     /// )?;
362     /// bu.apply(
363     ///     1,
364     ///     TransposeMode::NoTranspose,
365     ///     EvalMode::Interp,
366     ///     &x_nodes,
367     ///     &mut x_qpts,
368     /// )?;
369     ///
370     /// // Create function x^3 + 1 on Gauss Lobatto points
371     /// let mut u_arr = [0.; Q];
372     /// u_arr
373     ///     .iter_mut()
374     ///     .zip(x_nodes.view()?.iter())
375     ///     .for_each(|(u, x)| *u = x * x * x + 1.);
376     /// let u = ceed.vector_from_slice(&u_arr)?;
377     ///
378     /// // Map function to Gauss points
379     /// let mut v = ceed.vector(Q)?;
380     /// v.set_value(0.);
381     /// bu.apply(1, TransposeMode::NoTranspose, EvalMode::Interp, &u, &mut v)?;
382     ///
383     /// // Verify results
384     /// v.view()?
385     ///     .iter()
386     ///     .zip(x_qpts.view()?.iter())
387     ///     .for_each(|(v, x)| {
388     ///         let true_value = x * x * x + 1.;
389     ///         assert!(
390     ///             (*v - true_value).abs() < 10.0 * libceed::EPSILON,
391     ///             "Incorrect basis application"
392     ///         );
393     ///     });
394     /// # Ok(())
395     /// # }
396     /// ```
397     pub fn apply(
398         &self,
399         nelem: usize,
400         tmode: TransposeMode,
401         emode: EvalMode,
402         u: &Vector,
403         v: &mut Vector,
404     ) -> crate::Result<i32> {
405         let (nelem, tmode, emode) = (
406             i32::try_from(nelem).unwrap(),
407             tmode as bind_ceed::CeedTransposeMode,
408             emode as bind_ceed::CeedEvalMode,
409         );
410         let ierr =
411             unsafe { bind_ceed::CeedBasisApply(self.ptr, nelem, tmode, emode, u.ptr, v.ptr) };
412         self.check_error(ierr)
413     }
414 
415     /// Returns the dimension for given Basis
416     ///
417     /// ```
418     /// # use libceed::prelude::*;
419     /// # fn main() -> libceed::Result<()> {
420     /// # let ceed = libceed::Ceed::default_init();
421     /// let dim = 2;
422     /// let b = ceed.basis_tensor_H1_Lagrange(dim, 1, 3, 4, QuadMode::Gauss)?;
423     ///
424     /// let d = b.dimension();
425     /// assert_eq!(d, dim, "Incorrect dimension");
426     /// # Ok(())
427     /// # }
428     /// ```
429     pub fn dimension(&self) -> usize {
430         let mut dim = 0;
431         unsafe { bind_ceed::CeedBasisGetDimension(self.ptr, &mut dim) };
432         usize::try_from(dim).unwrap()
433     }
434 
435     /// Returns number of components for given Basis
436     ///
437     /// ```
438     /// # use libceed::prelude::*;
439     /// # fn main() -> libceed::Result<()> {
440     /// # let ceed = libceed::Ceed::default_init();
441     /// let ncomp = 2;
442     /// let b = ceed.basis_tensor_H1_Lagrange(1, ncomp, 3, 4, QuadMode::Gauss)?;
443     ///
444     /// let n = b.num_components();
445     /// assert_eq!(n, ncomp, "Incorrect number of components");
446     /// # Ok(())
447     /// # }
448     /// ```
449     pub fn num_components(&self) -> usize {
450         let mut ncomp = 0;
451         unsafe { bind_ceed::CeedBasisGetNumComponents(self.ptr, &mut ncomp) };
452         usize::try_from(ncomp).unwrap()
453     }
454 
455     /// Returns total number of nodes (in dim dimensions) of a Basis
456     ///
457     /// ```
458     /// # use libceed::prelude::*;
459     /// # fn main() -> libceed::Result<()> {
460     /// # let ceed = libceed::Ceed::default_init();
461     /// let p = 3;
462     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, p, 4, QuadMode::Gauss)?;
463     ///
464     /// let nnodes = b.num_nodes();
465     /// assert_eq!(nnodes, p * p, "Incorrect number of nodes");
466     /// # Ok(())
467     /// # }
468     /// ```
469     pub fn num_nodes(&self) -> usize {
470         let mut nnodes = 0;
471         unsafe { bind_ceed::CeedBasisGetNumNodes(self.ptr, &mut nnodes) };
472         usize::try_from(nnodes).unwrap()
473     }
474 
475     /// Returns total number of quadrature points (in dim dimensions) of a
476     /// Basis
477     ///
478     /// ```
479     /// # use libceed::prelude::*;
480     /// # fn main() -> libceed::Result<()> {
481     /// # let ceed = libceed::Ceed::default_init();
482     /// let q = 4;
483     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, 3, q, QuadMode::Gauss)?;
484     ///
485     /// let nqpts = b.num_quadrature_points();
486     /// assert_eq!(nqpts, q * q, "Incorrect number of quadrature points");
487     /// # Ok(())
488     /// # }
489     /// ```
490     pub fn num_quadrature_points(&self) -> usize {
491         let mut Q = 0;
492         unsafe {
493             bind_ceed::CeedBasisGetNumQuadraturePoints(self.ptr, &mut Q);
494         }
495         usize::try_from(Q).unwrap()
496     }
497 
498     /// Create projection from self to specified Basis.
499     ///
500     /// Both bases must have the same quadrature space. The input bases need not
501     /// be nested as function spaces; this interface solves a least squares
502     /// problem to find a representation in the `to` basis that agrees at
503     /// quadrature points with the origin basis. Since the bases need not be
504     /// Lagrange, the resulting projection "basis" will have empty quadrature
505     /// points and weights.
506     ///
507     /// ```
508     /// # use libceed::prelude::*;
509     /// # fn main() -> libceed::Result<()> {
510     /// # let ceed = libceed::Ceed::default_init();
511     /// let coarse = ceed.basis_tensor_H1_Lagrange(1, 1, 2, 3, QuadMode::Gauss)?;
512     /// let fine = ceed.basis_tensor_H1_Lagrange(1, 1, 3, 3, QuadMode::Gauss)?;
513     /// let proj = coarse.create_projection(&fine)?;
514     /// let u = ceed.vector_from_slice(&[1., 2.])?;
515     /// let mut v = ceed.vector(3)?;
516     /// proj.apply(1, TransposeMode::NoTranspose, EvalMode::Interp, &u, &mut v)?;
517     /// let expected = [1., 1.5, 2.];
518     /// for (a, b) in v.view()?.iter().zip(expected) {
519     ///     assert!(
520     ///         (a - b).abs() < 10.0 * libceed::EPSILON,
521     ///         "Incorrect projection of linear Lagrange to quadratic Lagrange"
522     ///     );
523     /// }
524     /// # Ok(())
525     /// # }
526     /// ```
527     pub fn create_projection(&self, to: &Self) -> crate::Result<Self> {
528         let mut ptr = std::ptr::null_mut();
529         let ierr = unsafe { bind_ceed::CeedBasisCreateProjection(self.ptr, to.ptr, &mut ptr) };
530         self.check_error(ierr)?;
531         Ok(Self {
532             ptr,
533             _lifeline: PhantomData,
534         })
535     }
536 }
537 
538 // -----------------------------------------------------------------------------
539