134b254c5SRichard Tran Mills #include <petscregressor.h> 234b254c5SRichard Tran Mills 334b254c5SRichard Tran Mills static char help[] = "Tests some linear PetscRegressor types with different regularizers.\n\n"; 434b254c5SRichard Tran Mills 534b254c5SRichard Tran Mills typedef struct _AppCtx { 634b254c5SRichard Tran Mills Mat X; /* Training data */ 734b254c5SRichard Tran Mills Vec y; /* Target data */ 834b254c5SRichard Tran Mills Vec y_predicted; /* Target data */ 934b254c5SRichard Tran Mills Vec coefficients; 1034b254c5SRichard Tran Mills PetscInt N; /* Data size */ 1134b254c5SRichard Tran Mills PetscBool flg_string; 1234b254c5SRichard Tran Mills PetscBool flg_ascii; 1334b254c5SRichard Tran Mills PetscBool flg_view_sol; 1434b254c5SRichard Tran Mills PetscBool test_prefix; 1534b254c5SRichard Tran Mills } *AppCtx; 1634b254c5SRichard Tran Mills 1734b254c5SRichard Tran Mills static PetscErrorCode DestroyCtx(AppCtx *ctx) 1834b254c5SRichard Tran Mills { 1934b254c5SRichard Tran Mills PetscFunctionBegin; 2034b254c5SRichard Tran Mills PetscCall(MatDestroy(&(*ctx)->X)); 2134b254c5SRichard Tran Mills PetscCall(VecDestroy(&(*ctx)->y)); 2234b254c5SRichard Tran Mills PetscCall(VecDestroy(&(*ctx)->y_predicted)); 2334b254c5SRichard Tran Mills PetscCall(PetscFree(*ctx)); 2434b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 2534b254c5SRichard Tran Mills } 2634b254c5SRichard Tran Mills 2734b254c5SRichard Tran Mills static PetscErrorCode TestRegressorViews(PetscRegressor regressor, AppCtx ctx) 2834b254c5SRichard Tran Mills { 2934b254c5SRichard Tran Mills PetscRegressorType check_type; 3034b254c5SRichard Tran Mills PetscBool match; 3134b254c5SRichard Tran Mills 3234b254c5SRichard Tran Mills PetscFunctionBegin; 3334b254c5SRichard Tran Mills if (ctx->flg_view_sol) { 3434b254c5SRichard Tran Mills PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Training target vector is\n")); 3534b254c5SRichard Tran Mills PetscCall(VecView(ctx->y, PETSC_VIEWER_STDOUT_WORLD)); 3634b254c5SRichard Tran Mills PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Predicted values are\n")); 3734b254c5SRichard Tran Mills PetscCall(VecView(ctx->y_predicted, PETSC_VIEWER_STDOUT_WORLD)); 3834b254c5SRichard Tran Mills PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Coefficients are\n")); 3934b254c5SRichard Tran Mills PetscCall(VecView(ctx->coefficients, PETSC_VIEWER_STDOUT_WORLD)); 4034b254c5SRichard Tran Mills } 4134b254c5SRichard Tran Mills 4234b254c5SRichard Tran Mills if (ctx->flg_string) { 4334b254c5SRichard Tran Mills PetscViewer stringviewer; 4434b254c5SRichard Tran Mills char string[512]; 4534b254c5SRichard Tran Mills const char *outstring; 4634b254c5SRichard Tran Mills 4734b254c5SRichard Tran Mills PetscCall(PetscViewerStringOpen(PETSC_COMM_WORLD, string, sizeof(string), &stringviewer)); 4834b254c5SRichard Tran Mills PetscCall(PetscRegressorView(regressor, stringviewer)); 4934b254c5SRichard Tran Mills PetscCall(PetscViewerStringGetStringRead(stringviewer, &outstring, NULL)); 5034b254c5SRichard Tran Mills PetscCheck((char *)outstring == (char *)string, PETSC_COMM_WORLD, PETSC_ERR_PLIB, "String returned from viewer does not equal original string"); 5134b254c5SRichard Tran Mills PetscCall(PetscPrintf(PETSC_COMM_WORLD, "Output from string viewer:%s\n", outstring)); 5234b254c5SRichard Tran Mills PetscCall(PetscViewerDestroy(&stringviewer)); 5334b254c5SRichard Tran Mills } else if (ctx->flg_ascii) PetscCall(PetscRegressorView(regressor, PETSC_VIEWER_STDOUT_WORLD)); 5434b254c5SRichard Tran Mills 5534b254c5SRichard Tran Mills PetscCall(PetscRegressorGetType(regressor, &check_type)); 5634b254c5SRichard Tran Mills PetscCall(PetscStrcmp(check_type, PETSCREGRESSORLINEAR, &match)); 5734b254c5SRichard Tran Mills PetscCheck(match, PETSC_COMM_WORLD, PETSC_ERR_ARG_NOTSAMETYPE, "Regressor type is not Linear"); 5834b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 5934b254c5SRichard Tran Mills } 6034b254c5SRichard Tran Mills 6134b254c5SRichard Tran Mills static PetscErrorCode TestPrefixRegressor(PetscRegressor regressor, AppCtx ctx) 6234b254c5SRichard Tran Mills { 6334b254c5SRichard Tran Mills PetscFunctionBegin; 6434b254c5SRichard Tran Mills if (ctx->test_prefix) { 6534b254c5SRichard Tran Mills PetscCall(PetscRegressorSetOptionsPrefix(regressor, "sys1_")); 6634b254c5SRichard Tran Mills PetscCall(PetscRegressorAppendOptionsPrefix(regressor, "sys2_")); 6734b254c5SRichard Tran Mills } 6834b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 6934b254c5SRichard Tran Mills } 7034b254c5SRichard Tran Mills 7134b254c5SRichard Tran Mills static PetscErrorCode CreateData(AppCtx ctx) 7234b254c5SRichard Tran Mills { 7334b254c5SRichard Tran Mills PetscMPIInt rank; 7434b254c5SRichard Tran Mills PetscInt i; 7534b254c5SRichard Tran Mills PetscScalar mean; 7634b254c5SRichard Tran Mills 7734b254c5SRichard Tran Mills PetscFunctionBegin; 7834b254c5SRichard Tran Mills PetscCallMPI(MPI_Comm_rank(PETSC_COMM_WORLD, &rank)); 7934b254c5SRichard Tran Mills PetscCall(VecCreate(PETSC_COMM_WORLD, &ctx->y)); 8034b254c5SRichard Tran Mills PetscCall(VecSetSizes(ctx->y, PETSC_DECIDE, ctx->N)); 8134b254c5SRichard Tran Mills PetscCall(VecSetFromOptions(ctx->y)); 8234b254c5SRichard Tran Mills PetscCall(VecDuplicate(ctx->y, &ctx->y_predicted)); 8334b254c5SRichard Tran Mills PetscCall(MatCreate(PETSC_COMM_WORLD, &ctx->X)); 8434b254c5SRichard Tran Mills PetscCall(MatSetSizes(ctx->X, PETSC_DECIDE, PETSC_DECIDE, ctx->N, ctx->N)); 8534b254c5SRichard Tran Mills PetscCall(MatSetFromOptions(ctx->X)); 8634b254c5SRichard Tran Mills PetscCall(MatSetUp(ctx->X)); 8734b254c5SRichard Tran Mills 8834b254c5SRichard Tran Mills if (!rank) { 8934b254c5SRichard Tran Mills for (i = 0; i < ctx->N; i++) { 9034b254c5SRichard Tran Mills PetscCall(VecSetValue(ctx->y, i, (PetscScalar)i, INSERT_VALUES)); 9134b254c5SRichard Tran Mills PetscCall(MatSetValue(ctx->X, i, i, 1.0, INSERT_VALUES)); 9234b254c5SRichard Tran Mills } 9334b254c5SRichard Tran Mills } 9434b254c5SRichard Tran Mills /* Set up a training data matrix that is the identity. 9534b254c5SRichard Tran Mills * We do this because this gives us a special case in which we can analytically determine what the regression 9634b254c5SRichard Tran Mills * coefficients should be for ordinary least squares, LASSO (L1 regularized), and ridge (L2 regularized) regression. 9734b254c5SRichard Tran Mills * See details in section 6.2 of James et al.'s An Introduction to Statistical Learning (ISLR), in the subsection 9834b254c5SRichard Tran Mills * titled "A Simple Special Case for Ridge Regression and the Lasso". 9934b254c5SRichard Tran Mills * Note that the coefficients we generate with ridge regression (-regressor_linear_type ridge -regressor_regularizer_weight <lambda>, or, equivalently, 10034b254c5SRichard Tran Mills * -tao_brgn_regularization_type l2pure -tao_brgn_regularizer_weight <lambda>) match those of the ISLR formula exactly. 10134b254c5SRichard Tran Mills * For LASSO it does not match the ISLR formula: where they use lambda/2, we need to use lambda. 10234b254c5SRichard Tran Mills * It also doesn't match what Scikit-learn does; in that case their lambda is 1/n_samples of our lambda. Apparently everyone is scaling 10334b254c5SRichard Tran Mills * their loss function by a different value, hence the need to change what "lambda" is. But it's clear that ISLR, Scikit-learn, and we 10434b254c5SRichard Tran Mills * are basically doing the same thing otherwise. */ 10534b254c5SRichard Tran Mills PetscCall(VecAssemblyBegin(ctx->y)); 10634b254c5SRichard Tran Mills PetscCall(VecAssemblyEnd(ctx->y)); 10734b254c5SRichard Tran Mills PetscCall(MatAssemblyBegin(ctx->X, MAT_FINAL_ASSEMBLY)); 10834b254c5SRichard Tran Mills PetscCall(MatAssemblyEnd(ctx->X, MAT_FINAL_ASSEMBLY)); 10934b254c5SRichard Tran Mills /* Center the target vector we will train with. */ 11034b254c5SRichard Tran Mills PetscCall(VecMean(ctx->y, &mean)); 11134b254c5SRichard Tran Mills PetscCall(VecShift(ctx->y, -1.0 * mean)); 11234b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 11334b254c5SRichard Tran Mills } 11434b254c5SRichard Tran Mills 11534b254c5SRichard Tran Mills static PetscErrorCode ConfigureContext(AppCtx ctx) 11634b254c5SRichard Tran Mills { 11734b254c5SRichard Tran Mills PetscFunctionBegin; 11834b254c5SRichard Tran Mills ctx->flg_string = PETSC_FALSE; 11934b254c5SRichard Tran Mills ctx->flg_ascii = PETSC_FALSE; 12034b254c5SRichard Tran Mills ctx->flg_view_sol = PETSC_FALSE; 12134b254c5SRichard Tran Mills ctx->test_prefix = PETSC_FALSE; 12234b254c5SRichard Tran Mills ctx->N = 10; 12334b254c5SRichard Tran Mills PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Options for PetscRegressor ex3:", ""); 12434b254c5SRichard Tran Mills PetscCall(PetscOptionsInt("-N", "Dimension of the N x N data matrix", "ex3.c", ctx->N, &ctx->N, NULL)); 12534b254c5SRichard Tran Mills PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_string_viewer", &ctx->flg_string, NULL)); 12634b254c5SRichard Tran Mills PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_ascii_viewer", &ctx->flg_ascii, NULL)); 12734b254c5SRichard Tran Mills PetscCall(PetscOptionsGetBool(NULL, NULL, "-view_sols", &ctx->flg_view_sol, NULL)); 12834b254c5SRichard Tran Mills PetscCall(PetscOptionsGetBool(NULL, NULL, "-test_prefix", &ctx->test_prefix, NULL)); 12934b254c5SRichard Tran Mills PetscOptionsEnd(); 13034b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 13134b254c5SRichard Tran Mills } 13234b254c5SRichard Tran Mills 13334b254c5SRichard Tran Mills int main(int argc, char **args) 13434b254c5SRichard Tran Mills { 13534b254c5SRichard Tran Mills AppCtx ctx; 13634b254c5SRichard Tran Mills PetscRegressor regressor; 13734b254c5SRichard Tran Mills PetscScalar intercept; 13834b254c5SRichard Tran Mills 13934b254c5SRichard Tran Mills /* Initialize PETSc */ 14034b254c5SRichard Tran Mills PetscCall(PetscInitialize(&argc, &args, (char *)0, help)); 14134b254c5SRichard Tran Mills 14234b254c5SRichard Tran Mills /* Initialize problem parameters and data */ 14334b254c5SRichard Tran Mills PetscCall(PetscNew(&ctx)); 14434b254c5SRichard Tran Mills PetscCall(ConfigureContext(ctx)); 14534b254c5SRichard Tran Mills PetscCall(CreateData(ctx)); 14634b254c5SRichard Tran Mills 14734b254c5SRichard Tran Mills /* Create Regressor solver with desired type and options */ 14834b254c5SRichard Tran Mills PetscCall(PetscRegressorCreate(PETSC_COMM_WORLD, ®ressor)); 14934b254c5SRichard Tran Mills PetscCall(PetscRegressorSetType(regressor, PETSCREGRESSORLINEAR)); 15034b254c5SRichard Tran Mills PetscCall(PetscRegressorLinearSetType(regressor, REGRESSOR_LINEAR_OLS)); 15134b254c5SRichard Tran Mills PetscCall(PetscRegressorLinearSetFitIntercept(regressor, PETSC_FALSE)); 15234b254c5SRichard Tran Mills /* Testing prefix functions for Regressor */ 15334b254c5SRichard Tran Mills PetscCall(TestPrefixRegressor(regressor, ctx)); 15434b254c5SRichard Tran Mills /* Check for command line options */ 15534b254c5SRichard Tran Mills PetscCall(PetscRegressorSetFromOptions(regressor)); 15634b254c5SRichard Tran Mills /* Fit the regressor */ 15734b254c5SRichard Tran Mills PetscCall(PetscRegressorFit(regressor, ctx->X, ctx->y)); 15834b254c5SRichard Tran Mills /* Predict data with fitted regressor */ 15934b254c5SRichard Tran Mills PetscCall(PetscRegressorPredict(regressor, ctx->X, ctx->y_predicted)); 16034b254c5SRichard Tran Mills /* Get other desired output data */ 16134b254c5SRichard Tran Mills PetscCall(PetscRegressorLinearGetIntercept(regressor, &intercept)); 16234b254c5SRichard Tran Mills PetscCall(PetscRegressorLinearGetCoefficients(regressor, &ctx->coefficients)); 16334b254c5SRichard Tran Mills 16434b254c5SRichard Tran Mills /* Testing Views, and GetTypes */ 16534b254c5SRichard Tran Mills PetscCall(TestRegressorViews(regressor, ctx)); 16634b254c5SRichard Tran Mills PetscCall(PetscRegressorDestroy(®ressor)); 16734b254c5SRichard Tran Mills PetscCall(DestroyCtx(&ctx)); 16834b254c5SRichard Tran Mills PetscCall(PetscFinalize()); 16934b254c5SRichard Tran Mills return 0; 17034b254c5SRichard Tran Mills } 17134b254c5SRichard Tran Mills 17234b254c5SRichard Tran Mills /*TEST 17334b254c5SRichard Tran Mills 17434b254c5SRichard Tran Mills build: 17534b254c5SRichard Tran Mills requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES) 17634b254c5SRichard Tran Mills 17734b254c5SRichard Tran Mills test: 17834b254c5SRichard Tran Mills suffix: prefix_tao 17934b254c5SRichard Tran Mills args: -sys1_sys2_regressor_view -test_prefix 18034b254c5SRichard Tran Mills 18134b254c5SRichard Tran Mills test: 18234b254c5SRichard Tran Mills suffix: prefix_ksp 183*789736e1SBarry Smith args: -sys1_sys2_regressor_view -test_prefix -sys1_sys2_regressor_linear_use_ksp -sys1_sys2_regressor_linear_ksp_monitor 184*789736e1SBarry Smith 185*789736e1SBarry Smith test: 186*789736e1SBarry Smith suffix: prefix_ksp_cholesky 187*789736e1SBarry Smith args: -sys1_sys2_regressor_view -test_prefix -sys1_sys2_regressor_linear_use_ksp -sys1_sys2_regressor_linear_pc_type cholesky 188*789736e1SBarry Smith TODO: Could not locate a solver type for factorization type CHOLESKY and matrix type normal 189*789736e1SBarry Smith 190*789736e1SBarry Smith test: 191*789736e1SBarry Smith suffix: prefix_ksp_suitesparse 192*789736e1SBarry Smith requires: suitesparse 193*789736e1SBarry Smith args: -sys1_sys2_regressor_view -test_prefix -sys1_sys2_regressor_linear_use_ksp -sys1_sys2_regressor_linear_pc_type qr -sys1_sys2_regressor_linear_pc_factor_mat_solver_type spqr -sys1_sys2_regressor_linear_ksp_monitor 19434b254c5SRichard Tran Mills 19534b254c5SRichard Tran Mills test: 19634b254c5SRichard Tran Mills suffix: asciiview 19734b254c5SRichard Tran Mills args: -test_ascii_viewer 19834b254c5SRichard Tran Mills 19934b254c5SRichard Tran Mills test: 20034b254c5SRichard Tran Mills suffix: stringview 20134b254c5SRichard Tran Mills args: -test_string_viewer 20234b254c5SRichard Tran Mills 20334b254c5SRichard Tran Mills test: 20434b254c5SRichard Tran Mills suffix: ksp_intercept 20534b254c5SRichard Tran Mills args: -regressor_linear_use_ksp -regressor_linear_fit_intercept -regressor_view 20634b254c5SRichard Tran Mills 20734b254c5SRichard Tran Mills test: 20834b254c5SRichard Tran Mills suffix: ksp_no_intercept 20934b254c5SRichard Tran Mills args: -regressor_linear_use_ksp -regressor_view 21034b254c5SRichard Tran Mills 21134b254c5SRichard Tran Mills test: 21234b254c5SRichard Tran Mills suffix: lasso_1 21334b254c5SRichard Tran Mills nsize: 1 21434b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type lasso -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols 21534b254c5SRichard Tran Mills 21634b254c5SRichard Tran Mills test: 21734b254c5SRichard Tran Mills suffix: lasso_2 21834b254c5SRichard Tran Mills nsize: 2 21934b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type lasso -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols 22034b254c5SRichard Tran Mills 22134b254c5SRichard Tran Mills test: 22234b254c5SRichard Tran Mills suffix: ridge_1 22334b254c5SRichard Tran Mills nsize: 1 22434b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type ridge -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols 22534b254c5SRichard Tran Mills 22634b254c5SRichard Tran Mills test: 22734b254c5SRichard Tran Mills suffix: ridge_2 22834b254c5SRichard Tran Mills nsize: 2 22934b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type ridge -regressor_regularizer_weight 2 -regressor_linear_fit_intercept -view_sols 23034b254c5SRichard Tran Mills 23134b254c5SRichard Tran Mills test: 23234b254c5SRichard Tran Mills suffix: ols_1 23334b254c5SRichard Tran Mills nsize: 1 23434b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type ols -regressor_linear_fit_intercept -view_sols 23534b254c5SRichard Tran Mills 23634b254c5SRichard Tran Mills test: 23734b254c5SRichard Tran Mills suffix: ols_2 23834b254c5SRichard Tran Mills nsize: 2 23934b254c5SRichard Tran Mills args: -regressor_type linear -regressor_linear_type ols -regressor_linear_fit_intercept -view_sols 24034b254c5SRichard Tran Mills 24134b254c5SRichard Tran Mills TEST*/ 242