xref: /petsc/src/mat/impls/aij/mpi/fdmpiaij.c (revision e0f5bfbec699682fa3e8b8532b1176849ea4e12a)
1 #include <../src/mat/impls/sell/mpi/mpisell.h>
2 #include <../src/mat/impls/aij/mpi/mpiaij.h>
3 #include <../src/mat/impls/baij/mpi/mpibaij.h>
4 #include <petsc/private/isimpl.h>
5 
6 PetscErrorCode MatFDColoringApply_BAIJ(Mat J, MatFDColoring coloring, Vec x1, void *sctx) {
7   PetscErrorCode (*f)(void *, Vec, Vec, void *) = (PetscErrorCode(*)(void *, Vec, Vec, void *))coloring->f;
8   PetscInt           k, cstart, cend, l, row, col, nz, spidx, i, j;
9   PetscScalar        dx = 0.0, *w3_array, *dy_i, *dy = coloring->dy;
10   PetscScalar       *vscale_array;
11   const PetscScalar *xx;
12   PetscReal          epsilon = coloring->error_rel, umin = coloring->umin, unorm;
13   Vec                w1 = coloring->w1, w2 = coloring->w2, w3, vscale = coloring->vscale;
14   void              *fctx  = coloring->fctx;
15   PetscInt           ctype = coloring->ctype, nxloc, nrows_k;
16   PetscScalar       *valaddr;
17   MatEntry          *Jentry  = coloring->matentry;
18   MatEntry2         *Jentry2 = coloring->matentry2;
19   const PetscInt     ncolors = coloring->ncolors, *ncolumns = coloring->ncolumns, *nrows = coloring->nrows;
20   PetscInt           bs = J->rmap->bs;
21 
22   PetscFunctionBegin;
23   PetscCall(VecBindToCPU(x1, PETSC_TRUE));
24   /* (1) Set w1 = F(x1) */
25   if (!coloring->fset) {
26     PetscCall(PetscLogEventBegin(MAT_FDColoringFunction, coloring, 0, 0, 0));
27     PetscCall((*f)(sctx, x1, w1, fctx));
28     PetscCall(PetscLogEventEnd(MAT_FDColoringFunction, coloring, 0, 0, 0));
29   } else {
30     coloring->fset = PETSC_FALSE;
31   }
32 
33   /* (2) Compute vscale = 1./dx - the local scale factors, including ghost points */
34   PetscCall(VecGetLocalSize(x1, &nxloc));
35   if (coloring->htype[0] == 'w') {
36     /* vscale = dx is a constant scalar */
37     PetscCall(VecNorm(x1, NORM_2, &unorm));
38     dx = 1.0 / (PetscSqrtReal(1.0 + unorm) * epsilon);
39   } else {
40     PetscCall(VecGetArrayRead(x1, &xx));
41     PetscCall(VecGetArray(vscale, &vscale_array));
42     for (col = 0; col < nxloc; col++) {
43       dx = xx[col];
44       if (PetscAbsScalar(dx) < umin) {
45         if (PetscRealPart(dx) >= 0.0) dx = umin;
46         else if (PetscRealPart(dx) < 0.0) dx = -umin;
47       }
48       dx *= epsilon;
49       vscale_array[col] = 1.0 / dx;
50     }
51     PetscCall(VecRestoreArrayRead(x1, &xx));
52     PetscCall(VecRestoreArray(vscale, &vscale_array));
53   }
54   if (ctype == IS_COLORING_GLOBAL && coloring->htype[0] == 'd') {
55     PetscCall(VecGhostUpdateBegin(vscale, INSERT_VALUES, SCATTER_FORWARD));
56     PetscCall(VecGhostUpdateEnd(vscale, INSERT_VALUES, SCATTER_FORWARD));
57   }
58 
59   /* (3) Loop over each color */
60   if (!coloring->w3) {
61     PetscCall(VecDuplicate(x1, &coloring->w3));
62     /* Vec is used intensively in particular piece of scalar CPU code; won't benefit from bouncing back and forth to the GPU */
63     PetscCall(VecBindToCPU(coloring->w3, PETSC_TRUE));
64   }
65   w3 = coloring->w3;
66 
67   PetscCall(VecGetOwnershipRange(x1, &cstart, &cend)); /* used by ghosted vscale */
68   if (vscale) PetscCall(VecGetArray(vscale, &vscale_array));
69   nz = 0;
70   for (k = 0; k < ncolors; k++) {
71     coloring->currentcolor = k;
72 
73     /*
74       (3-1) Loop over each column associated with color
75       adding the perturbation to the vector w3 = x1 + dx.
76     */
77     PetscCall(VecCopy(x1, w3));
78     dy_i = dy;
79     for (i = 0; i < bs; i++) { /* Loop over a block of columns */
80       PetscCall(VecGetArray(w3, &w3_array));
81       if (ctype == IS_COLORING_GLOBAL) w3_array -= cstart; /* shift pointer so global index can be used */
82       if (coloring->htype[0] == 'w') {
83         for (l = 0; l < ncolumns[k]; l++) {
84           col = i + bs * coloring->columns[k][l]; /* local column (in global index!) of the matrix we are probing for */
85           w3_array[col] += 1.0 / dx;
86           if (i) w3_array[col - 1] -= 1.0 / dx; /* resume original w3[col-1] */
87         }
88       } else {                  /* htype == 'ds' */
89         vscale_array -= cstart; /* shift pointer so global index can be used */
90         for (l = 0; l < ncolumns[k]; l++) {
91           col = i + bs * coloring->columns[k][l]; /* local column (in global index!) of the matrix we are probing for */
92           w3_array[col] += 1.0 / vscale_array[col];
93           if (i) w3_array[col - 1] -= 1.0 / vscale_array[col - 1]; /* resume original w3[col-1] */
94         }
95         vscale_array += cstart;
96       }
97       if (ctype == IS_COLORING_GLOBAL) w3_array += cstart;
98       PetscCall(VecRestoreArray(w3, &w3_array));
99 
100       /*
101        (3-2) Evaluate function at w3 = x1 + dx (here dx is a vector of perturbations)
102                            w2 = F(x1 + dx) - F(x1)
103        */
104       PetscCall(PetscLogEventBegin(MAT_FDColoringFunction, 0, 0, 0, 0));
105       PetscCall(VecPlaceArray(w2, dy_i)); /* place w2 to the array dy_i */
106       PetscCall((*f)(sctx, w3, w2, fctx));
107       PetscCall(PetscLogEventEnd(MAT_FDColoringFunction, 0, 0, 0, 0));
108       PetscCall(VecAXPY(w2, -1.0, w1));
109       PetscCall(VecResetArray(w2));
110       dy_i += nxloc; /* points to dy+i*nxloc */
111     }
112 
113     /*
114      (3-3) Loop over rows of vector, putting results into Jacobian matrix
115     */
116     nrows_k = nrows[k];
117     if (coloring->htype[0] == 'w') {
118       for (l = 0; l < nrows_k; l++) {
119         row     = bs * Jentry2[nz].row; /* local row index */
120         valaddr = Jentry2[nz++].valaddr;
121         spidx   = 0;
122         dy_i    = dy;
123         for (i = 0; i < bs; i++) {   /* column of the block */
124           for (j = 0; j < bs; j++) { /* row of the block */
125             valaddr[spidx++] = dy_i[row + j] * dx;
126           }
127           dy_i += nxloc; /* points to dy+i*nxloc */
128         }
129       }
130     } else { /* htype == 'ds' */
131       for (l = 0; l < nrows_k; l++) {
132         row     = bs * Jentry[nz].row; /* local row index */
133         col     = bs * Jentry[nz].col; /* local column index */
134         valaddr = Jentry[nz++].valaddr;
135         spidx   = 0;
136         dy_i    = dy;
137         for (i = 0; i < bs; i++) {   /* column of the block */
138           for (j = 0; j < bs; j++) { /* row of the block */
139             valaddr[spidx++] = dy_i[row + j] * vscale_array[col + i];
140           }
141           dy_i += nxloc; /* points to dy+i*nxloc */
142         }
143       }
144     }
145   }
146   PetscCall(MatAssemblyBegin(J, MAT_FINAL_ASSEMBLY));
147   PetscCall(MatAssemblyEnd(J, MAT_FINAL_ASSEMBLY));
148   if (vscale) PetscCall(VecRestoreArray(vscale, &vscale_array));
149 
150   coloring->currentcolor = -1;
151   PetscCall(VecBindToCPU(x1, PETSC_FALSE));
152   PetscFunctionReturn(0);
153 }
154 
155 /* this is declared PETSC_EXTERN because it is used by MatFDColoringUseDM() which is in the DM library */
156 PetscErrorCode MatFDColoringApply_AIJ(Mat J, MatFDColoring coloring, Vec x1, void *sctx) {
157   PetscErrorCode (*f)(void *, Vec, Vec, void *) = (PetscErrorCode(*)(void *, Vec, Vec, void *))coloring->f;
158   PetscInt           k, cstart, cend, l, row, col, nz;
159   PetscScalar        dx = 0.0, *y, *w3_array;
160   const PetscScalar *xx;
161   PetscScalar       *vscale_array;
162   PetscReal          epsilon = coloring->error_rel, umin = coloring->umin, unorm;
163   Vec                w1 = coloring->w1, w2 = coloring->w2, w3, vscale = coloring->vscale;
164   void              *fctx  = coloring->fctx;
165   ISColoringType     ctype = coloring->ctype;
166   PetscInt           nxloc, nrows_k;
167   MatEntry          *Jentry  = coloring->matentry;
168   MatEntry2         *Jentry2 = coloring->matentry2;
169   const PetscInt     ncolors = coloring->ncolors, *ncolumns = coloring->ncolumns, *nrows = coloring->nrows;
170   PetscBool          alreadyboundtocpu;
171 
172   PetscFunctionBegin;
173   PetscCall(VecBoundToCPU(x1, &alreadyboundtocpu));
174   PetscCall(VecBindToCPU(x1, PETSC_TRUE));
175   PetscCheck(!(ctype == IS_COLORING_LOCAL) || !(J->ops->fdcoloringapply == MatFDColoringApply_AIJ), PetscObjectComm((PetscObject)J), PETSC_ERR_SUP, "Must call MatColoringUseDM() with IS_COLORING_LOCAL");
176   /* (1) Set w1 = F(x1) */
177   if (!coloring->fset) {
178     PetscCall(PetscLogEventBegin(MAT_FDColoringFunction, 0, 0, 0, 0));
179     PetscCall((*f)(sctx, x1, w1, fctx));
180     PetscCall(PetscLogEventEnd(MAT_FDColoringFunction, 0, 0, 0, 0));
181   } else {
182     coloring->fset = PETSC_FALSE;
183   }
184 
185   /* (2) Compute vscale = 1./dx - the local scale factors, including ghost points */
186   if (coloring->htype[0] == 'w') {
187     /* vscale = 1./dx is a constant scalar */
188     PetscCall(VecNorm(x1, NORM_2, &unorm));
189     dx = 1.0 / (PetscSqrtReal(1.0 + unorm) * epsilon);
190   } else {
191     PetscCall(VecGetLocalSize(x1, &nxloc));
192     PetscCall(VecGetArrayRead(x1, &xx));
193     PetscCall(VecGetArray(vscale, &vscale_array));
194     for (col = 0; col < nxloc; col++) {
195       dx = xx[col];
196       if (PetscAbsScalar(dx) < umin) {
197         if (PetscRealPart(dx) >= 0.0) dx = umin;
198         else if (PetscRealPart(dx) < 0.0) dx = -umin;
199       }
200       dx *= epsilon;
201       vscale_array[col] = 1.0 / dx;
202     }
203     PetscCall(VecRestoreArrayRead(x1, &xx));
204     PetscCall(VecRestoreArray(vscale, &vscale_array));
205   }
206   if (ctype == IS_COLORING_GLOBAL && coloring->htype[0] == 'd') {
207     PetscCall(VecGhostUpdateBegin(vscale, INSERT_VALUES, SCATTER_FORWARD));
208     PetscCall(VecGhostUpdateEnd(vscale, INSERT_VALUES, SCATTER_FORWARD));
209   }
210 
211   /* (3) Loop over each color */
212   if (!coloring->w3) { PetscCall(VecDuplicate(x1, &coloring->w3)); }
213   w3 = coloring->w3;
214 
215   PetscCall(VecGetOwnershipRange(x1, &cstart, &cend)); /* used by ghosted vscale */
216   if (vscale) PetscCall(VecGetArray(vscale, &vscale_array));
217   nz = 0;
218 
219   if (coloring->bcols > 1) { /* use blocked insertion of Jentry */
220     PetscInt     i, m = J->rmap->n, nbcols, bcols = coloring->bcols;
221     PetscScalar *dy = coloring->dy, *dy_k;
222 
223     nbcols = 0;
224     for (k = 0; k < ncolors; k += bcols) {
225       /*
226        (3-1) Loop over each column associated with color
227        adding the perturbation to the vector w3 = x1 + dx.
228        */
229 
230       dy_k = dy;
231       if (k + bcols > ncolors) bcols = ncolors - k;
232       for (i = 0; i < bcols; i++) {
233         coloring->currentcolor = k + i;
234 
235         PetscCall(VecCopy(x1, w3));
236         PetscCall(VecGetArray(w3, &w3_array));
237         if (ctype == IS_COLORING_GLOBAL) w3_array -= cstart; /* shift pointer so global index can be used */
238         if (coloring->htype[0] == 'w') {
239           for (l = 0; l < ncolumns[k + i]; l++) {
240             col = coloring->columns[k + i][l]; /* local column (in global index!) of the matrix we are probing for */
241             w3_array[col] += 1.0 / dx;
242           }
243         } else {                  /* htype == 'ds' */
244           vscale_array -= cstart; /* shift pointer so global index can be used */
245           for (l = 0; l < ncolumns[k + i]; l++) {
246             col = coloring->columns[k + i][l]; /* local column (in global index!) of the matrix we are probing for */
247             w3_array[col] += 1.0 / vscale_array[col];
248           }
249           vscale_array += cstart;
250         }
251         if (ctype == IS_COLORING_GLOBAL) w3_array += cstart;
252         PetscCall(VecRestoreArray(w3, &w3_array));
253 
254         /*
255          (3-2) Evaluate function at w3 = x1 + dx (here dx is a vector of perturbations)
256                            w2 = F(x1 + dx) - F(x1)
257          */
258         PetscCall(PetscLogEventBegin(MAT_FDColoringFunction, 0, 0, 0, 0));
259         PetscCall(VecPlaceArray(w2, dy_k)); /* place w2 to the array dy_i */
260         PetscCall((*f)(sctx, w3, w2, fctx));
261         PetscCall(PetscLogEventEnd(MAT_FDColoringFunction, 0, 0, 0, 0));
262         PetscCall(VecAXPY(w2, -1.0, w1));
263         PetscCall(VecResetArray(w2));
264         dy_k += m; /* points to dy+i*nxloc */
265       }
266 
267       /*
268        (3-3) Loop over block rows of vector, putting results into Jacobian matrix
269        */
270       nrows_k = nrows[nbcols++];
271 
272       if (coloring->htype[0] == 'w') {
273         for (l = 0; l < nrows_k; l++) {
274           row = Jentry2[nz].row; /* local row index */
275                                  /* The 'useless' ifdef is due to a bug in NVIDIA nvc 21.11, which triggers a segfault on this line. We write it in
276              another way, and it seems work. See https://lists.mcs.anl.gov/pipermail/petsc-users/2021-December/045158.html
277            */
278 #if defined(PETSC_USE_COMPLEX)
279           PetscScalar *tmp = Jentry2[nz].valaddr;
280           *tmp             = dy[row] * dx;
281 #else
282           *(Jentry2[nz].valaddr) = dy[row] * dx;
283 #endif
284           nz++;
285         }
286       } else { /* htype == 'ds' */
287         for (l = 0; l < nrows_k; l++) {
288           row = Jentry[nz].row; /* local row index */
289 #if defined(PETSC_USE_COMPLEX)  /* See https://lists.mcs.anl.gov/pipermail/petsc-users/2021-December/045158.html */
290           PetscScalar *tmp = Jentry[nz].valaddr;
291           *tmp             = dy[row] * vscale_array[Jentry[nz].col];
292 #else
293           *(Jentry[nz].valaddr)  = dy[row] * vscale_array[Jentry[nz].col];
294 #endif
295           nz++;
296         }
297       }
298     }
299   } else { /* bcols == 1 */
300     for (k = 0; k < ncolors; k++) {
301       coloring->currentcolor = k;
302 
303       /*
304        (3-1) Loop over each column associated with color
305        adding the perturbation to the vector w3 = x1 + dx.
306        */
307       PetscCall(VecCopy(x1, w3));
308       PetscCall(VecGetArray(w3, &w3_array));
309       if (ctype == IS_COLORING_GLOBAL) w3_array -= cstart; /* shift pointer so global index can be used */
310       if (coloring->htype[0] == 'w') {
311         for (l = 0; l < ncolumns[k]; l++) {
312           col = coloring->columns[k][l]; /* local column (in global index!) of the matrix we are probing for */
313           w3_array[col] += 1.0 / dx;
314         }
315       } else {                  /* htype == 'ds' */
316         vscale_array -= cstart; /* shift pointer so global index can be used */
317         for (l = 0; l < ncolumns[k]; l++) {
318           col = coloring->columns[k][l]; /* local column (in global index!) of the matrix we are probing for */
319           w3_array[col] += 1.0 / vscale_array[col];
320         }
321         vscale_array += cstart;
322       }
323       if (ctype == IS_COLORING_GLOBAL) w3_array += cstart;
324       PetscCall(VecRestoreArray(w3, &w3_array));
325 
326       /*
327        (3-2) Evaluate function at w3 = x1 + dx (here dx is a vector of perturbations)
328                            w2 = F(x1 + dx) - F(x1)
329        */
330       PetscCall(PetscLogEventBegin(MAT_FDColoringFunction, 0, 0, 0, 0));
331       PetscCall((*f)(sctx, w3, w2, fctx));
332       PetscCall(PetscLogEventEnd(MAT_FDColoringFunction, 0, 0, 0, 0));
333       PetscCall(VecAXPY(w2, -1.0, w1));
334 
335       /*
336        (3-3) Loop over rows of vector, putting results into Jacobian matrix
337        */
338       nrows_k = nrows[k];
339       PetscCall(VecGetArray(w2, &y));
340       if (coloring->htype[0] == 'w') {
341         for (l = 0; l < nrows_k; l++) {
342           row = Jentry2[nz].row; /* local row index */
343 #if defined(PETSC_USE_COMPLEX)   /* See https://lists.mcs.anl.gov/pipermail/petsc-users/2021-December/045158.html */
344           PetscScalar *tmp = Jentry2[nz].valaddr;
345           *tmp             = y[row] * dx;
346 #else
347           *(Jentry2[nz].valaddr) = y[row] * dx;
348 #endif
349           nz++;
350         }
351       } else { /* htype == 'ds' */
352         for (l = 0; l < nrows_k; l++) {
353           row = Jentry[nz].row; /* local row index */
354 #if defined(PETSC_USE_COMPLEX)  /* See https://lists.mcs.anl.gov/pipermail/petsc-users/2021-December/045158.html */
355           PetscScalar *tmp = Jentry[nz].valaddr;
356           *tmp             = y[row] * vscale_array[Jentry[nz].col];
357 #else
358           *(Jentry[nz].valaddr)  = y[row] * vscale_array[Jentry[nz].col];
359 #endif
360           nz++;
361         }
362       }
363       PetscCall(VecRestoreArray(w2, &y));
364     }
365   }
366 
367 #if defined(PETSC_HAVE_VIENNACL) || defined(PETSC_HAVE_CUDA)
368   if (J->offloadmask != PETSC_OFFLOAD_UNALLOCATED) J->offloadmask = PETSC_OFFLOAD_CPU;
369 #endif
370   PetscCall(MatAssemblyBegin(J, MAT_FINAL_ASSEMBLY));
371   PetscCall(MatAssemblyEnd(J, MAT_FINAL_ASSEMBLY));
372   if (vscale) PetscCall(VecRestoreArray(vscale, &vscale_array));
373   coloring->currentcolor = -1;
374   if (!alreadyboundtocpu) PetscCall(VecBindToCPU(x1, PETSC_FALSE));
375   PetscFunctionReturn(0);
376 }
377 
378 PetscErrorCode MatFDColoringSetUp_MPIXAIJ(Mat mat, ISColoring iscoloring, MatFDColoring c) {
379   PetscMPIInt            size, *ncolsonproc, *disp, nn;
380   PetscInt               i, n, nrows, nrows_i, j, k, m, ncols, col, *rowhit, cstart, cend, colb;
381   const PetscInt        *is, *A_ci, *A_cj, *B_ci, *B_cj, *row = NULL, *ltog = NULL;
382   PetscInt               nis = iscoloring->n, nctot, *cols, tmp = 0;
383   ISLocalToGlobalMapping map   = mat->cmap->mapping;
384   PetscInt               ctype = c->ctype, *spidxA, *spidxB, nz, bs, bs2, spidx;
385   Mat                    A, B;
386   PetscScalar           *A_val, *B_val, **valaddrhit;
387   MatEntry              *Jentry;
388   MatEntry2             *Jentry2;
389   PetscBool              isBAIJ, isSELL;
390   PetscInt               bcols = c->bcols;
391 #if defined(PETSC_USE_CTABLE)
392   PetscTable colmap = NULL;
393 #else
394   PetscInt *colmap = NULL;      /* local col number of off-diag col */
395 #endif
396 
397   PetscFunctionBegin;
398   if (ctype == IS_COLORING_LOCAL) {
399     PetscCheck(map, PetscObjectComm((PetscObject)mat), PETSC_ERR_ARG_INCOMP, "When using ghosted differencing matrix must have local to global mapping provided with MatSetLocalToGlobalMapping");
400     PetscCall(ISLocalToGlobalMappingGetIndices(map, &ltog));
401   }
402 
403   PetscCall(MatGetBlockSize(mat, &bs));
404   PetscCall(PetscObjectBaseTypeCompare((PetscObject)mat, MATMPIBAIJ, &isBAIJ));
405   PetscCall(PetscObjectTypeCompare((PetscObject)mat, MATMPISELL, &isSELL));
406   if (isBAIJ) {
407     Mat_MPIBAIJ *baij = (Mat_MPIBAIJ *)mat->data;
408     Mat_SeqBAIJ *spA, *spB;
409     A     = baij->A;
410     spA   = (Mat_SeqBAIJ *)A->data;
411     A_val = spA->a;
412     B     = baij->B;
413     spB   = (Mat_SeqBAIJ *)B->data;
414     B_val = spB->a;
415     nz    = spA->nz + spB->nz; /* total nonzero entries of mat */
416     if (!baij->colmap) PetscCall(MatCreateColmap_MPIBAIJ_Private(mat));
417     colmap = baij->colmap;
418     PetscCall(MatGetColumnIJ_SeqBAIJ_Color(A, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &A_ci, &A_cj, &spidxA, NULL));
419     PetscCall(MatGetColumnIJ_SeqBAIJ_Color(B, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &B_ci, &B_cj, &spidxB, NULL));
420 
421     if (ctype == IS_COLORING_GLOBAL && c->htype[0] == 'd') { /* create vscale for storing dx */
422       PetscInt *garray;
423       PetscCall(PetscMalloc1(B->cmap->n, &garray));
424       for (i = 0; i < baij->B->cmap->n / bs; i++) {
425         for (j = 0; j < bs; j++) garray[i * bs + j] = bs * baij->garray[i] + j;
426       }
427       PetscCall(VecCreateGhost(PetscObjectComm((PetscObject)mat), mat->cmap->n, PETSC_DETERMINE, B->cmap->n, garray, &c->vscale));
428       PetscCall(VecBindToCPU(c->vscale, PETSC_TRUE));
429       PetscCall(PetscFree(garray));
430     }
431   } else if (isSELL) {
432     Mat_MPISELL *sell = (Mat_MPISELL *)mat->data;
433     Mat_SeqSELL *spA, *spB;
434     A     = sell->A;
435     spA   = (Mat_SeqSELL *)A->data;
436     A_val = spA->val;
437     B     = sell->B;
438     spB   = (Mat_SeqSELL *)B->data;
439     B_val = spB->val;
440     nz    = spA->nz + spB->nz; /* total nonzero entries of mat */
441     if (!sell->colmap) {
442       /* Allow access to data structures of local part of matrix
443        - creates aij->colmap which maps global column number to local number in part B */
444       PetscCall(MatCreateColmap_MPISELL_Private(mat));
445     }
446     colmap = sell->colmap;
447     PetscCall(MatGetColumnIJ_SeqSELL_Color(A, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &A_ci, &A_cj, &spidxA, NULL));
448     PetscCall(MatGetColumnIJ_SeqSELL_Color(B, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &B_ci, &B_cj, &spidxB, NULL));
449 
450     bs = 1; /* only bs=1 is supported for non MPIBAIJ matrix */
451 
452     if (ctype == IS_COLORING_GLOBAL && c->htype[0] == 'd') { /* create vscale for storing dx */
453       PetscCall(VecCreateGhost(PetscObjectComm((PetscObject)mat), mat->cmap->n, PETSC_DETERMINE, B->cmap->n, sell->garray, &c->vscale));
454       PetscCall(VecBindToCPU(c->vscale, PETSC_TRUE));
455     }
456   } else {
457     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)mat->data;
458     Mat_SeqAIJ *spA, *spB;
459     A     = aij->A;
460     spA   = (Mat_SeqAIJ *)A->data;
461     A_val = spA->a;
462     B     = aij->B;
463     spB   = (Mat_SeqAIJ *)B->data;
464     B_val = spB->a;
465     nz    = spA->nz + spB->nz; /* total nonzero entries of mat */
466     if (!aij->colmap) {
467       /* Allow access to data structures of local part of matrix
468        - creates aij->colmap which maps global column number to local number in part B */
469       PetscCall(MatCreateColmap_MPIAIJ_Private(mat));
470     }
471     colmap = aij->colmap;
472     PetscCall(MatGetColumnIJ_SeqAIJ_Color(A, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &A_ci, &A_cj, &spidxA, NULL));
473     PetscCall(MatGetColumnIJ_SeqAIJ_Color(B, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &B_ci, &B_cj, &spidxB, NULL));
474 
475     bs = 1; /* only bs=1 is supported for non MPIBAIJ matrix */
476 
477     if (ctype == IS_COLORING_GLOBAL && c->htype[0] == 'd') { /* create vscale for storing dx */
478       PetscCall(VecCreateGhost(PetscObjectComm((PetscObject)mat), mat->cmap->n, PETSC_DETERMINE, B->cmap->n, aij->garray, &c->vscale));
479       PetscCall(VecBindToCPU(c->vscale, PETSC_TRUE));
480     }
481   }
482 
483   m      = mat->rmap->n / bs;
484   cstart = mat->cmap->rstart / bs;
485   cend   = mat->cmap->rend / bs;
486 
487   PetscCall(PetscMalloc2(nis, &c->ncolumns, nis, &c->columns));
488   PetscCall(PetscMalloc1(nis, &c->nrows));
489 
490   if (c->htype[0] == 'd') {
491     PetscCall(PetscMalloc1(nz, &Jentry));
492     c->matentry = Jentry;
493   } else if (c->htype[0] == 'w') {
494     PetscCall(PetscMalloc1(nz, &Jentry2));
495     c->matentry2 = Jentry2;
496   } else SETERRQ(PetscObjectComm((PetscObject)mat), PETSC_ERR_SUP, "htype is not supported");
497 
498   PetscCall(PetscMalloc2(m + 1, &rowhit, m + 1, &valaddrhit));
499   nz = 0;
500   PetscCall(ISColoringGetIS(iscoloring, PETSC_OWN_POINTER, PETSC_IGNORE, &c->isa));
501 
502   if (ctype == IS_COLORING_GLOBAL) {
503     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
504     PetscCall(PetscMalloc2(size, &ncolsonproc, size, &disp));
505   }
506 
507   for (i = 0; i < nis; i++) { /* for each local color */
508     PetscCall(ISGetLocalSize(c->isa[i], &n));
509     PetscCall(ISGetIndices(c->isa[i], &is));
510 
511     c->ncolumns[i] = n; /* local number of columns of this color on this process */
512     c->columns[i]  = (PetscInt *)is;
513 
514     if (ctype == IS_COLORING_GLOBAL) {
515       /* Determine nctot, the total (parallel) number of columns of this color */
516       /* ncolsonproc[j]: local ncolumns on proc[j] of this color */
517       PetscCall(PetscMPIIntCast(n, &nn));
518       PetscCallMPI(MPI_Allgather(&nn, 1, MPI_INT, ncolsonproc, 1, MPI_INT, PetscObjectComm((PetscObject)mat)));
519       nctot = 0;
520       for (j = 0; j < size; j++) nctot += ncolsonproc[j];
521       if (!nctot) PetscCall(PetscInfo(mat, "Coloring of matrix has some unneeded colors with no corresponding rows\n"));
522 
523       disp[0] = 0;
524       for (j = 1; j < size; j++) disp[j] = disp[j - 1] + ncolsonproc[j - 1];
525 
526       /* Get cols, the complete list of columns for this color on each process */
527       PetscCall(PetscMalloc1(nctot + 1, &cols));
528       PetscCallMPI(MPI_Allgatherv((void *)is, n, MPIU_INT, cols, ncolsonproc, disp, MPIU_INT, PetscObjectComm((PetscObject)mat)));
529     } else if (ctype == IS_COLORING_LOCAL) {
530       /* Determine local number of columns of this color on this process, including ghost points */
531       nctot = n;
532       cols  = (PetscInt *)is;
533     } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Not provided for this MatFDColoring type");
534 
535     /* Mark all rows affect by these columns */
536     PetscCall(PetscArrayzero(rowhit, m));
537     bs2     = bs * bs;
538     nrows_i = 0;
539     for (j = 0; j < nctot; j++) { /* loop over columns*/
540       if (ctype == IS_COLORING_LOCAL) {
541         col = ltog[cols[j]];
542       } else {
543         col = cols[j];
544       }
545       if (col >= cstart && col < cend) { /* column is in A, diagonal block of mat */
546         tmp   = A_ci[col - cstart];
547         row   = A_cj + tmp;
548         nrows = A_ci[col - cstart + 1] - tmp;
549         nrows_i += nrows;
550 
551         /* loop over columns of A marking them in rowhit */
552         for (k = 0; k < nrows; k++) {
553           /* set valaddrhit for part A */
554           spidx            = bs2 * spidxA[tmp + k];
555           valaddrhit[*row] = &A_val[spidx];
556           rowhit[*row++]   = col - cstart + 1; /* local column index */
557         }
558       } else { /* column is in B, off-diagonal block of mat */
559 #if defined(PETSC_USE_CTABLE)
560         PetscCall(PetscTableFind(colmap, col + 1, &colb));
561         colb--;
562 #else
563         colb = colmap[col] - 1; /* local column index */
564 #endif
565         if (colb == -1) {
566           nrows = 0;
567         } else {
568           colb  = colb / bs;
569           tmp   = B_ci[colb];
570           row   = B_cj + tmp;
571           nrows = B_ci[colb + 1] - tmp;
572         }
573         nrows_i += nrows;
574         /* loop over columns of B marking them in rowhit */
575         for (k = 0; k < nrows; k++) {
576           /* set valaddrhit for part B */
577           spidx            = bs2 * spidxB[tmp + k];
578           valaddrhit[*row] = &B_val[spidx];
579           rowhit[*row++]   = colb + 1 + cend - cstart; /* local column index */
580         }
581       }
582     }
583     c->nrows[i] = nrows_i;
584 
585     if (c->htype[0] == 'd') {
586       for (j = 0; j < m; j++) {
587         if (rowhit[j]) {
588           Jentry[nz].row     = j;             /* local row index */
589           Jentry[nz].col     = rowhit[j] - 1; /* local column index */
590           Jentry[nz].valaddr = valaddrhit[j]; /* address of mat value for this entry */
591           nz++;
592         }
593       }
594     } else { /* c->htype == 'wp' */
595       for (j = 0; j < m; j++) {
596         if (rowhit[j]) {
597           Jentry2[nz].row     = j;             /* local row index */
598           Jentry2[nz].valaddr = valaddrhit[j]; /* address of mat value for this entry */
599           nz++;
600         }
601       }
602     }
603     if (ctype == IS_COLORING_GLOBAL) PetscCall(PetscFree(cols));
604   }
605   if (ctype == IS_COLORING_GLOBAL) PetscCall(PetscFree2(ncolsonproc, disp));
606 
607   if (bcols > 1) { /* reorder Jentry for faster MatFDColoringApply() */
608     PetscCall(MatFDColoringSetUpBlocked_AIJ_Private(mat, c, nz));
609   }
610 
611   if (isBAIJ) {
612     PetscCall(MatRestoreColumnIJ_SeqBAIJ_Color(A, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &A_ci, &A_cj, &spidxA, NULL));
613     PetscCall(MatRestoreColumnIJ_SeqBAIJ_Color(B, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &B_ci, &B_cj, &spidxB, NULL));
614     PetscCall(PetscMalloc1(bs * mat->rmap->n, &c->dy));
615   } else if (isSELL) {
616     PetscCall(MatRestoreColumnIJ_SeqSELL_Color(A, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &A_ci, &A_cj, &spidxA, NULL));
617     PetscCall(MatRestoreColumnIJ_SeqSELL_Color(B, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &B_ci, &B_cj, &spidxB, NULL));
618   } else {
619     PetscCall(MatRestoreColumnIJ_SeqAIJ_Color(A, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &A_ci, &A_cj, &spidxA, NULL));
620     PetscCall(MatRestoreColumnIJ_SeqAIJ_Color(B, 0, PETSC_FALSE, PETSC_FALSE, &ncols, &B_ci, &B_cj, &spidxB, NULL));
621   }
622 
623   PetscCall(ISColoringRestoreIS(iscoloring, PETSC_OWN_POINTER, &c->isa));
624   PetscCall(PetscFree2(rowhit, valaddrhit));
625 
626   if (ctype == IS_COLORING_LOCAL) PetscCall(ISLocalToGlobalMappingRestoreIndices(map, &ltog));
627   PetscCall(PetscInfo(c, "ncolors %" PetscInt_FMT ", brows %" PetscInt_FMT " and bcols %" PetscInt_FMT " are used.\n", c->ncolors, c->brows, c->bcols));
628   PetscFunctionReturn(0);
629 }
630 
631 PetscErrorCode MatFDColoringCreate_MPIXAIJ(Mat mat, ISColoring iscoloring, MatFDColoring c) {
632   PetscInt  bs, nis = iscoloring->n, m = mat->rmap->n;
633   PetscBool isBAIJ, isSELL;
634 
635   PetscFunctionBegin;
636   /* set default brows and bcols for speedup inserting the dense matrix into sparse Jacobian;
637    bcols is chosen s.t. dy-array takes 50% of memory space as mat */
638   PetscCall(MatGetBlockSize(mat, &bs));
639   PetscCall(PetscObjectBaseTypeCompare((PetscObject)mat, MATMPIBAIJ, &isBAIJ));
640   PetscCall(PetscObjectTypeCompare((PetscObject)mat, MATMPISELL, &isSELL));
641   if (isBAIJ || m == 0) {
642     c->brows = m;
643     c->bcols = 1;
644   } else if (isSELL) {
645     /* bcols is chosen s.t. dy-array takes 50% of local memory space as mat */
646     Mat_MPISELL *sell = (Mat_MPISELL *)mat->data;
647     Mat_SeqSELL *spA, *spB;
648     Mat          A, B;
649     PetscInt     nz, brows, bcols;
650     PetscReal    mem;
651 
652     bs = 1; /* only bs=1 is supported for MPISELL matrix */
653 
654     A     = sell->A;
655     spA   = (Mat_SeqSELL *)A->data;
656     B     = sell->B;
657     spB   = (Mat_SeqSELL *)B->data;
658     nz    = spA->nz + spB->nz; /* total local nonzero entries of mat */
659     mem   = nz * (sizeof(PetscScalar) + sizeof(PetscInt)) + 3 * m * sizeof(PetscInt);
660     bcols = (PetscInt)(0.5 * mem / (m * sizeof(PetscScalar)));
661     brows = 1000 / bcols;
662     if (bcols > nis) bcols = nis;
663     if (brows == 0 || brows > m) brows = m;
664     c->brows = brows;
665     c->bcols = bcols;
666   } else { /* mpiaij matrix */
667     /* bcols is chosen s.t. dy-array takes 50% of local memory space as mat */
668     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)mat->data;
669     Mat_SeqAIJ *spA, *spB;
670     Mat         A, B;
671     PetscInt    nz, brows, bcols;
672     PetscReal   mem;
673 
674     bs = 1; /* only bs=1 is supported for MPIAIJ matrix */
675 
676     A     = aij->A;
677     spA   = (Mat_SeqAIJ *)A->data;
678     B     = aij->B;
679     spB   = (Mat_SeqAIJ *)B->data;
680     nz    = spA->nz + spB->nz; /* total local nonzero entries of mat */
681     mem   = nz * (sizeof(PetscScalar) + sizeof(PetscInt)) + 3 * m * sizeof(PetscInt);
682     bcols = (PetscInt)(0.5 * mem / (m * sizeof(PetscScalar)));
683     brows = 1000 / bcols;
684     if (bcols > nis) bcols = nis;
685     if (brows == 0 || brows > m) brows = m;
686     c->brows = brows;
687     c->bcols = bcols;
688   }
689 
690   c->M       = mat->rmap->N / bs; /* set the global rows and columns and local rows */
691   c->N       = mat->cmap->N / bs;
692   c->m       = mat->rmap->n / bs;
693   c->rstart  = mat->rmap->rstart / bs;
694   c->ncolors = nis;
695   PetscFunctionReturn(0);
696 }
697 
698 /*@C
699 
700     MatFDColoringSetValues - takes a matrix in compressed color format and enters the matrix into a PETSc `Mat`
701 
702    Collective on J
703 
704    Input Parameters:
705 +    J - the sparse matrix
706 .    coloring - created with `MatFDColoringCreate()` and a local coloring
707 -    y - column major storage of matrix values with one color of values per column, the number of rows of y should match
708          the number of local rows of J and the number of columns is the number of colors.
709 
710    Level: intermediate
711 
712    Notes: the matrix in compressed color format may come from an automatic differentiation code
713 
714    The code will be slightly faster if `MatFDColoringSetBlockSize`(coloring,`PETSC_DEFAULT`,nc); is called immediately after creating the coloring
715 
716 .seealso: `MatFDColoringCreate()`, `ISColoring`, `ISColoringCreate()`, `ISColoringSetType()`, `IS_COLORING_LOCAL`, `MatFDColoringSetBlockSize()`
717 @*/
718 PetscErrorCode MatFDColoringSetValues(Mat J, MatFDColoring coloring, const PetscScalar *y) {
719   MatEntry2      *Jentry2;
720   PetscInt        row, i, nrows_k, l, ncolors, nz = 0, bcols, nbcols = 0;
721   const PetscInt *nrows;
722   PetscBool       eq;
723 
724   PetscFunctionBegin;
725   PetscValidHeaderSpecific(J, MAT_CLASSID, 1);
726   PetscValidHeaderSpecific(coloring, MAT_FDCOLORING_CLASSID, 2);
727   PetscCall(PetscObjectCompareId((PetscObject)J, coloring->matid, &eq));
728   PetscCheck(eq, PetscObjectComm((PetscObject)J), PETSC_ERR_ARG_WRONG, "Matrix used with MatFDColoringSetValues() must be that used with MatFDColoringCreate()");
729   Jentry2 = coloring->matentry2;
730   nrows   = coloring->nrows;
731   ncolors = coloring->ncolors;
732   bcols   = coloring->bcols;
733 
734   for (i = 0; i < ncolors; i += bcols) {
735     nrows_k = nrows[nbcols++];
736     for (l = 0; l < nrows_k; l++) {
737       row                      = Jentry2[nz].row; /* local row index */
738       *(Jentry2[nz++].valaddr) = y[row];
739     }
740     y += bcols * coloring->m;
741   }
742   PetscCall(MatAssemblyBegin(J, MAT_FINAL_ASSEMBLY));
743   PetscCall(MatAssemblyEnd(J, MAT_FINAL_ASSEMBLY));
744   PetscFunctionReturn(0);
745 }
746