xref: /petsc/src/mat/impls/aij/mpi/strumpack/strumpack.c (revision ad0c5e61da9dd0152d1d25cbddc60ed07d23358c)
1 #include <../src/mat/impls/aij/seq/aij.h>
2 #include <../src/mat/impls/aij/mpi/mpiaij.h>
3 #include <StrumpackSparseSolver.h>
4 
5 /*
6   These are only relevant for MATMPIAIJ, not for MATSEQAIJ.
7     REPLICATED  - STRUMPACK expects the entire sparse matrix and right-hand side on every process.
8     DISTRIBUTED - STRUMPACK expects the sparse matrix and right-hand side to be distributed across the entire MPI communicator.
9 */
10 typedef enum {REPLICATED, DISTRIBUTED} STRUMPACK_MatInputMode;
11 const char *STRUMPACK_MatInputModes[] = {"REPLICATED","DISTRIBUTED","STRUMPACK_MatInputMode","PETSC_",0};
12 
13 typedef struct {
14   STRUMPACK_SparseSolver S;
15   STRUMPACK_MatInputMode MatInputMode;
16 } STRUMPACK_data;
17 
18 
19 #undef __FUNCT__
20 #define __FUNCT__ "MatGetDiagonal_STRUMPACK"
21 static PetscErrorCode MatGetDiagonal_STRUMPACK(Mat A,Vec v)
22 {
23   PetscFunctionBegin;
24   SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_SUP,"Mat type: STRUMPACK factor");
25   PetscFunctionReturn(0);
26 }
27 
28 #undef __FUNCT__
29 #define __FUNCT__ "MatDestroy_STRUMPACK"
30 static PetscErrorCode MatDestroy_STRUMPACK(Mat A)
31 {
32   STRUMPACK_data *sp = (STRUMPACK_data*)A->spptr;
33   PetscErrorCode ierr;
34   PetscBool      flg;
35 
36   PetscFunctionBegin;
37   /* Deallocate STRUMPACK storage */
38   PetscStackCall("STRUMPACK_destroy",STRUMPACK_destroy(&(sp->S)));
39   ierr = PetscFree(A->spptr);CHKERRQ(ierr);
40   ierr = PetscObjectTypeCompare((PetscObject)A,MATSEQAIJ,&flg);CHKERRQ(ierr);
41   if (flg) {
42     ierr = MatDestroy_SeqAIJ(A);CHKERRQ(ierr);
43   } else {
44     ierr = MatDestroy_MPIAIJ(A);CHKERRQ(ierr);
45   }
46   PetscFunctionReturn(0);
47 }
48 
49 #undef __FUNCT__
50 #define __FUNCT__ "MatSolve_STRUMPACK"
51 static PetscErrorCode MatSolve_STRUMPACK(Mat A,Vec b_mpi,Vec x)
52 {
53   STRUMPACK_data        *sp = (STRUMPACK_data*)A->spptr;
54   STRUMPACK_RETURN_CODE sp_err;
55   PetscErrorCode        ierr;
56   PetscMPIInt           size;
57   PetscInt              N=A->cmap->N;
58   const PetscScalar     *bptr;
59   PetscScalar           *xptr;
60   Vec                   x_seq,b_seq;
61   IS                    iden;
62   VecScatter            scat;
63 
64   PetscFunctionBegin;
65   ierr = MPI_Comm_size(PetscObjectComm((PetscObject)A),&size);CHKERRQ(ierr);
66   if (size > 1 && sp->MatInputMode == REPLICATED) {
67     ierr = VecCreateSeq(PETSC_COMM_SELF,N,&x_seq);CHKERRQ(ierr);
68     ierr = VecGetArray(x_seq,&xptr);CHKERRQ(ierr);
69     /* replicated mat input, convert b to b_seq */
70     ierr = VecCreateSeq(PETSC_COMM_SELF,N,&b_seq);CHKERRQ(ierr);
71     ierr = ISCreateStride(PETSC_COMM_SELF,N,0,1,&iden);CHKERRQ(ierr);
72     ierr = VecScatterCreate(b_mpi,iden,b_seq,iden,&scat);CHKERRQ(ierr);
73     ierr = ISDestroy(&iden);CHKERRQ(ierr);
74     ierr = VecScatterBegin(scat,b_mpi,b_seq,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
75     ierr = VecScatterEnd(scat,b_mpi,b_seq,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr);
76     ierr = VecGetArrayRead(b_seq,&bptr);CHKERRQ(ierr);
77   } else { /* size==1 || distributed mat input */
78     ierr = VecGetArray(x,&xptr);CHKERRQ(ierr);
79     ierr = VecGetArrayRead(b_mpi,&bptr);CHKERRQ(ierr);
80   }
81 
82   PetscStackCall("STRUMPACK_solve",sp_err = STRUMPACK_solve(sp->S,(PetscScalar*)bptr,xptr,0));
83 
84   if (sp_err != STRUMPACK_SUCCESS) {
85     if (sp_err == STRUMPACK_MATRIX_NOT_SET)        SETERRQ(PETSC_COMM_SELF,PETSC_ERR_LIB,"STRUMPACK error: matrix was not set");
86     else if (sp_err == STRUMPACK_REORDERING_ERROR) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_LIB,"STRUMPACK error: matrix reordering failed");
87     else                                           SETERRQ(PETSC_COMM_SELF,PETSC_ERR_LIB,"STRUMPACK error: solve failed");
88   }
89 
90   if (size > 1 && sp->MatInputMode == REPLICATED) {
91     ierr = VecRestoreArrayRead(b_seq,&bptr);CHKERRQ(ierr);
92     ierr = VecDestroy(&b_seq);CHKERRQ(ierr);
93     /* convert seq x to mpi x */
94     ierr = VecRestoreArray(x_seq,&xptr);CHKERRQ(ierr);
95     ierr = VecScatterBegin(scat,x_seq,x,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
96     ierr = VecScatterEnd(scat,x_seq,x,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr);
97     ierr = VecScatterDestroy(&scat);CHKERRQ(ierr);
98     ierr = VecDestroy(&x_seq);CHKERRQ(ierr);
99   } else {
100     ierr = VecRestoreArray(x,&xptr);CHKERRQ(ierr);
101     ierr = VecRestoreArrayRead(b_mpi,&bptr);CHKERRQ(ierr);
102   }
103 
104   PetscFunctionReturn(0);
105 }
106 
107 #undef __FUNCT__
108 #define __FUNCT__ "MatMatSolve_STRUMPACK"
109 static PetscErrorCode MatMatSolve_STRUMPACK(Mat A,Mat B_mpi,Mat X)
110 {
111   PetscErrorCode   ierr;
112   PetscBool        flg;
113 
114   PetscFunctionBegin;
115   ierr = PetscObjectTypeCompareAny((PetscObject)B_mpi,&flg,MATSEQDENSE,MATMPIDENSE,NULL);CHKERRQ(ierr);
116   if (!flg) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONG,"Matrix B must be MATDENSE matrix");
117   ierr = PetscObjectTypeCompareAny((PetscObject)X,&flg,MATSEQDENSE,MATMPIDENSE,NULL);CHKERRQ(ierr);
118   if (!flg) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONG,"Matrix X must be MATDENSE matrix");
119   SETERRQ(PETSC_COMM_SELF,PETSC_ERR_SUP,"MatMatSolve_STRUMPACK() is not implemented yet");
120   PetscFunctionReturn(0);
121 }
122 
123 #undef __FUNCT__
124 #define __FUNCT__ "MatFactorInfo_STRUMPACK"
125 static PetscErrorCode MatFactorInfo_STRUMPACK(Mat A,PetscViewer viewer)
126 {
127   PetscErrorCode  ierr;
128 
129   PetscFunctionBegin;
130   /* check if matrix is strumpack type */
131   if (A->ops->solve != MatSolve_STRUMPACK) PetscFunctionReturn(0);
132   ierr = PetscViewerASCIIPrintf(viewer,"STRUMPACK sparse solver!\n");CHKERRQ(ierr);
133   PetscFunctionReturn(0);
134 }
135 
136 #undef __FUNCT__
137 #define __FUNCT__ "MatView_STRUMPACK"
138 static PetscErrorCode MatView_STRUMPACK(Mat A,PetscViewer viewer)
139 {
140   PetscErrorCode    ierr;
141   PetscBool         iascii;
142   PetscViewerFormat format;
143 
144   PetscFunctionBegin;
145   ierr = PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&iascii);CHKERRQ(ierr);
146   if (iascii) {
147     ierr = PetscViewerGetFormat(viewer,&format);CHKERRQ(ierr);
148     if (format == PETSC_VIEWER_ASCII_INFO) {
149       ierr = MatFactorInfo_STRUMPACK(A,viewer);CHKERRQ(ierr);
150     }
151   }
152   PetscFunctionReturn(0);
153 }
154 
155 #undef __FUNCT__
156 #define __FUNCT__ "MatLUFactorNumeric_STRUMPACK"
157 static PetscErrorCode MatLUFactorNumeric_STRUMPACK(Mat F,Mat A,const MatFactorInfo *info)
158 {
159   STRUMPACK_data        *sp = (STRUMPACK_data*)F->spptr;
160   STRUMPACK_RETURN_CODE sp_err;
161   Mat                   *tseq,A_seq = NULL;
162   Mat_SeqAIJ            *A_d,*A_o;
163   Mat_MPIAIJ            *mat;
164   PetscErrorCode        ierr;
165   PetscInt              M=A->rmap->N,m=A->rmap->n,N=A->cmap->N;
166   PetscMPIInt           size;
167   IS                    isrow;
168   PetscBool             flg;
169 
170   PetscFunctionBegin;
171   ierr = MPI_Comm_size(PetscObjectComm((PetscObject)A),&size);CHKERRQ(ierr);
172 
173   ierr = PetscObjectTypeCompare((PetscObject)A,MATMPIAIJ,&flg);CHKERRQ(ierr);
174   if (flg) { /* A is MATMPIAIJ */
175     if (sp->MatInputMode == REPLICATED) {
176       if (size > 1) { /* convert mpi A to seq mat A */
177         ierr = ISCreateStride(PETSC_COMM_SELF,M,0,1,&isrow);CHKERRQ(ierr);
178         ierr = MatGetSubMatrices(A,1,&isrow,&isrow,MAT_INITIAL_MATRIX,&tseq);CHKERRQ(ierr);
179         ierr = ISDestroy(&isrow);CHKERRQ(ierr);
180         A_seq = *tseq;
181         ierr  = PetscFree(tseq);CHKERRQ(ierr);
182         A_d   = (Mat_SeqAIJ*)A_seq->data;
183       } else { /* size == 1 */
184         mat = (Mat_MPIAIJ*)A->data;
185         A_d = (Mat_SeqAIJ*)(mat->A)->data;
186       }
187       PetscStackCall("STRUMPACK_set_csr_matrix",STRUMPACK_set_csr_matrix(sp->S,&N,A_d->i,A_d->j,A_d->a,0));
188     } else { /* sp->MatInputMode == DISTRIBUTED */
189       mat = (Mat_MPIAIJ*)A->data;
190       A_d = (Mat_SeqAIJ*)(mat->A)->data;
191       A_o = (Mat_SeqAIJ*)(mat->B)->data;
192       PetscStackCall("STRUMPACK_set_MPIAIJ_matrix",STRUMPACK_set_MPIAIJ_matrix(sp->S,&m,A_d->i,A_d->j,A_d->a,A_o->i,A_o->j,A_o->a,mat->garray));
193     }
194   } else { /* A is MATSEQAIJ */
195     A_d = (Mat_SeqAIJ*)A->data;
196     PetscStackCall("STRUMPACK_set_csr_matrix",STRUMPACK_set_csr_matrix(sp->S,&N,A_d->i,A_d->j,A_d->a,0));
197   }
198 
199   /* Reorder and Factor the matrix. */
200   /* TODO figure out how to avoid reorder if the matrix values changed, but the pattern remains the same. */
201   PetscStackCall("STRUMPACK_reorder",sp_err = STRUMPACK_reorder(sp->S));
202   PetscStackCall("STRUMPACK_factor",sp_err = STRUMPACK_factor(sp->S));
203   if (sp_err != STRUMPACK_SUCCESS) {
204     if (sp_err == STRUMPACK_MATRIX_NOT_SET)        SETERRQ(PETSC_COMM_SELF,PETSC_ERR_LIB,"STRUMPACK error: matrix was not set");
205     else if (sp_err == STRUMPACK_REORDERING_ERROR) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_LIB,"STRUMPACK error: matrix reordering failed");
206     else                                           SETERRQ(PETSC_COMM_SELF,PETSC_ERR_LIB,"STRUMPACK error: factorization failed");
207   }
208   if (flg && sp->MatInputMode == REPLICATED && size > 1) {
209     ierr = MatDestroy(&A_seq);CHKERRQ(ierr);
210   }
211   PetscFunctionReturn(0);
212 }
213 
214 #undef __FUNCT__
215 #define __FUNCT__ "MatLUFactorSymbolic_STRUMPACK"
216 static PetscErrorCode MatLUFactorSymbolic_STRUMPACK(Mat F,Mat A,IS r,IS c,const MatFactorInfo *info)
217 {
218   PetscFunctionBegin;
219   F->ops->lufactornumeric = MatLUFactorNumeric_STRUMPACK;
220   F->ops->solve           = MatSolve_STRUMPACK;
221   F->ops->matsolve        = MatMatSolve_STRUMPACK;
222   PetscFunctionReturn(0);
223 }
224 
225 #undef __FUNCT__
226 #define __FUNCT__ "MatFactorGetSolverPackage_aij_strumpack"
227 static PetscErrorCode MatFactorGetSolverPackage_aij_strumpack(Mat A,const MatSolverPackage *type)
228 {
229   PetscFunctionBegin;
230   *type = MATSOLVERSTRUMPACK;
231   PetscFunctionReturn(0);
232 }
233 
234 #undef __FUNCT__
235 #define __FUNCT__ "MatGetFactor_aij_strumpack"
236 static PetscErrorCode MatGetFactor_aij_strumpack(Mat A,MatFactorType ftype,Mat *F)
237 {
238   Mat                 B;
239   STRUMPACK_data      *sp;
240   PetscErrorCode      ierr;
241   PetscInt            M=A->rmap->N,N=A->cmap->N;
242   PetscMPIInt         size;
243   STRUMPACK_INTERFACE iface;
244   PetscBool           verb,flg;
245   int                 argc;
246   char                **args;
247   const STRUMPACK_PRECISION table[2][2][2] = {{{STRUMPACK_FLOATCOMPLEX_64,STRUMPACK_DOUBLECOMPLEX_64},
248                                                {STRUMPACK_FLOAT_64,STRUMPACK_DOUBLE_64}},
249                                               {{STRUMPACK_FLOATCOMPLEX,STRUMPACK_DOUBLECOMPLEX},
250                                                {STRUMPACK_FLOAT,STRUMPACK_DOUBLE}}};
251   const STRUMPACK_PRECISION prec = table[(sizeof(PetscInt)==8)?0:1][(PETSC_SCALAR==PETSC_COMPLEX)?0:1][(PETSC_REAL==PETSC_FLOAT)?0:1];
252 
253   PetscFunctionBegin;
254   ierr = MPI_Comm_size(PetscObjectComm((PetscObject)A),&size);CHKERRQ(ierr);
255   /* Create the factorization matrix */
256   ierr = MatCreate(PetscObjectComm((PetscObject)A),&B);CHKERRQ(ierr);
257   ierr = MatSetSizes(B,A->rmap->n,A->cmap->n,M,N);CHKERRQ(ierr);
258   ierr = MatSetType(B,((PetscObject)A)->type_name);CHKERRQ(ierr);
259   ierr = MatSeqAIJSetPreallocation(B,0,NULL);
260   ierr = MatMPIAIJSetPreallocation(B,0,NULL,0,NULL);CHKERRQ(ierr);
261   B->ops->lufactorsymbolic = MatLUFactorSymbolic_STRUMPACK;
262   B->ops->view             = MatView_STRUMPACK;
263   B->ops->destroy          = MatDestroy_STRUMPACK;
264   B->ops->getdiagonal      = MatGetDiagonal_STRUMPACK;
265   ierr = PetscObjectComposeFunction((PetscObject)B,"MatFactorGetSolverPackage_C",MatFactorGetSolverPackage_aij_strumpack);CHKERRQ(ierr);
266   B->factortype = MAT_FACTOR_LU;
267   ierr     = PetscNewLog(B,&sp);CHKERRQ(ierr);
268   B->spptr = sp;
269 
270   ierr = PetscOptionsBegin(PetscObjectComm((PetscObject)A),((PetscObject)A)->prefix,"STRUMPACK Options","Mat");CHKERRQ(ierr);
271   sp->MatInputMode = DISTRIBUTED;
272   iface = STRUMPACK_MPI_DIST;
273   ierr = PetscOptionsEnum("-mat_strumpack_matinput","Matrix input mode (replicated or distributed)","None",STRUMPACK_MatInputModes,
274                           (PetscEnum)sp->MatInputMode,(PetscEnum*)&sp->MatInputMode,NULL);CHKERRQ(ierr);
275   if (sp->MatInputMode == DISTRIBUTED && size == 1) sp->MatInputMode = REPLICATED;
276   if (sp->MatInputMode == DISTRIBUTED)     iface = STRUMPACK_MPI_DIST;
277   else if (sp->MatInputMode == REPLICATED) iface = STRUMPACK_MPI;
278 
279   ierr = PetscObjectTypeCompare((PetscObject)A,MATSEQAIJ,&flg);
280   if (flg) iface = STRUMPACK_MT;
281 
282   if (PetscLogPrintInfo) verb = PETSC_TRUE;
283   else verb = PETSC_FALSE;
284   ierr = PetscOptionsBool("-mat_strumpack_verbose","Print STRUMPACK information","None",verb,&verb,NULL);CHKERRQ(ierr);
285   PetscOptionsEnd();
286 
287   ierr = PetscGetArgs(&argc,&args);CHKERRQ(ierr);
288   PetscStackCall("STRUMPACK_init",STRUMPACK_init(&(sp->S),PetscObjectComm((PetscObject)A),prec,iface,argc,args,verb));
289   PetscStackCall("STRUMPACK_set_from_options",STRUMPACK_set_from_options(sp->S));
290 
291   *F = B;
292   PetscFunctionReturn(0);
293 }
294 
295 #undef __FUNCT__
296 #define __FUNCT__ "MatSolverPackageRegister_STRUMPACK"
297 PETSC_EXTERN PetscErrorCode MatSolverPackageRegister_STRUMPACK(void)
298 {
299   PetscErrorCode ierr;
300 
301   PetscFunctionBegin;
302   ierr = MatSolverPackageRegister(MATSOLVERSTRUMPACK,MATMPIAIJ,MAT_FACTOR_LU,MatGetFactor_aij_strumpack);CHKERRQ(ierr);
303   ierr = MatSolverPackageRegister(MATSOLVERSTRUMPACK,MATSEQAIJ,MAT_FACTOR_LU,MatGetFactor_aij_strumpack);CHKERRQ(ierr);
304   PetscFunctionReturn(0);
305 }
306 
307 /*MC
308   MATSOLVERSSTRUMPACK - Parallel direct solver package for LU factorization
309 
310   Use ./configure --download-strumpack to have PETSc installed with STRUMPACK
311 
312   Use -pc_type lu -pc_factor_mat_solver_package strumpack to us this direct solver
313 
314    Works with AIJ matrices
315 
316 .seealso: PCLU
317 
318 .seealso: PCFactorSetMatSolverPackage(), MatSolverPackage
319 
320 M*/
321 
322