xref: /petsc/src/dm/impls/plex/kokkos/plexlocalizationletkf.kokkos.cxx (revision cc6b3d485b7186cdb6e8e6b8ee2dca644baa10ec)
1 #include <petsc/private/dmpleximpl.h>
2 #include <petscdmplex.h>
3 #include <petscmat.h>
4 #include <petsc_kokkos.hpp>
5 #include <cmath>
6 #include <cstdlib>
7 #include <algorithm>
8 #include <Kokkos_Core.hpp>
9 
10 typedef struct {
11   PetscReal distance;
12   PetscInt  obs_index;
13 } DistObsPair;
14 
15 KOKKOS_INLINE_FUNCTION
16 static PetscReal GaspariCohn(PetscReal distance, PetscReal radius)
17 {
18   if (radius <= 0.0) return 0.0;
19   PetscReal r      = distance / radius;
20   PetscReal weight = 0.0;
21 
22   if (r >= 2.0) {
23     weight = 0.0;
24   } else if (r >= 1.0) {
25     // Region [1, 2]
26     PetscReal r2 = r * r;
27     PetscReal r3 = r2 * r;
28     PetscReal r4 = r3 * r;
29     PetscReal r5 = r4 * r;
30     weight       = (1.0 / 12.0) * r5 - (0.5) * r4 + (0.625) * r3 + (5.0 / 3.0) * r2 - 5.0 * r + 4.0 - (2.0 / 3.0) / r;
31   } else {
32     // Region [0, 1]
33     PetscReal r2 = r * r;
34     PetscReal r3 = r2 * r;
35     PetscReal r4 = r3 * r;
36     PetscReal r5 = r4 * r;
37     weight       = -0.25 * r5 + 0.5 * r4 + 0.625 * r3 - (5.0 / 3.0) * r2 + 1.0;
38   }
39   return weight;
40 }
41 
42 /*@
43   DMPlexGetLETKFLocalizationMatrix - Compute localization weight matrix for LETKF
44 
45   Collective
46 
47   Input Parameters:
48 + plex - The DMPlex object
49 . numobservations - Number of nearest observations to use per vertex
50 . numglobalobs - Total number of observations
51 - H - Observation operator matrix
52 
53   Output Parameter:
54 . Q - Localization weight matrix (sparse, AIJ format)
55 
56   Notes:
57   The output matrix Q has dimensions (numVertices x numglobalobs) where
58   numVertices is the number of vertices in the DMPlex. Each row contains
59   exactly numobservations non-zero entries corresponding to the nearest
60   observations, weighted by the Gaspari-Cohn fifth-order piecewise
61   rational function.
62 
63   The observation locations are computed as H * V where V is the vector
64   of vertex coordinates. The localization weights ensure smooth tapering
65   of observation influence with distance.
66 
67   Kokkos is required for this routine. LETKF has a lot of fine grain parallelism and is not useful without threads or GPUs.
68 
69   Level: intermediate
70 
71 .seealso: `DMPLEX`, `DMPlexGetDepthStratum()`, `DMGetCoordinatesLocal()`
72 @*/
73 PetscErrorCode DMPlexGetLETKFLocalizationMatrix(DM plex, PetscInt numobservations, PetscInt numglobalobs, Mat H, Mat *Q)
74 {
75   PetscInt      dim, vStart, vEnd, numVertices, d;
76   PetscInt      M, N;
77   Vec           coordinates;
78   Vec          *obs_vecs;
79   PetscScalar **obs_coords;
80   PetscInt      localRows, globalRows;
81   MPI_Comm      comm;
82 
83   PetscFunctionBegin;
84   PetscValidHeaderSpecific(plex, DM_CLASSID, 1);
85   PetscValidHeaderSpecific(H, MAT_CLASSID, 4);
86   PetscAssertPointer(Q, 5);
87 
88   PetscCall(PetscKokkosInitializeCheck());
89 
90   PetscCall(PetscObjectGetComm((PetscObject)plex, &comm));
91   PetscCall(DMGetCoordinateDim(plex, &dim));
92   PetscCall(DMPlexGetDepthStratum(plex, 0, &vStart, &vEnd));
93   numVertices = vEnd - vStart;
94 
95   /* Check H dimensions */
96   PetscCall(MatGetSize(H, &M, &N));
97   PetscCheck(M == numglobalobs, comm, PETSC_ERR_ARG_SIZ, "H matrix rows %" PetscInt_FMT " != numglobalobs %" PetscInt_FMT, M, numglobalobs);
98 
99   PetscCall(DMGetCoordinates(plex, &coordinates));
100   PetscCheck(coordinates, comm, PETSC_ERR_ARG_WRONGSTATE, "DM must have coordinates");
101 
102   /* Allocate storage for observation locations */
103   PetscCall(PetscMalloc1(dim, &obs_vecs));
104   PetscCall(PetscMalloc1(dim, &obs_coords));
105 
106   /* Compute observation locations per dimension */
107   for (d = 0; d < dim; ++d) {
108     Vec coord_comp;
109     PetscCall(MatCreateVecs(H, &coord_comp, &obs_vecs[d]));
110     PetscCall(VecStrideGather(coordinates, d, coord_comp, INSERT_VALUES));
111     PetscCall(MatMult(H, coord_comp, obs_vecs[d]));
112     PetscCall(VecGetArray(obs_vecs[d], &obs_coords[d]));
113     PetscCall(VecDestroy(&coord_comp));
114   }
115 
116   /* Create output matrix Q */
117   localRows = numVertices;
118   PetscCallMPI(MPIU_Allreduce(&localRows, &globalRows, 1, MPIU_INT, MPI_SUM, comm));
119 
120   PetscCall(MatCreate(comm, Q));
121   PetscCall(MatSetSizes(*Q, localRows, PETSC_DECIDE, globalRows, numglobalobs));
122   PetscCall(MatSetType(*Q, MATMPIAIJ));
123   PetscCall(MatMPIAIJSetPreallocation(*Q, numobservations, NULL, numobservations, NULL));
124   PetscCall(MatSetUp(*Q));
125 
126   /* Prepare Kokkos Views */
127   using ExecSpace = Kokkos::DefaultExecutionSpace;
128   using MemSpace  = ExecSpace::memory_space;
129 
130   /* Vertex Coordinates */
131   Kokkos::View<PetscReal **, Kokkos::LayoutRight, MemSpace> vertex_coords_dev("vertex_coords", numVertices, dim);
132   {
133     Kokkos::View<PetscReal **, Kokkos::LayoutRight, Kokkos::HostSpace> vertex_coords_host("vertex_coords_host", numVertices, dim);
134     Vec                                                                localCoords;
135     PetscScalar                                                       *local_coords_array;
136     PetscSection                                                       coordSection;
137     PetscCall(DMGetCoordinatesLocal(plex, &localCoords));
138     PetscCall(DMGetCoordinateSection(plex, &coordSection));
139     PetscCall(VecGetArray(localCoords, &local_coords_array));
140 
141     for (PetscInt v = 0; v < numVertices; ++v) {
142       PetscInt off;
143       PetscCall(PetscSectionGetOffset(coordSection, vStart + v, &off));
144       for (d = 0; d < dim; ++d) vertex_coords_host(v, d) = PetscRealPart(local_coords_array[off + d]);
145     }
146     PetscCall(VecRestoreArray(localCoords, &local_coords_array));
147     Kokkos::deep_copy(vertex_coords_dev, vertex_coords_host);
148   }
149 
150   /* Observation Coordinates */
151   Kokkos::View<PetscReal **, Kokkos::LayoutRight, MemSpace> obs_coords_dev("obs_coords", numglobalobs, dim);
152   {
153     Kokkos::View<PetscReal **, Kokkos::LayoutRight, Kokkos::HostSpace> obs_coords_host("obs_coords_host", numglobalobs, dim);
154     for (PetscInt j = 0; j < numglobalobs; ++j) {
155       for (d = 0; d < dim; ++d) obs_coords_host(j, d) = PetscRealPart(obs_coords[d][j]);
156     }
157     Kokkos::deep_copy(obs_coords_dev, obs_coords_host);
158   }
159 
160   /* Global Rows */
161   Kokkos::View<PetscInt *, MemSpace> global_rows_dev("global_rows", numVertices);
162   {
163     Kokkos::View<PetscInt *, Kokkos::HostSpace> global_rows_host("global_rows_host", numVertices);
164     PetscSection                                globalSection;
165     PetscCall(DMGetGlobalSection(plex, &globalSection));
166     for (PetscInt v = 0; v < numVertices; ++v) {
167       PetscInt globalRow;
168       PetscCall(PetscSectionGetOffset(globalSection, vStart + v, &globalRow));
169       global_rows_host(v) = globalRow;
170     }
171     Kokkos::deep_copy(global_rows_dev, global_rows_host);
172   }
173 
174   /* Output Views */
175   Kokkos::View<PetscInt **, Kokkos::LayoutRight, MemSpace>    indices_dev("indices", numVertices, numobservations);
176   Kokkos::View<PetscScalar **, Kokkos::LayoutRight, MemSpace> values_dev("values", numVertices, numobservations);
177 
178   /* Temporary storage for top-k per vertex */
179   Kokkos::View<PetscReal **, Kokkos::LayoutRight, MemSpace> best_dists_dev("best_dists", numVertices, numobservations);
180   Kokkos::View<PetscInt **, Kokkos::LayoutRight, MemSpace>  best_idxs_dev("best_idxs", numVertices, numobservations);
181 
182   Kokkos::deep_copy(best_dists_dev, 1.0e30);
183 
184   /* Main Kernel */
185   Kokkos::parallel_for(
186     "ComputeLocalization", Kokkos::RangePolicy<ExecSpace>(0, numVertices), KOKKOS_LAMBDA(const PetscInt i) {
187       PetscReal current_max_dist = 1.0e30;
188       PetscInt  count            = 0;
189 
190       // Iterate over all observations
191       for (PetscInt j = 0; j < numglobalobs; ++j) {
192         PetscReal dist2 = 0.0;
193         for (PetscInt d = 0; d < dim; ++d) {
194           PetscReal diff = vertex_coords_dev(i, d) - obs_coords_dev(j, d);
195           dist2 += diff * diff;
196         }
197 
198         if (count < numobservations) {
199           // Insert sorted
200           PetscInt pos = count;
201           while (pos > 0 && best_dists_dev(i, pos - 1) > dist2) {
202             best_dists_dev(i, pos) = best_dists_dev(i, pos - 1);
203             best_idxs_dev(i, pos)  = best_idxs_dev(i, pos - 1);
204             pos--;
205           }
206           best_dists_dev(i, pos) = dist2;
207           best_idxs_dev(i, pos)  = j;
208           count++;
209           if (count == numobservations) current_max_dist = best_dists_dev(i, numobservations - 1);
210         } else if (dist2 < current_max_dist) {
211           // Insert sorted
212           PetscInt pos = numobservations - 1;
213           while (pos > 0 && best_dists_dev(i, pos - 1) > dist2) {
214             best_dists_dev(i, pos) = best_dists_dev(i, pos - 1);
215             best_idxs_dev(i, pos)  = best_idxs_dev(i, pos - 1);
216             pos--;
217           }
218           best_dists_dev(i, pos) = dist2;
219           best_idxs_dev(i, pos)  = j;
220           current_max_dist       = best_dists_dev(i, numobservations - 1);
221         }
222       }
223 
224       // Compute weights
225       PetscReal radius2 = best_dists_dev(i, numobservations - 1);
226       PetscReal radius  = std::sqrt(radius2);
227       if (radius == 0.0) radius = 1.0;
228 
229       for (PetscInt k = 0; k < numobservations; ++k) {
230         PetscReal dist    = std::sqrt(best_dists_dev(i, k));
231         indices_dev(i, k) = best_idxs_dev(i, k);
232         values_dev(i, k)  = GaspariCohn(dist, radius);
233       }
234     });
235 
236   /* Copy back to host and fill matrix */
237   Kokkos::View<PetscInt **, Kokkos::LayoutRight, Kokkos::HostSpace>    indices_host     = Kokkos::create_mirror_view(indices_dev);
238   Kokkos::View<PetscScalar **, Kokkos::LayoutRight, Kokkos::HostSpace> values_host      = Kokkos::create_mirror_view(values_dev);
239   Kokkos::View<PetscInt *, Kokkos::HostSpace>                          global_rows_host = Kokkos::create_mirror_view(global_rows_dev);
240 
241   Kokkos::deep_copy(indices_host, indices_dev);
242   Kokkos::deep_copy(values_host, values_dev);
243   Kokkos::deep_copy(global_rows_host, global_rows_dev);
244 
245   for (PetscInt i = 0; i < numVertices; ++i) {
246     PetscInt globalRow = global_rows_host(i);
247     PetscCall(MatSetValues(*Q, 1, &globalRow, numobservations, &indices_host(i, 0), &values_host(i, 0), INSERT_VALUES));
248   }
249 
250   /* Cleanup Phase 2 storage */
251   for (d = 0; d < dim; ++d) {
252     PetscCall(VecRestoreArray(obs_vecs[d], &obs_coords[d]));
253     PetscCall(VecDestroy(&obs_vecs[d]));
254   }
255   PetscCall(PetscFree(obs_vecs));
256   PetscCall(PetscFree(obs_coords));
257 
258   /* Assemble matrix */
259   PetscCall(MatAssemblyBegin(*Q, MAT_FINAL_ASSEMBLY));
260   PetscCall(MatAssemblyEnd(*Q, MAT_FINAL_ASSEMBLY));
261   PetscFunctionReturn(PETSC_SUCCESS);
262 }
263