1 #include <petsc/private/dmpleximpl.h> 2 #include <petscdmplex.h> 3 #include <petscmat.h> 4 #include <petsc_kokkos.hpp> 5 #include <cmath> 6 #include <cstdlib> 7 #include <algorithm> 8 #include <Kokkos_Core.hpp> 9 10 typedef struct { 11 PetscReal distance; 12 PetscInt obs_index; 13 } DistObsPair; 14 15 KOKKOS_INLINE_FUNCTION 16 static PetscReal GaspariCohn(PetscReal distance, PetscReal radius) 17 { 18 if (radius <= 0.0) return 0.0; 19 PetscReal r = distance / radius; 20 PetscReal weight = 0.0; 21 22 if (r >= 2.0) { 23 weight = 0.0; 24 } else if (r >= 1.0) { 25 // Region [1, 2] 26 PetscReal r2 = r * r; 27 PetscReal r3 = r2 * r; 28 PetscReal r4 = r3 * r; 29 PetscReal r5 = r4 * r; 30 weight = (1.0 / 12.0) * r5 - (0.5) * r4 + (0.625) * r3 + (5.0 / 3.0) * r2 - 5.0 * r + 4.0 - (2.0 / 3.0) / r; 31 } else { 32 // Region [0, 1] 33 PetscReal r2 = r * r; 34 PetscReal r3 = r2 * r; 35 PetscReal r4 = r3 * r; 36 PetscReal r5 = r4 * r; 37 weight = -0.25 * r5 + 0.5 * r4 + 0.625 * r3 - (5.0 / 3.0) * r2 + 1.0; 38 } 39 return weight; 40 } 41 42 /*@ 43 DMPlexGetLETKFLocalizationMatrix - Compute localization weight matrix for LETKF 44 45 Collective 46 47 Input Parameters: 48 + plex - The DMPlex object 49 . numobservations - Number of nearest observations to use per vertex 50 . numglobalobs - Total number of observations 51 - H - Observation operator matrix 52 53 Output Parameter: 54 . Q - Localization weight matrix (sparse, AIJ format) 55 56 Notes: 57 The output matrix Q has dimensions (numVertices x numglobalobs) where 58 numVertices is the number of vertices in the DMPlex. Each row contains 59 exactly numobservations non-zero entries corresponding to the nearest 60 observations, weighted by the Gaspari-Cohn fifth-order piecewise 61 rational function. 62 63 The observation locations are computed as H * V where V is the vector 64 of vertex coordinates. The localization weights ensure smooth tapering 65 of observation influence with distance. 66 67 Kokkos is required for this routine. LETKF has a lot of fine grain parallelism and is not useful without threads or GPUs. 68 69 Level: intermediate 70 71 .seealso: `DMPLEX`, `DMPlexGetDepthStratum()`, `DMGetCoordinatesLocal()` 72 @*/ 73 PetscErrorCode DMPlexGetLETKFLocalizationMatrix(DM plex, PetscInt numobservations, PetscInt numglobalobs, Mat H, Mat *Q) 74 { 75 PetscInt dim, vStart, vEnd, numVertices, d; 76 PetscInt M, N; 77 Vec coordinates; 78 Vec *obs_vecs; 79 PetscScalar **obs_coords; 80 PetscInt localRows, globalRows; 81 MPI_Comm comm; 82 83 PetscFunctionBegin; 84 PetscValidHeaderSpecific(plex, DM_CLASSID, 1); 85 PetscValidHeaderSpecific(H, MAT_CLASSID, 4); 86 PetscAssertPointer(Q, 5); 87 88 PetscCall(PetscKokkosInitializeCheck()); 89 90 PetscCall(PetscObjectGetComm((PetscObject)plex, &comm)); 91 PetscCall(DMGetCoordinateDim(plex, &dim)); 92 PetscCall(DMPlexGetDepthStratum(plex, 0, &vStart, &vEnd)); 93 numVertices = vEnd - vStart; 94 95 /* Check H dimensions */ 96 PetscCall(MatGetSize(H, &M, &N)); 97 PetscCheck(M == numglobalobs, comm, PETSC_ERR_ARG_SIZ, "H matrix rows %" PetscInt_FMT " != numglobalobs %" PetscInt_FMT, M, numglobalobs); 98 99 PetscCall(DMGetCoordinates(plex, &coordinates)); 100 PetscCheck(coordinates, comm, PETSC_ERR_ARG_WRONGSTATE, "DM must have coordinates"); 101 102 /* Allocate storage for observation locations */ 103 PetscCall(PetscMalloc1(dim, &obs_vecs)); 104 PetscCall(PetscMalloc1(dim, &obs_coords)); 105 106 /* Compute observation locations per dimension */ 107 for (d = 0; d < dim; ++d) { 108 Vec coord_comp; 109 PetscCall(MatCreateVecs(H, &coord_comp, &obs_vecs[d])); 110 PetscCall(VecStrideGather(coordinates, d, coord_comp, INSERT_VALUES)); 111 PetscCall(MatMult(H, coord_comp, obs_vecs[d])); 112 PetscCall(VecGetArray(obs_vecs[d], &obs_coords[d])); 113 PetscCall(VecDestroy(&coord_comp)); 114 } 115 116 /* Create output matrix Q */ 117 localRows = numVertices; 118 PetscCallMPI(MPIU_Allreduce(&localRows, &globalRows, 1, MPIU_INT, MPI_SUM, comm)); 119 120 PetscCall(MatCreate(comm, Q)); 121 PetscCall(MatSetSizes(*Q, localRows, PETSC_DECIDE, globalRows, numglobalobs)); 122 PetscCall(MatSetType(*Q, MATMPIAIJ)); 123 PetscCall(MatMPIAIJSetPreallocation(*Q, numobservations, NULL, numobservations, NULL)); 124 PetscCall(MatSetUp(*Q)); 125 126 /* Prepare Kokkos Views */ 127 using ExecSpace = Kokkos::DefaultExecutionSpace; 128 using MemSpace = ExecSpace::memory_space; 129 130 /* Vertex Coordinates */ 131 Kokkos::View<PetscReal **, Kokkos::LayoutRight, MemSpace> vertex_coords_dev("vertex_coords", numVertices, dim); 132 { 133 Kokkos::View<PetscReal **, Kokkos::LayoutRight, Kokkos::HostSpace> vertex_coords_host("vertex_coords_host", numVertices, dim); 134 Vec localCoords; 135 PetscScalar *local_coords_array; 136 PetscSection coordSection; 137 PetscCall(DMGetCoordinatesLocal(plex, &localCoords)); 138 PetscCall(DMGetCoordinateSection(plex, &coordSection)); 139 PetscCall(VecGetArray(localCoords, &local_coords_array)); 140 141 for (PetscInt v = 0; v < numVertices; ++v) { 142 PetscInt off; 143 PetscCall(PetscSectionGetOffset(coordSection, vStart + v, &off)); 144 for (d = 0; d < dim; ++d) vertex_coords_host(v, d) = PetscRealPart(local_coords_array[off + d]); 145 } 146 PetscCall(VecRestoreArray(localCoords, &local_coords_array)); 147 Kokkos::deep_copy(vertex_coords_dev, vertex_coords_host); 148 } 149 150 /* Observation Coordinates */ 151 Kokkos::View<PetscReal **, Kokkos::LayoutRight, MemSpace> obs_coords_dev("obs_coords", numglobalobs, dim); 152 { 153 Kokkos::View<PetscReal **, Kokkos::LayoutRight, Kokkos::HostSpace> obs_coords_host("obs_coords_host", numglobalobs, dim); 154 for (PetscInt j = 0; j < numglobalobs; ++j) { 155 for (d = 0; d < dim; ++d) obs_coords_host(j, d) = PetscRealPart(obs_coords[d][j]); 156 } 157 Kokkos::deep_copy(obs_coords_dev, obs_coords_host); 158 } 159 160 /* Global Rows */ 161 Kokkos::View<PetscInt *, MemSpace> global_rows_dev("global_rows", numVertices); 162 { 163 Kokkos::View<PetscInt *, Kokkos::HostSpace> global_rows_host("global_rows_host", numVertices); 164 PetscSection globalSection; 165 PetscCall(DMGetGlobalSection(plex, &globalSection)); 166 for (PetscInt v = 0; v < numVertices; ++v) { 167 PetscInt globalRow; 168 PetscCall(PetscSectionGetOffset(globalSection, vStart + v, &globalRow)); 169 global_rows_host(v) = globalRow; 170 } 171 Kokkos::deep_copy(global_rows_dev, global_rows_host); 172 } 173 174 /* Output Views */ 175 Kokkos::View<PetscInt **, Kokkos::LayoutRight, MemSpace> indices_dev("indices", numVertices, numobservations); 176 Kokkos::View<PetscScalar **, Kokkos::LayoutRight, MemSpace> values_dev("values", numVertices, numobservations); 177 178 /* Temporary storage for top-k per vertex */ 179 Kokkos::View<PetscReal **, Kokkos::LayoutRight, MemSpace> best_dists_dev("best_dists", numVertices, numobservations); 180 Kokkos::View<PetscInt **, Kokkos::LayoutRight, MemSpace> best_idxs_dev("best_idxs", numVertices, numobservations); 181 182 Kokkos::deep_copy(best_dists_dev, 1.0e30); 183 184 /* Main Kernel */ 185 Kokkos::parallel_for( 186 "ComputeLocalization", Kokkos::RangePolicy<ExecSpace>(0, numVertices), KOKKOS_LAMBDA(const PetscInt i) { 187 PetscReal current_max_dist = 1.0e30; 188 PetscInt count = 0; 189 190 // Iterate over all observations 191 for (PetscInt j = 0; j < numglobalobs; ++j) { 192 PetscReal dist2 = 0.0; 193 for (PetscInt d = 0; d < dim; ++d) { 194 PetscReal diff = vertex_coords_dev(i, d) - obs_coords_dev(j, d); 195 dist2 += diff * diff; 196 } 197 198 if (count < numobservations) { 199 // Insert sorted 200 PetscInt pos = count; 201 while (pos > 0 && best_dists_dev(i, pos - 1) > dist2) { 202 best_dists_dev(i, pos) = best_dists_dev(i, pos - 1); 203 best_idxs_dev(i, pos) = best_idxs_dev(i, pos - 1); 204 pos--; 205 } 206 best_dists_dev(i, pos) = dist2; 207 best_idxs_dev(i, pos) = j; 208 count++; 209 if (count == numobservations) current_max_dist = best_dists_dev(i, numobservations - 1); 210 } else if (dist2 < current_max_dist) { 211 // Insert sorted 212 PetscInt pos = numobservations - 1; 213 while (pos > 0 && best_dists_dev(i, pos - 1) > dist2) { 214 best_dists_dev(i, pos) = best_dists_dev(i, pos - 1); 215 best_idxs_dev(i, pos) = best_idxs_dev(i, pos - 1); 216 pos--; 217 } 218 best_dists_dev(i, pos) = dist2; 219 best_idxs_dev(i, pos) = j; 220 current_max_dist = best_dists_dev(i, numobservations - 1); 221 } 222 } 223 224 // Compute weights 225 PetscReal radius2 = best_dists_dev(i, numobservations - 1); 226 PetscReal radius = std::sqrt(radius2); 227 if (radius == 0.0) radius = 1.0; 228 229 for (PetscInt k = 0; k < numobservations; ++k) { 230 PetscReal dist = std::sqrt(best_dists_dev(i, k)); 231 indices_dev(i, k) = best_idxs_dev(i, k); 232 values_dev(i, k) = GaspariCohn(dist, radius); 233 } 234 }); 235 236 /* Copy back to host and fill matrix */ 237 Kokkos::View<PetscInt **, Kokkos::LayoutRight, Kokkos::HostSpace> indices_host = Kokkos::create_mirror_view(indices_dev); 238 Kokkos::View<PetscScalar **, Kokkos::LayoutRight, Kokkos::HostSpace> values_host = Kokkos::create_mirror_view(values_dev); 239 Kokkos::View<PetscInt *, Kokkos::HostSpace> global_rows_host = Kokkos::create_mirror_view(global_rows_dev); 240 241 Kokkos::deep_copy(indices_host, indices_dev); 242 Kokkos::deep_copy(values_host, values_dev); 243 Kokkos::deep_copy(global_rows_host, global_rows_dev); 244 245 for (PetscInt i = 0; i < numVertices; ++i) { 246 PetscInt globalRow = global_rows_host(i); 247 PetscCall(MatSetValues(*Q, 1, &globalRow, numobservations, &indices_host(i, 0), &values_host(i, 0), INSERT_VALUES)); 248 } 249 250 /* Cleanup Phase 2 storage */ 251 for (d = 0; d < dim; ++d) { 252 PetscCall(VecRestoreArray(obs_vecs[d], &obs_coords[d])); 253 PetscCall(VecDestroy(&obs_vecs[d])); 254 } 255 PetscCall(PetscFree(obs_vecs)); 256 PetscCall(PetscFree(obs_coords)); 257 258 /* Assemble matrix */ 259 PetscCall(MatAssemblyBegin(*Q, MAT_FINAL_ASSEMBLY)); 260 PetscCall(MatAssemblyEnd(*Q, MAT_FINAL_ASSEMBLY)); 261 PetscFunctionReturn(PETSC_SUCCESS); 262 } 263