xref: /petsc/src/snes/tutorials/ex55k.kokkos.cxx (revision c5566c22a574ae02fd64a900a65fba926662d66a)
1*c5566c22SJunchao Zhang #include <Kokkos_Core.hpp>
2*c5566c22SJunchao Zhang #include <petscdmda_kokkos.hpp>
3*c5566c22SJunchao Zhang 
4*c5566c22SJunchao Zhang #include <petscdm.h>
5*c5566c22SJunchao Zhang #include <petscdmda.h>
6*c5566c22SJunchao Zhang #include <petscsnes.h>
7*c5566c22SJunchao Zhang #include "ex55.h"
8*c5566c22SJunchao Zhang 
9*c5566c22SJunchao Zhang using namespace Kokkos;
10*c5566c22SJunchao Zhang using DefaultMemorySpace                 = Kokkos::DefaultExecutionSpace::memory_space;
11*c5566c22SJunchao Zhang using ConstPetscScalarKokkosOffsetView2D = Kokkos::Experimental::OffsetView<const PetscScalar**,Kokkos::LayoutRight,DefaultMemorySpace>;
12*c5566c22SJunchao Zhang using PetscScalarKokkosOffsetView2D      = Kokkos::Experimental::OffsetView<PetscScalar**,Kokkos::LayoutRight,DefaultMemorySpace>;
13*c5566c22SJunchao Zhang 
14*c5566c22SJunchao Zhang using PetscCountKokkosView           = Kokkos::View<PetscCount*,DefaultMemorySpace>;
15*c5566c22SJunchao Zhang using PetscIntKokkosView             = Kokkos::View<PetscInt*,DefaultMemorySpace>;
16*c5566c22SJunchao Zhang using PetscCountKokkosViewHost       = Kokkos::View<PetscCount*,Kokkos::HostSpace>;
17*c5566c22SJunchao Zhang using PetscScalarKokkosView          = Kokkos::View<PetscScalar*,DefaultMemorySpace>;
18*c5566c22SJunchao Zhang 
19*c5566c22SJunchao Zhang KOKKOS_INLINE_FUNCTION PetscErrorCode MMSSolution1(AppCtx *user,const DMDACoor2d *c,PetscScalar *u)
20*c5566c22SJunchao Zhang {
21*c5566c22SJunchao Zhang   PetscReal x = PetscRealPart(c->x), y = PetscRealPart(c->y);
22*c5566c22SJunchao Zhang   u[0] = x*(1 - x)*y*(1 - y);
23*c5566c22SJunchao Zhang   return 0;
24*c5566c22SJunchao Zhang }
25*c5566c22SJunchao Zhang 
26*c5566c22SJunchao Zhang KOKKOS_INLINE_FUNCTION PetscErrorCode MMSForcing1(PetscReal user_param,const DMDACoor2d *c,PetscScalar *f)
27*c5566c22SJunchao Zhang {
28*c5566c22SJunchao Zhang   PetscReal x = PetscRealPart(c->x), y = PetscRealPart(c->y);
29*c5566c22SJunchao Zhang   f[0] = 2*x*(1 - x) + 2*y*(1 - y) - user_param*PetscExpReal(x*(1 - x)*y*(1 - y));
30*c5566c22SJunchao Zhang   return 0;
31*c5566c22SJunchao Zhang }
32*c5566c22SJunchao Zhang 
33*c5566c22SJunchao Zhang PetscErrorCode FormFunctionLocalVec(DMDALocalInfo *info,Vec x,Vec f,AppCtx *user)
34*c5566c22SJunchao Zhang {
35*c5566c22SJunchao Zhang   PetscReal      lambda,hx,hy,hxdhy,hydhx;
36*c5566c22SJunchao Zhang   PetscInt       xs = info->xs,ys = info->ys,xm = info->xm,ym = info->ym,mx = info->mx,my = info->my;
37*c5566c22SJunchao Zhang   PetscReal      user_param = user->param;
38*c5566c22SJunchao Zhang 
39*c5566c22SJunchao Zhang   ConstPetscScalarKokkosOffsetView2D xv;
40*c5566c22SJunchao Zhang   PetscScalarKokkosOffsetView2D      fv;
41*c5566c22SJunchao Zhang 
42*c5566c22SJunchao Zhang   PetscFunctionBeginUser;
43*c5566c22SJunchao Zhang   lambda = user->param;
44*c5566c22SJunchao Zhang   hx     = 1.0/(PetscReal)(info->mx-1);
45*c5566c22SJunchao Zhang   hy     = 1.0/(PetscReal)(info->my-1);
46*c5566c22SJunchao Zhang   hxdhy  = hx/hy;
47*c5566c22SJunchao Zhang   hydhx  = hy/hx;
48*c5566c22SJunchao Zhang   /*
49*c5566c22SJunchao Zhang      Compute function over the locally owned part of the grid
50*c5566c22SJunchao Zhang   */
51*c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecGetKokkosOffsetView(info->da,x,&xv));
52*c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecGetKokkosOffsetViewWrite(info->da,f,&fv));
53*c5566c22SJunchao Zhang 
54*c5566c22SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for ("FormFunctionLocalVec",
55*c5566c22SJunchao Zhang     MDRangePolicy <Rank<2,Iterate::Right,Iterate::Right>>({ys,xs},{ys+ym,xs+xm}),
56*c5566c22SJunchao Zhang     KOKKOS_LAMBDA (PetscInt j,PetscInt i)
57*c5566c22SJunchao Zhang   {
58*c5566c22SJunchao Zhang     DMDACoor2d   c;
59*c5566c22SJunchao Zhang     PetscScalar  u,ue,uw,un,us,uxx,uyy,mms_solution,mms_forcing;
60*c5566c22SJunchao Zhang 
61*c5566c22SJunchao Zhang     if (i == 0 || j == 0 || i == mx-1 || j == my-1) {
62*c5566c22SJunchao Zhang       c.x = i*hx; c.y = j*hy;
63*c5566c22SJunchao Zhang       MMSSolution1(user,&c,&mms_solution);
64*c5566c22SJunchao Zhang       fv(j,i) = 2.0*(hydhx+hxdhy)*(xv(j,i) - mms_solution);
65*c5566c22SJunchao Zhang     } else {
66*c5566c22SJunchao Zhang       u  = xv(j,i);
67*c5566c22SJunchao Zhang       uw = xv(j,i-1);
68*c5566c22SJunchao Zhang       ue = xv(j,i+1);
69*c5566c22SJunchao Zhang       un = xv(j-1,i);
70*c5566c22SJunchao Zhang       us = xv(j+1,i);
71*c5566c22SJunchao Zhang 
72*c5566c22SJunchao Zhang       /* Enforce boundary conditions at neighboring points -- setting these values causes the Jacobian to be symmetric. */
73*c5566c22SJunchao Zhang       if (i-1 == 0) {c.x = (i-1)*hx; c.y = j*hy; MMSSolution1(user,&c,&uw);}
74*c5566c22SJunchao Zhang       if (i+1 == mx-1) {c.x = (i+1)*hx; c.y = j*hy; MMSSolution1(user,&c,&ue);}
75*c5566c22SJunchao Zhang       if (j-1 == 0) {c.x = i*hx; c.y = (j-1)*hy; MMSSolution1(user,&c,&un);}
76*c5566c22SJunchao Zhang       if (j+1 == my-1) {c.x = i*hx; c.y = (j+1)*hy; MMSSolution1(user,&c,&us);}
77*c5566c22SJunchao Zhang 
78*c5566c22SJunchao Zhang       uxx     = (2.0*u - uw - ue)*hydhx;
79*c5566c22SJunchao Zhang       uyy     = (2.0*u - un - us)*hxdhy;
80*c5566c22SJunchao Zhang       mms_forcing = 0;
81*c5566c22SJunchao Zhang       c.x = i*hx; c.y = j*hy;
82*c5566c22SJunchao Zhang       MMSForcing1(user_param,&c,&mms_forcing);
83*c5566c22SJunchao Zhang       fv(j,i) = uxx + uyy - hx*hy*(lambda*PetscExpScalar(u) + mms_forcing);
84*c5566c22SJunchao Zhang     }
85*c5566c22SJunchao Zhang   }));
86*c5566c22SJunchao Zhang 
87*c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecRestoreKokkosOffsetView(info->da,x,&xv));
88*c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecRestoreKokkosOffsetViewWrite(info->da,f,&fv));
89*c5566c22SJunchao Zhang 
90*c5566c22SJunchao Zhang   PetscCall(PetscLogFlops(11.0*info->ym*info->xm));
91*c5566c22SJunchao Zhang   PetscFunctionReturn(0);
92*c5566c22SJunchao Zhang }
93*c5566c22SJunchao Zhang 
94*c5566c22SJunchao Zhang PetscErrorCode FormObjectiveLocalVec(DMDALocalInfo *info,Vec x,PetscReal *obj,AppCtx *user)
95*c5566c22SJunchao Zhang {
96*c5566c22SJunchao Zhang   PetscInt       xs = info->xs,ys = info->ys,xm = info->xm,ym = info->ym,mx = info->mx,my = info->my;
97*c5566c22SJunchao Zhang   PetscReal      lambda,hx,hy,hxdhy,hydhx,sc,lobj=0;
98*c5566c22SJunchao Zhang   MPI_Comm       comm;
99*c5566c22SJunchao Zhang 
100*c5566c22SJunchao Zhang   ConstPetscScalarKokkosOffsetView2D xv;
101*c5566c22SJunchao Zhang 
102*c5566c22SJunchao Zhang   PetscFunctionBeginUser;
103*c5566c22SJunchao Zhang   *obj   = 0;
104*c5566c22SJunchao Zhang   PetscCall(PetscObjectGetComm((PetscObject)info->da,&comm));
105*c5566c22SJunchao Zhang   lambda = user->param;
106*c5566c22SJunchao Zhang   hx     = 1.0/(PetscReal)(mx-1);
107*c5566c22SJunchao Zhang   hy     = 1.0/(PetscReal)(my-1);
108*c5566c22SJunchao Zhang   sc     = hx*hy*lambda;
109*c5566c22SJunchao Zhang   hxdhy  = hx/hy;
110*c5566c22SJunchao Zhang   hydhx  = hy/hx;
111*c5566c22SJunchao Zhang   /*
112*c5566c22SJunchao Zhang      Compute function over the locally owned part of the grid
113*c5566c22SJunchao Zhang   */
114*c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecGetKokkosOffsetView(info->da,x,&xv));
115*c5566c22SJunchao Zhang 
116*c5566c22SJunchao Zhang   PetscCallCXX(Kokkos::parallel_reduce("FormObjectiveLocalVec",
117*c5566c22SJunchao Zhang     MDRangePolicy <Rank<2,Iterate::Right,Iterate::Right>>({ys,xs},{ys+ym,xs+xm}),
118*c5566c22SJunchao Zhang     KOKKOS_LAMBDA (PetscInt j,PetscInt i,PetscReal& update)
119*c5566c22SJunchao Zhang   {
120*c5566c22SJunchao Zhang     PetscScalar    u,ue,uw,un,us,uxux,uyuy;
121*c5566c22SJunchao Zhang     if (i == 0 || j == 0 || i == mx-1 || j == my-1) {
122*c5566c22SJunchao Zhang       update += PetscRealPart((hydhx + hxdhy)*xv(j,i)*xv(j,i));
123*c5566c22SJunchao Zhang     } else {
124*c5566c22SJunchao Zhang       u  = xv(j,i);
125*c5566c22SJunchao Zhang       uw = xv(j,i-1);
126*c5566c22SJunchao Zhang       ue = xv(j,i+1);
127*c5566c22SJunchao Zhang       un = xv(j-1,i);
128*c5566c22SJunchao Zhang       us = xv(j+1,i);
129*c5566c22SJunchao Zhang 
130*c5566c22SJunchao Zhang       if (i-1 == 0)    uw = 0.;
131*c5566c22SJunchao Zhang       if (i+1 == mx-1) ue = 0.;
132*c5566c22SJunchao Zhang       if (j-1 == 0)    un = 0.;
133*c5566c22SJunchao Zhang       if (j+1 == my-1) us = 0.;
134*c5566c22SJunchao Zhang 
135*c5566c22SJunchao Zhang       /* F[u] = 1/2\int_{\omega}\nabla^2u(x)*u(x)*dx */
136*c5566c22SJunchao Zhang 
137*c5566c22SJunchao Zhang       uxux = u*(2.*u - ue - uw)*hydhx;
138*c5566c22SJunchao Zhang       uyuy = u*(2.*u - un - us)*hxdhy;
139*c5566c22SJunchao Zhang 
140*c5566c22SJunchao Zhang       update += PetscRealPart(0.5*(uxux + uyuy) - sc*PetscExpScalar(u));
141*c5566c22SJunchao Zhang     }
142*c5566c22SJunchao Zhang   },lobj));
143*c5566c22SJunchao Zhang 
144*c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecRestoreKokkosOffsetView(info->da,x,&xv));
145*c5566c22SJunchao Zhang   PetscCall(PetscLogFlops(12.0*info->ym*info->xm));
146*c5566c22SJunchao Zhang   PetscCallMPI(MPI_Allreduce(&lobj,obj,1,MPIU_REAL,MPIU_SUM,comm));
147*c5566c22SJunchao Zhang   PetscFunctionReturn(0);
148*c5566c22SJunchao Zhang }
149*c5566c22SJunchao Zhang 
150*c5566c22SJunchao Zhang PetscErrorCode FormJacobianLocalVec(DMDALocalInfo *info,Vec x,Mat jac,Mat jacpre,AppCtx *user)
151*c5566c22SJunchao Zhang {
152*c5566c22SJunchao Zhang   PetscInt       i,j;
153*c5566c22SJunchao Zhang   PetscInt       xs = info->xs,ys = info->ys,xm = info->xm,ym = info->ym,mx = info->mx,my = info->my;
154*c5566c22SJunchao Zhang   MatStencil     col[5],row;
155*c5566c22SJunchao Zhang   PetscScalar    lambda,hx,hy,hxdhy,hydhx,sc;
156*c5566c22SJunchao Zhang   DM             coordDA;
157*c5566c22SJunchao Zhang   Vec            coordinates;
158*c5566c22SJunchao Zhang   DMDACoor2d     **coords;
159*c5566c22SJunchao Zhang 
160*c5566c22SJunchao Zhang   PetscFunctionBeginUser;
161*c5566c22SJunchao Zhang   lambda = user->param;
162*c5566c22SJunchao Zhang   /* Extract coordinates */
163*c5566c22SJunchao Zhang   PetscCall(DMGetCoordinateDM(info->da, &coordDA));
164*c5566c22SJunchao Zhang   PetscCall(DMGetCoordinates(info->da, &coordinates));
165*c5566c22SJunchao Zhang 
166*c5566c22SJunchao Zhang   PetscCall(DMDAVecGetArray(coordDA, coordinates, &coords));
167*c5566c22SJunchao Zhang   hx     = xm > 1 ? PetscRealPart(coords[ys][xs+1].x) - PetscRealPart(coords[ys][xs].x) : 1.0;
168*c5566c22SJunchao Zhang   hy     = ym > 1 ? PetscRealPart(coords[ys+1][xs].y) - PetscRealPart(coords[ys][xs].y) : 1.0;
169*c5566c22SJunchao Zhang   PetscCall(DMDAVecRestoreArray(coordDA, coordinates, &coords));
170*c5566c22SJunchao Zhang 
171*c5566c22SJunchao Zhang   hxdhy  = hx/hy;
172*c5566c22SJunchao Zhang   hydhx  = hy/hx;
173*c5566c22SJunchao Zhang   sc     = hx*hy*lambda;
174*c5566c22SJunchao Zhang 
175*c5566c22SJunchao Zhang   /* ----------------------------------------- */
176*c5566c22SJunchao Zhang   /*  MatSetPreallocationCOO()                 */
177*c5566c22SJunchao Zhang   /* ----------------------------------------- */
178*c5566c22SJunchao Zhang   PetscCount ncoo = ((PetscCount)xm)*((PetscCount)ym)*5;
179*c5566c22SJunchao Zhang   PetscInt   *coo_i,*coo_j,*ip,*jp;
180*c5566c22SJunchao Zhang   PetscCall(PetscMalloc2(ncoo,&coo_i,ncoo,&coo_j)); /* 5-point stencil such that each row has at most 5 nonzeros */
181*c5566c22SJunchao Zhang 
182*c5566c22SJunchao Zhang   ip = coo_i;
183*c5566c22SJunchao Zhang   jp = coo_j;
184*c5566c22SJunchao Zhang   for (j=ys; j<ys+ym; j++) {
185*c5566c22SJunchao Zhang     for (i=xs; i<xs+xm; i++) {
186*c5566c22SJunchao Zhang       row.j = j; row.i = i;
187*c5566c22SJunchao Zhang       /* Initialize neighbors with negative indices */
188*c5566c22SJunchao Zhang       col[0].j = col[1].j = col[3].j = col[4].j = -1;
189*c5566c22SJunchao Zhang       /* boundary points */
190*c5566c22SJunchao Zhang       if (i == 0 || j == 0 || i == mx-1 || j == my-1) {
191*c5566c22SJunchao Zhang         col[2].j = row.j;
192*c5566c22SJunchao Zhang         col[2].i = row.i;
193*c5566c22SJunchao Zhang       } else {
194*c5566c22SJunchao Zhang         /* interior grid points */
195*c5566c22SJunchao Zhang         if (j-1 != 0) {
196*c5566c22SJunchao Zhang           col[0].j = j - 1;
197*c5566c22SJunchao Zhang           col[0].i = i;
198*c5566c22SJunchao Zhang         }
199*c5566c22SJunchao Zhang 
200*c5566c22SJunchao Zhang         if (i-1 != 0) {
201*c5566c22SJunchao Zhang           col[1].j = j;
202*c5566c22SJunchao Zhang           col[1].i = i-1;
203*c5566c22SJunchao Zhang         }
204*c5566c22SJunchao Zhang 
205*c5566c22SJunchao Zhang         col[2].j = row.j;
206*c5566c22SJunchao Zhang         col[2].i = row.i;
207*c5566c22SJunchao Zhang 
208*c5566c22SJunchao Zhang         if (i+1 != mx-1) {
209*c5566c22SJunchao Zhang           col[3].j = j;
210*c5566c22SJunchao Zhang           col[3].i = i+1;
211*c5566c22SJunchao Zhang         }
212*c5566c22SJunchao Zhang 
213*c5566c22SJunchao Zhang         if (j+1 != mx-1) {
214*c5566c22SJunchao Zhang           col[4].j = j + 1;
215*c5566c22SJunchao Zhang           col[4].i = i;
216*c5566c22SJunchao Zhang         }
217*c5566c22SJunchao Zhang       }
218*c5566c22SJunchao Zhang       PetscCall(DMDAMapMatStencilToGlobal(info->da,5,col,jp));
219*c5566c22SJunchao Zhang       for (PetscInt k=0; k<5; k++) ip[k] = jp[2];
220*c5566c22SJunchao Zhang       ip += 5;
221*c5566c22SJunchao Zhang       jp += 5;
222*c5566c22SJunchao Zhang     }
223*c5566c22SJunchao Zhang   }
224*c5566c22SJunchao Zhang 
225*c5566c22SJunchao Zhang   PetscCall(MatSetPreallocationCOO(jacpre,ncoo,coo_i,coo_j));
226*c5566c22SJunchao Zhang   PetscCall(PetscFree2(coo_i,coo_j));
227*c5566c22SJunchao Zhang 
228*c5566c22SJunchao Zhang   /* ----------------------------------------- */
229*c5566c22SJunchao Zhang   /*  MatSetValuesCOO()                        */
230*c5566c22SJunchao Zhang   /* ----------------------------------------- */
231*c5566c22SJunchao Zhang   PetscScalarKokkosView              coo_v("coo_v",ncoo);
232*c5566c22SJunchao Zhang   ConstPetscScalarKokkosOffsetView2D xv;
233*c5566c22SJunchao Zhang 
234*c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecGetKokkosOffsetView(info->da,x,&xv));
235*c5566c22SJunchao Zhang 
236*c5566c22SJunchao Zhang   PetscCallCXX(Kokkos::parallel_for ("FormFunctionLocalVec",
237*c5566c22SJunchao Zhang     MDRangePolicy <Rank<2,Iterate::Right,Iterate::Right>>({ys,xs},{ys+ym,xs+xm}),
238*c5566c22SJunchao Zhang     KOKKOS_LAMBDA (PetscCount j,PetscCount i)
239*c5566c22SJunchao Zhang   {
240*c5566c22SJunchao Zhang     PetscInt p = ((j-ys)*xm + (i-xs))*5;
241*c5566c22SJunchao Zhang     /* boundary points */
242*c5566c22SJunchao Zhang     if (i == 0 || j == 0 || i == mx-1 || j == my-1) {
243*c5566c22SJunchao Zhang       coo_v(p+2) =  2.0*(hydhx + hxdhy);
244*c5566c22SJunchao Zhang     } else {
245*c5566c22SJunchao Zhang       /* interior grid points */
246*c5566c22SJunchao Zhang       if (j-1 != 0) {
247*c5566c22SJunchao Zhang         coo_v(p+0)     = -hxdhy;
248*c5566c22SJunchao Zhang       }
249*c5566c22SJunchao Zhang       if (i-1 != 0) {
250*c5566c22SJunchao Zhang         coo_v(p+1)     = -hydhx;
251*c5566c22SJunchao Zhang       }
252*c5566c22SJunchao Zhang 
253*c5566c22SJunchao Zhang       coo_v(p+2) = 2.0*(hydhx + hxdhy) - sc*PetscExpScalar(xv(j,i));
254*c5566c22SJunchao Zhang 
255*c5566c22SJunchao Zhang       if (i+1 != mx-1) {
256*c5566c22SJunchao Zhang         coo_v(p+3)     = -hydhx;
257*c5566c22SJunchao Zhang       }
258*c5566c22SJunchao Zhang       if (j+1 != mx-1) {
259*c5566c22SJunchao Zhang         coo_v(p+4)     = -hxdhy;
260*c5566c22SJunchao Zhang       }
261*c5566c22SJunchao Zhang     }
262*c5566c22SJunchao Zhang   }));
263*c5566c22SJunchao Zhang   PetscCall(MatSetValuesCOO(jacpre,coo_v.data(),INSERT_VALUES));
264*c5566c22SJunchao Zhang   PetscCallCXX(DMDAVecRestoreKokkosOffsetView(info->da,x,&xv));
265*c5566c22SJunchao Zhang   PetscFunctionReturn(0);
266*c5566c22SJunchao Zhang }
267