提交 c7e30769 编写于 作者: P PyCaret

updated classification.py

上级 d3f02786
......@@ -2434,12 +2434,9 @@ def create_model(estimator = None,
import secrets
URI = secrets.token_hex(nbytes=4)
mlflow.set_tag("URI", URI)
mlflow.set_tag("URI", URI)
mlflow.set_tag("USI", USI)
mlflow.set_tag("Run Time", runtime)
mlflow.set_tag("Run ID", RunID)
# Log training time in seconds
......@@ -8543,7 +8540,22 @@ def calibrate_model(estimator,
def get_model_name(e):
return str(e).split("(")[0]
mn = get_model_name(estimator)
if len(estimator.classes_) > 2:
if hasattr(estimator, 'voting'):
mn = get_model_name(estimator)
else:
mn = get_model_name(estimator.estimator)
else:
if hasattr(estimator, 'base_estimator'):
mn = get_model_name(estimator.base_estimator)
else:
mn = get_model_name(estimator)
if 'catboost' in mn:
mn = 'CatBoostClassifier'
model_dict_logging = {'ExtraTreesClassifier' : 'Extra Trees Classifier',
'GradientBoostingClassifier' : 'Gradient Boosting Classifier',
......@@ -9106,13 +9118,18 @@ def finalize_model(estimator):
if type(estimator) is not list:
if len(estimator.classes_) > 2:
mn = get_model_name(estimator.estimator)
if hasattr(estimator, 'base_estimator'):
mn = get_model_name(estimator.base_estimator)
if hasattr(estimator, 'voting'):
mn = get_model_name(estimator)
else:
mn = get_model_name(estimator.estimator)
else:
mn = get_model_name(estimator)
if hasattr(estimator, 'base_estimator'):
mn = get_model_name(estimator.base_estimator)
else:
mn = get_model_name(estimator)
if 'catboost' in mn:
mn = 'CatBoostClassifier'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册