xref: /petsc/src/dm/impls/plex/kokkos/plexlocalizationletkf.kokkos.cxx (revision 828beda2407cdd56b6d541804943d5bc46cf6ab9)
1*828beda2SMark Adams #include <petsc/private/dmpleximpl.h>
2*828beda2SMark Adams #include <petscdmplex.h>
3*828beda2SMark Adams #include <petscmat.h>
4*828beda2SMark Adams #include <petsc_kokkos.hpp>
5*828beda2SMark Adams #include <cmath>
6*828beda2SMark Adams #include <cstdlib>
7*828beda2SMark Adams #include <algorithm>
8*828beda2SMark Adams #include <Kokkos_Core.hpp>
9*828beda2SMark Adams 
10*828beda2SMark Adams typedef struct {
11*828beda2SMark Adams   PetscReal distance;
12*828beda2SMark Adams   PetscInt  obs_index;
13*828beda2SMark Adams } DistObsPair;
14*828beda2SMark Adams 
15*828beda2SMark Adams KOKKOS_INLINE_FUNCTION
16*828beda2SMark Adams static PetscReal GaspariCohn(PetscReal distance, PetscReal radius)
17*828beda2SMark Adams {
18*828beda2SMark Adams   if (radius <= 0.0) return 0.0;
19*828beda2SMark Adams   PetscReal r      = distance / radius;
20*828beda2SMark Adams   PetscReal weight = 0.0;
21*828beda2SMark Adams 
22*828beda2SMark Adams   if (r >= 2.0) {
23*828beda2SMark Adams     weight = 0.0;
24*828beda2SMark Adams   } else if (r >= 1.0) {
25*828beda2SMark Adams     // Region [1, 2]
26*828beda2SMark Adams     PetscReal r2 = r * r;
27*828beda2SMark Adams     PetscReal r3 = r2 * r;
28*828beda2SMark Adams     PetscReal r4 = r3 * r;
29*828beda2SMark Adams     PetscReal r5 = r4 * r;
30*828beda2SMark Adams     weight       = (1.0 / 12.0) * r5 - (0.5) * r4 + (0.625) * r3 + (5.0 / 3.0) * r2 - 5.0 * r + 4.0 - (2.0 / 3.0) / r;
31*828beda2SMark Adams   } else {
32*828beda2SMark Adams     // Region [0, 1]
33*828beda2SMark Adams     PetscReal r2 = r * r;
34*828beda2SMark Adams     PetscReal r3 = r2 * r;
35*828beda2SMark Adams     PetscReal r4 = r3 * r;
36*828beda2SMark Adams     PetscReal r5 = r4 * r;
37*828beda2SMark Adams     weight       = -0.25 * r5 + 0.5 * r4 + 0.625 * r3 - (5.0 / 3.0) * r2 + 1.0;
38*828beda2SMark Adams   }
39*828beda2SMark Adams   return weight;
40*828beda2SMark Adams }
41*828beda2SMark Adams 
42*828beda2SMark Adams /*@
43*828beda2SMark Adams   DMPlexGetLETKFLocalizationMatrix - Compute localization weight matrix for LETKF
44*828beda2SMark Adams 
45*828beda2SMark Adams   Collective
46*828beda2SMark Adams 
47*828beda2SMark Adams   Input Parameters:
48*828beda2SMark Adams + plex - The DMPlex object
49*828beda2SMark Adams . numobservations - Number of nearest observations to use per vertex
50*828beda2SMark Adams . numglobalobs - Total number of observations
51*828beda2SMark Adams - H - Observation operator matrix
52*828beda2SMark Adams 
53*828beda2SMark Adams   Output Parameter:
54*828beda2SMark Adams . Q - Localization weight matrix (sparse, AIJ format)
55*828beda2SMark Adams 
56*828beda2SMark Adams   Notes:
57*828beda2SMark Adams   The output matrix Q has dimensions (numVertices x numglobalobs) where
58*828beda2SMark Adams   numVertices is the number of vertices in the DMPlex. Each row contains
59*828beda2SMark Adams   exactly numobservations non-zero entries corresponding to the nearest
60*828beda2SMark Adams   observations, weighted by the Gaspari-Cohn fifth-order piecewise
61*828beda2SMark Adams   rational function.
62*828beda2SMark Adams 
63*828beda2SMark Adams   The observation locations are computed as H * V where V is the vector
64*828beda2SMark Adams   of vertex coordinates. The localization weights ensure smooth tapering
65*828beda2SMark Adams   of observation influence with distance.
66*828beda2SMark Adams 
67*828beda2SMark Adams   Kokkos is required for this routine. LETKF has a lot of fine grain parallelism and is not useful without threads or GPUs.
68*828beda2SMark Adams 
69*828beda2SMark Adams   Level: intermediate
70*828beda2SMark Adams 
71*828beda2SMark Adams .seealso: `DMPLEX`, `DMPlexGetDepthStratum()`, `DMGetCoordinatesLocal()`
72*828beda2SMark Adams @*/
73*828beda2SMark Adams PetscErrorCode DMPlexGetLETKFLocalizationMatrix(DM plex, PetscInt numobservations, PetscInt numglobalobs, Mat H, Mat *Q)
74*828beda2SMark Adams {
75*828beda2SMark Adams   PetscInt      dim, vStart, vEnd, numVertices, d;
76*828beda2SMark Adams   PetscInt      M, N;
77*828beda2SMark Adams   Vec           coordinates;
78*828beda2SMark Adams   Vec          *obs_vecs;
79*828beda2SMark Adams   PetscScalar **obs_coords;
80*828beda2SMark Adams   PetscInt      localRows, globalRows;
81*828beda2SMark Adams   MPI_Comm      comm;
82*828beda2SMark Adams 
83*828beda2SMark Adams   PetscFunctionBegin;
84*828beda2SMark Adams   PetscValidHeaderSpecific(plex, DM_CLASSID, 1);
85*828beda2SMark Adams   PetscValidHeaderSpecific(H, MAT_CLASSID, 4);
86*828beda2SMark Adams   PetscAssertPointer(Q, 5);
87*828beda2SMark Adams 
88*828beda2SMark Adams   PetscCall(PetscKokkosInitializeCheck());
89*828beda2SMark Adams 
90*828beda2SMark Adams   PetscCall(PetscObjectGetComm((PetscObject)plex, &comm));
91*828beda2SMark Adams   PetscCall(DMGetCoordinateDim(plex, &dim));
92*828beda2SMark Adams   PetscCall(DMPlexGetDepthStratum(plex, 0, &vStart, &vEnd));
93*828beda2SMark Adams   numVertices = vEnd - vStart;
94*828beda2SMark Adams 
95*828beda2SMark Adams   /* Check H dimensions */
96*828beda2SMark Adams   PetscCall(MatGetSize(H, &M, &N));
97*828beda2SMark Adams   PetscCheck(M == numglobalobs, comm, PETSC_ERR_ARG_SIZ, "H matrix rows %" PetscInt_FMT " != numglobalobs %" PetscInt_FMT, M, numglobalobs);
98*828beda2SMark Adams 
99*828beda2SMark Adams   PetscCall(DMGetCoordinates(plex, &coordinates));
100*828beda2SMark Adams   PetscCheck(coordinates, comm, PETSC_ERR_ARG_WRONGSTATE, "DM must have coordinates");
101*828beda2SMark Adams 
102*828beda2SMark Adams   /* Allocate storage for observation locations */
103*828beda2SMark Adams   PetscCall(PetscMalloc1(dim, &obs_vecs));
104*828beda2SMark Adams   PetscCall(PetscMalloc1(dim, &obs_coords));
105*828beda2SMark Adams 
106*828beda2SMark Adams   /* Compute observation locations per dimension */
107*828beda2SMark Adams   for (d = 0; d < dim; ++d) {
108*828beda2SMark Adams     Vec coord_comp;
109*828beda2SMark Adams     PetscCall(MatCreateVecs(H, &coord_comp, &obs_vecs[d]));
110*828beda2SMark Adams     PetscCall(VecStrideGather(coordinates, d, coord_comp, INSERT_VALUES));
111*828beda2SMark Adams     PetscCall(MatMult(H, coord_comp, obs_vecs[d]));
112*828beda2SMark Adams     PetscCall(VecGetArray(obs_vecs[d], &obs_coords[d]));
113*828beda2SMark Adams     PetscCall(VecDestroy(&coord_comp));
114*828beda2SMark Adams   }
115*828beda2SMark Adams 
116*828beda2SMark Adams   /* Create output matrix Q */
117*828beda2SMark Adams   localRows = numVertices;
118*828beda2SMark Adams   PetscCallMPI(MPIU_Allreduce(&localRows, &globalRows, 1, MPIU_INT, MPI_SUM, comm));
119*828beda2SMark Adams 
120*828beda2SMark Adams   PetscCall(MatCreate(comm, Q));
121*828beda2SMark Adams   PetscCall(MatSetSizes(*Q, localRows, PETSC_DECIDE, globalRows, numglobalobs));
122*828beda2SMark Adams   PetscCall(MatSetType(*Q, MATMPIAIJ));
123*828beda2SMark Adams   PetscCall(MatMPIAIJSetPreallocation(*Q, numobservations, NULL, numobservations, NULL));
124*828beda2SMark Adams   PetscCall(MatSetUp(*Q));
125*828beda2SMark Adams 
126*828beda2SMark Adams   /* Prepare Kokkos Views */
127*828beda2SMark Adams   using ExecSpace = Kokkos::DefaultExecutionSpace;
128*828beda2SMark Adams   using MemSpace  = ExecSpace::memory_space;
129*828beda2SMark Adams 
130*828beda2SMark Adams   /* Vertex Coordinates */
131*828beda2SMark Adams   Kokkos::View<PetscReal **, Kokkos::LayoutRight, MemSpace> vertex_coords_dev("vertex_coords", numVertices, dim);
132*828beda2SMark Adams   {
133*828beda2SMark Adams     Kokkos::View<PetscReal **, Kokkos::LayoutRight, Kokkos::HostSpace> vertex_coords_host("vertex_coords_host", numVertices, dim);
134*828beda2SMark Adams     Vec                                                                localCoords;
135*828beda2SMark Adams     PetscScalar                                                       *local_coords_array;
136*828beda2SMark Adams     PetscSection                                                       coordSection;
137*828beda2SMark Adams     PetscCall(DMGetCoordinatesLocal(plex, &localCoords));
138*828beda2SMark Adams     PetscCall(DMGetCoordinateSection(plex, &coordSection));
139*828beda2SMark Adams     PetscCall(VecGetArray(localCoords, &local_coords_array));
140*828beda2SMark Adams 
141*828beda2SMark Adams     for (PetscInt v = 0; v < numVertices; ++v) {
142*828beda2SMark Adams       PetscInt off;
143*828beda2SMark Adams       PetscCall(PetscSectionGetOffset(coordSection, vStart + v, &off));
144*828beda2SMark Adams       for (d = 0; d < dim; ++d) vertex_coords_host(v, d) = PetscRealPart(local_coords_array[off + d]);
145*828beda2SMark Adams     }
146*828beda2SMark Adams     PetscCall(VecRestoreArray(localCoords, &local_coords_array));
147*828beda2SMark Adams     Kokkos::deep_copy(vertex_coords_dev, vertex_coords_host);
148*828beda2SMark Adams   }
149*828beda2SMark Adams 
150*828beda2SMark Adams   /* Observation Coordinates */
151*828beda2SMark Adams   Kokkos::View<PetscReal **, Kokkos::LayoutRight, MemSpace> obs_coords_dev("obs_coords", numglobalobs, dim);
152*828beda2SMark Adams   {
153*828beda2SMark Adams     Kokkos::View<PetscReal **, Kokkos::LayoutRight, Kokkos::HostSpace> obs_coords_host("obs_coords_host", numglobalobs, dim);
154*828beda2SMark Adams     for (PetscInt j = 0; j < numglobalobs; ++j) {
155*828beda2SMark Adams       for (d = 0; d < dim; ++d) obs_coords_host(j, d) = PetscRealPart(obs_coords[d][j]);
156*828beda2SMark Adams     }
157*828beda2SMark Adams     Kokkos::deep_copy(obs_coords_dev, obs_coords_host);
158*828beda2SMark Adams   }
159*828beda2SMark Adams 
160*828beda2SMark Adams   /* Global Rows */
161*828beda2SMark Adams   Kokkos::View<PetscInt *, MemSpace> global_rows_dev("global_rows", numVertices);
162*828beda2SMark Adams   {
163*828beda2SMark Adams     Kokkos::View<PetscInt *, Kokkos::HostSpace> global_rows_host("global_rows_host", numVertices);
164*828beda2SMark Adams     PetscSection                                globalSection;
165*828beda2SMark Adams     PetscCall(DMGetGlobalSection(plex, &globalSection));
166*828beda2SMark Adams     for (PetscInt v = 0; v < numVertices; ++v) {
167*828beda2SMark Adams       PetscInt globalRow;
168*828beda2SMark Adams       PetscCall(PetscSectionGetOffset(globalSection, vStart + v, &globalRow));
169*828beda2SMark Adams       global_rows_host(v) = globalRow;
170*828beda2SMark Adams     }
171*828beda2SMark Adams     Kokkos::deep_copy(global_rows_dev, global_rows_host);
172*828beda2SMark Adams   }
173*828beda2SMark Adams 
174*828beda2SMark Adams   /* Output Views */
175*828beda2SMark Adams   Kokkos::View<PetscInt **, Kokkos::LayoutRight, MemSpace>    indices_dev("indices", numVertices, numobservations);
176*828beda2SMark Adams   Kokkos::View<PetscScalar **, Kokkos::LayoutRight, MemSpace> values_dev("values", numVertices, numobservations);
177*828beda2SMark Adams 
178*828beda2SMark Adams   /* Temporary storage for top-k per vertex */
179*828beda2SMark Adams   Kokkos::View<PetscReal **, Kokkos::LayoutRight, MemSpace> best_dists_dev("best_dists", numVertices, numobservations);
180*828beda2SMark Adams   Kokkos::View<PetscInt **, Kokkos::LayoutRight, MemSpace>  best_idxs_dev("best_idxs", numVertices, numobservations);
181*828beda2SMark Adams 
182*828beda2SMark Adams   Kokkos::deep_copy(best_dists_dev, 1.0e30);
183*828beda2SMark Adams 
184*828beda2SMark Adams   /* Main Kernel */
185*828beda2SMark Adams   Kokkos::parallel_for(
186*828beda2SMark Adams     "ComputeLocalization", Kokkos::RangePolicy<ExecSpace>(0, numVertices), KOKKOS_LAMBDA(const PetscInt i) {
187*828beda2SMark Adams       PetscReal current_max_dist = 1.0e30;
188*828beda2SMark Adams       PetscInt  count            = 0;
189*828beda2SMark Adams 
190*828beda2SMark Adams       // Iterate over all observations
191*828beda2SMark Adams       for (PetscInt j = 0; j < numglobalobs; ++j) {
192*828beda2SMark Adams         PetscReal dist2 = 0.0;
193*828beda2SMark Adams         for (PetscInt d = 0; d < dim; ++d) {
194*828beda2SMark Adams           PetscReal diff = vertex_coords_dev(i, d) - obs_coords_dev(j, d);
195*828beda2SMark Adams           dist2 += diff * diff;
196*828beda2SMark Adams         }
197*828beda2SMark Adams 
198*828beda2SMark Adams         if (count < numobservations) {
199*828beda2SMark Adams           // Insert sorted
200*828beda2SMark Adams           PetscInt pos = count;
201*828beda2SMark Adams           while (pos > 0 && best_dists_dev(i, pos - 1) > dist2) {
202*828beda2SMark Adams             best_dists_dev(i, pos) = best_dists_dev(i, pos - 1);
203*828beda2SMark Adams             best_idxs_dev(i, pos)  = best_idxs_dev(i, pos - 1);
204*828beda2SMark Adams             pos--;
205*828beda2SMark Adams           }
206*828beda2SMark Adams           best_dists_dev(i, pos) = dist2;
207*828beda2SMark Adams           best_idxs_dev(i, pos)  = j;
208*828beda2SMark Adams           count++;
209*828beda2SMark Adams           if (count == numobservations) current_max_dist = best_dists_dev(i, numobservations - 1);
210*828beda2SMark Adams         } else if (dist2 < current_max_dist) {
211*828beda2SMark Adams           // Insert sorted
212*828beda2SMark Adams           PetscInt pos = numobservations - 1;
213*828beda2SMark Adams           while (pos > 0 && best_dists_dev(i, pos - 1) > dist2) {
214*828beda2SMark Adams             best_dists_dev(i, pos) = best_dists_dev(i, pos - 1);
215*828beda2SMark Adams             best_idxs_dev(i, pos)  = best_idxs_dev(i, pos - 1);
216*828beda2SMark Adams             pos--;
217*828beda2SMark Adams           }
218*828beda2SMark Adams           best_dists_dev(i, pos) = dist2;
219*828beda2SMark Adams           best_idxs_dev(i, pos)  = j;
220*828beda2SMark Adams           current_max_dist       = best_dists_dev(i, numobservations - 1);
221*828beda2SMark Adams         }
222*828beda2SMark Adams       }
223*828beda2SMark Adams 
224*828beda2SMark Adams       // Compute weights
225*828beda2SMark Adams       PetscReal radius2 = best_dists_dev(i, numobservations - 1);
226*828beda2SMark Adams       PetscReal radius  = std::sqrt(radius2);
227*828beda2SMark Adams       if (radius == 0.0) radius = 1.0;
228*828beda2SMark Adams 
229*828beda2SMark Adams       for (PetscInt k = 0; k < numobservations; ++k) {
230*828beda2SMark Adams         PetscReal dist    = std::sqrt(best_dists_dev(i, k));
231*828beda2SMark Adams         indices_dev(i, k) = best_idxs_dev(i, k);
232*828beda2SMark Adams         values_dev(i, k)  = GaspariCohn(dist, radius);
233*828beda2SMark Adams       }
234*828beda2SMark Adams     });
235*828beda2SMark Adams 
236*828beda2SMark Adams   /* Copy back to host and fill matrix */
237*828beda2SMark Adams   Kokkos::View<PetscInt **, Kokkos::LayoutRight, Kokkos::HostSpace>    indices_host     = Kokkos::create_mirror_view(indices_dev);
238*828beda2SMark Adams   Kokkos::View<PetscScalar **, Kokkos::LayoutRight, Kokkos::HostSpace> values_host      = Kokkos::create_mirror_view(values_dev);
239*828beda2SMark Adams   Kokkos::View<PetscInt *, Kokkos::HostSpace>                          global_rows_host = Kokkos::create_mirror_view(global_rows_dev);
240*828beda2SMark Adams 
241*828beda2SMark Adams   Kokkos::deep_copy(indices_host, indices_dev);
242*828beda2SMark Adams   Kokkos::deep_copy(values_host, values_dev);
243*828beda2SMark Adams   Kokkos::deep_copy(global_rows_host, global_rows_dev);
244*828beda2SMark Adams 
245*828beda2SMark Adams   for (PetscInt i = 0; i < numVertices; ++i) {
246*828beda2SMark Adams     PetscInt globalRow = global_rows_host(i);
247*828beda2SMark Adams     PetscCall(MatSetValues(*Q, 1, &globalRow, numobservations, &indices_host(i, 0), &values_host(i, 0), INSERT_VALUES));
248*828beda2SMark Adams   }
249*828beda2SMark Adams 
250*828beda2SMark Adams   /* Cleanup Phase 2 storage */
251*828beda2SMark Adams   for (d = 0; d < dim; ++d) {
252*828beda2SMark Adams     PetscCall(VecRestoreArray(obs_vecs[d], &obs_coords[d]));
253*828beda2SMark Adams     PetscCall(VecDestroy(&obs_vecs[d]));
254*828beda2SMark Adams   }
255*828beda2SMark Adams   PetscCall(PetscFree(obs_vecs));
256*828beda2SMark Adams   PetscCall(PetscFree(obs_coords));
257*828beda2SMark Adams 
258*828beda2SMark Adams   /* Assemble matrix */
259*828beda2SMark Adams   PetscCall(MatAssemblyBegin(*Q, MAT_FINAL_ASSEMBLY));
260*828beda2SMark Adams   PetscCall(MatAssemblyEnd(*Q, MAT_FINAL_ASSEMBLY));
261*828beda2SMark Adams   PetscFunctionReturn(PETSC_SUCCESS);
262*828beda2SMark Adams }
263