xref: /petsc/src/mat/tests/ex6k.kokkos.cxx (revision a90d8e383a2827d476809587898a1fbbc9581506)
1*a90d8e38SSatish Balay static char help[] = "Benchmarking MatProduct with AIJ and its subclass matrix types\n";
2*a90d8e38SSatish Balay 
3*a90d8e38SSatish Balay /*
4*a90d8e38SSatish Balay Usage:
5*a90d8e38SSatish Balay   mpirun -n <np> ./ex6k
6*a90d8e38SSatish Balay     -A <filename>     : input PETSc binary file for matrix A; one can convert a file from MatrixMarket using mat/tests/ex72.c
7*a90d8e38SSatish Balay     -P <filename>     : input PETSc binary file for matrix P; optional, if not given, P = A
8*a90d8e38SSatish Balay     -mat_type  <str>  : aij or its subclass. Default is aij.
9*a90d8e38SSatish Balay     -prod_type <str>  : AP, AtP, APt, PtAP or PAPt. Default is AP.
10*a90d8e38SSatish Balay     -n <num>          : run MatProductNumeric() this many times and report average time. Default is 100.
11*a90d8e38SSatish Balay 
12*a90d8e38SSatish Balay Notes:
13*a90d8e38SSatish Balay   It uses CPU-timer to measure the time.
14*a90d8e38SSatish Balay 
15*a90d8e38SSatish Balay Examples:
16*a90d8e38SSatish Balay   On OLCF Summit (with GPU-aware MPI)
17*a90d8e38SSatish Balay     # 6 MPI ranks:
18*a90d8e38SSatish Balay     # 6 resource sets (-n 6), 1 MPI rank per RS (-a 1), 7 CPU cores per RS (-c 7), and 1 GPU per RS (-g 1), 6 RSs per node (-r 6)
19*a90d8e38SSatish Balay     jsrun --smpiargs "-gpu" -n 6 -a 1 -c 7 -g 1 -r 6 ./ex6k -A cage12.aij -mat_type aijcusparse
20*a90d8e38SSatish Balay 
21*a90d8e38SSatish Balay     # 1 MPI rank
22*a90d8e38SSatish Balay     jsrun --smpiargs "-gpu" -n 1 -a 1 -c 7 -g 1 -r 1 ./ex6k -A cage12.aij -mat_type aijcusparse
23*a90d8e38SSatish Balay 
24*a90d8e38SSatish Balay   On OLCF Crusher:
25*a90d8e38SSatish Balay     # 1 MPI rank
26*a90d8e38SSatish Balay     # run with 1 node (-N1), 1 mpi rank (-n1), 2 hardware threads per rank (-c2)
27*a90d8e38SSatish Balay     srun -N1 -n1 -c2 --gpus-per-node=8 --gpu-bind=closest ./ex6k -A HV15R.aij -mat_type aijkokkos
28*a90d8e38SSatish Balay 
29*a90d8e38SSatish Balay     # 8 MPI ranks
30*a90d8e38SSatish Balay     srun -N1 -n8 -c2 --gpus-per-node=8 --gpu-bind=closest ./ex6k -A HV15R.aij -mat_type aijkokkos
31*a90d8e38SSatish Balay */
32*a90d8e38SSatish Balay #include <petscmat.h>
33*a90d8e38SSatish Balay #include <petscdevice.h>
34*a90d8e38SSatish Balay 
35*a90d8e38SSatish Balay #if defined(PETSC_HAVE_CUDA)
36*a90d8e38SSatish Balay   #include <petscdevice_cuda.h>
37*a90d8e38SSatish Balay   #define SyncDevice() PetscCallCUDA(cudaDeviceSynchronize())
38*a90d8e38SSatish Balay #elif defined(PETSC_HAVE_HIP)
39*a90d8e38SSatish Balay   #include <petscdevice_hip.h>
40*a90d8e38SSatish Balay   #define SyncDevice() PetscCallHIP(hipDeviceSynchronize())
41*a90d8e38SSatish Balay #elif defined(PETSC_HAVE_KOKKOS)
42*a90d8e38SSatish Balay   #include <Kokkos_Core.hpp>
43*a90d8e38SSatish Balay   #define SyncDevice() Kokkos::fence()
44*a90d8e38SSatish Balay #else
45*a90d8e38SSatish Balay   #define SyncDevice()
46*a90d8e38SSatish Balay #endif
47*a90d8e38SSatish Balay 
48*a90d8e38SSatish Balay int main(int argc, char **args)
49*a90d8e38SSatish Balay {
50*a90d8e38SSatish Balay   Mat            A, P, C;
51*a90d8e38SSatish Balay   Mat            A2, P2, C2; /* Shadow matrices (of MATAIJ) of A,P,C for initialization and validation */
52*a90d8e38SSatish Balay   char           matTypeStr[64], prodTypeStr[32];
53*a90d8e38SSatish Balay   char           fileA[PETSC_MAX_PATH_LEN], fileP[PETSC_MAX_PATH_LEN];
54*a90d8e38SSatish Balay   PetscViewer    fdA, fdP;
55*a90d8e38SSatish Balay   PetscBool      flg, flgA, flgP, equal = PETSC_FALSE;
56*a90d8e38SSatish Balay   PetscLogStage  stage;
57*a90d8e38SSatish Balay   PetscInt       i, n = 100, nskip = 2, M, N;
58*a90d8e38SSatish Balay   MatInfo        info;
59*a90d8e38SSatish Balay   PetscLogDouble tstart = 0, tend = 0, avgTime;
60*a90d8e38SSatish Balay   PetscMPIInt    size;
61*a90d8e38SSatish Balay   MatProductType prodType;
62*a90d8e38SSatish Balay   PetscBool      isAP, isAtP, isAPt, isPtAP, isPAPt;
63*a90d8e38SSatish Balay 
64*a90d8e38SSatish Balay   PetscFunctionBeginUser;
65*a90d8e38SSatish Balay   PetscCall(PetscInitialize(&argc, &args, nullptr, help));
66*a90d8e38SSatish Balay   PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD, &size));
67*a90d8e38SSatish Balay 
68*a90d8e38SSatish Balay   /* Read options -n */
69*a90d8e38SSatish Balay   PetscCall(PetscOptionsGetInt(NULL, NULL, "-n", &n, NULL));
70*a90d8e38SSatish Balay 
71*a90d8e38SSatish Balay   /* Load the matrix from a binary file */
72*a90d8e38SSatish Balay   PetscCall(PetscOptionsGetString(NULL, NULL, "-A", fileA, PETSC_MAX_PATH_LEN, &flgA));
73*a90d8e38SSatish Balay   PetscCall(PetscOptionsGetString(NULL, NULL, "-P", fileP, PETSC_MAX_PATH_LEN, &flgP));
74*a90d8e38SSatish Balay   PetscCheck(flgA, PETSC_COMM_WORLD, PETSC_ERR_USER_INPUT, "Must give a PETSc matrix binary file with the -A option");
75*a90d8e38SSatish Balay 
76*a90d8e38SSatish Balay   PetscCall(PetscOptionsGetString(NULL, NULL, "-mat_type", matTypeStr, sizeof(matTypeStr), &flg));
77*a90d8e38SSatish Balay   if (!flg) PetscCall(PetscStrncpy(matTypeStr, MATAIJ, sizeof(matTypeStr))); /* Inject the default if not provided */
78*a90d8e38SSatish Balay 
79*a90d8e38SSatish Balay   PetscCall(PetscOptionsGetString(NULL, NULL, "-prod_type", prodTypeStr, sizeof(prodTypeStr), &flg));
80*a90d8e38SSatish Balay   if (!flg) PetscCall(PetscStrncpy(prodTypeStr, "AP", sizeof(prodTypeStr))); /* Inject the default if not provided */
81*a90d8e38SSatish Balay 
82*a90d8e38SSatish Balay   PetscCall(PetscStrcmp(prodTypeStr, "AP", &isAP));
83*a90d8e38SSatish Balay   PetscCall(PetscStrcmp(prodTypeStr, "AtP", &isAtP));
84*a90d8e38SSatish Balay   PetscCall(PetscStrcmp(prodTypeStr, "APt", &isAPt));
85*a90d8e38SSatish Balay   PetscCall(PetscStrcmp(prodTypeStr, "PtAP", &isPtAP));
86*a90d8e38SSatish Balay   PetscCall(PetscStrcmp(prodTypeStr, "PAPt", &isPAPt));
87*a90d8e38SSatish Balay 
88*a90d8e38SSatish Balay   if (isAP) prodType = MATPRODUCT_AB;
89*a90d8e38SSatish Balay   else if (isAtP) prodType = MATPRODUCT_AtB;
90*a90d8e38SSatish Balay   else if (isAPt) prodType = MATPRODUCT_ABt;
91*a90d8e38SSatish Balay   else if (isPtAP) prodType = MATPRODUCT_PtAP;
92*a90d8e38SSatish Balay   else if (isPAPt) prodType = MATPRODUCT_RARt;
93*a90d8e38SSatish Balay   else SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_USER, "Unsupported product type %s", prodTypeStr);
94*a90d8e38SSatish Balay 
95*a90d8e38SSatish Balay   /* Read the matrix file to A2 */
96*a90d8e38SSatish Balay   PetscCall(PetscViewerBinaryOpen(PETSC_COMM_WORLD, fileA, FILE_MODE_READ, &fdA));
97*a90d8e38SSatish Balay   PetscCall(MatCreate(PETSC_COMM_WORLD, &A2));
98*a90d8e38SSatish Balay   PetscCall(MatSetType(A2, MATAIJ));
99*a90d8e38SSatish Balay   PetscCall(MatLoad(A2, fdA));
100*a90d8e38SSatish Balay   PetscCall(PetscViewerDestroy(&fdA));
101*a90d8e38SSatish Balay 
102*a90d8e38SSatish Balay   PetscCall(MatGetSize(A2, &M, &N));
103*a90d8e38SSatish Balay   PetscCall(MatGetInfo(A2, MAT_GLOBAL_SUM, &info));
104*a90d8e38SSatish Balay   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Input matrix A: %s, %" PetscInt_FMT " x %" PetscInt_FMT ", %lld nonzeros, %.1f per row\n", fileA, M, N, (long long)info.nz_used, (double)info.nz_used / (double)M));
105*a90d8e38SSatish Balay 
106*a90d8e38SSatish Balay   /* Copy A2 to A and convert A to the specified type */
107*a90d8e38SSatish Balay   PetscCall(MatDuplicate(A2, MAT_COPY_VALUES, &A));
108*a90d8e38SSatish Balay   PetscCall(MatConvert(A, matTypeStr, MAT_INPLACE_MATRIX, &A));
109*a90d8e38SSatish Balay 
110*a90d8e38SSatish Balay   /* Init P, P2 similarly */
111*a90d8e38SSatish Balay   if (flgP) { /* If user provided P */
112*a90d8e38SSatish Balay     PetscCall(PetscViewerBinaryOpen(PETSC_COMM_WORLD, fileP, FILE_MODE_READ, &fdP));
113*a90d8e38SSatish Balay     PetscCall(MatCreate(PETSC_COMM_WORLD, &P2));
114*a90d8e38SSatish Balay     PetscCall(MatSetType(P2, MATAIJ));
115*a90d8e38SSatish Balay     PetscCall(MatLoad(P2, fdP));
116*a90d8e38SSatish Balay     PetscCall(PetscViewerDestroy(&fdP));
117*a90d8e38SSatish Balay 
118*a90d8e38SSatish Balay     PetscCall(MatDuplicate(P2, MAT_COPY_VALUES, &P));
119*a90d8e38SSatish Balay     PetscCall(MatConvert(P, matTypeStr, MAT_INPLACE_MATRIX, &P));
120*a90d8e38SSatish Balay 
121*a90d8e38SSatish Balay     PetscCall(MatGetSize(P2, &M, &N));
122*a90d8e38SSatish Balay     PetscCall(MatGetInfo(P2, MAT_GLOBAL_SUM, &info));
123*a90d8e38SSatish Balay     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Input matrix P: %s, %" PetscInt_FMT " x %" PetscInt_FMT ", %lld nonzeros, %.1f per row\n", fileP, M, N, (long long)info.nz_used, (double)info.nz_used / (double)M));
124*a90d8e38SSatish Balay   } else { /* otherwise just let P = A */
125*a90d8e38SSatish Balay     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Input matrix P = A\n"));
126*a90d8e38SSatish Balay     P2 = A2;
127*a90d8e38SSatish Balay     P  = A;
128*a90d8e38SSatish Balay   }
129*a90d8e38SSatish Balay 
130*a90d8e38SSatish Balay   /* Compute the reference C2 */
131*a90d8e38SSatish Balay   PetscCall(MatProductCreate(A2, P2, NULL, &C2));
132*a90d8e38SSatish Balay   PetscCall(MatProductSetType(C2, prodType));
133*a90d8e38SSatish Balay   PetscCall(MatProductSetFill(C2, PETSC_DEFAULT));
134*a90d8e38SSatish Balay   PetscCall(MatProductSetFromOptions(C2));
135*a90d8e38SSatish Balay   PetscCall(MatProductSymbolic(C2));
136*a90d8e38SSatish Balay   PetscCall(MatProductNumeric(C2));
137*a90d8e38SSatish Balay   PetscCall(MatGetSize(C2, &M, &N));
138*a90d8e38SSatish Balay   PetscCall(MatGetInfo(C2, MAT_GLOBAL_SUM, &info));
139*a90d8e38SSatish Balay   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Mat product  C = %s: %" PetscInt_FMT " x %" PetscInt_FMT ", %lld nonzeros, %.1f per row\n", prodTypeStr, M, N, (long long)info.nz_used, (double)info.nz_used / (double)M));
140*a90d8e38SSatish Balay 
141*a90d8e38SSatish Balay   /* Compute C */
142*a90d8e38SSatish Balay   PetscCall(MatProductCreate(A, P, NULL, &C));
143*a90d8e38SSatish Balay   PetscCall(MatProductSetType(C, prodType));
144*a90d8e38SSatish Balay   PetscCall(MatProductSetAlgorithm(C, MATPRODUCTALGORITHMBACKEND));
145*a90d8e38SSatish Balay   PetscCall(MatProductSetFill(C, PETSC_DEFAULT));
146*a90d8e38SSatish Balay   PetscCall(MatProductSetFromOptions(C));
147*a90d8e38SSatish Balay 
148*a90d8e38SSatish Balay   /* Measure  MatProductSymbolic */
149*a90d8e38SSatish Balay   PetscCall(PetscLogStageRegister("MatProductSymbolic", &stage));
150*a90d8e38SSatish Balay   PetscCall(PetscLogStagePush(stage));
151*a90d8e38SSatish Balay   SyncDevice();
152*a90d8e38SSatish Balay   PetscCallMPI(MPI_Barrier(PETSC_COMM_WORLD));
153*a90d8e38SSatish Balay   PetscCall(PetscTime(&tstart));
154*a90d8e38SSatish Balay   PetscCall(MatProductSymbolic(C));
155*a90d8e38SSatish Balay   SyncDevice();
156*a90d8e38SSatish Balay   PetscCallMPI(MPI_Barrier(PETSC_COMM_WORLD));
157*a90d8e38SSatish Balay   PetscCall(PetscTime(&tend));
158*a90d8e38SSatish Balay   avgTime = (tend - tstart) * 1e6; /* microseconds */
159*a90d8e38SSatish Balay   PetscCall(PetscLogStagePop());
160*a90d8e38SSatish Balay   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\nMatProductSymbolic()         time (us) with %d MPI ranks = %8.2f\n", size, avgTime));
161*a90d8e38SSatish Balay 
162*a90d8e38SSatish Balay   /* Measure  MatProductNumeric */
163*a90d8e38SSatish Balay   PetscCall(PetscLogStageRegister("MatProductNumeric", &stage));
164*a90d8e38SSatish Balay   for (i = 0; i < n + nskip; i++) {
165*a90d8e38SSatish Balay     if (i == nskip) {
166*a90d8e38SSatish Balay       SyncDevice();
167*a90d8e38SSatish Balay       PetscCall(PetscLogStagePush(stage));
168*a90d8e38SSatish Balay       PetscCallMPI(MPI_Barrier(PETSC_COMM_WORLD));
169*a90d8e38SSatish Balay       PetscCall(PetscTime(&tstart));
170*a90d8e38SSatish Balay     }
171*a90d8e38SSatish Balay     PetscCall(MatProductReplaceMats(A, P, NULL, C));
172*a90d8e38SSatish Balay     PetscCall(MatProductNumeric(C));
173*a90d8e38SSatish Balay   }
174*a90d8e38SSatish Balay   SyncDevice();
175*a90d8e38SSatish Balay   PetscCallMPI(MPI_Barrier(PETSC_COMM_WORLD));
176*a90d8e38SSatish Balay   PetscCall(PetscTime(&tend));
177*a90d8e38SSatish Balay   avgTime = (tend - tstart) * 1e6 / n; /* microseconds */
178*a90d8e38SSatish Balay   PetscCall(PetscLogStagePop());
179*a90d8e38SSatish Balay 
180*a90d8e38SSatish Balay   PetscCall(MatMultEqual(C, C2, 8, &equal)); /* Not MatEqual() since C and C2 are not necessarily bitwise equal */
181*a90d8e38SSatish Balay 
182*a90d8e38SSatish Balay   PetscCheck(equal, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "Matrix production error");
183*a90d8e38SSatish Balay   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "MatProductNumeric()  average time (us) with %d MPI ranks = %8.2f\n", size, avgTime));
184*a90d8e38SSatish Balay 
185*a90d8e38SSatish Balay   PetscCall(MatDestroy(&A));
186*a90d8e38SSatish Balay   if (flgP) PetscCall(MatDestroy(&P));
187*a90d8e38SSatish Balay   PetscCall(MatDestroy(&C));
188*a90d8e38SSatish Balay 
189*a90d8e38SSatish Balay   PetscCall(MatDestroy(&A2));
190*a90d8e38SSatish Balay   if (flgP) PetscCall(MatDestroy(&P2));
191*a90d8e38SSatish Balay   PetscCall(MatDestroy(&C2));
192*a90d8e38SSatish Balay 
193*a90d8e38SSatish Balay   PetscCall(PetscFinalize());
194*a90d8e38SSatish Balay   return 0;
195*a90d8e38SSatish Balay }
196*a90d8e38SSatish Balay 
197*a90d8e38SSatish Balay /*TEST
198*a90d8e38SSatish Balay 
199*a90d8e38SSatish Balay   testset:
200*a90d8e38SSatish Balay     args: -n 2 -A ${DATAFILESPATH}/matrices/small
201*a90d8e38SSatish Balay     nsize: 1
202*a90d8e38SSatish Balay     filter: grep "DOES_NOT_EXIST"
203*a90d8e38SSatish Balay     output_file: output/empty.out
204*a90d8e38SSatish Balay     requires: datafilespath !complex double !defined(PETSC_USE_64BIT_INDICES) kokkos_kernels
205*a90d8e38SSatish Balay 
206*a90d8e38SSatish Balay     test:
207*a90d8e38SSatish Balay       suffix: 1
208*a90d8e38SSatish Balay       requires: cuda
209*a90d8e38SSatish Balay       args: -mat_type aijcusparse
210*a90d8e38SSatish Balay 
211*a90d8e38SSatish Balay     test:
212*a90d8e38SSatish Balay       suffix: 2
213*a90d8e38SSatish Balay       args: -mat_type aijkokkos
214*a90d8e38SSatish Balay 
215*a90d8e38SSatish Balay     test:
216*a90d8e38SSatish Balay       suffix: 3
217*a90d8e38SSatish Balay       requires: hip
218*a90d8e38SSatish Balay       args: -mat_type aijhipsparse
219*a90d8e38SSatish Balay 
220*a90d8e38SSatish Balay TEST*/
221