From 10003c2e196eea9b73e3cd435ce3b9df993f5ad8 Mon Sep 17 00:00:00 2001 From: kouzhenzhong Date: Wed, 22 Apr 2020 12:52:44 +0800 Subject: [PATCH] lineagemgr: user defined error not affect other infos record when not raise exception in lineage callback --- .../lineagemgr/summary/_summary_adapter.py | 15 ++++++++------- .../collection/model/test_model_lineage.py | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/mindinsight/lineagemgr/summary/_summary_adapter.py b/mindinsight/lineagemgr/summary/_summary_adapter.py index 21e1894..501d626 100644 --- a/mindinsight/lineagemgr/summary/_summary_adapter.py +++ b/mindinsight/lineagemgr/summary/_summary_adapter.py @@ -327,7 +327,9 @@ def _package_user_defined_info(user_defined_dict, user_defined_message): """ for key, value in user_defined_dict.items(): if not isinstance(key, str): - raise LineageParamTypeError("The key must be str.") + error_msg = f"Invalid key type in user defined info. The {key}'s type" \ + f"'{type(key).__name__}' is not supported. It should be str." + log.error(error_msg) if isinstance(value, int): attr_name = "map_int32" @@ -336,13 +338,12 @@ def _package_user_defined_info(user_defined_dict, user_defined_message): elif isinstance(value, str): attr_name = "map_str" else: - error_msg = "Value type {} is not supported in user defined event package." \ - "Only str, int and float are permitted now.".format(type(value)) - log.error(error_msg) - raise LineageParamTypeError(error_msg) + attr_name = "attr_name" add_user_defined_info = user_defined_message.user_info.add() try: getattr(add_user_defined_info, attr_name)[key] = value - except ValueError: - raise LineageParamValueError("Value is out of range or not be supported yet.") + except AttributeError: + error_msg = f"Invalid value type in user defined info. The {value}'s type" \ + f"'{type(value).__name__}' is not supported. It should be float, int or str." + log.error(error_msg) diff --git a/tests/st/func/lineagemgr/collection/model/test_model_lineage.py b/tests/st/func/lineagemgr/collection/model/test_model_lineage.py index 73f54da..83645d3 100644 --- a/tests/st/func/lineagemgr/collection/model/test_model_lineage.py +++ b/tests/st/func/lineagemgr/collection/model/test_model_lineage.py @@ -88,6 +88,25 @@ class TestModelLineage(TestCase): lineage_log_path = self.summary_record.full_file_name + '_lineage' assert os.path.isfile(lineage_log_path) is True + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + @pytest.mark.platform_x86_cpu + @pytest.mark.env_single + def test_train_begin_with_user_defined_info(self): + """Test TrainLineage with nested user defined info.""" + user_defined_info = {"info": {"version": "v1"}} + train_callback = TrainLineage( + self.summary_record, + False, + user_defined_info + ) + train_callback.begin(RunContext(self.run_context)) + assert train_callback.initial_learning_rate == 0.12 + lineage_log_path = self.summary_record.full_file_name + '_lineage' + assert os.path.isfile(lineage_log_path) is True + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_gpu_training -- GitLab