xref: /petsc/src/ml/regressor/impls/linear/linear.c (revision 540f39e10888403b6af8026dfa429a04c272f2c1)
134b254c5SRichard Tran Mills #include <../src/ml/regressor/impls/linear/linearimpl.h> /*I "petscregressor.h" I*/
234b254c5SRichard Tran Mills #include <../src/tao/leastsquares/impls/brgn/brgn.h>     /*I "petsctao.h" I*/
334b254c5SRichard Tran Mills 
434b254c5SRichard Tran Mills const char *const PetscRegressorLinearTypes[] = {"ols", "lasso", "ridge", "RegressorLinearType", "REGRESSOR_LINEAR_", NULL};
534b254c5SRichard Tran Mills 
634b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearSetFitIntercept_Linear(PetscRegressor regressor, PetscBool flg)
734b254c5SRichard Tran Mills {
834b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
934b254c5SRichard Tran Mills 
1034b254c5SRichard Tran Mills   PetscFunctionBegin;
1134b254c5SRichard Tran Mills   linear->fit_intercept = flg;
1234b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
1334b254c5SRichard Tran Mills }
1434b254c5SRichard Tran Mills 
1534b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearSetType_Linear(PetscRegressor regressor, PetscRegressorLinearType type)
1634b254c5SRichard Tran Mills {
1734b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
1834b254c5SRichard Tran Mills 
1934b254c5SRichard Tran Mills   PetscFunctionBegin;
2034b254c5SRichard Tran Mills   linear->type = type;
2134b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
2234b254c5SRichard Tran Mills }
2334b254c5SRichard Tran Mills 
2434b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetType_Linear(PetscRegressor regressor, PetscRegressorLinearType *type)
2534b254c5SRichard Tran Mills {
2634b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
2734b254c5SRichard Tran Mills 
2834b254c5SRichard Tran Mills   PetscFunctionBegin;
2934b254c5SRichard Tran Mills   *type = linear->type;
3034b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
3134b254c5SRichard Tran Mills }
3234b254c5SRichard Tran Mills 
3334b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetIntercept_Linear(PetscRegressor regressor, PetscScalar *intercept)
3434b254c5SRichard Tran Mills {
3534b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
3634b254c5SRichard Tran Mills 
3734b254c5SRichard Tran Mills   PetscFunctionBegin;
3834b254c5SRichard Tran Mills   *intercept = linear->intercept;
3934b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
4034b254c5SRichard Tran Mills }
4134b254c5SRichard Tran Mills 
4234b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetCoefficients_Linear(PetscRegressor regressor, Vec *coefficients)
4334b254c5SRichard Tran Mills {
4434b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
4534b254c5SRichard Tran Mills 
4634b254c5SRichard Tran Mills   PetscFunctionBegin;
4734b254c5SRichard Tran Mills   *coefficients = linear->coefficients;
4834b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
4934b254c5SRichard Tran Mills }
5034b254c5SRichard Tran Mills 
5134b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetKSP_Linear(PetscRegressor regressor, KSP *ksp)
5234b254c5SRichard Tran Mills {
5334b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
5434b254c5SRichard Tran Mills 
5534b254c5SRichard Tran Mills   PetscFunctionBegin;
5634b254c5SRichard Tran Mills   if (!linear->ksp) {
5734b254c5SRichard Tran Mills     PetscCall(KSPCreate(PetscObjectComm((PetscObject)regressor), &linear->ksp));
5834b254c5SRichard Tran Mills     PetscCall(PetscObjectIncrementTabLevel((PetscObject)linear->ksp, (PetscObject)regressor, 1));
5934b254c5SRichard Tran Mills     PetscCall(PetscObjectSetOptions((PetscObject)linear->ksp, ((PetscObject)regressor)->options));
6034b254c5SRichard Tran Mills   }
6134b254c5SRichard Tran Mills   *ksp = linear->ksp;
6234b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
6334b254c5SRichard Tran Mills }
6434b254c5SRichard Tran Mills 
6534b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearSetUseKSP_Linear(PetscRegressor regressor, PetscBool flg)
6634b254c5SRichard Tran Mills {
6734b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
6834b254c5SRichard Tran Mills 
6934b254c5SRichard Tran Mills   PetscFunctionBegin;
7034b254c5SRichard Tran Mills   linear->use_ksp = flg;
7134b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
7234b254c5SRichard Tran Mills }
7334b254c5SRichard Tran Mills 
7434b254c5SRichard Tran Mills static PetscErrorCode EvaluateResidual(Tao tao, Vec x, Vec f, void *ptr)
7534b254c5SRichard Tran Mills {
7634b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)ptr;
7734b254c5SRichard Tran Mills 
7834b254c5SRichard Tran Mills   PetscFunctionBegin;
7934b254c5SRichard Tran Mills   /* Evaluate f = A * x - b */
8034b254c5SRichard Tran Mills   PetscCall(MatMult(linear->X, x, f));
8134b254c5SRichard Tran Mills   PetscCall(VecAXPY(f, -1.0, linear->rhs));
8234b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
8334b254c5SRichard Tran Mills }
8434b254c5SRichard Tran Mills 
8534b254c5SRichard Tran Mills static PetscErrorCode EvaluateJacobian(Tao tao, Vec x, Mat J, Mat Jpre, void *ptr)
8634b254c5SRichard Tran Mills {
87*540f39e1SHansol Suh   /* The TAOBRGN API expects us to pass an EvaluateJacobian() routine to it, but in this case it is a dummy function.
88*540f39e1SHansol Suh      Denoting our data matrix as X, for linear least squares J[m][n] = df[m]/dx[n] = X[m][n]. */
8934b254c5SRichard Tran Mills   PetscFunctionBegin;
9034b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
9134b254c5SRichard Tran Mills }
9234b254c5SRichard Tran Mills 
9334b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorSetUp_Linear(PetscRegressor regressor)
9434b254c5SRichard Tran Mills {
9534b254c5SRichard Tran Mills   PetscInt               M, N;
9634b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
9734b254c5SRichard Tran Mills   KSP                    ksp;
9834b254c5SRichard Tran Mills   Tao                    tao;
9934b254c5SRichard Tran Mills 
10034b254c5SRichard Tran Mills   PetscFunctionBegin;
10134b254c5SRichard Tran Mills   PetscCall(MatGetSize(regressor->training, &M, &N));
10234b254c5SRichard Tran Mills 
10334b254c5SRichard Tran Mills   if (linear->fit_intercept) {
10434b254c5SRichard Tran Mills     /* If we are fitting the intercept, we need to make A a composite matrix using MATCENTERING to preserve sparsity.
10534b254c5SRichard Tran Mills      * Though there might be some cases we don't want to do this for, depending on what kind of matrix is passed in. (Probably bad idea for dense?)
10634b254c5SRichard Tran Mills      * We will also need to ensure that the right-hand side passed to the KSP is also mean-centered, since we
10734b254c5SRichard Tran Mills      * intend to compute the intercept separately from regression coefficients (that is, we will not be adding a
10834b254c5SRichard Tran Mills      * column of all 1s to our design matrix). */
10934b254c5SRichard Tran Mills     PetscCall(MatCreateCentering(PetscObjectComm((PetscObject)regressor), PETSC_DECIDE, M, &linear->C));
11034b254c5SRichard Tran Mills     PetscCall(MatCreate(PetscObjectComm((PetscObject)regressor), &linear->X));
11134b254c5SRichard Tran Mills     PetscCall(MatSetSizes(linear->X, PETSC_DECIDE, PETSC_DECIDE, M, N));
11234b254c5SRichard Tran Mills     PetscCall(MatSetType(linear->X, MATCOMPOSITE));
11334b254c5SRichard Tran Mills     PetscCall(MatCompositeSetType(linear->X, MAT_COMPOSITE_MULTIPLICATIVE));
11434b254c5SRichard Tran Mills     PetscCall(MatCompositeAddMat(linear->X, regressor->training));
11534b254c5SRichard Tran Mills     PetscCall(MatCompositeAddMat(linear->X, linear->C));
11634b254c5SRichard Tran Mills     PetscCall(VecDuplicate(regressor->target, &linear->rhs));
11734b254c5SRichard Tran Mills     PetscCall(MatMult(linear->C, regressor->target, linear->rhs));
11834b254c5SRichard Tran Mills   } else {
11934b254c5SRichard Tran Mills     // When not fitting intercept, we assume that the input data are already centered.
12034b254c5SRichard Tran Mills     linear->X   = regressor->training;
12134b254c5SRichard Tran Mills     linear->rhs = regressor->target;
12234b254c5SRichard Tran Mills 
12334b254c5SRichard Tran Mills     PetscCall(PetscObjectReference((PetscObject)linear->X));
12434b254c5SRichard Tran Mills     PetscCall(PetscObjectReference((PetscObject)linear->rhs));
12534b254c5SRichard Tran Mills   }
12634b254c5SRichard Tran Mills 
12734b254c5SRichard Tran Mills   if (linear->coefficients) PetscCall(VecDestroy(&linear->coefficients));
12834b254c5SRichard Tran Mills 
12934b254c5SRichard Tran Mills   if (linear->use_ksp) {
13034b254c5SRichard Tran Mills     PetscCheck(linear->type == REGRESSOR_LINEAR_OLS, PetscObjectComm((PetscObject)regressor), PETSC_ERR_ARG_WRONGSTATE, "KSP can be used to fit a linear regressor only when its type is OLS");
13134b254c5SRichard Tran Mills 
13234b254c5SRichard Tran Mills     if (!linear->ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
13334b254c5SRichard Tran Mills     ksp = linear->ksp;
13434b254c5SRichard Tran Mills 
13534b254c5SRichard Tran Mills     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, NULL));
13634b254c5SRichard Tran Mills     /* Set up the KSP to solve the least squares problem (without solving for intercept, as this is done separately) using KSPLSQR. */
13734b254c5SRichard Tran Mills     PetscCall(MatCreateNormal(linear->X, &linear->XtX));
13834b254c5SRichard Tran Mills     PetscCall(KSPSetType(ksp, KSPLSQR));
13934b254c5SRichard Tran Mills     PetscCall(KSPSetOperators(ksp, linear->X, linear->XtX));
14034b254c5SRichard Tran Mills     PetscCall(KSPSetOptionsPrefix(ksp, ((PetscObject)regressor)->prefix));
14134b254c5SRichard Tran Mills     PetscCall(KSPAppendOptionsPrefix(ksp, "regressor_linear_"));
14234b254c5SRichard Tran Mills     PetscCall(KSPSetFromOptions(ksp));
14334b254c5SRichard Tran Mills   } else {
14434b254c5SRichard Tran Mills     /* Note: Currently implementation creates TAO inside of implementations.
14534b254c5SRichard Tran Mills       * Thus, all the prefix jobs are done inside implementations, not in interface */
14634b254c5SRichard Tran Mills     const char *prefix;
14734b254c5SRichard Tran Mills 
14834b254c5SRichard Tran Mills     if (!regressor->tao) PetscCall(PetscRegressorGetTao(regressor, &tao));
14934b254c5SRichard Tran Mills 
15034b254c5SRichard Tran Mills     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, &linear->residual));
15134b254c5SRichard Tran Mills     /* Set up the TAO object to solve the (regularized) least squares problem (without solving for intercept, which is done separately) using TAOBRGN. */
15234b254c5SRichard Tran Mills     PetscCall(TaoSetType(tao, TAOBRGN));
15334b254c5SRichard Tran Mills     PetscCall(TaoSetSolution(tao, linear->coefficients));
15434b254c5SRichard Tran Mills     PetscCall(TaoSetResidualRoutine(tao, linear->residual, EvaluateResidual, linear));
15534b254c5SRichard Tran Mills     PetscCall(TaoSetJacobianResidualRoutine(tao, linear->X, linear->X, EvaluateJacobian, linear));
15634b254c5SRichard Tran Mills     if (!linear->use_ksp) PetscCall(TaoBRGNSetRegularizerWeight(tao, regressor->regularizer_weight));
15734b254c5SRichard Tran Mills     // Set the regularization type and weight for the BRGN as linear->type dictates:
15834b254c5SRichard Tran Mills     // TODO BRGN needs to be BRGNSetRegularizationType
15934b254c5SRichard Tran Mills     // PetscOptionsSetValue no longer works due to functioning prefix system
16034b254c5SRichard Tran Mills     PetscCall(PetscRegressorGetOptionsPrefix(regressor, &prefix));
16134b254c5SRichard Tran Mills     PetscCall(TaoSetOptionsPrefix(regressor->tao, prefix));
16234b254c5SRichard Tran Mills     PetscCall(TaoAppendOptionsPrefix(tao, "regressor_linear_"));
16334b254c5SRichard Tran Mills     {
16434b254c5SRichard Tran Mills       TAO_BRGN *gn = (TAO_BRGN *)regressor->tao->data;
16534b254c5SRichard Tran Mills 
16634b254c5SRichard Tran Mills       switch (linear->type) {
16734b254c5SRichard Tran Mills       case REGRESSOR_LINEAR_OLS:
16834b254c5SRichard Tran Mills         regressor->regularizer_weight = 0.0; // OLS, by definition, uses a regularizer weight of 0
16934b254c5SRichard Tran Mills         break;
17034b254c5SRichard Tran Mills       case REGRESSOR_LINEAR_LASSO:
17134b254c5SRichard Tran Mills         gn->reg_type = BRGN_REGULARIZATION_L1DICT;
17234b254c5SRichard Tran Mills         break;
17334b254c5SRichard Tran Mills       case REGRESSOR_LINEAR_RIDGE:
17434b254c5SRichard Tran Mills         gn->reg_type = BRGN_REGULARIZATION_L2PURE;
17534b254c5SRichard Tran Mills         break;
17634b254c5SRichard Tran Mills       default:
17734b254c5SRichard Tran Mills         break;
17834b254c5SRichard Tran Mills       }
17934b254c5SRichard Tran Mills     }
18034b254c5SRichard Tran Mills     PetscCall(TaoSetFromOptions(tao));
18134b254c5SRichard Tran Mills   }
18234b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
18334b254c5SRichard Tran Mills }
18434b254c5SRichard Tran Mills 
18534b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorReset_Linear(PetscRegressor regressor)
18634b254c5SRichard Tran Mills {
18734b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
18834b254c5SRichard Tran Mills 
18934b254c5SRichard Tran Mills   PetscFunctionBegin;
19034b254c5SRichard Tran Mills   /* Destroy the PETSc objects associated with the linear regressor implementation. */
19134b254c5SRichard Tran Mills   linear->ksp_its     = 0;
19234b254c5SRichard Tran Mills   linear->ksp_tot_its = 0;
19334b254c5SRichard Tran Mills 
19434b254c5SRichard Tran Mills   PetscCall(MatDestroy(&linear->X));
19534b254c5SRichard Tran Mills   PetscCall(MatDestroy(&linear->XtX));
19634b254c5SRichard Tran Mills   PetscCall(MatDestroy(&linear->C));
19734b254c5SRichard Tran Mills   PetscCall(KSPDestroy(&linear->ksp));
19834b254c5SRichard Tran Mills   PetscCall(VecDestroy(&linear->coefficients));
19934b254c5SRichard Tran Mills   PetscCall(VecDestroy(&linear->rhs));
20034b254c5SRichard Tran Mills   PetscCall(VecDestroy(&linear->residual));
20134b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
20234b254c5SRichard Tran Mills }
20334b254c5SRichard Tran Mills 
20434b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorDestroy_Linear(PetscRegressor regressor)
20534b254c5SRichard Tran Mills {
20634b254c5SRichard Tran Mills   PetscFunctionBegin;
20734b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", NULL));
20834b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", NULL));
20934b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", NULL));
21034b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", NULL));
21134b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", NULL));
21234b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", NULL));
21334b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", NULL));
21434b254c5SRichard Tran Mills   PetscCall(PetscRegressorReset_Linear(regressor));
21534b254c5SRichard Tran Mills   PetscCall(PetscFree(regressor->data));
21634b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
21734b254c5SRichard Tran Mills }
21834b254c5SRichard Tran Mills 
21934b254c5SRichard Tran Mills /*@
22034b254c5SRichard Tran Mills   PetscRegressorLinearSetFitIntercept - Set a flag to indicate that the intercept (also known as the "bias" or "offset") should
22134b254c5SRichard Tran Mills   be calculated; data are assumed to be mean-centered if false.
22234b254c5SRichard Tran Mills 
22334b254c5SRichard Tran Mills   Logically Collective
22434b254c5SRichard Tran Mills 
22534b254c5SRichard Tran Mills   Input Parameters:
22634b254c5SRichard Tran Mills + regressor - the `PetscRegressor` context
22734b254c5SRichard Tran Mills - flg       - `PETSC_TRUE` to calculate the intercept, `PETSC_FALSE` to assume mean-centered data (default is `PETSC_TRUE`)
22834b254c5SRichard Tran Mills 
22934b254c5SRichard Tran Mills   Level: intermediate
23034b254c5SRichard Tran Mills 
23134b254c5SRichard Tran Mills   Options Database Key:
23234b254c5SRichard Tran Mills . regressor_linear_fit_intercept <true,false> - fit the intercept
23334b254c5SRichard Tran Mills 
23434b254c5SRichard Tran Mills   Note:
23534b254c5SRichard Tran Mills   If the user indicates that the intercept should not be calculated, the intercept will be set to zero.
23634b254c5SRichard Tran Mills 
23734b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PetscRegressorFit()`
23834b254c5SRichard Tran Mills @*/
23934b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearSetFitIntercept(PetscRegressor regressor, PetscBool flg)
24034b254c5SRichard Tran Mills {
24134b254c5SRichard Tran Mills   PetscFunctionBegin;
24234b254c5SRichard Tran Mills   /* TODO: Add companion PetscRegressorLinearGetFitIntercept(), and put it in the .seealso: */
24334b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
24434b254c5SRichard Tran Mills   PetscValidLogicalCollectiveBool(regressor, flg, 2);
24534b254c5SRichard Tran Mills   PetscTryMethod(regressor, "PetscRegressorLinearSetFitIntercept_C", (PetscRegressor, PetscBool), (regressor, flg));
24634b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
24734b254c5SRichard Tran Mills }
24834b254c5SRichard Tran Mills 
24934b254c5SRichard Tran Mills /*@
25034b254c5SRichard Tran Mills   PetscRegressorLinearSetUseKSP - Set a flag to indicate that a `KSP` object, instead of a `Tao` one, should be used
25134b254c5SRichard Tran Mills   to fit the regressor
25234b254c5SRichard Tran Mills 
25334b254c5SRichard Tran Mills   Logically Collective
25434b254c5SRichard Tran Mills 
25534b254c5SRichard Tran Mills   Input Parameters:
25634b254c5SRichard Tran Mills + regressor - the `PetscRegressor` context
25734b254c5SRichard Tran Mills - flg       - `PETSC_TRUE` to use a `KSP`, `PETSC_FALSE` to use a `Tao` object (default is false)
25834b254c5SRichard Tran Mills 
25934b254c5SRichard Tran Mills   Options Database Key:
26034b254c5SRichard Tran Mills . regressor_linear_use_ksp <true,false> - use `KSP`
26134b254c5SRichard Tran Mills 
26234b254c5SRichard Tran Mills   Level: intermediate
26334b254c5SRichard Tran Mills 
26434b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PetscRegressorLinearGetKSP()`
26534b254c5SRichard Tran Mills @*/
26634b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearSetUseKSP(PetscRegressor regressor, PetscBool flg)
26734b254c5SRichard Tran Mills {
26834b254c5SRichard Tran Mills   PetscFunctionBegin;
26934b254c5SRichard Tran Mills   /* TODO: Add companion PetscRegressorLinearGetUseKSP(), and put it in the .seealso: */
27034b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
27134b254c5SRichard Tran Mills   PetscValidLogicalCollectiveBool(regressor, flg, 2);
27234b254c5SRichard Tran Mills   PetscTryMethod(regressor, "PetscRegressorLinearSetUseKSP_C", (PetscRegressor, PetscBool), (regressor, flg));
27334b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
27434b254c5SRichard Tran Mills }
27534b254c5SRichard Tran Mills 
27634b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorSetFromOptions_Linear(PetscRegressor regressor, PetscOptionItems PetscOptionsObject)
27734b254c5SRichard Tran Mills {
27834b254c5SRichard Tran Mills   PetscBool              set, flg = PETSC_FALSE;
27934b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
28034b254c5SRichard Tran Mills 
28134b254c5SRichard Tran Mills   PetscFunctionBegin;
28234b254c5SRichard Tran Mills   PetscOptionsHeadBegin(PetscOptionsObject, "PetscRegressor options for linear regressors");
28334b254c5SRichard Tran Mills   PetscCall(PetscOptionsBool("-regressor_linear_fit_intercept", "Calculate intercept for linear model", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
28434b254c5SRichard Tran Mills   if (set) PetscCall(PetscRegressorLinearSetFitIntercept(regressor, flg));
28534b254c5SRichard Tran Mills   PetscCall(PetscOptionsBool("-regressor_linear_use_ksp", "Use KSP instead of TAO for linear model fitting problem", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
28634b254c5SRichard Tran Mills   if (set) PetscCall(PetscRegressorLinearSetUseKSP(regressor, flg));
28734b254c5SRichard Tran Mills   PetscCall(PetscOptionsEnum("-regressor_linear_type", "Linear regression method", "PetscRegressorLinearTypes", PetscRegressorLinearTypes, (PetscEnum)linear->type, (PetscEnum *)&linear->type, NULL));
28834b254c5SRichard Tran Mills   PetscOptionsHeadEnd();
28934b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
29034b254c5SRichard Tran Mills }
29134b254c5SRichard Tran Mills 
29234b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorView_Linear(PetscRegressor regressor, PetscViewer viewer)
29334b254c5SRichard Tran Mills {
29434b254c5SRichard Tran Mills   PetscBool              isascii;
29534b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
29634b254c5SRichard Tran Mills 
29734b254c5SRichard Tran Mills   PetscFunctionBegin;
29834b254c5SRichard Tran Mills   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
29934b254c5SRichard Tran Mills   if (isascii) {
30034b254c5SRichard Tran Mills     PetscCall(PetscViewerASCIIPushTab(viewer));
30134b254c5SRichard Tran Mills     PetscCall(PetscViewerASCIIPrintf(viewer, "PetscRegressor Linear Type: %s\n", PetscRegressorLinearTypes[linear->type]));
30234b254c5SRichard Tran Mills     if (linear->ksp) {
30334b254c5SRichard Tran Mills       PetscCall(KSPView(linear->ksp, viewer));
30434b254c5SRichard Tran Mills       PetscCall(PetscViewerASCIIPrintf(viewer, "total KSP iterations: %" PetscInt_FMT "\n", linear->ksp_tot_its));
30534b254c5SRichard Tran Mills     }
30634b254c5SRichard Tran Mills     if (linear->fit_intercept) PetscCall(PetscViewerASCIIPrintf(viewer, "Intercept=%g\n", (double)linear->intercept));
30734b254c5SRichard Tran Mills     PetscCall(PetscViewerASCIIPopTab(viewer));
30834b254c5SRichard Tran Mills   }
30934b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
31034b254c5SRichard Tran Mills }
31134b254c5SRichard Tran Mills 
31234b254c5SRichard Tran Mills /*@
31334b254c5SRichard Tran Mills   PetscRegressorLinearGetKSP - Returns the `KSP` context for a `PETSCREGRESSORLINEAR` object.
31434b254c5SRichard Tran Mills 
31534b254c5SRichard Tran Mills   Not Collective, but if the `PetscRegressor` is parallel, then the `KSP` object is parallel
31634b254c5SRichard Tran Mills 
31734b254c5SRichard Tran Mills   Input Parameter:
31834b254c5SRichard Tran Mills . regressor - the `PetscRegressor` context
31934b254c5SRichard Tran Mills 
32034b254c5SRichard Tran Mills   Output Parameter:
32134b254c5SRichard Tran Mills . ksp - the `KSP` context
32234b254c5SRichard Tran Mills 
32334b254c5SRichard Tran Mills   Level: beginner
32434b254c5SRichard Tran Mills 
32534b254c5SRichard Tran Mills   Note:
32634b254c5SRichard Tran Mills   This routine will always return a `KSP`, but, depending on the type of the linear regressor and the options that are set, the regressor may actually use a `Tao` object instead of this `KSP`.
32734b254c5SRichard Tran Mills 
32834b254c5SRichard Tran Mills .seealso: `PetscRegressorGetTao()`
32934b254c5SRichard Tran Mills @*/
33034b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearGetKSP(PetscRegressor regressor, KSP *ksp)
33134b254c5SRichard Tran Mills {
33234b254c5SRichard Tran Mills   PetscFunctionBegin;
33334b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
33434b254c5SRichard Tran Mills   PetscAssertPointer(ksp, 2);
33534b254c5SRichard Tran Mills   PetscUseMethod(regressor, "PetscRegressorLinearGetKSP_C", (PetscRegressor, KSP *), (regressor, ksp));
33634b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
33734b254c5SRichard Tran Mills }
33834b254c5SRichard Tran Mills 
33934b254c5SRichard Tran Mills /*@
34034b254c5SRichard Tran Mills   PetscRegressorLinearGetCoefficients - Get a vector of the fitted coefficients from a linear regression model
34134b254c5SRichard Tran Mills 
34234b254c5SRichard Tran Mills   Not Collective but the vector is parallel
34334b254c5SRichard Tran Mills 
34434b254c5SRichard Tran Mills   Input Parameter:
34534b254c5SRichard Tran Mills . regressor - the `PetscRegressor` context
34634b254c5SRichard Tran Mills 
34734b254c5SRichard Tran Mills   Output Parameter:
34834b254c5SRichard Tran Mills . coefficients - the vector of the coefficients
34934b254c5SRichard Tran Mills 
35034b254c5SRichard Tran Mills   Level: beginner
35134b254c5SRichard Tran Mills 
35234b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PetscRegressorLinearGetIntercept()`, `PETSCREGRESSORLINEAR`, `Vec`
35334b254c5SRichard Tran Mills @*/
35434b254c5SRichard Tran Mills PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetCoefficients(PetscRegressor regressor, Vec *coefficients)
35534b254c5SRichard Tran Mills {
35634b254c5SRichard Tran Mills   PetscFunctionBegin;
35734b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
35834b254c5SRichard Tran Mills   PetscAssertPointer(coefficients, 2);
35934b254c5SRichard Tran Mills   PetscUseMethod(regressor, "PetscRegressorLinearGetCoefficients_C", (PetscRegressor, Vec *), (regressor, coefficients));
36034b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
36134b254c5SRichard Tran Mills }
36234b254c5SRichard Tran Mills 
36334b254c5SRichard Tran Mills /*@
36434b254c5SRichard Tran Mills   PetscRegressorLinearGetIntercept - Get the intercept from a linear regression model
36534b254c5SRichard Tran Mills 
36634b254c5SRichard Tran Mills   Not Collective
36734b254c5SRichard Tran Mills 
36834b254c5SRichard Tran Mills   Input Parameter:
36934b254c5SRichard Tran Mills . regressor - the `PetscRegressor` context
37034b254c5SRichard Tran Mills 
37134b254c5SRichard Tran Mills   Output Parameter:
37234b254c5SRichard Tran Mills . intercept - the intercept
37334b254c5SRichard Tran Mills 
37434b254c5SRichard Tran Mills   Level: beginner
37534b254c5SRichard Tran Mills 
37634b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PetscRegressorLinearSetFitIntercept()`, `PetscRegressorLinearGetCoefficients()`, `PETSCREGRESSORLINEAR`
37734b254c5SRichard Tran Mills @*/
37834b254c5SRichard Tran Mills PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetIntercept(PetscRegressor regressor, PetscScalar *intercept)
37934b254c5SRichard Tran Mills {
38034b254c5SRichard Tran Mills   PetscFunctionBegin;
38134b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
38234b254c5SRichard Tran Mills   PetscAssertPointer(intercept, 2);
38334b254c5SRichard Tran Mills   PetscUseMethod(regressor, "PetscRegressorLinearGetIntercept_C", (PetscRegressor, PetscScalar *), (regressor, intercept));
38434b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
38534b254c5SRichard Tran Mills }
38634b254c5SRichard Tran Mills 
38734b254c5SRichard Tran Mills /*@C
38834b254c5SRichard Tran Mills   PetscRegressorLinearSetType - Sets the type of linear regression to be performed
38934b254c5SRichard Tran Mills 
39034b254c5SRichard Tran Mills   Logically Collective
39134b254c5SRichard Tran Mills 
39234b254c5SRichard Tran Mills   Input Parameters:
39334b254c5SRichard Tran Mills + regressor - the `PetscRegressor` context (should be of type `PETSCREGRESSORLINEAR`)
39434b254c5SRichard Tran Mills - type      - a known linear regression method
39534b254c5SRichard Tran Mills 
39634b254c5SRichard Tran Mills   Options Database Key:
39734b254c5SRichard Tran Mills . -regressor_linear_type - Sets the linear regression method; use -help for a list of available methods
39834b254c5SRichard Tran Mills    (for instance "-regressor_linear_type ols" or "-regressor_linear_type lasso")
39934b254c5SRichard Tran Mills 
40034b254c5SRichard Tran Mills   Level: intermediate
40134b254c5SRichard Tran Mills 
40234b254c5SRichard Tran Mills .seealso: `PetscRegressorLinearGetType()`, `PetscRegressorLinearType`, `PetscRegressorSetType()`
40334b254c5SRichard Tran Mills @*/
40434b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearSetType(PetscRegressor regressor, PetscRegressorLinearType type)
40534b254c5SRichard Tran Mills {
40634b254c5SRichard Tran Mills   PetscFunctionBegin;
40734b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
40834b254c5SRichard Tran Mills   PetscValidLogicalCollectiveEnum(regressor, type, 2);
40934b254c5SRichard Tran Mills   PetscTryMethod(regressor, "PetscRegressorLinearSetType_C", (PetscRegressor, PetscRegressorLinearType), (regressor, type));
41034b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
41134b254c5SRichard Tran Mills }
41234b254c5SRichard Tran Mills 
41334b254c5SRichard Tran Mills /*@
41434b254c5SRichard Tran Mills   PetscRegressorLinearGetType - Return the type for the `PETSCREGRESSORLINEAR` solver
41534b254c5SRichard Tran Mills 
41634b254c5SRichard Tran Mills   Input Parameter:
41734b254c5SRichard Tran Mills . regressor - the `PetscRegressor` solver context
41834b254c5SRichard Tran Mills 
41934b254c5SRichard Tran Mills   Output Parameter:
42034b254c5SRichard Tran Mills . type - `PETSCREGRESSORLINEAR` type
42134b254c5SRichard Tran Mills 
42234b254c5SRichard Tran Mills   Level: advanced
42334b254c5SRichard Tran Mills 
42434b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PETSCREGRESSORLINEAR`, `PetscRegressorLinearSetType()`, `PetscRegressorLinearType`
42534b254c5SRichard Tran Mills @*/
42634b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearGetType(PetscRegressor regressor, PetscRegressorLinearType *type)
42734b254c5SRichard Tran Mills {
42834b254c5SRichard Tran Mills   PetscFunctionBegin;
42934b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
43034b254c5SRichard Tran Mills   PetscAssertPointer(type, 2);
43134b254c5SRichard Tran Mills   PetscUseMethod(regressor, "PetscRegressorLinearGetType_C", (PetscRegressor, PetscRegressorLinearType *), (regressor, type));
43234b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
43334b254c5SRichard Tran Mills }
43434b254c5SRichard Tran Mills 
43534b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorFit_Linear(PetscRegressor regressor)
43634b254c5SRichard Tran Mills {
43734b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
43834b254c5SRichard Tran Mills   KSP                    ksp;
43934b254c5SRichard Tran Mills   PetscScalar            target_mean, *column_means_global, *column_means_local, column_means_dot_coefficients;
44034b254c5SRichard Tran Mills   Vec                    column_means;
44134b254c5SRichard Tran Mills   PetscInt               m, N, istart, i, kspits;
44234b254c5SRichard Tran Mills 
44334b254c5SRichard Tran Mills   PetscFunctionBegin;
44434b254c5SRichard Tran Mills   if (linear->use_ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
44534b254c5SRichard Tran Mills   ksp = linear->ksp;
44634b254c5SRichard Tran Mills 
44734b254c5SRichard Tran Mills   /* Solve the least-squares problem (previously set up in PetscRegressorSetUp_Linear()) without finding the intercept. */
44834b254c5SRichard Tran Mills   if (linear->use_ksp) {
44934b254c5SRichard Tran Mills     PetscCall(KSPSolve(ksp, linear->rhs, linear->coefficients));
45034b254c5SRichard Tran Mills     PetscCall(KSPGetIterationNumber(ksp, &kspits));
45134b254c5SRichard Tran Mills     linear->ksp_its += kspits;
45234b254c5SRichard Tran Mills     linear->ksp_tot_its += kspits;
45334b254c5SRichard Tran Mills   } else {
45434b254c5SRichard Tran Mills     PetscCall(TaoSolve(regressor->tao));
45534b254c5SRichard Tran Mills   }
45634b254c5SRichard Tran Mills 
45734b254c5SRichard Tran Mills   /* Calculate the intercept. */
45834b254c5SRichard Tran Mills   if (linear->fit_intercept) {
45934b254c5SRichard Tran Mills     PetscCall(MatGetSize(regressor->training, NULL, &N));
46034b254c5SRichard Tran Mills     PetscCall(PetscMalloc1(N, &column_means_global));
46134b254c5SRichard Tran Mills     PetscCall(VecMean(regressor->target, &target_mean));
46234b254c5SRichard Tran Mills     /* We need the means of all columns of regressor->training, placed into a Vec compatible with linear->coefficients.
46334b254c5SRichard Tran Mills      * Note the potential scalability issue: MatGetColumnMeans() computes means of ALL colummns. */
46434b254c5SRichard Tran Mills     PetscCall(MatGetColumnMeans(regressor->training, column_means_global));
46534b254c5SRichard Tran Mills     /* TODO: Calculation of the Vec and matrix column means should probably go into the SetUp phase, and also be placed
46634b254c5SRichard Tran Mills      *       into a routine that is callable from outside of PetscRegressorFit_Linear(), because we'll want to do the same
46734b254c5SRichard Tran Mills      *       thing for other models, such as ridge and LASSO regression, and should avoid code duplication.
46834b254c5SRichard Tran Mills      *       What we are calling 'target_mean' and 'column_means' should be stashed in the base linear regressor struct,
46934b254c5SRichard Tran Mills      *       and perhaps renamed to make it clear they are offsets that should be applied (though the current naming
47034b254c5SRichard Tran Mills      *       makes sense since it makes it clear where these come from.) */
47134b254c5SRichard Tran Mills     PetscCall(VecDuplicate(linear->coefficients, &column_means));
47234b254c5SRichard Tran Mills     PetscCall(VecGetLocalSize(column_means, &m));
47334b254c5SRichard Tran Mills     PetscCall(VecGetOwnershipRange(column_means, &istart, NULL));
47434b254c5SRichard Tran Mills     PetscCall(VecGetArrayWrite(column_means, &column_means_local));
47534b254c5SRichard Tran Mills     for (i = 0; i < m; i++) column_means_local[i] = column_means_global[istart + i];
47634b254c5SRichard Tran Mills     PetscCall(VecRestoreArrayWrite(column_means, &column_means_local));
47734b254c5SRichard Tran Mills     PetscCall(VecDot(column_means, linear->coefficients, &column_means_dot_coefficients));
47834b254c5SRichard Tran Mills     PetscCall(VecDestroy(&column_means));
47934b254c5SRichard Tran Mills     PetscCall(PetscFree(column_means_global));
48034b254c5SRichard Tran Mills     linear->intercept = target_mean - column_means_dot_coefficients;
48134b254c5SRichard Tran Mills   } else {
48234b254c5SRichard Tran Mills     linear->intercept = 0.0;
48334b254c5SRichard Tran Mills   }
48434b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
48534b254c5SRichard Tran Mills }
48634b254c5SRichard Tran Mills 
48734b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorPredict_Linear(PetscRegressor regressor, Mat X, Vec y)
48834b254c5SRichard Tran Mills {
48934b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
49034b254c5SRichard Tran Mills 
49134b254c5SRichard Tran Mills   PetscFunctionBegin;
49234b254c5SRichard Tran Mills   PetscCall(MatMult(X, linear->coefficients, y));
49334b254c5SRichard Tran Mills   PetscCall(VecShift(y, linear->intercept));
49434b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
49534b254c5SRichard Tran Mills }
49634b254c5SRichard Tran Mills 
49734b254c5SRichard Tran Mills /*MC
49834b254c5SRichard Tran Mills      PETSCREGRESSORLINEAR - Linear regression model (ordinary least squares or regularized variants)
49934b254c5SRichard Tran Mills 
50034b254c5SRichard Tran Mills    Options Database:
50134b254c5SRichard Tran Mills +  -regressor_linear_fit_intercept - Calculate the intercept for the linear model
50234b254c5SRichard Tran Mills -  -regressor_linear_use_ksp       - Use `KSP` instead of `Tao` for linear model fitting (non-regularized variants only)
50334b254c5SRichard Tran Mills 
50434b254c5SRichard Tran Mills    Level: beginner
50534b254c5SRichard Tran Mills 
50634b254c5SRichard Tran Mills    Note:
50734b254c5SRichard Tran Mills    This is the default regressor in `PetscRegressor`.
50834b254c5SRichard Tran Mills 
50934b254c5SRichard Tran Mills .seealso: `PetscRegressorCreate()`, `PetscRegressor`, `PetscRegressorSetType()`
51034b254c5SRichard Tran Mills M*/
51134b254c5SRichard Tran Mills PETSC_EXTERN PetscErrorCode PetscRegressorCreate_Linear(PetscRegressor regressor)
51234b254c5SRichard Tran Mills {
51334b254c5SRichard Tran Mills   PetscRegressor_Linear *linear;
51434b254c5SRichard Tran Mills 
51534b254c5SRichard Tran Mills   PetscFunctionBegin;
51634b254c5SRichard Tran Mills   PetscCall(PetscNew(&linear));
51734b254c5SRichard Tran Mills   regressor->data = (void *)linear;
51834b254c5SRichard Tran Mills 
51934b254c5SRichard Tran Mills   regressor->ops->setup          = PetscRegressorSetUp_Linear;
52034b254c5SRichard Tran Mills   regressor->ops->reset          = PetscRegressorReset_Linear;
52134b254c5SRichard Tran Mills   regressor->ops->destroy        = PetscRegressorDestroy_Linear;
52234b254c5SRichard Tran Mills   regressor->ops->setfromoptions = PetscRegressorSetFromOptions_Linear;
52334b254c5SRichard Tran Mills   regressor->ops->view           = PetscRegressorView_Linear;
52434b254c5SRichard Tran Mills   regressor->ops->fit            = PetscRegressorFit_Linear;
52534b254c5SRichard Tran Mills   regressor->ops->predict        = PetscRegressorPredict_Linear;
52634b254c5SRichard Tran Mills 
52734b254c5SRichard Tran Mills   linear->intercept     = 0.0;
52834b254c5SRichard Tran Mills   linear->fit_intercept = PETSC_TRUE;  /* Default to calculating the intercept. */
52934b254c5SRichard Tran Mills   linear->use_ksp       = PETSC_FALSE; /* Do not default to using KSP for solving the model-fitting problem (use TAO instead). */
53034b254c5SRichard Tran Mills   linear->type          = REGRESSOR_LINEAR_OLS;
53134b254c5SRichard Tran Mills   /* Above, manually set the default linear regressor type.
53234b254c5SRichard Tran Mills        We don't use PetscRegressorLinearSetType() here, because that expects the SetUp event to already have happened. */
53334b254c5SRichard Tran Mills 
53434b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", PetscRegressorLinearSetFitIntercept_Linear));
53534b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", PetscRegressorLinearSetUseKSP_Linear));
53634b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", PetscRegressorLinearGetKSP_Linear));
53734b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", PetscRegressorLinearGetCoefficients_Linear));
53834b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", PetscRegressorLinearGetIntercept_Linear));
53934b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", PetscRegressorLinearSetType_Linear));
54034b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", PetscRegressorLinearGetType_Linear));
54134b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
54234b254c5SRichard Tran Mills }
543