134b254c5SRichard Tran Mills #include <../src/ml/regressor/impls/linear/linearimpl.h> /*I "petscregressor.h" I*/ 234b254c5SRichard Tran Mills #include <../src/tao/leastsquares/impls/brgn/brgn.h> /*I "petsctao.h" I*/ 334b254c5SRichard Tran Mills 434b254c5SRichard Tran Mills const char *const PetscRegressorLinearTypes[] = {"ols", "lasso", "ridge", "RegressorLinearType", "REGRESSOR_LINEAR_", NULL}; 534b254c5SRichard Tran Mills 634b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearSetFitIntercept_Linear(PetscRegressor regressor, PetscBool flg) 734b254c5SRichard Tran Mills { 834b254c5SRichard Tran Mills PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 934b254c5SRichard Tran Mills 1034b254c5SRichard Tran Mills PetscFunctionBegin; 1134b254c5SRichard Tran Mills linear->fit_intercept = flg; 1234b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 1334b254c5SRichard Tran Mills } 1434b254c5SRichard Tran Mills 1534b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearSetType_Linear(PetscRegressor regressor, PetscRegressorLinearType type) 1634b254c5SRichard Tran Mills { 1734b254c5SRichard Tran Mills PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 1834b254c5SRichard Tran Mills 1934b254c5SRichard Tran Mills PetscFunctionBegin; 2034b254c5SRichard Tran Mills linear->type = type; 2134b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 2234b254c5SRichard Tran Mills } 2334b254c5SRichard Tran Mills 2434b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetType_Linear(PetscRegressor regressor, PetscRegressorLinearType *type) 2534b254c5SRichard Tran Mills { 2634b254c5SRichard Tran Mills PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 2734b254c5SRichard Tran Mills 2834b254c5SRichard Tran Mills PetscFunctionBegin; 2934b254c5SRichard Tran Mills *type = linear->type; 3034b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 3134b254c5SRichard Tran Mills } 3234b254c5SRichard Tran Mills 3334b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetIntercept_Linear(PetscRegressor regressor, PetscScalar *intercept) 3434b254c5SRichard Tran Mills { 3534b254c5SRichard Tran Mills PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 3634b254c5SRichard Tran Mills 3734b254c5SRichard Tran Mills PetscFunctionBegin; 3834b254c5SRichard Tran Mills *intercept = linear->intercept; 3934b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 4034b254c5SRichard Tran Mills } 4134b254c5SRichard Tran Mills 4234b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetCoefficients_Linear(PetscRegressor regressor, Vec *coefficients) 4334b254c5SRichard Tran Mills { 4434b254c5SRichard Tran Mills PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 4534b254c5SRichard Tran Mills 4634b254c5SRichard Tran Mills PetscFunctionBegin; 4734b254c5SRichard Tran Mills *coefficients = linear->coefficients; 4834b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 4934b254c5SRichard Tran Mills } 5034b254c5SRichard Tran Mills 5134b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearGetKSP_Linear(PetscRegressor regressor, KSP *ksp) 5234b254c5SRichard Tran Mills { 5334b254c5SRichard Tran Mills PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 5434b254c5SRichard Tran Mills 5534b254c5SRichard Tran Mills PetscFunctionBegin; 5634b254c5SRichard Tran Mills if (!linear->ksp) { 5734b254c5SRichard Tran Mills PetscCall(KSPCreate(PetscObjectComm((PetscObject)regressor), &linear->ksp)); 5834b254c5SRichard Tran Mills PetscCall(PetscObjectIncrementTabLevel((PetscObject)linear->ksp, (PetscObject)regressor, 1)); 5934b254c5SRichard Tran Mills PetscCall(PetscObjectSetOptions((PetscObject)linear->ksp, ((PetscObject)regressor)->options)); 6034b254c5SRichard Tran Mills } 6134b254c5SRichard Tran Mills *ksp = linear->ksp; 6234b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 6334b254c5SRichard Tran Mills } 6434b254c5SRichard Tran Mills 6534b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorLinearSetUseKSP_Linear(PetscRegressor regressor, PetscBool flg) 6634b254c5SRichard Tran Mills { 6734b254c5SRichard Tran Mills PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 6834b254c5SRichard Tran Mills 6934b254c5SRichard Tran Mills PetscFunctionBegin; 7034b254c5SRichard Tran Mills linear->use_ksp = flg; 7134b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 7234b254c5SRichard Tran Mills } 7334b254c5SRichard Tran Mills 7434b254c5SRichard Tran Mills static PetscErrorCode EvaluateResidual(Tao tao, Vec x, Vec f, void *ptr) 7534b254c5SRichard Tran Mills { 7634b254c5SRichard Tran Mills PetscRegressor_Linear *linear = (PetscRegressor_Linear *)ptr; 7734b254c5SRichard Tran Mills 7834b254c5SRichard Tran Mills PetscFunctionBegin; 7934b254c5SRichard Tran Mills /* Evaluate f = A * x - b */ 8034b254c5SRichard Tran Mills PetscCall(MatMult(linear->X, x, f)); 8134b254c5SRichard Tran Mills PetscCall(VecAXPY(f, -1.0, linear->rhs)); 8234b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 8334b254c5SRichard Tran Mills } 8434b254c5SRichard Tran Mills 8534b254c5SRichard Tran Mills static PetscErrorCode EvaluateJacobian(Tao tao, Vec x, Mat J, Mat Jpre, void *ptr) 8634b254c5SRichard Tran Mills { 87*540f39e1SHansol Suh /* The TAOBRGN API expects us to pass an EvaluateJacobian() routine to it, but in this case it is a dummy function. 88*540f39e1SHansol Suh Denoting our data matrix as X, for linear least squares J[m][n] = df[m]/dx[n] = X[m][n]. */ 8934b254c5SRichard Tran Mills PetscFunctionBegin; 9034b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 9134b254c5SRichard Tran Mills } 9234b254c5SRichard Tran Mills 9334b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorSetUp_Linear(PetscRegressor regressor) 9434b254c5SRichard Tran Mills { 9534b254c5SRichard Tran Mills PetscInt M, N; 9634b254c5SRichard Tran Mills PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 9734b254c5SRichard Tran Mills KSP ksp; 9834b254c5SRichard Tran Mills Tao tao; 9934b254c5SRichard Tran Mills 10034b254c5SRichard Tran Mills PetscFunctionBegin; 10134b254c5SRichard Tran Mills PetscCall(MatGetSize(regressor->training, &M, &N)); 10234b254c5SRichard Tran Mills 10334b254c5SRichard Tran Mills if (linear->fit_intercept) { 10434b254c5SRichard Tran Mills /* If we are fitting the intercept, we need to make A a composite matrix using MATCENTERING to preserve sparsity. 10534b254c5SRichard Tran Mills * Though there might be some cases we don't want to do this for, depending on what kind of matrix is passed in. (Probably bad idea for dense?) 10634b254c5SRichard Tran Mills * We will also need to ensure that the right-hand side passed to the KSP is also mean-centered, since we 10734b254c5SRichard Tran Mills * intend to compute the intercept separately from regression coefficients (that is, we will not be adding a 10834b254c5SRichard Tran Mills * column of all 1s to our design matrix). */ 10934b254c5SRichard Tran Mills PetscCall(MatCreateCentering(PetscObjectComm((PetscObject)regressor), PETSC_DECIDE, M, &linear->C)); 11034b254c5SRichard Tran Mills PetscCall(MatCreate(PetscObjectComm((PetscObject)regressor), &linear->X)); 11134b254c5SRichard Tran Mills PetscCall(MatSetSizes(linear->X, PETSC_DECIDE, PETSC_DECIDE, M, N)); 11234b254c5SRichard Tran Mills PetscCall(MatSetType(linear->X, MATCOMPOSITE)); 11334b254c5SRichard Tran Mills PetscCall(MatCompositeSetType(linear->X, MAT_COMPOSITE_MULTIPLICATIVE)); 11434b254c5SRichard Tran Mills PetscCall(MatCompositeAddMat(linear->X, regressor->training)); 11534b254c5SRichard Tran Mills PetscCall(MatCompositeAddMat(linear->X, linear->C)); 11634b254c5SRichard Tran Mills PetscCall(VecDuplicate(regressor->target, &linear->rhs)); 11734b254c5SRichard Tran Mills PetscCall(MatMult(linear->C, regressor->target, linear->rhs)); 11834b254c5SRichard Tran Mills } else { 11934b254c5SRichard Tran Mills // When not fitting intercept, we assume that the input data are already centered. 12034b254c5SRichard Tran Mills linear->X = regressor->training; 12134b254c5SRichard Tran Mills linear->rhs = regressor->target; 12234b254c5SRichard Tran Mills 12334b254c5SRichard Tran Mills PetscCall(PetscObjectReference((PetscObject)linear->X)); 12434b254c5SRichard Tran Mills PetscCall(PetscObjectReference((PetscObject)linear->rhs)); 12534b254c5SRichard Tran Mills } 12634b254c5SRichard Tran Mills 12734b254c5SRichard Tran Mills if (linear->coefficients) PetscCall(VecDestroy(&linear->coefficients)); 12834b254c5SRichard Tran Mills 12934b254c5SRichard Tran Mills if (linear->use_ksp) { 13034b254c5SRichard Tran Mills PetscCheck(linear->type == REGRESSOR_LINEAR_OLS, PetscObjectComm((PetscObject)regressor), PETSC_ERR_ARG_WRONGSTATE, "KSP can be used to fit a linear regressor only when its type is OLS"); 13134b254c5SRichard Tran Mills 13234b254c5SRichard Tran Mills if (!linear->ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp)); 13334b254c5SRichard Tran Mills ksp = linear->ksp; 13434b254c5SRichard Tran Mills 13534b254c5SRichard Tran Mills PetscCall(MatCreateVecs(linear->X, &linear->coefficients, NULL)); 13634b254c5SRichard Tran Mills /* Set up the KSP to solve the least squares problem (without solving for intercept, as this is done separately) using KSPLSQR. */ 13734b254c5SRichard Tran Mills PetscCall(MatCreateNormal(linear->X, &linear->XtX)); 13834b254c5SRichard Tran Mills PetscCall(KSPSetType(ksp, KSPLSQR)); 13934b254c5SRichard Tran Mills PetscCall(KSPSetOperators(ksp, linear->X, linear->XtX)); 14034b254c5SRichard Tran Mills PetscCall(KSPSetOptionsPrefix(ksp, ((PetscObject)regressor)->prefix)); 14134b254c5SRichard Tran Mills PetscCall(KSPAppendOptionsPrefix(ksp, "regressor_linear_")); 14234b254c5SRichard Tran Mills PetscCall(KSPSetFromOptions(ksp)); 14334b254c5SRichard Tran Mills } else { 14434b254c5SRichard Tran Mills /* Note: Currently implementation creates TAO inside of implementations. 14534b254c5SRichard Tran Mills * Thus, all the prefix jobs are done inside implementations, not in interface */ 14634b254c5SRichard Tran Mills const char *prefix; 14734b254c5SRichard Tran Mills 14834b254c5SRichard Tran Mills if (!regressor->tao) PetscCall(PetscRegressorGetTao(regressor, &tao)); 14934b254c5SRichard Tran Mills 15034b254c5SRichard Tran Mills PetscCall(MatCreateVecs(linear->X, &linear->coefficients, &linear->residual)); 15134b254c5SRichard Tran Mills /* Set up the TAO object to solve the (regularized) least squares problem (without solving for intercept, which is done separately) using TAOBRGN. */ 15234b254c5SRichard Tran Mills PetscCall(TaoSetType(tao, TAOBRGN)); 15334b254c5SRichard Tran Mills PetscCall(TaoSetSolution(tao, linear->coefficients)); 15434b254c5SRichard Tran Mills PetscCall(TaoSetResidualRoutine(tao, linear->residual, EvaluateResidual, linear)); 15534b254c5SRichard Tran Mills PetscCall(TaoSetJacobianResidualRoutine(tao, linear->X, linear->X, EvaluateJacobian, linear)); 15634b254c5SRichard Tran Mills if (!linear->use_ksp) PetscCall(TaoBRGNSetRegularizerWeight(tao, regressor->regularizer_weight)); 15734b254c5SRichard Tran Mills // Set the regularization type and weight for the BRGN as linear->type dictates: 15834b254c5SRichard Tran Mills // TODO BRGN needs to be BRGNSetRegularizationType 15934b254c5SRichard Tran Mills // PetscOptionsSetValue no longer works due to functioning prefix system 16034b254c5SRichard Tran Mills PetscCall(PetscRegressorGetOptionsPrefix(regressor, &prefix)); 16134b254c5SRichard Tran Mills PetscCall(TaoSetOptionsPrefix(regressor->tao, prefix)); 16234b254c5SRichard Tran Mills PetscCall(TaoAppendOptionsPrefix(tao, "regressor_linear_")); 16334b254c5SRichard Tran Mills { 16434b254c5SRichard Tran Mills TAO_BRGN *gn = (TAO_BRGN *)regressor->tao->data; 16534b254c5SRichard Tran Mills 16634b254c5SRichard Tran Mills switch (linear->type) { 16734b254c5SRichard Tran Mills case REGRESSOR_LINEAR_OLS: 16834b254c5SRichard Tran Mills regressor->regularizer_weight = 0.0; // OLS, by definition, uses a regularizer weight of 0 16934b254c5SRichard Tran Mills break; 17034b254c5SRichard Tran Mills case REGRESSOR_LINEAR_LASSO: 17134b254c5SRichard Tran Mills gn->reg_type = BRGN_REGULARIZATION_L1DICT; 17234b254c5SRichard Tran Mills break; 17334b254c5SRichard Tran Mills case REGRESSOR_LINEAR_RIDGE: 17434b254c5SRichard Tran Mills gn->reg_type = BRGN_REGULARIZATION_L2PURE; 17534b254c5SRichard Tran Mills break; 17634b254c5SRichard Tran Mills default: 17734b254c5SRichard Tran Mills break; 17834b254c5SRichard Tran Mills } 17934b254c5SRichard Tran Mills } 18034b254c5SRichard Tran Mills PetscCall(TaoSetFromOptions(tao)); 18134b254c5SRichard Tran Mills } 18234b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 18334b254c5SRichard Tran Mills } 18434b254c5SRichard Tran Mills 18534b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorReset_Linear(PetscRegressor regressor) 18634b254c5SRichard Tran Mills { 18734b254c5SRichard Tran Mills PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 18834b254c5SRichard Tran Mills 18934b254c5SRichard Tran Mills PetscFunctionBegin; 19034b254c5SRichard Tran Mills /* Destroy the PETSc objects associated with the linear regressor implementation. */ 19134b254c5SRichard Tran Mills linear->ksp_its = 0; 19234b254c5SRichard Tran Mills linear->ksp_tot_its = 0; 19334b254c5SRichard Tran Mills 19434b254c5SRichard Tran Mills PetscCall(MatDestroy(&linear->X)); 19534b254c5SRichard Tran Mills PetscCall(MatDestroy(&linear->XtX)); 19634b254c5SRichard Tran Mills PetscCall(MatDestroy(&linear->C)); 19734b254c5SRichard Tran Mills PetscCall(KSPDestroy(&linear->ksp)); 19834b254c5SRichard Tran Mills PetscCall(VecDestroy(&linear->coefficients)); 19934b254c5SRichard Tran Mills PetscCall(VecDestroy(&linear->rhs)); 20034b254c5SRichard Tran Mills PetscCall(VecDestroy(&linear->residual)); 20134b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 20234b254c5SRichard Tran Mills } 20334b254c5SRichard Tran Mills 20434b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorDestroy_Linear(PetscRegressor regressor) 20534b254c5SRichard Tran Mills { 20634b254c5SRichard Tran Mills PetscFunctionBegin; 20734b254c5SRichard Tran Mills PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", NULL)); 20834b254c5SRichard Tran Mills PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", NULL)); 20934b254c5SRichard Tran Mills PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", NULL)); 21034b254c5SRichard Tran Mills PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", NULL)); 21134b254c5SRichard Tran Mills PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", NULL)); 21234b254c5SRichard Tran Mills PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", NULL)); 21334b254c5SRichard Tran Mills PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", NULL)); 21434b254c5SRichard Tran Mills PetscCall(PetscRegressorReset_Linear(regressor)); 21534b254c5SRichard Tran Mills PetscCall(PetscFree(regressor->data)); 21634b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 21734b254c5SRichard Tran Mills } 21834b254c5SRichard Tran Mills 21934b254c5SRichard Tran Mills /*@ 22034b254c5SRichard Tran Mills PetscRegressorLinearSetFitIntercept - Set a flag to indicate that the intercept (also known as the "bias" or "offset") should 22134b254c5SRichard Tran Mills be calculated; data are assumed to be mean-centered if false. 22234b254c5SRichard Tran Mills 22334b254c5SRichard Tran Mills Logically Collective 22434b254c5SRichard Tran Mills 22534b254c5SRichard Tran Mills Input Parameters: 22634b254c5SRichard Tran Mills + regressor - the `PetscRegressor` context 22734b254c5SRichard Tran Mills - flg - `PETSC_TRUE` to calculate the intercept, `PETSC_FALSE` to assume mean-centered data (default is `PETSC_TRUE`) 22834b254c5SRichard Tran Mills 22934b254c5SRichard Tran Mills Level: intermediate 23034b254c5SRichard Tran Mills 23134b254c5SRichard Tran Mills Options Database Key: 23234b254c5SRichard Tran Mills . regressor_linear_fit_intercept <true,false> - fit the intercept 23334b254c5SRichard Tran Mills 23434b254c5SRichard Tran Mills Note: 23534b254c5SRichard Tran Mills If the user indicates that the intercept should not be calculated, the intercept will be set to zero. 23634b254c5SRichard Tran Mills 23734b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PetscRegressorFit()` 23834b254c5SRichard Tran Mills @*/ 23934b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearSetFitIntercept(PetscRegressor regressor, PetscBool flg) 24034b254c5SRichard Tran Mills { 24134b254c5SRichard Tran Mills PetscFunctionBegin; 24234b254c5SRichard Tran Mills /* TODO: Add companion PetscRegressorLinearGetFitIntercept(), and put it in the .seealso: */ 24334b254c5SRichard Tran Mills PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1); 24434b254c5SRichard Tran Mills PetscValidLogicalCollectiveBool(regressor, flg, 2); 24534b254c5SRichard Tran Mills PetscTryMethod(regressor, "PetscRegressorLinearSetFitIntercept_C", (PetscRegressor, PetscBool), (regressor, flg)); 24634b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 24734b254c5SRichard Tran Mills } 24834b254c5SRichard Tran Mills 24934b254c5SRichard Tran Mills /*@ 25034b254c5SRichard Tran Mills PetscRegressorLinearSetUseKSP - Set a flag to indicate that a `KSP` object, instead of a `Tao` one, should be used 25134b254c5SRichard Tran Mills to fit the regressor 25234b254c5SRichard Tran Mills 25334b254c5SRichard Tran Mills Logically Collective 25434b254c5SRichard Tran Mills 25534b254c5SRichard Tran Mills Input Parameters: 25634b254c5SRichard Tran Mills + regressor - the `PetscRegressor` context 25734b254c5SRichard Tran Mills - flg - `PETSC_TRUE` to use a `KSP`, `PETSC_FALSE` to use a `Tao` object (default is false) 25834b254c5SRichard Tran Mills 25934b254c5SRichard Tran Mills Options Database Key: 26034b254c5SRichard Tran Mills . regressor_linear_use_ksp <true,false> - use `KSP` 26134b254c5SRichard Tran Mills 26234b254c5SRichard Tran Mills Level: intermediate 26334b254c5SRichard Tran Mills 26434b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PetscRegressorLinearGetKSP()` 26534b254c5SRichard Tran Mills @*/ 26634b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearSetUseKSP(PetscRegressor regressor, PetscBool flg) 26734b254c5SRichard Tran Mills { 26834b254c5SRichard Tran Mills PetscFunctionBegin; 26934b254c5SRichard Tran Mills /* TODO: Add companion PetscRegressorLinearGetUseKSP(), and put it in the .seealso: */ 27034b254c5SRichard Tran Mills PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1); 27134b254c5SRichard Tran Mills PetscValidLogicalCollectiveBool(regressor, flg, 2); 27234b254c5SRichard Tran Mills PetscTryMethod(regressor, "PetscRegressorLinearSetUseKSP_C", (PetscRegressor, PetscBool), (regressor, flg)); 27334b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 27434b254c5SRichard Tran Mills } 27534b254c5SRichard Tran Mills 27634b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorSetFromOptions_Linear(PetscRegressor regressor, PetscOptionItems PetscOptionsObject) 27734b254c5SRichard Tran Mills { 27834b254c5SRichard Tran Mills PetscBool set, flg = PETSC_FALSE; 27934b254c5SRichard Tran Mills PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 28034b254c5SRichard Tran Mills 28134b254c5SRichard Tran Mills PetscFunctionBegin; 28234b254c5SRichard Tran Mills PetscOptionsHeadBegin(PetscOptionsObject, "PetscRegressor options for linear regressors"); 28334b254c5SRichard Tran Mills PetscCall(PetscOptionsBool("-regressor_linear_fit_intercept", "Calculate intercept for linear model", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set)); 28434b254c5SRichard Tran Mills if (set) PetscCall(PetscRegressorLinearSetFitIntercept(regressor, flg)); 28534b254c5SRichard Tran Mills PetscCall(PetscOptionsBool("-regressor_linear_use_ksp", "Use KSP instead of TAO for linear model fitting problem", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set)); 28634b254c5SRichard Tran Mills if (set) PetscCall(PetscRegressorLinearSetUseKSP(regressor, flg)); 28734b254c5SRichard Tran Mills PetscCall(PetscOptionsEnum("-regressor_linear_type", "Linear regression method", "PetscRegressorLinearTypes", PetscRegressorLinearTypes, (PetscEnum)linear->type, (PetscEnum *)&linear->type, NULL)); 28834b254c5SRichard Tran Mills PetscOptionsHeadEnd(); 28934b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 29034b254c5SRichard Tran Mills } 29134b254c5SRichard Tran Mills 29234b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorView_Linear(PetscRegressor regressor, PetscViewer viewer) 29334b254c5SRichard Tran Mills { 29434b254c5SRichard Tran Mills PetscBool isascii; 29534b254c5SRichard Tran Mills PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 29634b254c5SRichard Tran Mills 29734b254c5SRichard Tran Mills PetscFunctionBegin; 29834b254c5SRichard Tran Mills PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii)); 29934b254c5SRichard Tran Mills if (isascii) { 30034b254c5SRichard Tran Mills PetscCall(PetscViewerASCIIPushTab(viewer)); 30134b254c5SRichard Tran Mills PetscCall(PetscViewerASCIIPrintf(viewer, "PetscRegressor Linear Type: %s\n", PetscRegressorLinearTypes[linear->type])); 30234b254c5SRichard Tran Mills if (linear->ksp) { 30334b254c5SRichard Tran Mills PetscCall(KSPView(linear->ksp, viewer)); 30434b254c5SRichard Tran Mills PetscCall(PetscViewerASCIIPrintf(viewer, "total KSP iterations: %" PetscInt_FMT "\n", linear->ksp_tot_its)); 30534b254c5SRichard Tran Mills } 30634b254c5SRichard Tran Mills if (linear->fit_intercept) PetscCall(PetscViewerASCIIPrintf(viewer, "Intercept=%g\n", (double)linear->intercept)); 30734b254c5SRichard Tran Mills PetscCall(PetscViewerASCIIPopTab(viewer)); 30834b254c5SRichard Tran Mills } 30934b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 31034b254c5SRichard Tran Mills } 31134b254c5SRichard Tran Mills 31234b254c5SRichard Tran Mills /*@ 31334b254c5SRichard Tran Mills PetscRegressorLinearGetKSP - Returns the `KSP` context for a `PETSCREGRESSORLINEAR` object. 31434b254c5SRichard Tran Mills 31534b254c5SRichard Tran Mills Not Collective, but if the `PetscRegressor` is parallel, then the `KSP` object is parallel 31634b254c5SRichard Tran Mills 31734b254c5SRichard Tran Mills Input Parameter: 31834b254c5SRichard Tran Mills . regressor - the `PetscRegressor` context 31934b254c5SRichard Tran Mills 32034b254c5SRichard Tran Mills Output Parameter: 32134b254c5SRichard Tran Mills . ksp - the `KSP` context 32234b254c5SRichard Tran Mills 32334b254c5SRichard Tran Mills Level: beginner 32434b254c5SRichard Tran Mills 32534b254c5SRichard Tran Mills Note: 32634b254c5SRichard Tran Mills This routine will always return a `KSP`, but, depending on the type of the linear regressor and the options that are set, the regressor may actually use a `Tao` object instead of this `KSP`. 32734b254c5SRichard Tran Mills 32834b254c5SRichard Tran Mills .seealso: `PetscRegressorGetTao()` 32934b254c5SRichard Tran Mills @*/ 33034b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearGetKSP(PetscRegressor regressor, KSP *ksp) 33134b254c5SRichard Tran Mills { 33234b254c5SRichard Tran Mills PetscFunctionBegin; 33334b254c5SRichard Tran Mills PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1); 33434b254c5SRichard Tran Mills PetscAssertPointer(ksp, 2); 33534b254c5SRichard Tran Mills PetscUseMethod(regressor, "PetscRegressorLinearGetKSP_C", (PetscRegressor, KSP *), (regressor, ksp)); 33634b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 33734b254c5SRichard Tran Mills } 33834b254c5SRichard Tran Mills 33934b254c5SRichard Tran Mills /*@ 34034b254c5SRichard Tran Mills PetscRegressorLinearGetCoefficients - Get a vector of the fitted coefficients from a linear regression model 34134b254c5SRichard Tran Mills 34234b254c5SRichard Tran Mills Not Collective but the vector is parallel 34334b254c5SRichard Tran Mills 34434b254c5SRichard Tran Mills Input Parameter: 34534b254c5SRichard Tran Mills . regressor - the `PetscRegressor` context 34634b254c5SRichard Tran Mills 34734b254c5SRichard Tran Mills Output Parameter: 34834b254c5SRichard Tran Mills . coefficients - the vector of the coefficients 34934b254c5SRichard Tran Mills 35034b254c5SRichard Tran Mills Level: beginner 35134b254c5SRichard Tran Mills 35234b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PetscRegressorLinearGetIntercept()`, `PETSCREGRESSORLINEAR`, `Vec` 35334b254c5SRichard Tran Mills @*/ 35434b254c5SRichard Tran Mills PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetCoefficients(PetscRegressor regressor, Vec *coefficients) 35534b254c5SRichard Tran Mills { 35634b254c5SRichard Tran Mills PetscFunctionBegin; 35734b254c5SRichard Tran Mills PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1); 35834b254c5SRichard Tran Mills PetscAssertPointer(coefficients, 2); 35934b254c5SRichard Tran Mills PetscUseMethod(regressor, "PetscRegressorLinearGetCoefficients_C", (PetscRegressor, Vec *), (regressor, coefficients)); 36034b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 36134b254c5SRichard Tran Mills } 36234b254c5SRichard Tran Mills 36334b254c5SRichard Tran Mills /*@ 36434b254c5SRichard Tran Mills PetscRegressorLinearGetIntercept - Get the intercept from a linear regression model 36534b254c5SRichard Tran Mills 36634b254c5SRichard Tran Mills Not Collective 36734b254c5SRichard Tran Mills 36834b254c5SRichard Tran Mills Input Parameter: 36934b254c5SRichard Tran Mills . regressor - the `PetscRegressor` context 37034b254c5SRichard Tran Mills 37134b254c5SRichard Tran Mills Output Parameter: 37234b254c5SRichard Tran Mills . intercept - the intercept 37334b254c5SRichard Tran Mills 37434b254c5SRichard Tran Mills Level: beginner 37534b254c5SRichard Tran Mills 37634b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PetscRegressorLinearSetFitIntercept()`, `PetscRegressorLinearGetCoefficients()`, `PETSCREGRESSORLINEAR` 37734b254c5SRichard Tran Mills @*/ 37834b254c5SRichard Tran Mills PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetIntercept(PetscRegressor regressor, PetscScalar *intercept) 37934b254c5SRichard Tran Mills { 38034b254c5SRichard Tran Mills PetscFunctionBegin; 38134b254c5SRichard Tran Mills PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1); 38234b254c5SRichard Tran Mills PetscAssertPointer(intercept, 2); 38334b254c5SRichard Tran Mills PetscUseMethod(regressor, "PetscRegressorLinearGetIntercept_C", (PetscRegressor, PetscScalar *), (regressor, intercept)); 38434b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 38534b254c5SRichard Tran Mills } 38634b254c5SRichard Tran Mills 38734b254c5SRichard Tran Mills /*@C 38834b254c5SRichard Tran Mills PetscRegressorLinearSetType - Sets the type of linear regression to be performed 38934b254c5SRichard Tran Mills 39034b254c5SRichard Tran Mills Logically Collective 39134b254c5SRichard Tran Mills 39234b254c5SRichard Tran Mills Input Parameters: 39334b254c5SRichard Tran Mills + regressor - the `PetscRegressor` context (should be of type `PETSCREGRESSORLINEAR`) 39434b254c5SRichard Tran Mills - type - a known linear regression method 39534b254c5SRichard Tran Mills 39634b254c5SRichard Tran Mills Options Database Key: 39734b254c5SRichard Tran Mills . -regressor_linear_type - Sets the linear regression method; use -help for a list of available methods 39834b254c5SRichard Tran Mills (for instance "-regressor_linear_type ols" or "-regressor_linear_type lasso") 39934b254c5SRichard Tran Mills 40034b254c5SRichard Tran Mills Level: intermediate 40134b254c5SRichard Tran Mills 40234b254c5SRichard Tran Mills .seealso: `PetscRegressorLinearGetType()`, `PetscRegressorLinearType`, `PetscRegressorSetType()` 40334b254c5SRichard Tran Mills @*/ 40434b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearSetType(PetscRegressor regressor, PetscRegressorLinearType type) 40534b254c5SRichard Tran Mills { 40634b254c5SRichard Tran Mills PetscFunctionBegin; 40734b254c5SRichard Tran Mills PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1); 40834b254c5SRichard Tran Mills PetscValidLogicalCollectiveEnum(regressor, type, 2); 40934b254c5SRichard Tran Mills PetscTryMethod(regressor, "PetscRegressorLinearSetType_C", (PetscRegressor, PetscRegressorLinearType), (regressor, type)); 41034b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 41134b254c5SRichard Tran Mills } 41234b254c5SRichard Tran Mills 41334b254c5SRichard Tran Mills /*@ 41434b254c5SRichard Tran Mills PetscRegressorLinearGetType - Return the type for the `PETSCREGRESSORLINEAR` solver 41534b254c5SRichard Tran Mills 41634b254c5SRichard Tran Mills Input Parameter: 41734b254c5SRichard Tran Mills . regressor - the `PetscRegressor` solver context 41834b254c5SRichard Tran Mills 41934b254c5SRichard Tran Mills Output Parameter: 42034b254c5SRichard Tran Mills . type - `PETSCREGRESSORLINEAR` type 42134b254c5SRichard Tran Mills 42234b254c5SRichard Tran Mills Level: advanced 42334b254c5SRichard Tran Mills 42434b254c5SRichard Tran Mills .seealso: `PetscRegressor`, `PETSCREGRESSORLINEAR`, `PetscRegressorLinearSetType()`, `PetscRegressorLinearType` 42534b254c5SRichard Tran Mills @*/ 42634b254c5SRichard Tran Mills PetscErrorCode PetscRegressorLinearGetType(PetscRegressor regressor, PetscRegressorLinearType *type) 42734b254c5SRichard Tran Mills { 42834b254c5SRichard Tran Mills PetscFunctionBegin; 42934b254c5SRichard Tran Mills PetscValidHeaderSpecific(regressor, PETSCREGRESSOR_CLASSID, 1); 43034b254c5SRichard Tran Mills PetscAssertPointer(type, 2); 43134b254c5SRichard Tran Mills PetscUseMethod(regressor, "PetscRegressorLinearGetType_C", (PetscRegressor, PetscRegressorLinearType *), (regressor, type)); 43234b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 43334b254c5SRichard Tran Mills } 43434b254c5SRichard Tran Mills 43534b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorFit_Linear(PetscRegressor regressor) 43634b254c5SRichard Tran Mills { 43734b254c5SRichard Tran Mills PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 43834b254c5SRichard Tran Mills KSP ksp; 43934b254c5SRichard Tran Mills PetscScalar target_mean, *column_means_global, *column_means_local, column_means_dot_coefficients; 44034b254c5SRichard Tran Mills Vec column_means; 44134b254c5SRichard Tran Mills PetscInt m, N, istart, i, kspits; 44234b254c5SRichard Tran Mills 44334b254c5SRichard Tran Mills PetscFunctionBegin; 44434b254c5SRichard Tran Mills if (linear->use_ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp)); 44534b254c5SRichard Tran Mills ksp = linear->ksp; 44634b254c5SRichard Tran Mills 44734b254c5SRichard Tran Mills /* Solve the least-squares problem (previously set up in PetscRegressorSetUp_Linear()) without finding the intercept. */ 44834b254c5SRichard Tran Mills if (linear->use_ksp) { 44934b254c5SRichard Tran Mills PetscCall(KSPSolve(ksp, linear->rhs, linear->coefficients)); 45034b254c5SRichard Tran Mills PetscCall(KSPGetIterationNumber(ksp, &kspits)); 45134b254c5SRichard Tran Mills linear->ksp_its += kspits; 45234b254c5SRichard Tran Mills linear->ksp_tot_its += kspits; 45334b254c5SRichard Tran Mills } else { 45434b254c5SRichard Tran Mills PetscCall(TaoSolve(regressor->tao)); 45534b254c5SRichard Tran Mills } 45634b254c5SRichard Tran Mills 45734b254c5SRichard Tran Mills /* Calculate the intercept. */ 45834b254c5SRichard Tran Mills if (linear->fit_intercept) { 45934b254c5SRichard Tran Mills PetscCall(MatGetSize(regressor->training, NULL, &N)); 46034b254c5SRichard Tran Mills PetscCall(PetscMalloc1(N, &column_means_global)); 46134b254c5SRichard Tran Mills PetscCall(VecMean(regressor->target, &target_mean)); 46234b254c5SRichard Tran Mills /* We need the means of all columns of regressor->training, placed into a Vec compatible with linear->coefficients. 46334b254c5SRichard Tran Mills * Note the potential scalability issue: MatGetColumnMeans() computes means of ALL colummns. */ 46434b254c5SRichard Tran Mills PetscCall(MatGetColumnMeans(regressor->training, column_means_global)); 46534b254c5SRichard Tran Mills /* TODO: Calculation of the Vec and matrix column means should probably go into the SetUp phase, and also be placed 46634b254c5SRichard Tran Mills * into a routine that is callable from outside of PetscRegressorFit_Linear(), because we'll want to do the same 46734b254c5SRichard Tran Mills * thing for other models, such as ridge and LASSO regression, and should avoid code duplication. 46834b254c5SRichard Tran Mills * What we are calling 'target_mean' and 'column_means' should be stashed in the base linear regressor struct, 46934b254c5SRichard Tran Mills * and perhaps renamed to make it clear they are offsets that should be applied (though the current naming 47034b254c5SRichard Tran Mills * makes sense since it makes it clear where these come from.) */ 47134b254c5SRichard Tran Mills PetscCall(VecDuplicate(linear->coefficients, &column_means)); 47234b254c5SRichard Tran Mills PetscCall(VecGetLocalSize(column_means, &m)); 47334b254c5SRichard Tran Mills PetscCall(VecGetOwnershipRange(column_means, &istart, NULL)); 47434b254c5SRichard Tran Mills PetscCall(VecGetArrayWrite(column_means, &column_means_local)); 47534b254c5SRichard Tran Mills for (i = 0; i < m; i++) column_means_local[i] = column_means_global[istart + i]; 47634b254c5SRichard Tran Mills PetscCall(VecRestoreArrayWrite(column_means, &column_means_local)); 47734b254c5SRichard Tran Mills PetscCall(VecDot(column_means, linear->coefficients, &column_means_dot_coefficients)); 47834b254c5SRichard Tran Mills PetscCall(VecDestroy(&column_means)); 47934b254c5SRichard Tran Mills PetscCall(PetscFree(column_means_global)); 48034b254c5SRichard Tran Mills linear->intercept = target_mean - column_means_dot_coefficients; 48134b254c5SRichard Tran Mills } else { 48234b254c5SRichard Tran Mills linear->intercept = 0.0; 48334b254c5SRichard Tran Mills } 48434b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 48534b254c5SRichard Tran Mills } 48634b254c5SRichard Tran Mills 48734b254c5SRichard Tran Mills static PetscErrorCode PetscRegressorPredict_Linear(PetscRegressor regressor, Mat X, Vec y) 48834b254c5SRichard Tran Mills { 48934b254c5SRichard Tran Mills PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data; 49034b254c5SRichard Tran Mills 49134b254c5SRichard Tran Mills PetscFunctionBegin; 49234b254c5SRichard Tran Mills PetscCall(MatMult(X, linear->coefficients, y)); 49334b254c5SRichard Tran Mills PetscCall(VecShift(y, linear->intercept)); 49434b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 49534b254c5SRichard Tran Mills } 49634b254c5SRichard Tran Mills 49734b254c5SRichard Tran Mills /*MC 49834b254c5SRichard Tran Mills PETSCREGRESSORLINEAR - Linear regression model (ordinary least squares or regularized variants) 49934b254c5SRichard Tran Mills 50034b254c5SRichard Tran Mills Options Database: 50134b254c5SRichard Tran Mills + -regressor_linear_fit_intercept - Calculate the intercept for the linear model 50234b254c5SRichard Tran Mills - -regressor_linear_use_ksp - Use `KSP` instead of `Tao` for linear model fitting (non-regularized variants only) 50334b254c5SRichard Tran Mills 50434b254c5SRichard Tran Mills Level: beginner 50534b254c5SRichard Tran Mills 50634b254c5SRichard Tran Mills Note: 50734b254c5SRichard Tran Mills This is the default regressor in `PetscRegressor`. 50834b254c5SRichard Tran Mills 50934b254c5SRichard Tran Mills .seealso: `PetscRegressorCreate()`, `PetscRegressor`, `PetscRegressorSetType()` 51034b254c5SRichard Tran Mills M*/ 51134b254c5SRichard Tran Mills PETSC_EXTERN PetscErrorCode PetscRegressorCreate_Linear(PetscRegressor regressor) 51234b254c5SRichard Tran Mills { 51334b254c5SRichard Tran Mills PetscRegressor_Linear *linear; 51434b254c5SRichard Tran Mills 51534b254c5SRichard Tran Mills PetscFunctionBegin; 51634b254c5SRichard Tran Mills PetscCall(PetscNew(&linear)); 51734b254c5SRichard Tran Mills regressor->data = (void *)linear; 51834b254c5SRichard Tran Mills 51934b254c5SRichard Tran Mills regressor->ops->setup = PetscRegressorSetUp_Linear; 52034b254c5SRichard Tran Mills regressor->ops->reset = PetscRegressorReset_Linear; 52134b254c5SRichard Tran Mills regressor->ops->destroy = PetscRegressorDestroy_Linear; 52234b254c5SRichard Tran Mills regressor->ops->setfromoptions = PetscRegressorSetFromOptions_Linear; 52334b254c5SRichard Tran Mills regressor->ops->view = PetscRegressorView_Linear; 52434b254c5SRichard Tran Mills regressor->ops->fit = PetscRegressorFit_Linear; 52534b254c5SRichard Tran Mills regressor->ops->predict = PetscRegressorPredict_Linear; 52634b254c5SRichard Tran Mills 52734b254c5SRichard Tran Mills linear->intercept = 0.0; 52834b254c5SRichard Tran Mills linear->fit_intercept = PETSC_TRUE; /* Default to calculating the intercept. */ 52934b254c5SRichard Tran Mills linear->use_ksp = PETSC_FALSE; /* Do not default to using KSP for solving the model-fitting problem (use TAO instead). */ 53034b254c5SRichard Tran Mills linear->type = REGRESSOR_LINEAR_OLS; 53134b254c5SRichard Tran Mills /* Above, manually set the default linear regressor type. 53234b254c5SRichard Tran Mills We don't use PetscRegressorLinearSetType() here, because that expects the SetUp event to already have happened. */ 53334b254c5SRichard Tran Mills 53434b254c5SRichard Tran Mills PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", PetscRegressorLinearSetFitIntercept_Linear)); 53534b254c5SRichard Tran Mills PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", PetscRegressorLinearSetUseKSP_Linear)); 53634b254c5SRichard Tran Mills PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", PetscRegressorLinearGetKSP_Linear)); 53734b254c5SRichard Tran Mills PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", PetscRegressorLinearGetCoefficients_Linear)); 53834b254c5SRichard Tran Mills PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", PetscRegressorLinearGetIntercept_Linear)); 53934b254c5SRichard Tran Mills PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", PetscRegressorLinearSetType_Linear)); 54034b254c5SRichard Tran Mills PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", PetscRegressorLinearGetType_Linear)); 54134b254c5SRichard Tran Mills PetscFunctionReturn(PETSC_SUCCESS); 54234b254c5SRichard Tran Mills } 543