1*34b254c5SRichard Tran Mills #include <petscregressor.h> 2*34b254c5SRichard Tran Mills 3*34b254c5SRichard Tran Mills static char help[] = "Tests some linear PetscRegressor types with different regularizers.\n\n"; 4*34b254c5SRichard Tran Mills 5*34b254c5SRichard Tran Mills typedef struct _AppCtx { 6*34b254c5SRichard Tran Mills Mat X; /* Training data */ 7*34b254c5SRichard Tran Mills Vec y; /* Target data */ 8*34b254c5SRichard Tran Mills Vec y_predicted; /* Target data */ 9*34b254c5SRichard Tran Mills Vec coefficients; 10*34b254c5SRichard Tran Mills PetscInt N; /* Data size */ 11*34b254c5SRichard Tran Mills PetscBool flg_string; 12*34b254c5SRichard Tran Mills PetscBool flg_ascii; 13*34b254c5SRichard Tran Mills PetscBool flg_view_sol; 14*34b254c5SRichard Tran Mills PetscBool test_prefix; 15*34b254c5SRichard Tran Mills } *AppCtx; 16*34b254c5SRichard Tran Mills 17*34b254c5SRichard Tran Mills static PetscErrorCode DestroyCtx(AppCtx *ctx) 18*34b254c5SRichard Tran Mills { 19*34b254c5SRichard Tran Mills PetscFunctionBegin; 20*34b254c5SRichard Tran Mills PetscCall(MatDestroy(&(*ctx)->X)); 21*34b254c5SRichard Tran Mills PetscCall(VecDestroy(&(*ctx)->y)); 22*34b254c5SRichard Tran Mills PetscCall(VecDestroy(&(*ctx)->y_predicted)); 23*34b254c5SRichard Tran Mills PetscCall(PetscFree(*ctx)); 24*34b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 25*34b254c5SRichard Tran Mills } 26*34b254c5SRichard Tran Mills 27*34b254c5SRichard Tran Mills static PetscErrorCode TestRegressorViews(PetscRegressor regressor, AppCtx ctx) 28*34b254c5SRichard Tran Mills { 29*34b254c5SRichard Tran Mills PetscRegressorType check_type; 30*34b254c5SRichard Tran Mills PetscBool match; 31*34b254c5SRichard Tran Mills 32*34b254c5SRichard Tran Mills PetscFunctionBegin; 33*34b254c5SRichard Tran Mills if (ctx->flg_view_sol) { 34*34b254c5SRichard Tran Mills PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Training target vector is\n")); 35*34b254c5SRichard Tran Mills PetscCall(VecView(ctx->y, PETSC_VIEWER_STDOUT_WORLD)); 36*34b254c5SRichard Tran Mills PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Predicted values are\n")); 37*34b254c5SRichard Tran Mills PetscCall(VecView(ctx->y_predicted, PETSC_VIEWER_STDOUT_WORLD)); 38*34b254c5SRichard Tran Mills PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Coefficients are\n")); 39*34b254c5SRichard Tran Mills PetscCall(VecView(ctx->coefficients, PETSC_VIEWER_STDOUT_WORLD)); 40*34b254c5SRichard Tran Mills } 41*34b254c5SRichard Tran Mills 42*34b254c5SRichard Tran Mills if (ctx->flg_string) { 43*34b254c5SRichard Tran Mills PetscViewer stringviewer; 44*34b254c5SRichard Tran Mills char string[512]; 45*34b254c5SRichard Tran Mills const char *outstring; 46*34b254c5SRichard Tran Mills 47*34b254c5SRichard Tran Mills PetscCall(PetscViewerStringOpen(PETSC_COMM_WORLD, string, sizeof(string), &stringviewer)); 48*34b254c5SRichard Tran Mills PetscCall(PetscRegressorView(regressor, stringviewer)); 49*34b254c5SRichard Tran Mills PetscCall(PetscViewerStringGetStringRead(stringviewer, &outstring, NULL)); 50*34b254c5SRichard Tran Mills PetscCheck((char *)outstring == (char *)string, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "String returned from viewer does not equal original string"); 51*34b254c5SRichard Tran Mills PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Output from string viewer:%s\n", outstring)); 52*34b254c5SRichard Tran Mills PetscCall(PetscViewerDestroy(&stringviewer)); 53*34b254c5SRichard Tran Mills } else if (ctx->flg_ascii) PetscCall(PetscRegressorView(regressor, PETSC_VIEWER_STDOUT_WORLD)); 54*34b254c5SRichard Tran Mills 55*34b254c5SRichard Tran Mills PetscCall(PetscRegressorGetType(regressor, &check_type)); 56*34b254c5SRichard Tran Mills PetscCall(PetscStrcmp(check_type, PETSCREGRESSORLINEAR, &match)); 57*34b254c5SRichard Tran Mills PetscCheck(match, PETSC_COMM_WORLD, PETSC_ERR_ARG_NOTSAMETYPE, "Regressor type is not Linear"); 58*34b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 59*34b254c5SRichard Tran Mills } 60*34b254c5SRichard Tran Mills 61*34b254c5SRichard Tran Mills static PetscErrorCode TestPrefixRegressor(PetscRegressor regressor, AppCtx ctx) 62*34b254c5SRichard Tran Mills { 63*34b254c5SRichard Tran Mills PetscFunctionBegin; 64*34b254c5SRichard Tran Mills if (ctx->test_prefix) { 65*34b254c5SRichard Tran Mills PetscCall(PetscRegressorSetOptionsPrefix(regressor, "sys1_")); 66*34b254c5SRichard Tran Mills PetscCall(PetscRegressorAppendOptionsPrefix(regressor, "sys2_")); 67*34b254c5SRichard Tran Mills } 68*34b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 69*34b254c5SRichard Tran Mills } 70*34b254c5SRichard Tran Mills 71*34b254c5SRichard Tran Mills static PetscErrorCode CreateData(AppCtx ctx) 72*34b254c5SRichard Tran Mills { 73*34b254c5SRichard Tran Mills PetscMPIInt rank; 74*34b254c5SRichard Tran Mills PetscInt i; 75*34b254c5SRichard Tran Mills PetscScalar mean; 76*34b254c5SRichard Tran Mills 77*34b254c5SRichard Tran Mills PetscFunctionBegin; 78*34b254c5SRichard Tran Mills PetscCallMPI(MPI_Comm_rank(PETSC_COMM_WORLD, &rank)); 79*34b254c5SRichard Tran Mills PetscCall(VecCreate(PETSC_COMM_WORLD, &ctx->y)); 80*34b254c5SRichard Tran Mills PetscCall(VecSetSizes(ctx->y, PETSC_DECIDE, ctx->N)); 81*34b254c5SRichard Tran Mills PetscCall(VecSetFromOptions(ctx->y)); 82*34b254c5SRichard Tran Mills PetscCall(VecDuplicate(ctx->y, &ctx->y_predicted)); 83*34b254c5SRichard Tran Mills PetscCall(MatCreate(PETSC_COMM_WORLD, &ctx->X)); 84*34b254c5SRichard Tran Mills PetscCall(MatSetSizes(ctx->X, PETSC_DECIDE, PETSC_DECIDE, ctx->N, ctx->N)); 85*34b254c5SRichard Tran Mills PetscCall(MatSetFromOptions(ctx->X)); 86*34b254c5SRichard Tran Mills PetscCall(MatSetUp(ctx->X)); 87*34b254c5SRichard Tran Mills 88*34b254c5SRichard Tran Mills if (!rank) { 89*34b254c5SRichard Tran Mills for (i = 0; i < ctx->N; i++) { 90*34b254c5SRichard Tran Mills PetscCall(VecSetValue(ctx->y, i, (PetscScalar)i, INSERT_VALUES)); 91*34b254c5SRichard Tran Mills PetscCall(MatSetValue(ctx->X, i, i, 1.0, INSERT_VALUES)); 92*34b254c5SRichard Tran Mills } 93*34b254c5SRichard Tran Mills } 94*34b254c5SRichard Tran Mills /* Set up a training data matrix that is the identity. 95*34b254c5SRichard Tran Mills * We do this because this gives us a special case in which we can analytically determine what the regression 96*34b254c5SRichard Tran Mills * coefficients should be for ordinary least squares, LASSO (L1 regularized), and ridge (L2 regularized) regression. 97*34b254c5SRichard Tran Mills * See details in section 6.2 of James et al.'s An Introduction to Statistical Learning (ISLR), in the subsection 98*34b254c5SRichard Tran Mills * titled "A Simple Special Case for Ridge Regression and the Lasso". 99*34b254c5SRichard Tran Mills * Note that the coefficients we generate with ridge regression (-regressor_linear_type ridge -regressor_regularizer_weight <lambda>, or, equivalently, 100*34b254c5SRichard Tran Mills * -tao_brgn_regularization_type l2pure -tao_brgn_regularizer_weight <lambda>) match those of the ISLR formula exactly. 101*34b254c5SRichard Tran Mills * For LASSO it does not match the ISLR formula: where they use lambda/2, we need to use lambda. 102*34b254c5SRichard Tran Mills * It also doesn't match what Scikit-learn does; in that case their lambda is 1/n_samples of our lambda. Apparently everyone is scaling 103*34b254c5SRichard Tran Mills * their loss function by a different value, hence the need to change what "lambda" is. But it's clear that ISLR, Scikit-learn, and we 104*34b254c5SRichard Tran Mills * are basically doing the same thing otherwise. */ 105*34b254c5SRichard Tran Mills PetscCall(VecAssemblyBegin(ctx->y)); 106*34b254c5SRichard Tran Mills PetscCall(VecAssemblyEnd(ctx->y)); 107*34b254c5SRichard Tran Mills PetscCall(MatAssemblyBegin(ctx->X, MAT_FINAL_ASSEMBLY)); 108*34b254c5SRichard Tran Mills PetscCall(MatAssemblyEnd(ctx->X, MAT_FINAL_ASSEMBLY)); 109*34b254c5SRichard Tran Mills /* Center the target vector we will train with. */ 110*34b254c5SRichard Tran Mills PetscCall(VecMean(ctx->y, &mean)); 111*34b254c5SRichard Tran Mills PetscCall(VecShift(ctx->y, -1.0 * mean)); 112*34b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 113*34b254c5SRichard Tran Mills } 114*34b254c5SRichard Tran Mills 115*34b254c5SRichard Tran Mills static PetscErrorCode ConfigureContext(AppCtx ctx) 116*34b254c5SRichard Tran Mills { 117*34b254c5SRichard Tran Mills PetscFunctionBegin; 118*34b254c5SRichard Tran Mills ctx->flg_string = PETSC_FALSE; 119*34b254c5SRichard Tran Mills ctx->flg_ascii = PETSC_FALSE; 120*34b254c5SRichard Tran Mills ctx->flg_view_sol = PETSC_FALSE; 121*34b254c5SRichard Tran Mills ctx->test_prefix = PETSC_FALSE; 122*34b254c5SRichard Tran Mills ctx->N = 10; 123*34b254c5SRichard Tran Mills PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Options for PetscRegressor ex3:", ""); 124*34b254c5SRichard Tran Mills PetscCall(PetscOptionsInt("-N", "Dimension of the N x N data matrix", "ex3.c", ctx->N, &ctx->N, NULL)); 125*34b254c5SRichard Tran Mills PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_string_viewer", &ctx->flg_string, NULL)); 126*34b254c5SRichard Tran Mills PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_ascii_viewer", &ctx->flg_ascii, NULL)); 127*34b254c5SRichard Tran Mills PetscCall(PetscOptionsGetBool(NULL, NULL, "-view_sols", &ctx->flg_view_sol, NULL)); 128*34b254c5SRichard Tran Mills PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_prefix", &ctx->test_prefix, NULL)); 129*34b254c5SRichard Tran Mills PetscOptionsEnd(); 130*34b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 131*34b254c5SRichard Tran Mills } 132*34b254c5SRichard Tran Mills 133*34b254c5SRichard Tran Mills int main(int argc, char **args) 134*34b254c5SRichard Tran Mills { 135*34b254c5SRichard Tran Mills AppCtx ctx; 136*34b254c5SRichard Tran Mills PetscRegressor regressor; 137*34b254c5SRichard Tran Mills PetscScalar intercept; 138*34b254c5SRichard Tran Mills 139*34b254c5SRichard Tran Mills /* Initialize PETSc */ 140*34b254c5SRichard Tran Mills PetscCall(PetscInitialize(&argc, &args, (char *)0, help)); 141*34b254c5SRichard Tran Mills 142*34b254c5SRichard Tran Mills /* Initialize problem parameters and data */ 143*34b254c5SRichard Tran Mills PetscCall(PetscNew(&ctx)); 144*34b254c5SRichard Tran Mills PetscCall(ConfigureContext(ctx)); 145*34b254c5SRichard Tran Mills PetscCall(CreateData(ctx)); 146*34b254c5SRichard Tran Mills 147*34b254c5SRichard Tran Mills /* Create Regressor solver with desired type and options */ 148*34b254c5SRichard Tran Mills PetscCall(PetscRegressorCreate(PETSC_COMM_WORLD, ®ressor)); 149*34b254c5SRichard Tran Mills PetscCall(PetscRegressorSetType(regressor, PETSCREGRESSORLINEAR)); 150*34b254c5SRichard Tran Mills PetscCall(PetscRegressorLinearSetType(regressor, REGRESSOR_LINEAR_OLS)); 151*34b254c5SRichard Tran Mills PetscCall(PetscRegressorLinearSetFitIntercept(regressor, PETSC_FALSE)); 152*34b254c5SRichard Tran Mills /* Testing prefix functions for Regressor */ 153*34b254c5SRichard Tran Mills PetscCall(TestPrefixRegressor(regressor, ctx)); 154*34b254c5SRichard Tran Mills /* Check for command line options */ 155*34b254c5SRichard Tran Mills PetscCall(PetscRegressorSetFromOptions(regressor)); 156*34b254c5SRichard Tran Mills /* Fit the regressor */ 157*34b254c5SRichard Tran Mills PetscCall(PetscRegressorFit(regressor, ctx->X, ctx->y)); 158*34b254c5SRichard Tran Mills /* Predict data with fitted regressor */ 159*34b254c5SRichard Tran Mills PetscCall(PetscRegressorPredict(regressor, ctx->X, ctx->y_predicted)); 160*34b254c5SRichard Tran Mills /* Get other desired output data */ 161*34b254c5SRichard Tran Mills PetscCall(PetscRegressorLinearGetIntercept(regressor, &intercept)); 162*34b254c5SRichard Tran Mills PetscCall(PetscRegressorLinearGetCoefficients(regressor, &ctx->coefficients)); 163*34b254c5SRichard Tran Mills 164*34b254c5SRichard Tran Mills /* Testing Views, and GetTypes */ 165*34b254c5SRichard Tran Mills PetscCall(TestRegressorViews(regressor, ctx)); 166*34b254c5SRichard Tran Mills PetscCall(PetscRegressorDestroy(®ressor)); 167*34b254c5SRichard Tran Mills PetscCall(DestroyCtx(&ctx)); 168*34b254c5SRichard Tran Mills PetscCall(PetscFinalize()); 169*34b254c5SRichard Tran Mills return 0; 170*34b254c5SRichard Tran Mills } 171*34b254c5SRichard Tran Mills 172*34b254c5SRichard Tran Mills /*TEST 173*34b254c5SRichard Tran Mills 174*34b254c5SRichard Tran Mills build: 175*34b254c5SRichard Tran Mills requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES) 176*34b254c5SRichard Tran Mills 177*34b254c5SRichard Tran Mills test: 178*34b254c5SRichard Tran Mills suffix: prefix_tao 179*34b254c5SRichard Tran Mills args: -sys1_sys2_regressor_view -test_prefix 180*34b254c5SRichard Tran Mills 181*34b254c5SRichard Tran Mills test: 182*34b254c5SRichard Tran Mills suffix: prefix_ksp 183*34b254c5SRichard Tran Mills args: -sys1_sys2_regressor_view -test_prefix -sys1_sys2_regressor_linear_use_ksp 184*34b254c5SRichard Tran Mills 185*34b254c5SRichard Tran Mills test: 186*34b254c5SRichard Tran Mills suffix: asciiview 187*34b254c5SRichard Tran Mills args: -test_ascii_viewer 188*34b254c5SRichard Tran Mills 189*34b254c5SRichard Tran Mills test: 190*34b254c5SRichard Tran Mills suffix: stringview 191*34b254c5SRichard Tran Mills args: -test_string_viewer 192*34b254c5SRichard Tran Mills 193*34b254c5SRichard Tran Mills test: 194*34b254c5SRichard Tran Mills suffix: ksp_intercept 195*34b254c5SRichard Tran Mills args: -regressor_linear_use_ksp -regressor_linear_fit_intercept -regressor_view 196*34b254c5SRichard Tran Mills 197*34b254c5SRichard Tran Mills test: 198*34b254c5SRichard Tran Mills suffix: ksp_no_intercept 199*34b254c5SRichard Tran Mills args: -regressor_linear_use_ksp -regressor_view 200*34b254c5SRichard Tran Mills 201*34b254c5SRichard Tran Mills test: 202*34b254c5SRichard Tran Mills suffix: lasso_1 203*34b254c5SRichard Tran Mills nsize: 1 204*34b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type lasso -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols 205*34b254c5SRichard Tran Mills 206*34b254c5SRichard Tran Mills test: 207*34b254c5SRichard Tran Mills suffix: lasso_2 208*34b254c5SRichard Tran Mills nsize: 2 209*34b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type lasso -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols 210*34b254c5SRichard Tran Mills 211*34b254c5SRichard Tran Mills test: 212*34b254c5SRichard Tran Mills suffix: ridge_1 213*34b254c5SRichard Tran Mills nsize: 1 214*34b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type ridge -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols 215*34b254c5SRichard Tran Mills 216*34b254c5SRichard Tran Mills test: 217*34b254c5SRichard Tran Mills suffix: ridge_2 218*34b254c5SRichard Tran Mills nsize: 2 219*34b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type ridge -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols 220*34b254c5SRichard Tran Mills 221*34b254c5SRichard Tran Mills test: 222*34b254c5SRichard Tran Mills suffix: ols_1 223*34b254c5SRichard Tran Mills nsize: 1 224*34b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type ols -regressor_linear_fit_intercept -view_sols 225*34b254c5SRichard Tran Mills 226*34b254c5SRichard Tran Mills test: 227*34b254c5SRichard Tran Mills suffix: ols_2 228*34b254c5SRichard Tran Mills nsize: 2 229*34b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type ols -regressor_linear_fit_intercept -view_sols 230*34b254c5SRichard Tran Mills 231*34b254c5SRichard Tran Mills TEST*/ 232