xref: /petsc/src/snes/tutorials/ex55k.kokkos.cxx (revision 3ba1676111f5c958fe6c2729b46ca4d523958bb3)
1c5566c22SJunchao Zhang #include <Kokkos_Core.hpp>
2c5566c22SJunchao Zhang #include <petscdmda_kokkos.hpp>
3c5566c22SJunchao Zhang 
4c5566c22SJunchao Zhang #include <petscdm.h>
5c5566c22SJunchao Zhang #include <petscdmda.h>
6c5566c22SJunchao Zhang #include <petscsnes.h>
7c5566c22SJunchao Zhang #include "ex55.h"
8c5566c22SJunchao Zhang 
9c5566c22SJunchao Zhang using DefaultMemorySpace                 = Kokkos::DefaultExecutionSpace::memory_space;
10c5566c22SJunchao Zhang using ConstPetscScalarKokkosOffsetView2D = Kokkos::Experimental::OffsetView<const PetscScalar **, Kokkos::LayoutRight, DefaultMemorySpace>;
11c5566c22SJunchao Zhang using PetscScalarKokkosOffsetView2D      = Kokkos::Experimental::OffsetView<PetscScalar **, Kokkos::LayoutRight, DefaultMemorySpace>;
12c5566c22SJunchao Zhang 
13c5566c22SJunchao Zhang using PetscCountKokkosView     = Kokkos::View<PetscCount *, DefaultMemorySpace>;
14c5566c22SJunchao Zhang using PetscIntKokkosView       = Kokkos::View<PetscInt *, DefaultMemorySpace>;
15c5566c22SJunchao Zhang using PetscCountKokkosViewHost = Kokkos::View<PetscCount *, Kokkos::HostSpace>;
16c5566c22SJunchao Zhang using PetscScalarKokkosView    = Kokkos::View<PetscScalar *, DefaultMemorySpace>;
1733d966b4SJunchao Zhang using Kokkos::Iterate;
1833d966b4SJunchao Zhang using Kokkos::MDRangePolicy;
199371c9d4SSatish Balay using Kokkos::Rank;
20c5566c22SJunchao Zhang 
21d71ae5a4SJacob Faibussowitsch KOKKOS_INLINE_FUNCTION PetscErrorCode MMSSolution1(AppCtx *user, const DMDACoor2d *c, PetscScalar *u)
22d71ae5a4SJacob Faibussowitsch {
23c5566c22SJunchao Zhang   PetscReal x = PetscRealPart(c->x), y = PetscRealPart(c->y);
24c5566c22SJunchao Zhang   u[0] = x * (1 - x) * y * (1 - y);
25*3ba16761SJacob Faibussowitsch   return PETSC_SUCCESS;
26c5566c22SJunchao Zhang }
27c5566c22SJunchao Zhang 
28d71ae5a4SJacob Faibussowitsch KOKKOS_INLINE_FUNCTION PetscErrorCode MMSForcing1(PetscReal user_param, const DMDACoor2d *c, PetscScalar *f)
29d71ae5a4SJacob Faibussowitsch {
30c5566c22SJunchao Zhang   PetscReal x = PetscRealPart(c->x), y = PetscRealPart(c->y);
31c5566c22SJunchao Zhang   f[0] = 2 * x * (1 - x) + 2 * y * (1 - y) - user_param * PetscExpReal(x * (1 - x) * y * (1 - y));
32*3ba16761SJacob Faibussowitsch   return PETSC_SUCCESS;
33c5566c22SJunchao Zhang }
34c5566c22SJunchao Zhang 
35d71ae5a4SJacob Faibussowitsch PetscErrorCode FormFunctionLocalVec(DMDALocalInfo *info, Vec x, Vec f, AppCtx *user)
36d71ae5a4SJacob Faibussowitsch {
37c5566c22SJunchao Zhang   PetscReal lambda, hx, hy, hxdhy, hydhx;
38c5566c22SJunchao Zhang   PetscInt  xs = info->xs, ys = info->ys, xm = info->xm, ym = info->ym, mx = info->mx, my = info->my;
39c5566c22SJunchao Zhang   PetscReal user_param = user->param;
40c5566c22SJunchao Zhang 
41c5566c22SJunchao Zhang   ConstPetscScalarKokkosOffsetView2D xv;
42c5566c22SJunchao Zhang   PetscScalarKokkosOffsetView2D      fv;
43c5566c22SJunchao Zhang 
44c5566c22SJunchao Zhang   PetscFunctionBeginUser;
45c5566c22SJunchao Zhang   lambda = user->param;
46c5566c22SJunchao Zhang   hx     = 1.0 / (PetscReal)(info->mx - 1);
47c5566c22SJunchao Zhang   hy     = 1.0 / (PetscReal)(info->my - 1);
48c5566c22SJunchao Zhang   hxdhy  = hx / hy;
49c5566c22SJunchao Zhang   hydhx  = hy / hx;
50c5566c22SJunchao Zhang   /*
51c5566c22SJunchao Zhang      Compute function over the locally owned part of the grid
52c5566c22SJunchao Zhang   */
53*3ba16761SJacob Faibussowitsch   PetscCall(DMDAVecGetKokkosOffsetView(info->da, x, &xv));
54*3ba16761SJacob Faibussowitsch   PetscCall(DMDAVecGetKokkosOffsetViewWrite(info->da, f, &fv));
55c5566c22SJunchao Zhang 
569371c9d4SSatish Balay   PetscCallCXX(Kokkos::parallel_for(
579371c9d4SSatish Balay     "FormFunctionLocalVec", MDRangePolicy<Rank<2, Iterate::Right, Iterate::Right>>({ys, xs}, {ys + ym, xs + xm}), KOKKOS_LAMBDA(PetscInt j, PetscInt i) {
58c5566c22SJunchao Zhang       DMDACoor2d  c;
59c5566c22SJunchao Zhang       PetscScalar u, ue, uw, un, us, uxx, uyy, mms_solution, mms_forcing;
60c5566c22SJunchao Zhang 
61c5566c22SJunchao Zhang       if (i == 0 || j == 0 || i == mx - 1 || j == my - 1) {
629371c9d4SSatish Balay         c.x = i * hx;
639371c9d4SSatish Balay         c.y = j * hy;
64*3ba16761SJacob Faibussowitsch         static_cast<void>(MMSSolution1(user, &c, &mms_solution));
65c5566c22SJunchao Zhang         fv(j, i) = 2.0 * (hydhx + hxdhy) * (xv(j, i) - mms_solution);
66c5566c22SJunchao Zhang       } else {
67c5566c22SJunchao Zhang         u  = xv(j, i);
68c5566c22SJunchao Zhang         uw = xv(j, i - 1);
69c5566c22SJunchao Zhang         ue = xv(j, i + 1);
70c5566c22SJunchao Zhang         un = xv(j - 1, i);
71c5566c22SJunchao Zhang         us = xv(j + 1, i);
72c5566c22SJunchao Zhang 
73c5566c22SJunchao Zhang         /* Enforce boundary conditions at neighboring points -- setting these values causes the Jacobian to be symmetric. */
749371c9d4SSatish Balay         if (i - 1 == 0) {
759371c9d4SSatish Balay           c.x = (i - 1) * hx;
769371c9d4SSatish Balay           c.y = j * hy;
77*3ba16761SJacob Faibussowitsch           static_cast<void>(MMSSolution1(user, &c, &uw));
789371c9d4SSatish Balay         }
799371c9d4SSatish Balay         if (i + 1 == mx - 1) {
809371c9d4SSatish Balay           c.x = (i + 1) * hx;
819371c9d4SSatish Balay           c.y = j * hy;
82*3ba16761SJacob Faibussowitsch           static_cast<void>(MMSSolution1(user, &c, &ue));
839371c9d4SSatish Balay         }
849371c9d4SSatish Balay         if (j - 1 == 0) {
859371c9d4SSatish Balay           c.x = i * hx;
869371c9d4SSatish Balay           c.y = (j - 1) * hy;
87*3ba16761SJacob Faibussowitsch           static_cast<void>(MMSSolution1(user, &c, &un));
889371c9d4SSatish Balay         }
899371c9d4SSatish Balay         if (j + 1 == my - 1) {
909371c9d4SSatish Balay           c.x = i * hx;
919371c9d4SSatish Balay           c.y = (j + 1) * hy;
92*3ba16761SJacob Faibussowitsch           static_cast<void>(MMSSolution1(user, &c, &us));
939371c9d4SSatish Balay         }
94c5566c22SJunchao Zhang 
95c5566c22SJunchao Zhang         uxx         = (2.0 * u - uw - ue) * hydhx;
96c5566c22SJunchao Zhang         uyy         = (2.0 * u - un - us) * hxdhy;
97c5566c22SJunchao Zhang         mms_forcing = 0;
989371c9d4SSatish Balay         c.x         = i * hx;
999371c9d4SSatish Balay         c.y         = j * hy;
100*3ba16761SJacob Faibussowitsch         static_cast<void>(MMSForcing1(user_param, &c, &mms_forcing));
101c5566c22SJunchao Zhang         fv(j, i) = uxx + uyy - hx * hy * (lambda * PetscExpScalar(u) + mms_forcing);
102c5566c22SJunchao Zhang       }
103c5566c22SJunchao Zhang     }));
104c5566c22SJunchao Zhang 
105*3ba16761SJacob Faibussowitsch   PetscCall(DMDAVecRestoreKokkosOffsetView(info->da, x, &xv));
106*3ba16761SJacob Faibussowitsch   PetscCall(DMDAVecRestoreKokkosOffsetViewWrite(info->da, f, &fv));
107c5566c22SJunchao Zhang 
108c5566c22SJunchao Zhang   PetscCall(PetscLogFlops(11.0 * info->ym * info->xm));
109*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
110c5566c22SJunchao Zhang }
111c5566c22SJunchao Zhang 
112d71ae5a4SJacob Faibussowitsch PetscErrorCode FormObjectiveLocalVec(DMDALocalInfo *info, Vec x, PetscReal *obj, AppCtx *user)
113d71ae5a4SJacob Faibussowitsch {
114c5566c22SJunchao Zhang   PetscInt  xs = info->xs, ys = info->ys, xm = info->xm, ym = info->ym, mx = info->mx, my = info->my;
115c5566c22SJunchao Zhang   PetscReal lambda, hx, hy, hxdhy, hydhx, sc, lobj = 0;
116c5566c22SJunchao Zhang   MPI_Comm  comm;
117c5566c22SJunchao Zhang 
118c5566c22SJunchao Zhang   ConstPetscScalarKokkosOffsetView2D xv;
119c5566c22SJunchao Zhang 
120c5566c22SJunchao Zhang   PetscFunctionBeginUser;
121c5566c22SJunchao Zhang   *obj = 0;
122c5566c22SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)info->da, &comm));
123c5566c22SJunchao Zhang   lambda = user->param;
124c5566c22SJunchao Zhang   hx     = 1.0 / (PetscReal)(mx - 1);
125c5566c22SJunchao Zhang   hy     = 1.0 / (PetscReal)(my - 1);
126c5566c22SJunchao Zhang   sc     = hx * hy * lambda;
127c5566c22SJunchao Zhang   hxdhy  = hx / hy;
128c5566c22SJunchao Zhang   hydhx  = hy / hx;
129c5566c22SJunchao Zhang   /*
130c5566c22SJunchao Zhang      Compute function over the locally owned part of the grid
131c5566c22SJunchao Zhang   */
132*3ba16761SJacob Faibussowitsch   PetscCall(DMDAVecGetKokkosOffsetView(info->da, x, &xv));
133c5566c22SJunchao Zhang 
1349371c9d4SSatish Balay   PetscCallCXX(Kokkos::parallel_reduce(
1359371c9d4SSatish Balay     "FormObjectiveLocalVec", MDRangePolicy<Rank<2, Iterate::Right, Iterate::Right>>({ys, xs}, {ys + ym, xs + xm}),
1369371c9d4SSatish Balay     KOKKOS_LAMBDA(PetscInt j, PetscInt i, PetscReal & update) {
137c5566c22SJunchao Zhang       PetscScalar u, ue, uw, un, us, uxux, uyuy;
138c5566c22SJunchao Zhang       if (i == 0 || j == 0 || i == mx - 1 || j == my - 1) {
139c5566c22SJunchao Zhang         update += PetscRealPart((hydhx + hxdhy) * xv(j, i) * xv(j, i));
140c5566c22SJunchao Zhang       } else {
141c5566c22SJunchao Zhang         u  = xv(j, i);
142c5566c22SJunchao Zhang         uw = xv(j, i - 1);
143c5566c22SJunchao Zhang         ue = xv(j, i + 1);
144c5566c22SJunchao Zhang         un = xv(j - 1, i);
145c5566c22SJunchao Zhang         us = xv(j + 1, i);
146c5566c22SJunchao Zhang 
147c5566c22SJunchao Zhang         if (i - 1 == 0) uw = 0.;
148c5566c22SJunchao Zhang         if (i + 1 == mx - 1) ue = 0.;
149c5566c22SJunchao Zhang         if (j - 1 == 0) un = 0.;
150c5566c22SJunchao Zhang         if (j + 1 == my - 1) us = 0.;
151c5566c22SJunchao Zhang 
152c5566c22SJunchao Zhang         /* F[u] = 1/2\int_{\omega}\nabla^2u(x)*u(x)*dx */
153c5566c22SJunchao Zhang 
154c5566c22SJunchao Zhang         uxux = u * (2. * u - ue - uw) * hydhx;
155c5566c22SJunchao Zhang         uyuy = u * (2. * u - un - us) * hxdhy;
156c5566c22SJunchao Zhang 
157c5566c22SJunchao Zhang         update += PetscRealPart(0.5 * (uxux + uyuy) - sc * PetscExpScalar(u));
158c5566c22SJunchao Zhang       }
1599371c9d4SSatish Balay     },
1609371c9d4SSatish Balay     lobj));
161c5566c22SJunchao Zhang 
162*3ba16761SJacob Faibussowitsch   PetscCall(DMDAVecRestoreKokkosOffsetView(info->da, x, &xv));
163c5566c22SJunchao Zhang   PetscCall(PetscLogFlops(12.0 * info->ym * info->xm));
164c5566c22SJunchao Zhang   PetscCallMPI(MPI_Allreduce(&lobj, obj, 1, MPIU_REAL, MPIU_SUM, comm));
165*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
166c5566c22SJunchao Zhang }
167c5566c22SJunchao Zhang 
168d71ae5a4SJacob Faibussowitsch PetscErrorCode FormJacobianLocalVec(DMDALocalInfo *info, Vec x, Mat jac, Mat jacpre, AppCtx *user)
169d71ae5a4SJacob Faibussowitsch {
170c5566c22SJunchao Zhang   PetscInt     i, j;
171c5566c22SJunchao Zhang   PetscInt     xs = info->xs, ys = info->ys, xm = info->xm, ym = info->ym, mx = info->mx, my = info->my;
172c5566c22SJunchao Zhang   MatStencil   col[5], row;
173c5566c22SJunchao Zhang   PetscScalar  lambda, hx, hy, hxdhy, hydhx, sc;
174c5566c22SJunchao Zhang   DM           coordDA;
175c5566c22SJunchao Zhang   Vec          coordinates;
176c5566c22SJunchao Zhang   DMDACoor2d **coords;
177c5566c22SJunchao Zhang 
178c5566c22SJunchao Zhang   PetscFunctionBeginUser;
179c5566c22SJunchao Zhang   lambda = user->param;
180c5566c22SJunchao Zhang   /* Extract coordinates */
181c5566c22SJunchao Zhang   PetscCall(DMGetCoordinateDM(info->da, &coordDA));
182c5566c22SJunchao Zhang   PetscCall(DMGetCoordinates(info->da, &coordinates));
183c5566c22SJunchao Zhang 
184c5566c22SJunchao Zhang   PetscCall(DMDAVecGetArray(coordDA, coordinates, &coords));
185c5566c22SJunchao Zhang   hx = xm > 1 ? PetscRealPart(coords[ys][xs + 1].x) - PetscRealPart(coords[ys][xs].x) : 1.0;
186c5566c22SJunchao Zhang   hy = ym > 1 ? PetscRealPart(coords[ys + 1][xs].y) - PetscRealPart(coords[ys][xs].y) : 1.0;
187c5566c22SJunchao Zhang   PetscCall(DMDAVecRestoreArray(coordDA, coordinates, &coords));
188c5566c22SJunchao Zhang 
189c5566c22SJunchao Zhang   hxdhy = hx / hy;
190c5566c22SJunchao Zhang   hydhx = hy / hx;
191c5566c22SJunchao Zhang   sc    = hx * hy * lambda;
192c5566c22SJunchao Zhang 
193c5566c22SJunchao Zhang   /* ----------------------------------------- */
194c5566c22SJunchao Zhang   /*  MatSetPreallocationCOO()                 */
195c5566c22SJunchao Zhang   /* ----------------------------------------- */
196c5566c22SJunchao Zhang   PetscCount ncoo = ((PetscCount)xm) * ((PetscCount)ym) * 5;
197c5566c22SJunchao Zhang   PetscInt  *coo_i, *coo_j, *ip, *jp;
198c5566c22SJunchao Zhang   PetscCall(PetscMalloc2(ncoo, &coo_i, ncoo, &coo_j)); /* 5-point stencil such that each row has at most 5 nonzeros */
199c5566c22SJunchao Zhang 
200c5566c22SJunchao Zhang   ip = coo_i;
201c5566c22SJunchao Zhang   jp = coo_j;
202c5566c22SJunchao Zhang   for (j = ys; j < ys + ym; j++) {
203c5566c22SJunchao Zhang     for (i = xs; i < xs + xm; i++) {
2049371c9d4SSatish Balay       row.j = j;
2059371c9d4SSatish Balay       row.i = i;
206c5566c22SJunchao Zhang       /* Initialize neighbors with negative indices */
207c5566c22SJunchao Zhang       col[0].j = col[1].j = col[3].j = col[4].j = -1;
208c5566c22SJunchao Zhang       /* boundary points */
209c5566c22SJunchao Zhang       if (i == 0 || j == 0 || i == mx - 1 || j == my - 1) {
210c5566c22SJunchao Zhang         col[2].j = row.j;
211c5566c22SJunchao Zhang         col[2].i = row.i;
212c5566c22SJunchao Zhang       } else {
213c5566c22SJunchao Zhang         /* interior grid points */
214c5566c22SJunchao Zhang         if (j - 1 != 0) {
215c5566c22SJunchao Zhang           col[0].j = j - 1;
216c5566c22SJunchao Zhang           col[0].i = i;
217c5566c22SJunchao Zhang         }
218c5566c22SJunchao Zhang 
219c5566c22SJunchao Zhang         if (i - 1 != 0) {
220c5566c22SJunchao Zhang           col[1].j = j;
221c5566c22SJunchao Zhang           col[1].i = i - 1;
222c5566c22SJunchao Zhang         }
223c5566c22SJunchao Zhang 
224c5566c22SJunchao Zhang         col[2].j = row.j;
225c5566c22SJunchao Zhang         col[2].i = row.i;
226c5566c22SJunchao Zhang 
227c5566c22SJunchao Zhang         if (i + 1 != mx - 1) {
228c5566c22SJunchao Zhang           col[3].j = j;
229c5566c22SJunchao Zhang           col[3].i = i + 1;
230c5566c22SJunchao Zhang         }
231c5566c22SJunchao Zhang 
232c5566c22SJunchao Zhang         if (j + 1 != mx - 1) {
233c5566c22SJunchao Zhang           col[4].j = j + 1;
234c5566c22SJunchao Zhang           col[4].i = i;
235c5566c22SJunchao Zhang         }
236c5566c22SJunchao Zhang       }
237c5566c22SJunchao Zhang       PetscCall(DMDAMapMatStencilToGlobal(info->da, 5, col, jp));
238c5566c22SJunchao Zhang       for (PetscInt k = 0; k < 5; k++) ip[k] = jp[2];
239c5566c22SJunchao Zhang       ip += 5;
240c5566c22SJunchao Zhang       jp += 5;
241c5566c22SJunchao Zhang     }
242c5566c22SJunchao Zhang   }
243c5566c22SJunchao Zhang 
244c5566c22SJunchao Zhang   PetscCall(MatSetPreallocationCOO(jacpre, ncoo, coo_i, coo_j));
245c5566c22SJunchao Zhang   PetscCall(PetscFree2(coo_i, coo_j));
246c5566c22SJunchao Zhang 
247c5566c22SJunchao Zhang   /* ----------------------------------------- */
248c5566c22SJunchao Zhang   /*  MatSetValuesCOO()                        */
249c5566c22SJunchao Zhang   /* ----------------------------------------- */
250c5566c22SJunchao Zhang   PetscScalarKokkosView              coo_v("coo_v", ncoo);
251c5566c22SJunchao Zhang   ConstPetscScalarKokkosOffsetView2D xv;
252c5566c22SJunchao Zhang 
253*3ba16761SJacob Faibussowitsch   PetscCall(DMDAVecGetKokkosOffsetView(info->da, x, &xv));
254c5566c22SJunchao Zhang 
2559371c9d4SSatish Balay   PetscCallCXX(Kokkos::parallel_for(
2569371c9d4SSatish Balay     "FormFunctionLocalVec", MDRangePolicy<Rank<2, Iterate::Right, Iterate::Right>>({ys, xs}, {ys + ym, xs + xm}), KOKKOS_LAMBDA(PetscCount j, PetscCount i) {
257c5566c22SJunchao Zhang       PetscInt p = ((j - ys) * xm + (i - xs)) * 5;
258c5566c22SJunchao Zhang       /* boundary points */
259c5566c22SJunchao Zhang       if (i == 0 || j == 0 || i == mx - 1 || j == my - 1) {
260c5566c22SJunchao Zhang         coo_v(p + 2) = 2.0 * (hydhx + hxdhy);
261c5566c22SJunchao Zhang       } else {
262c5566c22SJunchao Zhang         /* interior grid points */
263ad540459SPierre Jolivet         if (j - 1 != 0) coo_v(p + 0) = -hxdhy;
264ad540459SPierre Jolivet         if (i - 1 != 0) coo_v(p + 1) = -hydhx;
265c5566c22SJunchao Zhang 
266c5566c22SJunchao Zhang         coo_v(p + 2) = 2.0 * (hydhx + hxdhy) - sc * PetscExpScalar(xv(j, i));
267c5566c22SJunchao Zhang 
268ad540459SPierre Jolivet         if (i + 1 != mx - 1) coo_v(p + 3) = -hydhx;
269ad540459SPierre Jolivet         if (j + 1 != mx - 1) coo_v(p + 4) = -hxdhy;
270c5566c22SJunchao Zhang       }
271c5566c22SJunchao Zhang     }));
272c5566c22SJunchao Zhang   PetscCall(MatSetValuesCOO(jacpre, coo_v.data(), INSERT_VALUES));
273*3ba16761SJacob Faibussowitsch   PetscCall(DMDAVecRestoreKokkosOffsetView(info->da, x, &xv));
274*3ba16761SJacob Faibussowitsch   PetscFunctionReturn(PETSC_SUCCESS);
275c5566c22SJunchao Zhang }
276