xref: /petsc/src/ml/regressor/tests/ex3.c (revision 34b254c57d2aa195261fbc0db2d1455fb6d091da)
1*34b254c5SRichard Tran Mills #include <petscregressor.h>
2*34b254c5SRichard Tran Mills 
3*34b254c5SRichard Tran Mills static char help[] = "Tests some linear PetscRegressor types with different regularizers.\n\n";
4*34b254c5SRichard Tran Mills 
5*34b254c5SRichard Tran Mills typedef struct _AppCtx {
6*34b254c5SRichard Tran Mills   Mat       X;           /* Training data */
7*34b254c5SRichard Tran Mills   Vec       y;           /* Target data   */
8*34b254c5SRichard Tran Mills   Vec       y_predicted; /* Target data   */
9*34b254c5SRichard Tran Mills   Vec       coefficients;
10*34b254c5SRichard Tran Mills   PetscInt  N; /* Data size     */
11*34b254c5SRichard Tran Mills   PetscBool flg_string;
12*34b254c5SRichard Tran Mills   PetscBool flg_ascii;
13*34b254c5SRichard Tran Mills   PetscBool flg_view_sol;
14*34b254c5SRichard Tran Mills   PetscBool test_prefix;
15*34b254c5SRichard Tran Mills } *AppCtx;
16*34b254c5SRichard Tran Mills 
17*34b254c5SRichard Tran Mills static PetscErrorCode DestroyCtx(AppCtx *ctx)
18*34b254c5SRichard Tran Mills {
19*34b254c5SRichard Tran Mills   PetscFunctionBegin;
20*34b254c5SRichard Tran Mills   PetscCall(MatDestroy(&(*ctx)->X));
21*34b254c5SRichard Tran Mills   PetscCall(VecDestroy(&(*ctx)->y));
22*34b254c5SRichard Tran Mills   PetscCall(VecDestroy(&(*ctx)->y_predicted));
23*34b254c5SRichard Tran Mills   PetscCall(PetscFree(*ctx));
24*34b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
25*34b254c5SRichard Tran Mills }
26*34b254c5SRichard Tran Mills 
27*34b254c5SRichard Tran Mills static PetscErrorCode TestRegressorViews(PetscRegressor regressor, AppCtx ctx)
28*34b254c5SRichard Tran Mills {
29*34b254c5SRichard Tran Mills   PetscRegressorType check_type;
30*34b254c5SRichard Tran Mills   PetscBool          match;
31*34b254c5SRichard Tran Mills 
32*34b254c5SRichard Tran Mills   PetscFunctionBegin;
33*34b254c5SRichard Tran Mills   if (ctx->flg_view_sol) {
34*34b254c5SRichard Tran Mills     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Training target vector is\n"));
35*34b254c5SRichard Tran Mills     PetscCall(VecView(ctx->y, PETSC_VIEWER_STDOUT_WORLD));
36*34b254c5SRichard Tran Mills     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Predicted values are\n"));
37*34b254c5SRichard Tran Mills     PetscCall(VecView(ctx->y_predicted, PETSC_VIEWER_STDOUT_WORLD));
38*34b254c5SRichard Tran Mills     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Coefficients are\n"));
39*34b254c5SRichard Tran Mills     PetscCall(VecView(ctx->coefficients, PETSC_VIEWER_STDOUT_WORLD));
40*34b254c5SRichard Tran Mills   }
41*34b254c5SRichard Tran Mills 
42*34b254c5SRichard Tran Mills   if (ctx->flg_string) {
43*34b254c5SRichard Tran Mills     PetscViewer stringviewer;
44*34b254c5SRichard Tran Mills     char        string[512];
45*34b254c5SRichard Tran Mills     const char *outstring;
46*34b254c5SRichard Tran Mills 
47*34b254c5SRichard Tran Mills     PetscCall(PetscViewerStringOpen(PETSC_COMM_WORLD, string, sizeof(string), &stringviewer));
48*34b254c5SRichard Tran Mills     PetscCall(PetscRegressorView(regressor, stringviewer));
49*34b254c5SRichard Tran Mills     PetscCall(PetscViewerStringGetStringRead(stringviewer, &outstring, NULL));
50*34b254c5SRichard Tran Mills     PetscCheck((char *)outstring == (char *)string, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "String returned from viewer does not equal original string");
51*34b254c5SRichard Tran Mills     PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Output from string viewer:%s\n", outstring));
52*34b254c5SRichard Tran Mills     PetscCall(PetscViewerDestroy(&stringviewer));
53*34b254c5SRichard Tran Mills   } else if (ctx->flg_ascii) PetscCall(PetscRegressorView(regressor, PETSC_VIEWER_STDOUT_WORLD));
54*34b254c5SRichard Tran Mills 
55*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorGetType(regressor, &check_type));
56*34b254c5SRichard Tran Mills   PetscCall(PetscStrcmp(check_type, PETSCREGRESSORLINEAR, &match));
57*34b254c5SRichard Tran Mills   PetscCheck(match, PETSC_COMM_WORLD, PETSC_ERR_ARG_NOTSAMETYPE, "Regressor type is not Linear");
58*34b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
59*34b254c5SRichard Tran Mills }
60*34b254c5SRichard Tran Mills 
61*34b254c5SRichard Tran Mills static PetscErrorCode TestPrefixRegressor(PetscRegressor regressor, AppCtx ctx)
62*34b254c5SRichard Tran Mills {
63*34b254c5SRichard Tran Mills   PetscFunctionBegin;
64*34b254c5SRichard Tran Mills   if (ctx->test_prefix) {
65*34b254c5SRichard Tran Mills     PetscCall(PetscRegressorSetOptionsPrefix(regressor, "sys1_"));
66*34b254c5SRichard Tran Mills     PetscCall(PetscRegressorAppendOptionsPrefix(regressor, "sys2_"));
67*34b254c5SRichard Tran Mills   }
68*34b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
69*34b254c5SRichard Tran Mills }
70*34b254c5SRichard Tran Mills 
71*34b254c5SRichard Tran Mills static PetscErrorCode CreateData(AppCtx ctx)
72*34b254c5SRichard Tran Mills {
73*34b254c5SRichard Tran Mills   PetscMPIInt rank;
74*34b254c5SRichard Tran Mills   PetscInt    i;
75*34b254c5SRichard Tran Mills   PetscScalar mean;
76*34b254c5SRichard Tran Mills 
77*34b254c5SRichard Tran Mills   PetscFunctionBegin;
78*34b254c5SRichard Tran Mills   PetscCallMPI(MPI_Comm_rank(PETSC_COMM_WORLD, &rank));
79*34b254c5SRichard Tran Mills   PetscCall(VecCreate(PETSC_COMM_WORLD, &ctx->y));
80*34b254c5SRichard Tran Mills   PetscCall(VecSetSizes(ctx->y, PETSC_DECIDE, ctx->N));
81*34b254c5SRichard Tran Mills   PetscCall(VecSetFromOptions(ctx->y));
82*34b254c5SRichard Tran Mills   PetscCall(VecDuplicate(ctx->y, &ctx->y_predicted));
83*34b254c5SRichard Tran Mills   PetscCall(MatCreate(PETSC_COMM_WORLD, &ctx->X));
84*34b254c5SRichard Tran Mills   PetscCall(MatSetSizes(ctx->X, PETSC_DECIDE, PETSC_DECIDE, ctx->N, ctx->N));
85*34b254c5SRichard Tran Mills   PetscCall(MatSetFromOptions(ctx->X));
86*34b254c5SRichard Tran Mills   PetscCall(MatSetUp(ctx->X));
87*34b254c5SRichard Tran Mills 
88*34b254c5SRichard Tran Mills   if (!rank) {
89*34b254c5SRichard Tran Mills     for (i = 0; i < ctx->N; i++) {
90*34b254c5SRichard Tran Mills       PetscCall(VecSetValue(ctx->y, i, (PetscScalar)i, INSERT_VALUES));
91*34b254c5SRichard Tran Mills       PetscCall(MatSetValue(ctx->X, i, i, 1.0, INSERT_VALUES));
92*34b254c5SRichard Tran Mills     }
93*34b254c5SRichard Tran Mills   }
94*34b254c5SRichard Tran Mills   /* Set up a training data matrix that is the identity.
95*34b254c5SRichard Tran Mills    * We do this because this gives us a special case in which we can analytically determine what the regression
96*34b254c5SRichard Tran Mills    * coefficients should be for ordinary least squares, LASSO (L1 regularized), and ridge (L2 regularized) regression.
97*34b254c5SRichard Tran Mills    * See details in section 6.2 of James et al.'s An Introduction to Statistical Learning (ISLR), in the subsection
98*34b254c5SRichard Tran Mills    * titled "A Simple Special Case for Ridge Regression and the Lasso".
99*34b254c5SRichard Tran Mills    * Note that the coefficients we generate with ridge regression (-regressor_linear_type ridge -regressor_regularizer_weight <lambda>, or, equivalently,
100*34b254c5SRichard Tran Mills    * -tao_brgn_regularization_type l2pure -tao_brgn_regularizer_weight <lambda>) match those of the ISLR formula exactly.
101*34b254c5SRichard Tran Mills    * For LASSO it does not match the ISLR formula: where they use lambda/2, we need to use lambda.
102*34b254c5SRichard Tran Mills    * It also doesn't match what Scikit-learn does; in that case their lambda is 1/n_samples of our lambda. Apparently everyone is scaling
103*34b254c5SRichard Tran Mills    * their loss function by a different value, hence the need to change what "lambda" is. But it's clear that ISLR, Scikit-learn, and we
104*34b254c5SRichard Tran Mills    * are basically doing the same thing otherwise. */
105*34b254c5SRichard Tran Mills   PetscCall(VecAssemblyBegin(ctx->y));
106*34b254c5SRichard Tran Mills   PetscCall(VecAssemblyEnd(ctx->y));
107*34b254c5SRichard Tran Mills   PetscCall(MatAssemblyBegin(ctx->X, MAT_FINAL_ASSEMBLY));
108*34b254c5SRichard Tran Mills   PetscCall(MatAssemblyEnd(ctx->X, MAT_FINAL_ASSEMBLY));
109*34b254c5SRichard Tran Mills   /* Center the target vector we will train with. */
110*34b254c5SRichard Tran Mills   PetscCall(VecMean(ctx->y, &mean));
111*34b254c5SRichard Tran Mills   PetscCall(VecShift(ctx->y, -1.0 * mean));
112*34b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
113*34b254c5SRichard Tran Mills }
114*34b254c5SRichard Tran Mills 
115*34b254c5SRichard Tran Mills static PetscErrorCode ConfigureContext(AppCtx ctx)
116*34b254c5SRichard Tran Mills {
117*34b254c5SRichard Tran Mills   PetscFunctionBegin;
118*34b254c5SRichard Tran Mills   ctx->flg_string   = PETSC_FALSE;
119*34b254c5SRichard Tran Mills   ctx->flg_ascii    = PETSC_FALSE;
120*34b254c5SRichard Tran Mills   ctx->flg_view_sol = PETSC_FALSE;
121*34b254c5SRichard Tran Mills   ctx->test_prefix  = PETSC_FALSE;
122*34b254c5SRichard Tran Mills   ctx->N            = 10;
123*34b254c5SRichard Tran Mills   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Options for PetscRegressor ex3:", "");
124*34b254c5SRichard Tran Mills   PetscCall(PetscOptionsInt("-N", "Dimension of the N x N data matrix", "ex3.c", ctx->N, &ctx->N, NULL));
125*34b254c5SRichard Tran Mills   PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_string_viewer", &ctx->flg_string, NULL));
126*34b254c5SRichard Tran Mills   PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_ascii_viewer", &ctx->flg_ascii, NULL));
127*34b254c5SRichard Tran Mills   PetscCall(PetscOptionsGetBool(NULL, NULL, "-view_sols", &ctx->flg_view_sol, NULL));
128*34b254c5SRichard Tran Mills   PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_prefix", &ctx->test_prefix, NULL));
129*34b254c5SRichard Tran Mills   PetscOptionsEnd();
130*34b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
131*34b254c5SRichard Tran Mills }
132*34b254c5SRichard Tran Mills 
133*34b254c5SRichard Tran Mills int main(int argc, char **args)
134*34b254c5SRichard Tran Mills {
135*34b254c5SRichard Tran Mills   AppCtx         ctx;
136*34b254c5SRichard Tran Mills   PetscRegressor regressor;
137*34b254c5SRichard Tran Mills   PetscScalar    intercept;
138*34b254c5SRichard Tran Mills 
139*34b254c5SRichard Tran Mills   /* Initialize PETSc */
140*34b254c5SRichard Tran Mills   PetscCall(PetscInitialize(&argc, &args, (char *)0, help));
141*34b254c5SRichard Tran Mills 
142*34b254c5SRichard Tran Mills   /* Initialize problem parameters and data */
143*34b254c5SRichard Tran Mills   PetscCall(PetscNew(&ctx));
144*34b254c5SRichard Tran Mills   PetscCall(ConfigureContext(ctx));
145*34b254c5SRichard Tran Mills   PetscCall(CreateData(ctx));
146*34b254c5SRichard Tran Mills 
147*34b254c5SRichard Tran Mills   /* Create Regressor solver with desired type and options */
148*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorCreate(PETSC_COMM_WORLD, &regressor));
149*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorSetType(regressor, PETSCREGRESSORLINEAR));
150*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorLinearSetType(regressor, REGRESSOR_LINEAR_OLS));
151*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorLinearSetFitIntercept(regressor, PETSC_FALSE));
152*34b254c5SRichard Tran Mills   /* Testing prefix functions for Regressor */
153*34b254c5SRichard Tran Mills   PetscCall(TestPrefixRegressor(regressor, ctx));
154*34b254c5SRichard Tran Mills   /* Check for command line options */
155*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorSetFromOptions(regressor));
156*34b254c5SRichard Tran Mills   /* Fit the regressor */
157*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorFit(regressor, ctx->X, ctx->y));
158*34b254c5SRichard Tran Mills   /* Predict data with fitted regressor */
159*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorPredict(regressor, ctx->X, ctx->y_predicted));
160*34b254c5SRichard Tran Mills   /* Get other desired output data */
161*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorLinearGetIntercept(regressor, &intercept));
162*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorLinearGetCoefficients(regressor, &ctx->coefficients));
163*34b254c5SRichard Tran Mills 
164*34b254c5SRichard Tran Mills   /* Testing Views, and GetTypes */
165*34b254c5SRichard Tran Mills   PetscCall(TestRegressorViews(regressor, ctx));
166*34b254c5SRichard Tran Mills   PetscCall(PetscRegressorDestroy(&regressor));
167*34b254c5SRichard Tran Mills   PetscCall(DestroyCtx(&ctx));
168*34b254c5SRichard Tran Mills   PetscCall(PetscFinalize());
169*34b254c5SRichard Tran Mills   return 0;
170*34b254c5SRichard Tran Mills }
171*34b254c5SRichard Tran Mills 
172*34b254c5SRichard Tran Mills /*TEST
173*34b254c5SRichard Tran Mills 
174*34b254c5SRichard Tran Mills    build:
175*34b254c5SRichard Tran Mills       requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES)
176*34b254c5SRichard Tran Mills 
177*34b254c5SRichard Tran Mills    test:
178*34b254c5SRichard Tran Mills       suffix: prefix_tao
179*34b254c5SRichard Tran Mills       args: -sys1_sys2_regressor_view -test_prefix
180*34b254c5SRichard Tran Mills 
181*34b254c5SRichard Tran Mills    test:
182*34b254c5SRichard Tran Mills       suffix: prefix_ksp
183*34b254c5SRichard Tran Mills       args: -sys1_sys2_regressor_view -test_prefix -sys1_sys2_regressor_linear_use_ksp
184*34b254c5SRichard Tran Mills 
185*34b254c5SRichard Tran Mills    test:
186*34b254c5SRichard Tran Mills       suffix: asciiview
187*34b254c5SRichard Tran Mills       args: -test_ascii_viewer
188*34b254c5SRichard Tran Mills 
189*34b254c5SRichard Tran Mills    test:
190*34b254c5SRichard Tran Mills        suffix: stringview
191*34b254c5SRichard Tran Mills        args: -test_string_viewer
192*34b254c5SRichard Tran Mills 
193*34b254c5SRichard Tran Mills    test:
194*34b254c5SRichard Tran Mills       suffix: ksp_intercept
195*34b254c5SRichard Tran Mills       args: -regressor_linear_use_ksp -regressor_linear_fit_intercept -regressor_view
196*34b254c5SRichard Tran Mills 
197*34b254c5SRichard Tran Mills    test:
198*34b254c5SRichard Tran Mills       suffix: ksp_no_intercept
199*34b254c5SRichard Tran Mills       args: -regressor_linear_use_ksp -regressor_view
200*34b254c5SRichard Tran Mills 
201*34b254c5SRichard Tran Mills    test:
202*34b254c5SRichard Tran Mills       suffix: lasso_1
203*34b254c5SRichard Tran Mills       nsize: 1
204*34b254c5SRichard Tran Mills       args: -regressor_type linear -regressor_linear_type lasso -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
205*34b254c5SRichard Tran Mills 
206*34b254c5SRichard Tran Mills    test:
207*34b254c5SRichard Tran Mills       suffix: lasso_2
208*34b254c5SRichard Tran Mills       nsize: 2
209*34b254c5SRichard Tran Mills       args: -regressor_type linear -regressor_linear_type lasso -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
210*34b254c5SRichard Tran Mills 
211*34b254c5SRichard Tran Mills    test:
212*34b254c5SRichard Tran Mills       suffix: ridge_1
213*34b254c5SRichard Tran Mills       nsize: 1
214*34b254c5SRichard Tran Mills       args: -regressor_type linear -regressor_linear_type ridge -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
215*34b254c5SRichard Tran Mills 
216*34b254c5SRichard Tran Mills    test:
217*34b254c5SRichard Tran Mills       suffix: ridge_2
218*34b254c5SRichard Tran Mills       nsize: 2
219*34b254c5SRichard Tran Mills       args: -regressor_type linear -regressor_linear_type ridge -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols
220*34b254c5SRichard Tran Mills 
221*34b254c5SRichard Tran Mills    test:
222*34b254c5SRichard Tran Mills       suffix: ols_1
223*34b254c5SRichard Tran Mills       nsize: 1
224*34b254c5SRichard Tran Mills       args: -regressor_type linear -regressor_linear_type ols -regressor_linear_fit_intercept -view_sols
225*34b254c5SRichard Tran Mills 
226*34b254c5SRichard Tran Mills    test:
227*34b254c5SRichard Tran Mills       suffix: ols_2
228*34b254c5SRichard Tran Mills       nsize: 2
229*34b254c5SRichard Tran Mills       args: -regressor_type linear -regressor_linear_type ols -regressor_linear_fit_intercept -view_sols
230*34b254c5SRichard Tran Mills 
231*34b254c5SRichard Tran Mills TEST*/
232