1828beda2SMark Adams #include <petsc/private/dmpleximpl.h> 2828beda2SMark Adams #include <petscdmplex.h> 3828beda2SMark Adams #include <petscmat.h> 4828beda2SMark Adams #include <petsc_kokkos.hpp> 5828beda2SMark Adams #include <cmath> 6828beda2SMark Adams #include <cstdlib> 7828beda2SMark Adams #include <algorithm> 8828beda2SMark Adams #include <Kokkos_Core.hpp> 9828beda2SMark Adams 10828beda2SMark Adams typedef struct { 11828beda2SMark Adams PetscReal distance; 12828beda2SMark Adams PetscInt obs_index; 13828beda2SMark Adams } DistObsPair; 14828beda2SMark Adams 15828beda2SMark Adams KOKKOS_INLINE_FUNCTION 16828beda2SMark Adams static PetscReal GaspariCohn(PetscReal distance, PetscReal radius) 17828beda2SMark Adams { 18828beda2SMark Adams if (radius <= 0.0) return 0.0; 19*ee102026SMark Adams const PetscReal r = distance / radius; 20828beda2SMark Adams 21*ee102026SMark Adams if (r >= 2.0) return 0.0; 22*ee102026SMark Adams 23*ee102026SMark Adams const PetscReal r2 = r * r; 24*ee102026SMark Adams const PetscReal r3 = r2 * r; 25*ee102026SMark Adams const PetscReal r4 = r3 * r; 26*ee102026SMark Adams const PetscReal r5 = r4 * r; 27*ee102026SMark Adams 28*ee102026SMark Adams if (r <= 1.0) { 29828beda2SMark Adams // Region [0, 1] 30*ee102026SMark Adams return -0.25 * r5 + 0.5 * r4 + 0.625 * r3 - (5.0 / 3.0) * r2 + 1.0; 31*ee102026SMark Adams } else { 32*ee102026SMark Adams // Region [1, 2] 33*ee102026SMark Adams return (1.0 / 12.0) * r5 - 0.5 * r4 + 0.625 * r3 + (5.0 / 3.0) * r2 - 5.0 * r + 4.0 - (2.0 / 3.0) / r; 34828beda2SMark Adams } 35828beda2SMark Adams } 36828beda2SMark Adams 37828beda2SMark Adams /*@ 38*ee102026SMark Adams DMPlexGetLETKFLocalizationMatrix - Compute localization weight matrix for LETKF [move to ml/da/interface] 39828beda2SMark Adams 40828beda2SMark Adams Collective 41828beda2SMark Adams 42828beda2SMark Adams Input Parameters: 43*ee102026SMark Adams + n_obs_vertex - Number of nearest observations to use per vertex (eg, MAX_Q_NUM_LOCAL_OBSERVATIONS in LETKF) 44*ee102026SMark Adams . n_obs_local - Number of local observations 45*ee102026SMark Adams . n_dof - Number of degrees of freedom 46*ee102026SMark Adams . Vecxyz - Array of vectors containing the coordinates 47828beda2SMark Adams - H - Observation operator matrix 48828beda2SMark Adams 49828beda2SMark Adams Output Parameter: 50828beda2SMark Adams . Q - Localization weight matrix (sparse, AIJ format) 51828beda2SMark Adams 52828beda2SMark Adams Notes: 53*ee102026SMark Adams The output matrix Q has dimensions (n_vert_global x n_obs_global) where 54*ee102026SMark Adams n_vert_global is the number of vertices in the DMPlex. Each row contains 55*ee102026SMark Adams exactly n_obs_vertex non-zero entries corresponding to the nearest 56828beda2SMark Adams observations, weighted by the Gaspari-Cohn fifth-order piecewise 57828beda2SMark Adams rational function. 58828beda2SMark Adams 59828beda2SMark Adams The observation locations are computed as H * V where V is the vector 60828beda2SMark Adams of vertex coordinates. The localization weights ensure smooth tapering 61828beda2SMark Adams of observation influence with distance. 62828beda2SMark Adams 63*ee102026SMark Adams Kokkos is required for this routine. 64828beda2SMark Adams 65828beda2SMark Adams Level: intermediate 66828beda2SMark Adams 67*ee102026SMark Adams .seealso: 68828beda2SMark Adams @*/ 69*ee102026SMark Adams PetscErrorCode DMPlexGetLETKFLocalizationMatrix(const PetscInt n_obs_vertex, const PetscInt n_obs_local, const PetscInt n_dof, Vec Vecxyz[3], Mat H, Mat *Q) 70828beda2SMark Adams { 71*ee102026SMark Adams PetscInt dim = 0, n_vert_local, d, N, n_obs_global, n_state_local; 72828beda2SMark Adams Vec *obs_vecs; 73828beda2SMark Adams MPI_Comm comm; 74*ee102026SMark Adams PetscInt n_state_global; 75828beda2SMark Adams 76828beda2SMark Adams PetscFunctionBegin; 77*ee102026SMark Adams PetscValidHeaderSpecific(H, MAT_CLASSID, 5); 78*ee102026SMark Adams PetscAssertPointer(Q, 6); 79828beda2SMark Adams 80828beda2SMark Adams PetscCall(PetscKokkosInitializeCheck()); 81828beda2SMark Adams 82*ee102026SMark Adams PetscCall(PetscObjectGetComm((PetscObject)H, &comm)); 83*ee102026SMark Adams 84*ee102026SMark Adams /* Infer dim from the number of vectors in Vecxyz */ 85*ee102026SMark Adams for (d = 0; d < 3; ++d) { 86*ee102026SMark Adams if (Vecxyz[d]) dim++; 87*ee102026SMark Adams else break; 88*ee102026SMark Adams } 89*ee102026SMark Adams 90*ee102026SMark Adams PetscCheck(dim > 0, comm, PETSC_ERR_ARG_WRONG, "Dim must be > 0"); 91*ee102026SMark Adams PetscCheck(n_obs_vertex > 0, comm, PETSC_ERR_ARG_WRONG, "n_obs_vertex must be > 0"); 92*ee102026SMark Adams 93*ee102026SMark Adams PetscCall(VecGetSize(Vecxyz[0], &n_state_global)); 94*ee102026SMark Adams PetscCall(VecGetLocalSize(Vecxyz[0], &n_state_local)); 95*ee102026SMark Adams n_vert_local = n_state_local / n_dof; 96828beda2SMark Adams 97828beda2SMark Adams /* Check H dimensions */ 98*ee102026SMark Adams PetscCall(MatGetSize(H, &n_obs_global, &N)); 99*ee102026SMark Adams PetscCheck(N == n_state_global, comm, PETSC_ERR_ARG_SIZ, "H number of columns %" PetscInt_FMT " != global state size %" PetscInt_FMT, N, n_state_global); 100*ee102026SMark Adams // If n_obs_global < n_obs_vertex, we will pad with -1 indices and 0.0 weights. 101*ee102026SMark Adams // This is not an error condition, but rather a case where we have fewer observations than requested neighbors. 102828beda2SMark Adams 103828beda2SMark Adams /* Allocate storage for observation locations */ 104828beda2SMark Adams PetscCall(PetscMalloc1(dim, &obs_vecs)); 105828beda2SMark Adams 106828beda2SMark Adams /* Compute observation locations per dimension */ 107828beda2SMark Adams for (d = 0; d < dim; ++d) { 108*ee102026SMark Adams PetscCall(MatCreateVecs(H, NULL, &obs_vecs[d])); 109*ee102026SMark Adams PetscCall(MatMult(H, Vecxyz[d], obs_vecs[d])); 110828beda2SMark Adams } 111828beda2SMark Adams 112*ee102026SMark Adams /* Create output matrix Q in N/n_dof x P */ 113828beda2SMark Adams PetscCall(MatCreate(comm, Q)); 114*ee102026SMark Adams PetscCall(MatSetSizes(*Q, n_vert_local, n_obs_local, PETSC_DETERMINE, n_obs_global)); 115*ee102026SMark Adams PetscCall(MatSetType(*Q, MATAIJ)); 116*ee102026SMark Adams PetscCall(MatSeqAIJSetPreallocation(*Q, n_obs_vertex, NULL)); 117*ee102026SMark Adams PetscCall(MatMPIAIJSetPreallocation(*Q, n_obs_vertex, NULL, n_obs_vertex, NULL)); 118*ee102026SMark Adams PetscCall(MatSetFromOptions(*Q)); 119828beda2SMark Adams PetscCall(MatSetUp(*Q)); 120828beda2SMark Adams 121*ee102026SMark Adams PetscCall(PetscInfo((PetscObject)*Q, "Computing LETKF localization matrix: %" PetscInt_FMT " vertices, %" PetscInt_FMT " observations, %" PetscInt_FMT " neighbors\n", n_vert_local, n_obs_global, n_obs_vertex)); 122*ee102026SMark Adams 123828beda2SMark Adams /* Prepare Kokkos Views */ 124828beda2SMark Adams using ExecSpace = Kokkos::DefaultExecutionSpace; 125828beda2SMark Adams using MemSpace = ExecSpace::memory_space; 126828beda2SMark Adams 127828beda2SMark Adams /* Vertex Coordinates */ 128*ee102026SMark Adams // Use LayoutLeft for coalesced access on GPU (i is contiguous) 129*ee102026SMark Adams Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, MemSpace> vertex_coords_dev("vertex_coords", n_vert_local, dim); 130828beda2SMark Adams { 131*ee102026SMark Adams // Host view must match the data layout from VecGetArray (d-major, i-minor implies LayoutLeft for (i,d) view) 132*ee102026SMark Adams Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace> vertex_coords_host("vertex_coords_host", n_vert_local, dim); 133*ee102026SMark Adams for (d = 0; d < dim; ++d) { 134*ee102026SMark Adams const PetscScalar *local_coords_array; 135*ee102026SMark Adams PetscCall(VecGetArrayRead(Vecxyz[d], &local_coords_array)); 136*ee102026SMark Adams // Copy data. Since vertex_coords_host is LayoutLeft, &vertex_coords_host(0, d) is the start of column d. 137*ee102026SMark Adams for (PetscInt i = 0; i < n_vert_local; ++i) vertex_coords_host(i, d) = local_coords_array[i]; 138*ee102026SMark Adams PetscCall(VecRestoreArrayRead(Vecxyz[d], &local_coords_array)); 139828beda2SMark Adams } 140828beda2SMark Adams Kokkos::deep_copy(vertex_coords_dev, vertex_coords_host); 141828beda2SMark Adams } 142828beda2SMark Adams 143828beda2SMark Adams /* Observation Coordinates */ 144*ee102026SMark Adams Kokkos::View<PetscReal **, Kokkos::LayoutRight, MemSpace> obs_coords_dev("obs_coords", n_obs_global, dim); 145828beda2SMark Adams { 146*ee102026SMark Adams Kokkos::View<PetscReal **, Kokkos::LayoutRight, Kokkos::HostSpace> obs_coords_host("obs_coords_host", n_obs_global, dim); 147*ee102026SMark Adams for (d = 0; d < dim; ++d) { 148*ee102026SMark Adams VecScatter ctx; 149*ee102026SMark Adams Vec seq_vec; 150*ee102026SMark Adams const PetscScalar *array; 151*ee102026SMark Adams 152*ee102026SMark Adams PetscCall(VecScatterCreateToAll(obs_vecs[d], &ctx, &seq_vec)); 153*ee102026SMark Adams PetscCall(VecScatterBegin(ctx, obs_vecs[d], seq_vec, INSERT_VALUES, SCATTER_FORWARD)); 154*ee102026SMark Adams PetscCall(VecScatterEnd(ctx, obs_vecs[d], seq_vec, INSERT_VALUES, SCATTER_FORWARD)); 155*ee102026SMark Adams 156*ee102026SMark Adams PetscCall(VecGetArrayRead(seq_vec, &array)); 157*ee102026SMark Adams for (PetscInt j = 0; j < n_obs_global; ++j) obs_coords_host(j, d) = PetscRealPart(array[j]); 158*ee102026SMark Adams PetscCall(VecRestoreArrayRead(seq_vec, &array)); 159*ee102026SMark Adams PetscCall(VecScatterDestroy(&ctx)); 160*ee102026SMark Adams PetscCall(VecDestroy(&seq_vec)); 161828beda2SMark Adams } 162828beda2SMark Adams Kokkos::deep_copy(obs_coords_dev, obs_coords_host); 163828beda2SMark Adams } 164828beda2SMark Adams 165*ee102026SMark Adams PetscInt rstart; 166*ee102026SMark Adams PetscCall(VecGetOwnershipRange(Vecxyz[0], &rstart, NULL)); 167828beda2SMark Adams 168828beda2SMark Adams /* Output Views */ 169*ee102026SMark Adams // LayoutLeft for coalesced access on GPU 170*ee102026SMark Adams Kokkos::View<PetscInt **, Kokkos::LayoutLeft, MemSpace> indices_dev("indices", n_vert_local, n_obs_vertex); 171*ee102026SMark Adams Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, MemSpace> values_dev("values", n_vert_local, n_obs_vertex); 172828beda2SMark Adams 173828beda2SMark Adams /* Temporary storage for top-k per vertex */ 174*ee102026SMark Adams // LayoutLeft for coalesced access on GPU. 175*ee102026SMark Adams // Note: For the insertion sort within a thread, LayoutRight would offer better cache locality for the thread's private list. 176*ee102026SMark Adams // However, LayoutLeft is preferred for coalesced access across threads during the final weight computation and initialization. 177*ee102026SMark Adams // Given the random access nature of the sort (divergence), we stick to the default GPU layout (Left). 178*ee102026SMark Adams Kokkos::View<PetscReal **, Kokkos::LayoutLeft, MemSpace> best_dists_dev("best_dists", n_vert_local, n_obs_vertex); 179*ee102026SMark Adams Kokkos::View<PetscInt **, Kokkos::LayoutLeft, MemSpace> best_idxs_dev("best_idxs", n_vert_local, n_obs_vertex); 180828beda2SMark Adams 181828beda2SMark Adams /* Main Kernel */ 182828beda2SMark Adams Kokkos::parallel_for( 183*ee102026SMark Adams "ComputeLocalization", Kokkos::RangePolicy<ExecSpace>(0, n_vert_local), KOKKOS_LAMBDA(const PetscInt i) { 184*ee102026SMark Adams PetscReal current_max_dist = PETSC_MAX_REAL; 185*ee102026SMark Adams 186*ee102026SMark Adams // Cache vertex coordinates in registers to avoid repeated global memory access 187*ee102026SMark Adams // dim is small (<= 3), so this fits easily in registers 188*ee102026SMark Adams PetscReal v_coords[3] = {0.0, 0.0, 0.0}; 189*ee102026SMark Adams for (PetscInt d = 0; d < dim; ++d) v_coords[d] = PetscRealPart(vertex_coords_dev(i, d)); 190*ee102026SMark Adams 191*ee102026SMark Adams // Initialize with infinity 192*ee102026SMark Adams for (PetscInt k = 0; k < n_obs_vertex; ++k) { 193*ee102026SMark Adams best_dists_dev(i, k) = PETSC_MAX_REAL; 194*ee102026SMark Adams best_idxs_dev(i, k) = -1; 195*ee102026SMark Adams } 196828beda2SMark Adams 197828beda2SMark Adams // Iterate over all observations 198*ee102026SMark Adams for (PetscInt j = 0; j < n_obs_global; ++j) { 199828beda2SMark Adams PetscReal dist2 = 0.0; 200828beda2SMark Adams for (PetscInt d = 0; d < dim; ++d) { 201*ee102026SMark Adams PetscReal diff = v_coords[d] - obs_coords_dev(j, d); 202828beda2SMark Adams dist2 += diff * diff; 203828beda2SMark Adams } 204828beda2SMark Adams 205*ee102026SMark Adams // Check if this observation is closer than the furthest stored observation 206*ee102026SMark Adams if (dist2 < current_max_dist) { 207828beda2SMark Adams // Insert sorted 208*ee102026SMark Adams PetscInt pos = n_obs_vertex - 1; 209828beda2SMark Adams while (pos > 0 && best_dists_dev(i, pos - 1) > dist2) { 210828beda2SMark Adams best_dists_dev(i, pos) = best_dists_dev(i, pos - 1); 211828beda2SMark Adams best_idxs_dev(i, pos) = best_idxs_dev(i, pos - 1); 212828beda2SMark Adams pos--; 213828beda2SMark Adams } 214828beda2SMark Adams best_dists_dev(i, pos) = dist2; 215828beda2SMark Adams best_idxs_dev(i, pos) = j; 216*ee102026SMark Adams 217*ee102026SMark Adams // Update current max distance 218*ee102026SMark Adams current_max_dist = best_dists_dev(i, n_obs_vertex - 1); 219828beda2SMark Adams } 220828beda2SMark Adams } 221828beda2SMark Adams 222828beda2SMark Adams // Compute weights 223*ee102026SMark Adams PetscReal radius2 = best_dists_dev(i, n_obs_vertex - 1); 224828beda2SMark Adams PetscReal radius = std::sqrt(radius2); 225828beda2SMark Adams if (radius == 0.0) radius = 1.0; 226828beda2SMark Adams 227*ee102026SMark Adams for (PetscInt k = 0; k < n_obs_vertex; ++k) { 228*ee102026SMark Adams if (best_idxs_dev(i, k) != -1) { 229828beda2SMark Adams PetscReal dist = std::sqrt(best_dists_dev(i, k)); 230828beda2SMark Adams indices_dev(i, k) = best_idxs_dev(i, k); 231828beda2SMark Adams values_dev(i, k) = GaspariCohn(dist, radius); 232*ee102026SMark Adams } else { 233*ee102026SMark Adams indices_dev(i, k) = -1; // Ignore this entry 234*ee102026SMark Adams values_dev(i, k) = 0.0; 235*ee102026SMark Adams } 236828beda2SMark Adams } 237828beda2SMark Adams }); 238828beda2SMark Adams 239828beda2SMark Adams /* Copy back to host and fill matrix */ 240*ee102026SMark Adams // Host views must be LayoutRight for MatSetValues (row-major) 241*ee102026SMark Adams Kokkos::View<PetscInt **, Kokkos::LayoutRight, Kokkos::HostSpace> indices_host("indices_host", n_vert_local, n_obs_vertex); 242*ee102026SMark Adams Kokkos::View<PetscScalar **, Kokkos::LayoutRight, Kokkos::HostSpace> values_host("values_host", n_vert_local, n_obs_vertex); 243828beda2SMark Adams 244*ee102026SMark Adams // Deep copy will handle layout conversion (transpose) if device views are LayoutLeft 245*ee102026SMark Adams // Note: Kokkos::deep_copy cannot copy between different layouts if the memory spaces are different (e.g. GPU to Host). 246*ee102026SMark Adams // We need an intermediate mirror view on the host with the same layout as the device view. 247*ee102026SMark Adams Kokkos::View<PetscInt **, Kokkos::LayoutLeft, Kokkos::HostSpace> indices_host_left = Kokkos::create_mirror_view(indices_dev); 248*ee102026SMark Adams Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace> values_host_left = Kokkos::create_mirror_view(values_dev); 249828beda2SMark Adams 250*ee102026SMark Adams Kokkos::deep_copy(indices_host_left, indices_dev); 251*ee102026SMark Adams Kokkos::deep_copy(values_host_left, values_dev); 252*ee102026SMark Adams 253*ee102026SMark Adams // Now copy from LayoutLeft host view to LayoutRight host view 254*ee102026SMark Adams Kokkos::deep_copy(indices_host, indices_host_left); 255*ee102026SMark Adams Kokkos::deep_copy(values_host, values_host_left); 256*ee102026SMark Adams 257*ee102026SMark Adams for (PetscInt i = 0; i < n_vert_local; ++i) { 258*ee102026SMark Adams PetscInt globalRow = rstart + i; 259*ee102026SMark Adams PetscCall(MatSetValues(*Q, 1, &globalRow, n_obs_vertex, &indices_host(i, 0), &values_host(i, 0), INSERT_VALUES)); 260828beda2SMark Adams } 261828beda2SMark Adams 262828beda2SMark Adams /* Cleanup Phase 2 storage */ 263*ee102026SMark Adams for (d = 0; d < dim; ++d) PetscCall(VecDestroy(&obs_vecs[d])); 264828beda2SMark Adams PetscCall(PetscFree(obs_vecs)); 265828beda2SMark Adams 266828beda2SMark Adams /* Assemble matrix */ 267828beda2SMark Adams PetscCall(MatAssemblyBegin(*Q, MAT_FINAL_ASSEMBLY)); 268828beda2SMark Adams PetscCall(MatAssemblyEnd(*Q, MAT_FINAL_ASSEMBLY)); 269828beda2SMark Adams PetscFunctionReturn(PETSC_SUCCESS); 270828beda2SMark Adams } 271