xref: /petsc/src/ml/regressor/impls/linear/linear.c (revision c0b7dd1945badc50da999927ed61eaf839ea6a8c)
134b254c5SRichard Tran Mills #include <../src/ml/regressor/impls/linear/linearimpl.h> /*I "petscregressor.h" I*/
234b254c5SRichard Tran Mills 
334b254c5SRichard Tran Mills const char *const PetscRegressorLinearTypes[] = {"ols", "lasso", "ridge", "RegressorLinearType", "REGRESSOR_LINEAR_", NULL};
434b254c5SRichard Tran Mills 
534b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearSetFitIntercept_Linear(PetscRegressor regressor, PetscBool flg)
634b254c5SRichard Tran Mills {
734b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
834b254c5SRichard Tran Mills 
934b254c5SRichard Tran Mills   PetscFunctionBegin;
1034b254c5SRichard Tran Mills   linear->fit_intercept = flg;
1134b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
1234b254c5SRichard Tran Mills }
1334b254c5SRichard Tran Mills 
1434b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearSetType_Linear(PetscRegressor regressor, PetscRegressorLinearType type)
1534b254c5SRichard Tran Mills {
1634b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
1734b254c5SRichard Tran Mills 
1834b254c5SRichard Tran Mills   PetscFunctionBegin;
1934b254c5SRichard Tran Mills   linear->type = type;
2034b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
2134b254c5SRichard Tran Mills }
2234b254c5SRichard Tran Mills 
2334b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetType_Linear(PetscRegressor regressor, PetscRegressorLinearType *type)
2434b254c5SRichard Tran Mills {
2534b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
2634b254c5SRichard Tran Mills 
2734b254c5SRichard Tran Mills   PetscFunctionBegin;
2834b254c5SRichard Tran Mills   *type = linear->type;
2934b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
3034b254c5SRichard Tran Mills }
3134b254c5SRichard Tran Mills 
3234b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetIntercept_Linear(PetscRegressor regressor, PetscScalar *intercept)
3334b254c5SRichard Tran Mills {
3434b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
3534b254c5SRichard Tran Mills 
3634b254c5SRichard Tran Mills   PetscFunctionBegin;
3734b254c5SRichard Tran Mills   *intercept = linear->intercept;
3834b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
3934b254c5SRichard Tran Mills }
4034b254c5SRichard Tran Mills 
4134b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetCoefficients_Linear(PetscRegressor regressor, Vec *coefficients)
4234b254c5SRichard Tran Mills {
4334b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
4434b254c5SRichard Tran Mills 
4534b254c5SRichard Tran Mills   PetscFunctionBegin;
4634b254c5SRichard Tran Mills   *coefficients = linear->coefficients;
4734b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
4834b254c5SRichard Tran Mills }
4934b254c5SRichard Tran Mills 
5034b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetKSP_Linear(PetscRegressor regressor, KSP *ksp)
5134b254c5SRichard Tran Mills {
5234b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
5334b254c5SRichard Tran Mills 
5434b254c5SRichard Tran Mills   PetscFunctionBegin;
5534b254c5SRichard Tran Mills   if (!linear->ksp) {
5634b254c5SRichard Tran Mills     PetscCall(KSPCreate(PetscObjectComm((PetscObject)regressor), &linear->ksp));
5734b254c5SRichard Tran Mills     PetscCall(PetscObjectIncrementTabLevel((PetscObject)linear->ksp, (PetscObject)regressor, 1));
5834b254c5SRichard Tran Mills     PetscCall(PetscObjectSetOptions((PetscObject)linear->ksp, ((PetscObject)regressor)->options));
5934b254c5SRichard Tran Mills   }
6034b254c5SRichard Tran Mills   *ksp = linear->ksp;
6134b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
6234b254c5SRichard Tran Mills }
6334b254c5SRichard Tran Mills 
6434b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearSetUseKSP_Linear(PetscRegressor regressor, PetscBool flg)
6534b254c5SRichard Tran Mills {
6634b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
6734b254c5SRichard Tran Mills 
6834b254c5SRichard Tran Mills   PetscFunctionBegin;
6934b254c5SRichard Tran Mills   linear->use_ksp = flg;
7034b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
7134b254c5SRichard Tran Mills }
7234b254c5SRichard Tran Mills 
7334b254c5SRichard Tran Mills static PetscErrorCode EvaluateResidual(Tao tao, Vec x, Vec f, void *ptr)
7434b254c5SRichard Tran Mills {
7534b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)ptr;
7634b254c5SRichard Tran Mills 
7734b254c5SRichard Tran Mills   PetscFunctionBegin;
7834b254c5SRichard Tran Mills   /* Evaluate f = A * x - b */
7934b254c5SRichard Tran Mills   PetscCall(MatMult(linear->X, x, f));
8034b254c5SRichard Tran Mills   PetscCall(VecAXPY(f, -1.0, linear->rhs));
8134b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
8234b254c5SRichard Tran Mills }
8334b254c5SRichard Tran Mills 
8434b254c5SRichard Tran Mills static PetscErrorCode EvaluateJacobian(Tao tao, Vec x, Mat J, Mat Jpre, void *ptr)
8534b254c5SRichard Tran Mills {
86540f39e1SHansol Suh   /* The TAOBRGN API expects us to pass an EvaluateJacobian() routine to it, but in this case it is a dummy function.
87540f39e1SHansol Suh      Denoting our data matrix as X, for linear least squares J[m][n] = df[m]/dx[n] = X[m][n]. */
8834b254c5SRichard Tran Mills   PetscFunctionBegin;
8934b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
9034b254c5SRichard Tran Mills }
9134b254c5SRichard Tran Mills 
9234b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorSetUp_Linear(PetscRegressor regressor)
9334b254c5SRichard Tran Mills {
9434b254c5SRichard Tran Mills   PetscInt               M, N;
9534b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
9634b254c5SRichard Tran Mills   KSP                    ksp;
9734b254c5SRichard Tran Mills   Tao                    tao;
9834b254c5SRichard Tran Mills 
9934b254c5SRichard Tran Mills   PetscFunctionBegin;
10034b254c5SRichard Tran Mills   PetscCall(MatGetSize(regressor->training, &M, &N));
10134b254c5SRichard Tran Mills 
10234b254c5SRichard Tran Mills   if (linear->fit_intercept) {
10334b254c5SRichard Tran Mills     /* If we are fitting the intercept, we need to make A a composite matrix using MATCENTERING to preserve sparsity.
10434b254c5SRichard Tran Mills      * Though there might be some cases we don't want to do this for, depending on what kind of matrix is passed in. (Probably bad idea for dense?)
10534b254c5SRichard Tran Mills      * We will also need to ensure that the right-hand side passed to the KSP is also mean-centered, since we
10634b254c5SRichard Tran Mills      * intend to compute the intercept separately from regression coefficients (that is, we will not be adding a
10734b254c5SRichard Tran Mills      * column of all 1s to our design matrix). */
10834b254c5SRichard Tran Mills     PetscCall(MatCreateCentering(PetscObjectComm((PetscObject)regressor), PETSC_DECIDE, M, &linear->C));
10934b254c5SRichard Tran Mills     PetscCall(MatCreate(PetscObjectComm((PetscObject)regressor), &linear->X));
11034b254c5SRichard Tran Mills     PetscCall(MatSetSizes(linear->X, PETSC_DECIDE, PETSC_DECIDE, M, N));
11134b254c5SRichard Tran Mills     PetscCall(MatSetType(linear->X, MATCOMPOSITE));
11234b254c5SRichard Tran Mills     PetscCall(MatCompositeSetType(linear->X, MAT_COMPOSITE_MULTIPLICATIVE));
11334b254c5SRichard Tran Mills     PetscCall(MatCompositeAddMat(linear->X, regressor->training));
11434b254c5SRichard Tran Mills     PetscCall(MatCompositeAddMat(linear->X, linear->C));
11534b254c5SRichard Tran Mills     PetscCall(VecDuplicate(regressor->target, &linear->rhs));
11634b254c5SRichard Tran Mills     PetscCall(MatMult(linear->C, regressor->target, linear->rhs));
11734b254c5SRichard Tran Mills   } else {
11834b254c5SRichard Tran Mills     // When not fitting intercept, we assume that the input data are already centered.
11934b254c5SRichard Tran Mills     linear->X   = regressor->training;
12034b254c5SRichard Tran Mills     linear->rhs = regressor->target;
12134b254c5SRichard Tran Mills 
12234b254c5SRichard Tran Mills     PetscCall(PetscObjectReference((PetscObject)linear->X));
12334b254c5SRichard Tran Mills     PetscCall(PetscObjectReference((PetscObject)linear->rhs));
12434b254c5SRichard Tran Mills   }
12534b254c5SRichard Tran Mills 
12634b254c5SRichard Tran Mills   if (linear->coefficients) PetscCall(VecDestroy(&linear->coefficients));
12734b254c5SRichard Tran Mills 
12834b254c5SRichard Tran Mills   if (linear->use_ksp) {
12934b254c5SRichard Tran Mills     PetscCheck(linear->type == REGRESSOR_LINEAR_OLS, PetscObjectComm((PetscObject)regressor), PETSC_ERR_ARG_WRONGSTATE, "KSP can be used to fit a linear regressor only when its type is OLS");
13034b254c5SRichard Tran Mills 
13134b254c5SRichard Tran Mills     if (!linear->ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
13234b254c5SRichard Tran Mills     ksp = linear->ksp;
13334b254c5SRichard Tran Mills 
13434b254c5SRichard Tran Mills     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, NULL));
13534b254c5SRichard Tran Mills     /* Set up the KSP to solve the least squares problem (without solving for intercept, as this is done separately) using KSPLSQR. */
13634b254c5SRichard Tran Mills     PetscCall(MatCreateNormal(linear->X, &linear->XtX));
13734b254c5SRichard Tran Mills     PetscCall(KSPSetType(ksp, KSPLSQR));
13834b254c5SRichard Tran Mills     PetscCall(KSPSetOperators(ksp, linear->X, linear->XtX));
13934b254c5SRichard Tran Mills     PetscCall(KSPSetOptionsPrefix(ksp, ((PetscObject)regressor)->prefix));
14034b254c5SRichard Tran Mills     PetscCall(KSPAppendOptionsPrefix(ksp, "regressor_linear_"));
14134b254c5SRichard Tran Mills     PetscCall(KSPSetFromOptions(ksp));
14234b254c5SRichard Tran Mills   } else {
14334b254c5SRichard Tran Mills     /* Note: Currently implementation creates TAO inside of implementations.
14434b254c5SRichard Tran Mills       * Thus, all the prefix jobs are done inside implementations, not in interface */
14534b254c5SRichard Tran Mills     const char *prefix;
14634b254c5SRichard Tran Mills 
14734b254c5SRichard Tran Mills     if (!regressor->tao) PetscCall(PetscRegressorGetTao(regressor, &tao));
14834b254c5SRichard Tran Mills 
14934b254c5SRichard Tran Mills     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, &linear->residual));
15034b254c5SRichard Tran Mills     /* Set up the TAO object to solve the (regularized) least squares problem (without solving for intercept, which is done separately) using TAOBRGN. */
15134b254c5SRichard Tran Mills     PetscCall(TaoSetType(tao, TAOBRGN));
15234b254c5SRichard Tran Mills     PetscCall(TaoSetSolution(tao, linear->coefficients));
15334b254c5SRichard Tran Mills     PetscCall(TaoSetResidualRoutine(tao, linear->residual, EvaluateResidual, linear));
15434b254c5SRichard Tran Mills     PetscCall(TaoSetJacobianResidualRoutine(tao, linear->X, linear->X, EvaluateJacobian, linear));
15534b254c5SRichard Tran Mills     if (!linear->use_ksp) PetscCall(TaoBRGNSetRegularizerWeight(tao, regressor->regularizer_weight));
15634b254c5SRichard Tran Mills     // Set the regularization type and weight for the BRGN as linear->type dictates:
15734b254c5SRichard Tran Mills     // TODO BRGN needs to be BRGNSetRegularizationType
15834b254c5SRichard Tran Mills     // PetscOptionsSetValue no longer works due to functioning prefix system
15934b254c5SRichard Tran Mills     PetscCall(PetscRegressorGetOptionsPrefix(regressor, &prefix));
16034b254c5SRichard Tran Mills     PetscCall(TaoSetOptionsPrefix(regressor->tao, prefix));
16134b254c5SRichard Tran Mills     PetscCall(TaoAppendOptionsPrefix(tao, "regressor_linear_"));
16234b254c5SRichard Tran Mills     switch (linear->type) {
16334b254c5SRichard Tran Mills     case REGRESSOR_LINEAR_OLS:
16434b254c5SRichard Tran Mills       regressor->regularizer_weight = 0.0; // OLS, by definition, uses a regularizer weight of 0
16534b254c5SRichard Tran Mills       break;
16634b254c5SRichard Tran Mills     case REGRESSOR_LINEAR_LASSO:
167*c0b7dd19SHansol Suh       PetscCall(TaoBRGNSetRegularizationType(regressor->tao, TAOBRGN_REGULARIZATION_L1DICT));
16834b254c5SRichard Tran Mills       break;
16934b254c5SRichard Tran Mills     case REGRESSOR_LINEAR_RIDGE:
170*c0b7dd19SHansol Suh       PetscCall(TaoBRGNSetRegularizationType(regressor->tao, TAOBRGN_REGULARIZATION_L2PURE));
17134b254c5SRichard Tran Mills       break;
17234b254c5SRichard Tran Mills     default:
17334b254c5SRichard Tran Mills       break;
17434b254c5SRichard Tran Mills     }
17534b254c5SRichard Tran Mills     PetscCall(TaoSetFromOptions(tao));
17634b254c5SRichard Tran Mills   }
17734b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
17834b254c5SRichard Tran Mills }
17934b254c5SRichard Tran Mills 
18034b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorReset_Linear(PetscRegressor regressor)
18134b254c5SRichard Tran Mills {
18234b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
18334b254c5SRichard Tran Mills 
18434b254c5SRichard Tran Mills   PetscFunctionBegin;
18534b254c5SRichard Tran Mills   /* Destroy the PETSc objects associated with the linear regressor implementation. */
18634b254c5SRichard Tran Mills   linear->ksp_its     = 0;
18734b254c5SRichard Tran Mills   linear->ksp_tot_its = 0;
18834b254c5SRichard Tran Mills 
18934b254c5SRichard Tran Mills   PetscCall(MatDestroy(&linear->X));
19034b254c5SRichard Tran Mills   PetscCall(MatDestroy(&linear->XtX));
19134b254c5SRichard Tran Mills   PetscCall(MatDestroy(&linear->C));
19234b254c5SRichard Tran Mills   PetscCall(KSPDestroy(&linear->ksp));
19334b254c5SRichard Tran Mills   PetscCall(VecDestroy(&linear->coefficients));
19434b254c5SRichard Tran Mills   PetscCall(VecDestroy(&linear->rhs));
19534b254c5SRichard Tran Mills   PetscCall(VecDestroy(&linear->residual));
19634b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
19734b254c5SRichard Tran Mills }
19834b254c5SRichard Tran Mills 
19934b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorDestroy_Linear(PetscRegressor regressor)
20034b254c5SRichard Tran Mills {
20134b254c5SRichard Tran Mills   PetscFunctionBegin;
20234b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", NULL));
20334b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", NULL));
20434b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", NULL));
20534b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", NULL));
20634b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", NULL));
20734b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", NULL));
20834b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", NULL));
20934b254c5SRichard Tran Mills   PetscCall(PetscRegressorReset_Linear(regressor));
21034b254c5SRichard Tran Mills   PetscCall(PetscFree(regressor->data));
21134b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
21234b254c5SRichard Tran Mills }
21334b254c5SRichard Tran Mills 
21434b254c5SRichard Tran Mills /*@
21534b254c5SRichard Tran Mills   PetscRegressorLinearSetFitIntercept - Set a flag to indicate that the intercept (also known as the "bias" or "offset") should
21634b254c5SRichard Tran Mills   be calculated; data are assumed to be mean-centered if false.
21734b254c5SRichard Tran Mills 
21834b254c5SRichard Tran Mills   Logically Collective
21934b254c5SRichard Tran Mills 
22034b254c5SRichard Tran Mills   Input Parameters:
22134b254c5SRichard Tran Mills + regressor - the `PetscRegressor` context
22234b254c5SRichard Tran Mills - flg       - `PETSC_TRUE` to calculate the intercept, `PETSC_FALSE` to assume mean-centered data (default is `PETSC_TRUE`)
22334b254c5SRichard Tran Mills 
22434b254c5SRichard Tran Mills   Level: intermediate
22534b254c5SRichard Tran Mills 
22634b254c5SRichard Tran Mills   Options Database Key:
22734b254c5SRichard Tran Mills . regressor_linear_fit_intercept <true,false> - fit the intercept
22834b254c5SRichard Tran Mills 
22934b254c5SRichard Tran Mills   Note:
23034b254c5SRichard Tran Mills   If the user indicates that the intercept should not be calculated, the intercept will be set to zero.
23134b254c5SRichard Tran Mills 
23234b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PetscRegressorFit()`
23334b254c5SRichard Tran Mills @*/
23434b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearSetFitIntercept(PetscRegressor regressor, PetscBool flg)
23534b254c5SRichard Tran Mills {
23634b254c5SRichard Tran Mills   PetscFunctionBegin;
23734b254c5SRichard Tran Mills   /* TODO: Add companion PetscRegressorLinearGetFitIntercept(), and put it in the .seealso: */
23834b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
23934b254c5SRichard Tran Mills   PetscValidLogicalCollectiveBool(regressor, flg, 2);
24034b254c5SRichard Tran Mills   PetscTryMethod(regressor, "PetscRegressorLinearSetFitIntercept_C", (PetscRegressor, PetscBool), (regressor, flg));
24134b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
24234b254c5SRichard Tran Mills }
24334b254c5SRichard Tran Mills 
24434b254c5SRichard Tran Mills /*@
24534b254c5SRichard Tran Mills   PetscRegressorLinearSetUseKSP - Set a flag to indicate that a `KSP` object, instead of a `Tao` one, should be used
24634b254c5SRichard Tran Mills   to fit the regressor
24734b254c5SRichard Tran Mills 
24834b254c5SRichard Tran Mills   Logically Collective
24934b254c5SRichard Tran Mills 
25034b254c5SRichard Tran Mills   Input Parameters:
25134b254c5SRichard Tran Mills + regressor - the `PetscRegressor` context
25234b254c5SRichard Tran Mills - flg       - `PETSC_TRUE` to use a `KSP`, `PETSC_FALSE` to use a `Tao` object (default is false)
25334b254c5SRichard Tran Mills 
25434b254c5SRichard Tran Mills   Options Database Key:
25534b254c5SRichard Tran Mills . regressor_linear_use_ksp <true,false> - use `KSP`
25634b254c5SRichard Tran Mills 
25734b254c5SRichard Tran Mills   Level: intermediate
25834b254c5SRichard Tran Mills 
25934b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PetscRegressorLinearGetKSP()`
26034b254c5SRichard Tran Mills @*/
26134b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearSetUseKSP(PetscRegressor regressor, PetscBool flg)
26234b254c5SRichard Tran Mills {
26334b254c5SRichard Tran Mills   PetscFunctionBegin;
26434b254c5SRichard Tran Mills   /* TODO: Add companion PetscRegressorLinearGetUseKSP(), and put it in the .seealso: */
26534b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
26634b254c5SRichard Tran Mills   PetscValidLogicalCollectiveBool(regressor, flg, 2);
26734b254c5SRichard Tran Mills   PetscTryMethod(regressor, "PetscRegressorLinearSetUseKSP_C", (PetscRegressor, PetscBool), (regressor, flg));
26834b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
26934b254c5SRichard Tran Mills }
27034b254c5SRichard Tran Mills 
27134b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorSetFromOptions_Linear(PetscRegressor regressor, PetscOptionItems PetscOptionsObject)
27234b254c5SRichard Tran Mills {
27334b254c5SRichard Tran Mills   PetscBool              set, flg = PETSC_FALSE;
27434b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
27534b254c5SRichard Tran Mills 
27634b254c5SRichard Tran Mills   PetscFunctionBegin;
27734b254c5SRichard Tran Mills   PetscOptionsHeadBegin(PetscOptionsObject, "PetscRegressor options for linear regressors");
27834b254c5SRichard Tran Mills   PetscCall(PetscOptionsBool("-regressor_linear_fit_intercept", "Calculate intercept for linear model", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
27934b254c5SRichard Tran Mills   if (set) PetscCall(PetscRegressorLinearSetFitIntercept(regressor, flg));
28034b254c5SRichard Tran Mills   PetscCall(PetscOptionsBool("-regressor_linear_use_ksp", "Use KSP instead of TAO for linear model fitting problem", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
28134b254c5SRichard Tran Mills   if (set) PetscCall(PetscRegressorLinearSetUseKSP(regressor, flg));
28234b254c5SRichard Tran Mills   PetscCall(PetscOptionsEnum("-regressor_linear_type", "Linear regression method", "PetscRegressorLinearTypes", PetscRegressorLinearTypes, (PetscEnum)linear->type, (PetscEnum *)&linear->type, NULL));
28334b254c5SRichard Tran Mills   PetscOptionsHeadEnd();
28434b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
28534b254c5SRichard Tran Mills }
28634b254c5SRichard Tran Mills 
28734b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorView_Linear(PetscRegressor regressor, PetscViewer viewer)
28834b254c5SRichard Tran Mills {
28934b254c5SRichard Tran Mills   PetscBool              isascii;
29034b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
29134b254c5SRichard Tran Mills 
29234b254c5SRichard Tran Mills   PetscFunctionBegin;
29334b254c5SRichard Tran Mills   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
29434b254c5SRichard Tran Mills   if (isascii) {
29534b254c5SRichard Tran Mills     PetscCall(PetscViewerASCIIPushTab(viewer));
29634b254c5SRichard Tran Mills     PetscCall(PetscViewerASCIIPrintf(viewer, "PetscRegressor Linear Type: %s\n", PetscRegressorLinearTypes[linear->type]));
29734b254c5SRichard Tran Mills     if (linear->ksp) {
29834b254c5SRichard Tran Mills       PetscCall(KSPView(linear->ksp, viewer));
29934b254c5SRichard Tran Mills       PetscCall(PetscViewerASCIIPrintf(viewer, "total KSP iterations: %" PetscInt_FMT "\n", linear->ksp_tot_its));
30034b254c5SRichard Tran Mills     }
30134b254c5SRichard Tran Mills     if (linear->fit_intercept) PetscCall(PetscViewerASCIIPrintf(viewer, "Intercept=%g\n", (double)linear->intercept));
30234b254c5SRichard Tran Mills     PetscCall(PetscViewerASCIIPopTab(viewer));
30334b254c5SRichard Tran Mills   }
30434b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
30534b254c5SRichard Tran Mills }
30634b254c5SRichard Tran Mills 
30734b254c5SRichard Tran Mills /*@
30834b254c5SRichard Tran Mills   PetscRegressorLinearGetKSP - Returns the `KSP` context for a `PETSCREGRESSORLINEAR` object.
30934b254c5SRichard Tran Mills 
31034b254c5SRichard Tran Mills   Not Collective, but if the `PetscRegressor` is parallel, then the `KSP` object is parallel
31134b254c5SRichard Tran Mills 
31234b254c5SRichard Tran Mills   Input Parameter:
31334b254c5SRichard Tran Mills . regressor - the `PetscRegressor` context
31434b254c5SRichard Tran Mills 
31534b254c5SRichard Tran Mills   Output Parameter:
31634b254c5SRichard Tran Mills . ksp - the `KSP` context
31734b254c5SRichard Tran Mills 
31834b254c5SRichard Tran Mills   Level: beginner
31934b254c5SRichard Tran Mills 
32034b254c5SRichard Tran Mills   Note:
32134b254c5SRichard Tran Mills   This routine will always return a `KSP`, but, depending on the type of the linear regressor and the options that are set, the regressor may actually use a `Tao` object instead of this `KSP`.
32234b254c5SRichard Tran Mills 
32334b254c5SRichard Tran Mills .seealso: `PetscRegressorGetTao()`
32434b254c5SRichard Tran Mills @*/
32534b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearGetKSP(PetscRegressor regressor, KSP *ksp)
32634b254c5SRichard Tran Mills {
32734b254c5SRichard Tran Mills   PetscFunctionBegin;
32834b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
32934b254c5SRichard Tran Mills   PetscAssertPointer(ksp, 2);
33034b254c5SRichard Tran Mills   PetscUseMethod(regressor, "PetscRegressorLinearGetKSP_C", (PetscRegressor, KSP *), (regressor, ksp));
33134b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
33234b254c5SRichard Tran Mills }
33334b254c5SRichard Tran Mills 
33434b254c5SRichard Tran Mills /*@
33534b254c5SRichard Tran Mills   PetscRegressorLinearGetCoefficients - Get a vector of the fitted coefficients from a linear regression model
33634b254c5SRichard Tran Mills 
33734b254c5SRichard Tran Mills   Not Collective but the vector is parallel
33834b254c5SRichard Tran Mills 
33934b254c5SRichard Tran Mills   Input Parameter:
34034b254c5SRichard Tran Mills . regressor - the `PetscRegressor` context
34134b254c5SRichard Tran Mills 
34234b254c5SRichard Tran Mills   Output Parameter:
34334b254c5SRichard Tran Mills . coefficients - the vector of the coefficients
34434b254c5SRichard Tran Mills 
34534b254c5SRichard Tran Mills   Level: beginner
34634b254c5SRichard Tran Mills 
34734b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PetscRegressorLinearGetIntercept()`, `PETSCREGRESSORLINEAR`, `Vec`
34834b254c5SRichard Tran Mills @*/
34934b254c5SRichard Tran Mills PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetCoefficients(PetscRegressor regressor, Vec *coefficients)
35034b254c5SRichard Tran Mills {
35134b254c5SRichard Tran Mills   PetscFunctionBegin;
35234b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
35334b254c5SRichard Tran Mills   PetscAssertPointer(coefficients, 2);
35434b254c5SRichard Tran Mills   PetscUseMethod(regressor, "PetscRegressorLinearGetCoefficients_C", (PetscRegressor, Vec *), (regressor, coefficients));
35534b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
35634b254c5SRichard Tran Mills }
35734b254c5SRichard Tran Mills 
35834b254c5SRichard Tran Mills /*@
35934b254c5SRichard Tran Mills   PetscRegressorLinearGetIntercept - Get the intercept from a linear regression model
36034b254c5SRichard Tran Mills 
36134b254c5SRichard Tran Mills   Not Collective
36234b254c5SRichard Tran Mills 
36334b254c5SRichard Tran Mills   Input Parameter:
36434b254c5SRichard Tran Mills . regressor - the `PetscRegressor` context
36534b254c5SRichard Tran Mills 
36634b254c5SRichard Tran Mills   Output Parameter:
36734b254c5SRichard Tran Mills . intercept - the intercept
36834b254c5SRichard Tran Mills 
36934b254c5SRichard Tran Mills   Level: beginner
37034b254c5SRichard Tran Mills 
37134b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PetscRegressorLinearSetFitIntercept()`, `PetscRegressorLinearGetCoefficients()`, `PETSCREGRESSORLINEAR`
37234b254c5SRichard Tran Mills @*/
37334b254c5SRichard Tran Mills PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetIntercept(PetscRegressor regressor, PetscScalar *intercept)
37434b254c5SRichard Tran Mills {
37534b254c5SRichard Tran Mills   PetscFunctionBegin;
37634b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
37734b254c5SRichard Tran Mills   PetscAssertPointer(intercept, 2);
37834b254c5SRichard Tran Mills   PetscUseMethod(regressor, "PetscRegressorLinearGetIntercept_C", (PetscRegressor, PetscScalar *), (regressor, intercept));
37934b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
38034b254c5SRichard Tran Mills }
38134b254c5SRichard Tran Mills 
38234b254c5SRichard Tran Mills /*@C
38334b254c5SRichard Tran Mills   PetscRegressorLinearSetType - Sets the type of linear regression to be performed
38434b254c5SRichard Tran Mills 
38534b254c5SRichard Tran Mills   Logically Collective
38634b254c5SRichard Tran Mills 
38734b254c5SRichard Tran Mills   Input Parameters:
38834b254c5SRichard Tran Mills + regressor - the `PetscRegressor` context (should be of type `PETSCREGRESSORLINEAR`)
38934b254c5SRichard Tran Mills - type      - a known linear regression method
39034b254c5SRichard Tran Mills 
39134b254c5SRichard Tran Mills   Options Database Key:
39234b254c5SRichard Tran Mills . -regressor_linear_type - Sets the linear regression method; use -help for a list of available methods
39334b254c5SRichard Tran Mills    (for instance "-regressor_linear_type ols" or "-regressor_linear_type lasso")
39434b254c5SRichard Tran Mills 
39534b254c5SRichard Tran Mills   Level: intermediate
39634b254c5SRichard Tran Mills 
39734b254c5SRichard Tran Mills .seealso: `PetscRegressorLinearGetType()`, `PetscRegressorLinearType`, `PetscRegressorSetType()`
39834b254c5SRichard Tran Mills @*/
39934b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearSetType(PetscRegressor regressor, PetscRegressorLinearType type)
40034b254c5SRichard Tran Mills {
40134b254c5SRichard Tran Mills   PetscFunctionBegin;
40234b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
40334b254c5SRichard Tran Mills   PetscValidLogicalCollectiveEnum(regressor, type, 2);
40434b254c5SRichard Tran Mills   PetscTryMethod(regressor, "PetscRegressorLinearSetType_C", (PetscRegressor, PetscRegressorLinearType), (regressor, type));
40534b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
40634b254c5SRichard Tran Mills }
40734b254c5SRichard Tran Mills 
40834b254c5SRichard Tran Mills /*@
40934b254c5SRichard Tran Mills   PetscRegressorLinearGetType - Return the type for the `PETSCREGRESSORLINEAR` solver
41034b254c5SRichard Tran Mills 
41134b254c5SRichard Tran Mills   Input Parameter:
41234b254c5SRichard Tran Mills . regressor - the `PetscRegressor` solver context
41334b254c5SRichard Tran Mills 
41434b254c5SRichard Tran Mills   Output Parameter:
41534b254c5SRichard Tran Mills . type - `PETSCREGRESSORLINEAR` type
41634b254c5SRichard Tran Mills 
41734b254c5SRichard Tran Mills   Level: advanced
41834b254c5SRichard Tran Mills 
41934b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PETSCREGRESSORLINEAR`, `PetscRegressorLinearSetType()`, `PetscRegressorLinearType`
42034b254c5SRichard Tran Mills @*/
42134b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearGetType(PetscRegressor regressor, PetscRegressorLinearType *type)
42234b254c5SRichard Tran Mills {
42334b254c5SRichard Tran Mills   PetscFunctionBegin;
42434b254c5SRichard Tran Mills   PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1);
42534b254c5SRichard Tran Mills   PetscAssertPointer(type, 2);
42634b254c5SRichard Tran Mills   PetscUseMethod(regressor, "PetscRegressorLinearGetType_C", (PetscRegressor, PetscRegressorLinearType *), (regressor, type));
42734b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
42834b254c5SRichard Tran Mills }
42934b254c5SRichard Tran Mills 
43034b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorFit_Linear(PetscRegressor regressor)
43134b254c5SRichard Tran Mills {
43234b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
43334b254c5SRichard Tran Mills   KSP                    ksp;
43434b254c5SRichard Tran Mills   PetscScalar            target_mean, *column_means_global, *column_means_local, column_means_dot_coefficients;
43534b254c5SRichard Tran Mills   Vec                    column_means;
43634b254c5SRichard Tran Mills   PetscInt               m, N, istart, i, kspits;
43734b254c5SRichard Tran Mills 
43834b254c5SRichard Tran Mills   PetscFunctionBegin;
43934b254c5SRichard Tran Mills   if (linear->use_ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
44034b254c5SRichard Tran Mills   ksp = linear->ksp;
44134b254c5SRichard Tran Mills 
44234b254c5SRichard Tran Mills   /* Solve the least-squares problem (previously set up in PetscRegressorSetUp_Linear()) without finding the intercept. */
44334b254c5SRichard Tran Mills   if (linear->use_ksp) {
44434b254c5SRichard Tran Mills     PetscCall(KSPSolve(ksp, linear->rhs, linear->coefficients));
44534b254c5SRichard Tran Mills     PetscCall(KSPGetIterationNumber(ksp, &kspits));
44634b254c5SRichard Tran Mills     linear->ksp_its += kspits;
44734b254c5SRichard Tran Mills     linear->ksp_tot_its += kspits;
44834b254c5SRichard Tran Mills   } else {
44934b254c5SRichard Tran Mills     PetscCall(TaoSolve(regressor->tao));
45034b254c5SRichard Tran Mills   }
45134b254c5SRichard Tran Mills 
45234b254c5SRichard Tran Mills   /* Calculate the intercept. */
45334b254c5SRichard Tran Mills   if (linear->fit_intercept) {
45434b254c5SRichard Tran Mills     PetscCall(MatGetSize(regressor->training, NULL, &N));
45534b254c5SRichard Tran Mills     PetscCall(PetscMalloc1(N, &column_means_global));
45634b254c5SRichard Tran Mills     PetscCall(VecMean(regressor->target, &target_mean));
45734b254c5SRichard Tran Mills     /* We need the means of all columns of regressor->training, placed into a Vec compatible with linear->coefficients.
45834b254c5SRichard Tran Mills      * Note the potential scalability issue: MatGetColumnMeans() computes means of ALL colummns. */
45934b254c5SRichard Tran Mills     PetscCall(MatGetColumnMeans(regressor->training, column_means_global));
46034b254c5SRichard Tran Mills     /* TODO: Calculation of the Vec and matrix column means should probably go into the SetUp phase, and also be placed
46134b254c5SRichard Tran Mills      *       into a routine that is callable from outside of PetscRegressorFit_Linear(), because we'll want to do the same
46234b254c5SRichard Tran Mills      *       thing for other models, such as ridge and LASSO regression, and should avoid code duplication.
46334b254c5SRichard Tran Mills      *       What we are calling 'target_mean' and 'column_means' should be stashed in the base linear regressor struct,
46434b254c5SRichard Tran Mills      *       and perhaps renamed to make it clear they are offsets that should be applied (though the current naming
46534b254c5SRichard Tran Mills      *       makes sense since it makes it clear where these come from.) */
46634b254c5SRichard Tran Mills     PetscCall(VecDuplicate(linear->coefficients, &column_means));
46734b254c5SRichard Tran Mills     PetscCall(VecGetLocalSize(column_means, &m));
46834b254c5SRichard Tran Mills     PetscCall(VecGetOwnershipRange(column_means, &istart, NULL));
46934b254c5SRichard Tran Mills     PetscCall(VecGetArrayWrite(column_means, &column_means_local));
47034b254c5SRichard Tran Mills     for (i = 0; i < m; i++) column_means_local[i] = column_means_global[istart + i];
47134b254c5SRichard Tran Mills     PetscCall(VecRestoreArrayWrite(column_means, &column_means_local));
47234b254c5SRichard Tran Mills     PetscCall(VecDot(column_means, linear->coefficients, &column_means_dot_coefficients));
47334b254c5SRichard Tran Mills     PetscCall(VecDestroy(&column_means));
47434b254c5SRichard Tran Mills     PetscCall(PetscFree(column_means_global));
47534b254c5SRichard Tran Mills     linear->intercept = target_mean - column_means_dot_coefficients;
47634b254c5SRichard Tran Mills   } else {
47734b254c5SRichard Tran Mills     linear->intercept = 0.0;
47834b254c5SRichard Tran Mills   }
47934b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
48034b254c5SRichard Tran Mills }
48134b254c5SRichard Tran Mills 
48234b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorPredict_Linear(PetscRegressor regressor, Mat X, Vec y)
48334b254c5SRichard Tran Mills {
48434b254c5SRichard Tran Mills   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
48534b254c5SRichard Tran Mills 
48634b254c5SRichard Tran Mills   PetscFunctionBegin;
48734b254c5SRichard Tran Mills   PetscCall(MatMult(X, linear->coefficients, y));
48834b254c5SRichard Tran Mills   PetscCall(VecShift(y, linear->intercept));
48934b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
49034b254c5SRichard Tran Mills }
49134b254c5SRichard Tran Mills 
49234b254c5SRichard Tran Mills /*MC
49334b254c5SRichard Tran Mills      PETSCREGRESSORLINEAR - Linear regression model (ordinary least squares or regularized variants)
49434b254c5SRichard Tran Mills 
49534b254c5SRichard Tran Mills    Options Database:
49634b254c5SRichard Tran Mills +  -regressor_linear_fit_intercept - Calculate the intercept for the linear model
49734b254c5SRichard Tran Mills -  -regressor_linear_use_ksp       - Use `KSP` instead of `Tao` for linear model fitting (non-regularized variants only)
49834b254c5SRichard Tran Mills 
49934b254c5SRichard Tran Mills    Level: beginner
50034b254c5SRichard Tran Mills 
50134b254c5SRichard Tran Mills    Note:
50234b254c5SRichard Tran Mills    This is the default regressor in `PetscRegressor`.
50334b254c5SRichard Tran Mills 
50434b254c5SRichard Tran Mills .seealso: `PetscRegressorCreate()`, `PetscRegressor`, `PetscRegressorSetType()`
50534b254c5SRichard Tran Mills M*/
50634b254c5SRichard Tran Mills PETSC_EXTERN PetscErrorCode PetscRegressorCreate_Linear(PetscRegressor regressor)
50734b254c5SRichard Tran Mills {
50834b254c5SRichard Tran Mills   PetscRegressor_Linear *linear;
50934b254c5SRichard Tran Mills 
51034b254c5SRichard Tran Mills   PetscFunctionBegin;
51134b254c5SRichard Tran Mills   PetscCall(PetscNew(&linear));
51234b254c5SRichard Tran Mills   regressor->data = (void *)linear;
51334b254c5SRichard Tran Mills 
51434b254c5SRichard Tran Mills   regressor->ops->setup          = PetscRegressorSetUp_Linear;
51534b254c5SRichard Tran Mills   regressor->ops->reset          = PetscRegressorReset_Linear;
51634b254c5SRichard Tran Mills   regressor->ops->destroy        = PetscRegressorDestroy_Linear;
51734b254c5SRichard Tran Mills   regressor->ops->setfromoptions = PetscRegressorSetFromOptions_Linear;
51834b254c5SRichard Tran Mills   regressor->ops->view           = PetscRegressorView_Linear;
51934b254c5SRichard Tran Mills   regressor->ops->fit            = PetscRegressorFit_Linear;
52034b254c5SRichard Tran Mills   regressor->ops->predict        = PetscRegressorPredict_Linear;
52134b254c5SRichard Tran Mills 
52234b254c5SRichard Tran Mills   linear->intercept     = 0.0;
52334b254c5SRichard Tran Mills   linear->fit_intercept = PETSC_TRUE;  /* Default to calculating the intercept. */
52434b254c5SRichard Tran Mills   linear->use_ksp       = PETSC_FALSE; /* Do not default to using KSP for solving the model-fitting problem (use TAO instead). */
52534b254c5SRichard Tran Mills   linear->type          = REGRESSOR_LINEAR_OLS;
52634b254c5SRichard Tran Mills   /* Above, manually set the default linear regressor type.
52734b254c5SRichard Tran Mills        We don't use PetscRegressorLinearSetType() here, because that expects the SetUp event to already have happened. */
52834b254c5SRichard Tran Mills 
52934b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", PetscRegressorLinearSetFitIntercept_Linear));
53034b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", PetscRegressorLinearSetUseKSP_Linear));
53134b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", PetscRegressorLinearGetKSP_Linear));
53234b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", PetscRegressorLinearGetCoefficients_Linear));
53334b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", PetscRegressorLinearGetIntercept_Linear));
53434b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", PetscRegressorLinearSetType_Linear));
53534b254c5SRichard Tran Mills   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", PetscRegressorLinearGetType_Linear));
53634b254c5SRichard Tran Mills   PetscFunctionReturn(PETSC_SUCCESS);
53734b254c5SRichard Tran Mills }
538