xref: /petsc/src/snes/tutorials/ex55k.kokkos.cxx (revision 33d966b40488ab9db89c850499939de39321db8e)
1c5566c22SJunchao Zhang #include <Kokkos_Core.hpp>
2c5566c22SJunchao Zhang #include <petscdmda_kokkos.hpp>
3c5566c22SJunchao Zhang 
4c5566c22SJunchao Zhang #include <petscdm.h>
5c5566c22SJunchao Zhang #include <petscdmda.h>
6c5566c22SJunchao Zhang #include <petscsnes.h>
7c5566c22SJunchao Zhang #include "ex55.h"
8c5566c22SJunchao Zhang 
9c5566c22SJunchao Zhang using DefaultMemorySpace                 = Kokkos::DefaultExecutionSpace::memory_space;
10c5566c22SJunchao Zhang using ConstPetscScalarKokkosOffsetView2D = Kokkos::Experimental::OffsetView<const PetscScalar**,Kokkos::LayoutRight,DefaultMemorySpace>;
11c5566c22SJunchao Zhang using PetscScalarKokkosOffsetView2D      = Kokkos::Experimental::OffsetView<PetscScalar**,Kokkos::LayoutRight,DefaultMemorySpace>;
12c5566c22SJunchao Zhang 
13c5566c22SJunchao Zhang using PetscCountKokkosView           = Kokkos::View<PetscCount*,DefaultMemorySpace>;
14c5566c22SJunchao Zhang using PetscIntKokkosView             = Kokkos::View<PetscInt*,DefaultMemorySpace>;
15c5566c22SJunchao Zhang using PetscCountKokkosViewHost       = Kokkos::View<PetscCount*,Kokkos::HostSpace>;
16c5566c22SJunchao Zhang using PetscScalarKokkosView          = Kokkos::View<PetscScalar*,DefaultMemorySpace>;
17*33d966b4SJunchao Zhang using Kokkos::Iterate;
18*33d966b4SJunchao Zhang using Kokkos::Rank;
19*33d966b4SJunchao Zhang using Kokkos::MDRangePolicy;
20c5566c22SJunchao Zhang 
21c5566c22SJunchao Zhang KOKKOS_INLINE_FUNCTION PetscErrorCode MMSSolution1(AppCtx *user,const DMDACoor2d *c,PetscScalar *u)
22c5566c22SJunchao Zhang {
23c5566c22SJunchao Zhang   PetscReal x = PetscRealPart(c->x), y = PetscRealPart(c->y);
24c5566c22SJunchao Zhang   u[0] = x*(1 - x)*y*(1 - y);
25c5566c22SJunchao Zhang   return 0;
26c5566c22SJunchao Zhang }
27c5566c22SJunchao Zhang 
28c5566c22SJunchao Zhang KOKKOS_INLINE_FUNCTION PetscErrorCode MMSForcing1(PetscReal user_param,const DMDACoor2d *c,PetscScalar *f)
29c5566c22SJunchao Zhang {
30c5566c22SJunchao Zhang   PetscReal x = PetscRealPart(c->x), y = PetscRealPart(c->y);
31c5566c22SJunchao Zhang   f[0] = 2*x*(1 - x) + 2*y*(1 - y) - user_param*PetscExpReal(x*(1 - x)*y*(1 - y));
32c5566c22SJunchao Zhang   return 0;
33c5566c22SJunchao Zhang }
34c5566c22SJunchao Zhang 
35c5566c22SJunchao Zhang PetscErrorCode FormFunctionLocalVec(DMDALocalInfo *info,Vec x,Vec f,AppCtx *user)
36c5566c22SJunchao Zhang {
37c5566c22SJunchao Zhang   PetscReal      lambda,hx,hy,hxdhy,hydhx;
38c5566c22SJunchao Zhang   PetscInt       xs = info->xs,ys = info->ys,xm = info->xm,ym = info->ym,mx = info->mx,my = info->my;
39c5566c22SJunchao Zhang   PetscReal      user_param = user->param;
40c5566c22SJunchao Zhang 
41c5566c22SJunchao Zhang   ConstPetscScalarKokkosOffsetView2D xv;
42c5566c22SJunchao Zhang   PetscScalarKokkosOffsetView2D      fv;
43c5566c22SJunchao Zhang 
44c5566c22SJunchao Zhang   PetscFunctionBeginUser;
45c5566c22SJunchao Zhang   lambda = user->param;
46c5566c22SJunchao Zhang   hx     = 1.0/(PetscReal)(info->mx-1);
47c5566c22SJunchao Zhang   hy     = 1.0/(PetscReal)(info->my-1);
48c5566c22SJunchao Zhang   hxdhy  = hx/hy;
49c5566c22SJunchao Zhang   hydhx  = hy/hx;
50c5566c22SJunchao Zhang   /*
51c5566c22SJunchao Zhang      Compute function over the locally owned part of the grid
52c5566c22SJunchao Zhang   */
53c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecGetKokkosOffsetView(info->da,x,&xv));
54c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecGetKokkosOffsetViewWrite(info->da,f,&fv));
55c5566c22SJunchao Zhang 
56c5566c22SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for ("FormFunctionLocalVec",
57c5566c22SJunchao Zhang     MDRangePolicy <Rank<2,Iterate::Right,Iterate::Right>>({ys,xs},{ys+ym,xs+xm}),
58c5566c22SJunchao Zhang     KOKKOS_LAMBDA (PetscInt j,PetscInt i)
59c5566c22SJunchao Zhang   {
60c5566c22SJunchao Zhang     DMDACoor2d   c;
61c5566c22SJunchao Zhang     PetscScalar  u,ue,uw,un,us,uxx,uyy,mms_solution,mms_forcing;
62c5566c22SJunchao Zhang 
63c5566c22SJunchao Zhang     if (i == 0 || j == 0 || i == mx-1 || j == my-1) {
64c5566c22SJunchao Zhang       c.x = i*hx; c.y = j*hy;
65c5566c22SJunchao Zhang       MMSSolution1(user,&c,&mms_solution);
66c5566c22SJunchao Zhang       fv(j,i) = 2.0*(hydhx+hxdhy)*(xv(j,i) - mms_solution);
67c5566c22SJunchao Zhang     } else {
68c5566c22SJunchao Zhang       u  = xv(j,i);
69c5566c22SJunchao Zhang       uw = xv(j,i-1);
70c5566c22SJunchao Zhang       ue = xv(j,i+1);
71c5566c22SJunchao Zhang       un = xv(j-1,i);
72c5566c22SJunchao Zhang       us = xv(j+1,i);
73c5566c22SJunchao Zhang 
74c5566c22SJunchao Zhang       /* Enforce boundary conditions at neighboring points -- setting these values causes the Jacobian to be symmetric. */
75c5566c22SJunchao Zhang       if (i-1 == 0) {c.x = (i-1)*hx; c.y = j*hy; MMSSolution1(user,&c,&uw);}
76c5566c22SJunchao Zhang       if (i+1 == mx-1) {c.x = (i+1)*hx; c.y = j*hy; MMSSolution1(user,&c,&ue);}
77c5566c22SJunchao Zhang       if (j-1 == 0) {c.x = i*hx; c.y = (j-1)*hy; MMSSolution1(user,&c,&un);}
78c5566c22SJunchao Zhang       if (j+1 == my-1) {c.x = i*hx; c.y = (j+1)*hy; MMSSolution1(user,&c,&us);}
79c5566c22SJunchao Zhang 
80c5566c22SJunchao Zhang       uxx     = (2.0*u - uw - ue)*hydhx;
81c5566c22SJunchao Zhang       uyy     = (2.0*u - un - us)*hxdhy;
82c5566c22SJunchao Zhang       mms_forcing = 0;
83c5566c22SJunchao Zhang       c.x = i*hx; c.y = j*hy;
84c5566c22SJunchao Zhang       MMSForcing1(user_param,&c,&mms_forcing);
85c5566c22SJunchao Zhang       fv(j,i) = uxx + uyy - hx*hy*(lambda*PetscExpScalar(u) + mms_forcing);
86c5566c22SJunchao Zhang     }
87c5566c22SJunchao Zhang   }));
88c5566c22SJunchao Zhang 
89c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecRestoreKokkosOffsetView(info->da,x,&xv));
90c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecRestoreKokkosOffsetViewWrite(info->da,f,&fv));
91c5566c22SJunchao Zhang 
92c5566c22SJunchao Zhang   PetscCall(PetscLogFlops(11.0*info->ym*info->xm));
93c5566c22SJunchao Zhang   PetscFunctionReturn(0);
94c5566c22SJunchao Zhang }
95c5566c22SJunchao Zhang 
96c5566c22SJunchao Zhang PetscErrorCode FormObjectiveLocalVec(DMDALocalInfo *info,Vec x,PetscReal *obj,AppCtx *user)
97c5566c22SJunchao Zhang {
98c5566c22SJunchao Zhang   PetscInt       xs = info->xs,ys = info->ys,xm = info->xm,ym = info->ym,mx = info->mx,my = info->my;
99c5566c22SJunchao Zhang   PetscReal      lambda,hx,hy,hxdhy,hydhx,sc,lobj=0;
100c5566c22SJunchao Zhang   MPI_Comm       comm;
101c5566c22SJunchao Zhang 
102c5566c22SJunchao Zhang   ConstPetscScalarKokkosOffsetView2D xv;
103c5566c22SJunchao Zhang 
104c5566c22SJunchao Zhang   PetscFunctionBeginUser;
105c5566c22SJunchao Zhang   *obj   = 0;
106c5566c22SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)info->da,&comm));
107c5566c22SJunchao Zhang   lambda = user->param;
108c5566c22SJunchao Zhang   hx     = 1.0/(PetscReal)(mx-1);
109c5566c22SJunchao Zhang   hy     = 1.0/(PetscReal)(my-1);
110c5566c22SJunchao Zhang   sc     = hx*hy*lambda;
111c5566c22SJunchao Zhang   hxdhy  = hx/hy;
112c5566c22SJunchao Zhang   hydhx  = hy/hx;
113c5566c22SJunchao Zhang   /*
114c5566c22SJunchao Zhang      Compute function over the locally owned part of the grid
115c5566c22SJunchao Zhang   */
116c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecGetKokkosOffsetView(info->da,x,&xv));
117c5566c22SJunchao Zhang 
118c5566c22SJunchao Zhang   PetscCallCXX(Kokkos::parallel_reduce("FormObjectiveLocalVec",
119c5566c22SJunchao Zhang     MDRangePolicy <Rank<2,Iterate::Right,Iterate::Right>>({ys,xs},{ys+ym,xs+xm}),
120c5566c22SJunchao Zhang     KOKKOS_LAMBDA (PetscInt j,PetscInt i,PetscReal& update)
121c5566c22SJunchao Zhang   {
122c5566c22SJunchao Zhang     PetscScalar    u,ue,uw,un,us,uxux,uyuy;
123c5566c22SJunchao Zhang     if (i == 0 || j == 0 || i == mx-1 || j == my-1) {
124c5566c22SJunchao Zhang       update += PetscRealPart((hydhx + hxdhy)*xv(j,i)*xv(j,i));
125c5566c22SJunchao Zhang     } else {
126c5566c22SJunchao Zhang       u  = xv(j,i);
127c5566c22SJunchao Zhang       uw = xv(j,i-1);
128c5566c22SJunchao Zhang       ue = xv(j,i+1);
129c5566c22SJunchao Zhang       un = xv(j-1,i);
130c5566c22SJunchao Zhang       us = xv(j+1,i);
131c5566c22SJunchao Zhang 
132c5566c22SJunchao Zhang       if (i-1 == 0)    uw = 0.;
133c5566c22SJunchao Zhang       if (i+1 == mx-1) ue = 0.;
134c5566c22SJunchao Zhang       if (j-1 == 0)    un = 0.;
135c5566c22SJunchao Zhang       if (j+1 == my-1) us = 0.;
136c5566c22SJunchao Zhang 
137c5566c22SJunchao Zhang       /* F[u] = 1/2\int_{\omega}\nabla^2u(x)*u(x)*dx */
138c5566c22SJunchao Zhang 
139c5566c22SJunchao Zhang       uxux = u*(2.*u - ue - uw)*hydhx;
140c5566c22SJunchao Zhang       uyuy = u*(2.*u - un - us)*hxdhy;
141c5566c22SJunchao Zhang 
142c5566c22SJunchao Zhang       update += PetscRealPart(0.5*(uxux + uyuy) - sc*PetscExpScalar(u));
143c5566c22SJunchao Zhang     }
144c5566c22SJunchao Zhang   },lobj));
145c5566c22SJunchao Zhang 
146c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecRestoreKokkosOffsetView(info->da,x,&xv));
147c5566c22SJunchao Zhang   PetscCall(PetscLogFlops(12.0*info->ym*info->xm));
148c5566c22SJunchao Zhang   PetscCallMPI(MPI_Allreduce(&lobj,obj,1,MPIU_REAL,MPIU_SUM,comm));
149c5566c22SJunchao Zhang   PetscFunctionReturn(0);
150c5566c22SJunchao Zhang }
151c5566c22SJunchao Zhang 
152c5566c22SJunchao Zhang PetscErrorCode FormJacobianLocalVec(DMDALocalInfo *info,Vec x,Mat jac,Mat jacpre,AppCtx *user)
153c5566c22SJunchao Zhang {
154c5566c22SJunchao Zhang   PetscInt       i,j;
155c5566c22SJunchao Zhang   PetscInt       xs = info->xs,ys = info->ys,xm = info->xm,ym = info->ym,mx = info->mx,my = info->my;
156c5566c22SJunchao Zhang   MatStencil     col[5],row;
157c5566c22SJunchao Zhang   PetscScalar    lambda,hx,hy,hxdhy,hydhx,sc;
158c5566c22SJunchao Zhang   DM             coordDA;
159c5566c22SJunchao Zhang   Vec            coordinates;
160c5566c22SJunchao Zhang   DMDACoor2d     **coords;
161c5566c22SJunchao Zhang 
162c5566c22SJunchao Zhang   PetscFunctionBeginUser;
163c5566c22SJunchao Zhang   lambda = user->param;
164c5566c22SJunchao Zhang   /* Extract coordinates */
165c5566c22SJunchao Zhang   PetscCall(DMGetCoordinateDM(info->da, &coordDA));
166c5566c22SJunchao Zhang   PetscCall(DMGetCoordinates(info->da, &coordinates));
167c5566c22SJunchao Zhang 
168c5566c22SJunchao Zhang   PetscCall(DMDAVecGetArray(coordDA, coordinates, &coords));
169c5566c22SJunchao Zhang   hx     = xm > 1 ? PetscRealPart(coords[ys][xs+1].x) - PetscRealPart(coords[ys][xs].x) : 1.0;
170c5566c22SJunchao Zhang   hy     = ym > 1 ? PetscRealPart(coords[ys+1][xs].y) - PetscRealPart(coords[ys][xs].y) : 1.0;
171c5566c22SJunchao Zhang   PetscCall(DMDAVecRestoreArray(coordDA, coordinates, &coords));
172c5566c22SJunchao Zhang 
173c5566c22SJunchao Zhang   hxdhy  = hx/hy;
174c5566c22SJunchao Zhang   hydhx  = hy/hx;
175c5566c22SJunchao Zhang   sc     = hx*hy*lambda;
176c5566c22SJunchao Zhang 
177c5566c22SJunchao Zhang   /* ----------------------------------------- */
178c5566c22SJunchao Zhang   /*  MatSetPreallocationCOO()                 */
179c5566c22SJunchao Zhang   /* ----------------------------------------- */
180c5566c22SJunchao Zhang   PetscCount ncoo = ((PetscCount)xm)*((PetscCount)ym)*5;
181c5566c22SJunchao Zhang   PetscInt   *coo_i,*coo_j,*ip,*jp;
182c5566c22SJunchao Zhang   PetscCall(PetscMalloc2(ncoo,&coo_i,ncoo,&coo_j)); /* 5-point stencil such that each row has at most 5 nonzeros */
183c5566c22SJunchao Zhang 
184c5566c22SJunchao Zhang   ip = coo_i;
185c5566c22SJunchao Zhang   jp = coo_j;
186c5566c22SJunchao Zhang   for (j=ys; j<ys+ym; j++) {
187c5566c22SJunchao Zhang     for (i=xs; i<xs+xm; i++) {
188c5566c22SJunchao Zhang       row.j = j; row.i = i;
189c5566c22SJunchao Zhang       /* Initialize neighbors with negative indices */
190c5566c22SJunchao Zhang       col[0].j = col[1].j = col[3].j = col[4].j = -1;
191c5566c22SJunchao Zhang       /* boundary points */
192c5566c22SJunchao Zhang       if (i == 0 || j == 0 || i == mx-1 || j == my-1) {
193c5566c22SJunchao Zhang         col[2].j = row.j;
194c5566c22SJunchao Zhang         col[2].i = row.i;
195c5566c22SJunchao Zhang       } else {
196c5566c22SJunchao Zhang         /* interior grid points */
197c5566c22SJunchao Zhang         if (j-1 != 0) {
198c5566c22SJunchao Zhang           col[0].j = j - 1;
199c5566c22SJunchao Zhang           col[0].i = i;
200c5566c22SJunchao Zhang         }
201c5566c22SJunchao Zhang 
202c5566c22SJunchao Zhang         if (i-1 != 0) {
203c5566c22SJunchao Zhang           col[1].j = j;
204c5566c22SJunchao Zhang           col[1].i = i-1;
205c5566c22SJunchao Zhang         }
206c5566c22SJunchao Zhang 
207c5566c22SJunchao Zhang         col[2].j = row.j;
208c5566c22SJunchao Zhang         col[2].i = row.i;
209c5566c22SJunchao Zhang 
210c5566c22SJunchao Zhang         if (i+1 != mx-1) {
211c5566c22SJunchao Zhang           col[3].j = j;
212c5566c22SJunchao Zhang           col[3].i = i+1;
213c5566c22SJunchao Zhang         }
214c5566c22SJunchao Zhang 
215c5566c22SJunchao Zhang         if (j+1 != mx-1) {
216c5566c22SJunchao Zhang           col[4].j = j + 1;
217c5566c22SJunchao Zhang           col[4].i = i;
218c5566c22SJunchao Zhang         }
219c5566c22SJunchao Zhang       }
220c5566c22SJunchao Zhang       PetscCall(DMDAMapMatStencilToGlobal(info->da,5,col,jp));
221c5566c22SJunchao Zhang       for (PetscInt k=0; k<5; k++) ip[k] = jp[2];
222c5566c22SJunchao Zhang       ip += 5;
223c5566c22SJunchao Zhang       jp += 5;
224c5566c22SJunchao Zhang     }
225c5566c22SJunchao Zhang   }
226c5566c22SJunchao Zhang 
227c5566c22SJunchao Zhang   PetscCall(MatSetPreallocationCOO(jacpre,ncoo,coo_i,coo_j));
228c5566c22SJunchao Zhang   PetscCall(PetscFree2(coo_i,coo_j));
229c5566c22SJunchao Zhang 
230c5566c22SJunchao Zhang   /* ----------------------------------------- */
231c5566c22SJunchao Zhang   /*  MatSetValuesCOO()                        */
232c5566c22SJunchao Zhang   /* ----------------------------------------- */
233c5566c22SJunchao Zhang   PetscScalarKokkosView              coo_v("coo_v",ncoo);
234c5566c22SJunchao Zhang   ConstPetscScalarKokkosOffsetView2D xv;
235c5566c22SJunchao Zhang 
236c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecGetKokkosOffsetView(info->da,x,&xv));
237c5566c22SJunchao Zhang 
238c5566c22SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for ("FormFunctionLocalVec",
239c5566c22SJunchao Zhang     MDRangePolicy <Rank<2,Iterate::Right,Iterate::Right>>({ys,xs},{ys+ym,xs+xm}),
240c5566c22SJunchao Zhang     KOKKOS_LAMBDA (PetscCount j,PetscCount i)
241c5566c22SJunchao Zhang   {
242c5566c22SJunchao Zhang     PetscInt p = ((j-ys)*xm + (i-xs))*5;
243c5566c22SJunchao Zhang     /* boundary points */
244c5566c22SJunchao Zhang     if (i == 0 || j == 0 || i == mx-1 || j == my-1) {
245c5566c22SJunchao Zhang       coo_v(p+2) =  2.0*(hydhx + hxdhy);
246c5566c22SJunchao Zhang     } else {
247c5566c22SJunchao Zhang       /* interior grid points */
248c5566c22SJunchao Zhang       if (j-1 != 0) {
249c5566c22SJunchao Zhang         coo_v(p+0)     = -hxdhy;
250c5566c22SJunchao Zhang       }
251c5566c22SJunchao Zhang       if (i-1 != 0) {
252c5566c22SJunchao Zhang         coo_v(p+1)     = -hydhx;
253c5566c22SJunchao Zhang       }
254c5566c22SJunchao Zhang 
255c5566c22SJunchao Zhang       coo_v(p+2) = 2.0*(hydhx + hxdhy) - sc*PetscExpScalar(xv(j,i));
256c5566c22SJunchao Zhang 
257c5566c22SJunchao Zhang       if (i+1 != mx-1) {
258c5566c22SJunchao Zhang         coo_v(p+3)     = -hydhx;
259c5566c22SJunchao Zhang       }
260c5566c22SJunchao Zhang       if (j+1 != mx-1) {
261c5566c22SJunchao Zhang         coo_v(p+4)     = -hxdhy;
262c5566c22SJunchao Zhang       }
263c5566c22SJunchao Zhang     }
264c5566c22SJunchao Zhang   }));
265c5566c22SJunchao Zhang   PetscCall(MatSetValuesCOO(jacpre,coo_v.data(),INSERT_VALUES));
266c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecRestoreKokkosOffsetView(info->da,x,&xv));
267c5566c22SJunchao Zhang   PetscFunctionReturn(0);
268c5566c22SJunchao Zhang }
269