Source code for autoflow.pipeline.components.classification.catboost

from copy import deepcopy
from typing import Dict

from autoflow.pipeline.components.classification_base import AutoFlowClassificationAlgorithm
from autoflow.pipeline.components.utils import get_categorical_features_indices

__all__ = ["CatBoostClassifier"]


[docs]class CatBoostClassifier(AutoFlowClassificationAlgorithm): class__ = "CatBoostClassifier" module__ = "catboost" boost_model = True tree_model = True
[docs] def core_fit(self, estimator, X, y=None, X_valid=None, y_valid=None, X_test=None, y_test=None, feature_groups=None, columns_metadata=None): categorical_features_indices = get_categorical_features_indices(X, columns_metadata) if (X_valid is not None) and (y_valid is not None): eval_set = (X_valid, y_valid) else: eval_set = None return self.estimator.fit( X, y, cat_features=categorical_features_indices, eval_set=eval_set, silent=True )
[docs] def after_process_hyperparams(self, hyperparams) -> Dict: hyperparams = deepcopy(hyperparams) if "n_jobs" in hyperparams: hyperparams["thread_count"] = hyperparams.pop("n_jobs") return hyperparams