From c7e3076918396ac0b9aa8b901ba553c25715e8fb Mon Sep 17 00:00:00 2001 From: PyCaret Date: Mon, 15 Jun 2020 19:50:27 -0400 Subject: [PATCH] updated classification.py --- pycaret/classification.py | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/pycaret/classification.py b/pycaret/classification.py index 7c0d5ec..8893e03 100644 --- a/pycaret/classification.py +++ b/pycaret/classification.py @@ -2434,12 +2434,9 @@ def create_model(estimator = None, import secrets URI = secrets.token_hex(nbytes=4) - mlflow.set_tag("URI", URI) - + mlflow.set_tag("URI", URI) mlflow.set_tag("USI", USI) - mlflow.set_tag("Run Time", runtime) - mlflow.set_tag("Run ID", RunID) # Log training time in seconds @@ -8543,7 +8540,22 @@ def calibrate_model(estimator, def get_model_name(e): return str(e).split("(")[0] - mn = get_model_name(estimator) + if len(estimator.classes_) > 2: + + if hasattr(estimator, 'voting'): + mn = get_model_name(estimator) + else: + mn = get_model_name(estimator.estimator) + + else: + + if hasattr(estimator, 'base_estimator'): + mn = get_model_name(estimator.base_estimator) + else: + mn = get_model_name(estimator) + + if 'catboost' in mn: + mn = 'CatBoostClassifier' model_dict_logging = {'ExtraTreesClassifier' : 'Extra Trees Classifier', 'GradientBoostingClassifier' : 'Gradient Boosting Classifier', @@ -9106,13 +9118,18 @@ def finalize_model(estimator): if type(estimator) is not list: if len(estimator.classes_) > 2: - mn = get_model_name(estimator.estimator) - if hasattr(estimator, 'base_estimator'): - mn = get_model_name(estimator.base_estimator) + if hasattr(estimator, 'voting'): + mn = get_model_name(estimator) + else: + mn = get_model_name(estimator.estimator) else: - mn = get_model_name(estimator) + + if hasattr(estimator, 'base_estimator'): + mn = get_model_name(estimator.base_estimator) + else: + mn = get_model_name(estimator) if 'catboost' in mn: mn = 'CatBoostClassifier' -- GitLab