diff --git a/mindinsight/lineagemgr/summary/_summary_adapter.py b/mindinsight/lineagemgr/summary/_summary_adapter.py index 21e1894970ec7e715d17992081d0554da221c3fd..501d626838c2c8880bcafc7ce6e806fdb12970ad 100644 --- a/mindinsight/lineagemgr/summary/_summary_adapter.py +++ b/mindinsight/lineagemgr/summary/_summary_adapter.py @@ -327,7 +327,9 @@ def _package_user_defined_info(user_defined_dict, user_defined_message): """ for key, value in user_defined_dict.items(): if not isinstance(key, str): - raise LineageParamTypeError("The key must be str.") + error_msg = f"Invalid key type in user defined info. The {key}'s type" \ + f"'{type(key).__name__}' is not supported. It should be str." + log.error(error_msg) if isinstance(value, int): attr_name = "map_int32" @@ -336,13 +338,12 @@ def _package_user_defined_info(user_defined_dict, user_defined_message): elif isinstance(value, str): attr_name = "map_str" else: - error_msg = "Value type {} is not supported in user defined event package." \ - "Only str, int and float are permitted now.".format(type(value)) - log.error(error_msg) - raise LineageParamTypeError(error_msg) + attr_name = "attr_name" add_user_defined_info = user_defined_message.user_info.add() try: getattr(add_user_defined_info, attr_name)[key] = value - except ValueError: - raise LineageParamValueError("Value is out of range or not be supported yet.") + except AttributeError: + error_msg = f"Invalid value type in user defined info. The {value}'s type" \ + f"'{type(value).__name__}' is not supported. It should be float, int or str." + log.error(error_msg) diff --git a/tests/st/func/lineagemgr/collection/model/test_model_lineage.py b/tests/st/func/lineagemgr/collection/model/test_model_lineage.py index 73f54dab1c07b2a4a1a0f14e3850abb782ec9270..83645d39f47ba35d3c812d8b8886ddbe8a9a2e47 100644 --- a/tests/st/func/lineagemgr/collection/model/test_model_lineage.py +++ b/tests/st/func/lineagemgr/collection/model/test_model_lineage.py @@ -88,6 +88,25 @@ class TestModelLineage(TestCase): lineage_log_path = self.summary_record.full_file_name + '_lineage' assert os.path.isfile(lineage_log_path) is True + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + @pytest.mark.platform_x86_cpu + @pytest.mark.env_single + def test_train_begin_with_user_defined_info(self): + """Test TrainLineage with nested user defined info.""" + user_defined_info = {"info": {"version": "v1"}} + train_callback = TrainLineage( + self.summary_record, + False, + user_defined_info + ) + train_callback.begin(RunContext(self.run_context)) + assert train_callback.initial_learning_rate == 0.12 + lineage_log_path = self.summary_record.full_file_name + '_lineage' + assert os.path.isfile(lineage_log_path) is True + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_gpu_training