xref: /petsc/src/ml/regressor/tests/ex2.c (revision 34b254c57d2aa195261fbc0db2d1455fb6d091da)
1*34b254c5SRichard Tran Mills static char help[] = "Tests basic creation and destruction of PetscRegressor objects.\n\n";
2*34b254c5SRichard Tran Mills 
3*34b254c5SRichard Tran Mills /*
4*34b254c5SRichard Tran Mills     Uses PetscRegressor to train a linear model (that is, linear in its coefficients)
5*34b254c5SRichard Tran Mills     for a quadratic polynomial data-fitting problem. This is example 3.2 in the first (1977) edition of Michael
6*34b254c5SRichard Tran Mills     T. Heath's "Scientific Computing: An Introductory Survey" textbook.
7*34b254c5SRichard Tran Mills     This example and ex1.c are essentially the same, except the input arrays are mean-centered in ex1.c
8*34b254c5SRichard Tran Mills     and are not in ex2.c. (The data in ex2.c correspond to the data as presented in Heath's example.)
9*34b254c5SRichard Tran Mills */
10*34b254c5SRichard Tran Mills 
11*34b254c5SRichard Tran Mills #include <petscregressor.h>
12*34b254c5SRichard Tran Mills 
13*34b254c5SRichard Tran Mills int main(int argc, char **args)
14*34b254c5SRichard Tran Mills {
15*34b254c5SRichard Tran Mills   PetscRegressor regressor;
16*34b254c5SRichard Tran Mills   PetscMPIInt    rank;
17*34b254c5SRichard Tran Mills   Mat            X;
18*34b254c5SRichard Tran Mills   Vec            y, y_predicted, coefficients;
19*34b254c5SRichard Tran Mills   PetscScalar    intercept;
20*34b254c5SRichard Tran Mills   /* y_array[] and X_array[] are NOT mean-centered; in ex1.c they are! */
21*34b254c5SRichard Tran Mills   PetscScalar y_array[5]  = {1.0, 0.5, 0, 0.5, 2};
22*34b254c5SRichard Tran Mills   PetscScalar X_array[10] = {-1.00000, 1.00000, -0.50000, 0.25000, 0.00000, 0.00000, 0.50000, 0.25000, 1.00000, 1.00000};
23*34b254c5SRichard Tran Mills   PetscInt    rows_ix[5]  = {0, 1, 2, 3, 4};
24*34b254c5SRichard Tran Mills   PetscInt    cols_ix[2]  = {0, 1};
25*34b254c5SRichard Tran Mills 
26*34b254c5SRichard Tran Mills   PetscCall(PetscInitialize(&argc, &args, (char *)0, help));
27*34b254c5SRichard Tran Mills   PetscCallMPI(MPI_Comm_rank(PETSC_COMM_WORLD, &rank));
28*34b254c5SRichard Tran Mills 
29*34b254c5SRichard Tran Mills   PetscCall(VecCreate(PETSC_COMM_WORLD, &y));
30*34b254c5SRichard Tran Mills   PetscCall(VecSetSizes(y, PETSC_DECIDE, 5));
31*34b254c5SRichard Tran Mills   PetscCall(VecSetFromOptions(y));
32*34b254c5SRichard Tran Mills   PetscCall(VecDuplicate(y, &y_predicted));
33*34b254c5SRichard Tran Mills   PetscCall(MatCreate(PETSC_COMM_WORLD, &X));
34*34b254c5SRichard Tran Mills   PetscCall(MatSetSizes(X, PETSC_DECIDE, PETSC_DECIDE, 5, 2));
35*34b254c5SRichard Tran Mills   PetscCall(MatSetFromOptions(X));
36*34b254c5SRichard Tran Mills   PetscCall(MatSetUp(X));
37*34b254c5SRichard Tran Mills 
38*34b254c5SRichard Tran Mills   if (!rank) {
39*34b254c5SRichard Tran Mills     PetscCall(VecSetValues(y, 5, rows_ix, y_array, INSERT_VALUES));
40*34b254c5SRichard Tran Mills     PetscCall(MatSetValues(X, 5, rows_ix, 2, cols_ix, X_array, ADD_VALUES));
41*34b254c5SRichard Tran Mills   }
42*34b254c5SRichard Tran Mills   PetscCall(VecAssemblyBegin(y));
43*34b254c5SRichard Tran Mills   PetscCall(VecAssemblyEnd(y));
44*34b254c5SRichard Tran Mills   PetscCall(MatAssemblyBegin(X, MAT_FINAL_ASSEMBLY));
45*34b254c5SRichard Tran Mills   PetscCall(MatAssemblyEnd(X, MAT_FINAL_ASSEMBLY));
46*34b254c5SRichard Tran Mills 
47*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorCreate(PETSC_COMM_WORLD, &regressor));
48*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorSetType(regressor, PETSCREGRESSORLINEAR));
49*34b254c5SRichard Tran Mills   PetscRegressorSetFromOptions(regressor);
50*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorFit(regressor, X, y));
51*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorPredict(regressor, X, y_predicted));
52*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorLinearGetIntercept(regressor, &intercept));
53*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorLinearGetCoefficients(regressor, &coefficients));
54*34b254c5SRichard Tran Mills 
55*34b254c5SRichard Tran Mills   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Intercept is %lf\n", intercept));
56*34b254c5SRichard Tran Mills   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Coefficients are\n"));
57*34b254c5SRichard Tran Mills   PetscCall(VecView(coefficients, PETSC_VIEWER_STDOUT_WORLD));
58*34b254c5SRichard Tran Mills   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Predicted values are\n"));
59*34b254c5SRichard Tran Mills   PetscCall(VecView(y_predicted, PETSC_VIEWER_STDOUT_WORLD));
60*34b254c5SRichard Tran Mills 
61*34b254c5SRichard Tran Mills   PetscCall(MatDestroy(&X));
62*34b254c5SRichard Tran Mills   PetscCall(VecDestroy(&y));
63*34b254c5SRichard Tran Mills   PetscCall(VecDestroy(&y_predicted));
64*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorDestroy(&regressor));
65*34b254c5SRichard Tran Mills 
66*34b254c5SRichard Tran Mills   PetscCall(PetscFinalize());
67*34b254c5SRichard Tran Mills   return 0;
68*34b254c5SRichard Tran Mills }
69