提交 e63cc721 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!60 lineagemgr: user defined error not affect other lineage information record...

!60 lineagemgr: user defined error not affect other lineage information record when not raise exception in lineage callback
Merge pull request !60 from kouzhenzhong/user_defined_bug_fix
...@@ -327,7 +327,9 @@ def _package_user_defined_info(user_defined_dict, user_defined_message): ...@@ -327,7 +327,9 @@ def _package_user_defined_info(user_defined_dict, user_defined_message):
""" """
for key, value in user_defined_dict.items(): for key, value in user_defined_dict.items():
if not isinstance(key, str): if not isinstance(key, str):
raise LineageParamTypeError("The key must be str.") error_msg = f"Invalid key type in user defined info. The {key}'s type" \
f"'{type(key).__name__}' is not supported. It should be str."
log.error(error_msg)
if isinstance(value, int): if isinstance(value, int):
attr_name = "map_int32" attr_name = "map_int32"
...@@ -336,13 +338,12 @@ def _package_user_defined_info(user_defined_dict, user_defined_message): ...@@ -336,13 +338,12 @@ def _package_user_defined_info(user_defined_dict, user_defined_message):
elif isinstance(value, str): elif isinstance(value, str):
attr_name = "map_str" attr_name = "map_str"
else: else:
error_msg = "Value type {} is not supported in user defined event package." \ attr_name = "attr_name"
"Only str, int and float are permitted now.".format(type(value))
log.error(error_msg)
raise LineageParamTypeError(error_msg)
add_user_defined_info = user_defined_message.user_info.add() add_user_defined_info = user_defined_message.user_info.add()
try: try:
getattr(add_user_defined_info, attr_name)[key] = value getattr(add_user_defined_info, attr_name)[key] = value
except ValueError: except AttributeError:
raise LineageParamValueError("Value is out of range or not be supported yet.") error_msg = f"Invalid value type in user defined info. The {value}'s type" \
f"'{type(value).__name__}' is not supported. It should be float, int or str."
log.error(error_msg)
...@@ -88,6 +88,25 @@ class TestModelLineage(TestCase): ...@@ -88,6 +88,25 @@ class TestModelLineage(TestCase):
lineage_log_path = self.summary_record.full_file_name + '_lineage' lineage_log_path = self.summary_record.full_file_name + '_lineage'
assert os.path.isfile(lineage_log_path) is True assert os.path.isfile(lineage_log_path) is True
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_train_begin_with_user_defined_info(self):
"""Test TrainLineage with nested user defined info."""
user_defined_info = {"info": {"version": "v1"}}
train_callback = TrainLineage(
self.summary_record,
False,
user_defined_info
)
train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12
lineage_log_path = self.summary_record.full_file_name + '_lineage'
assert os.path.isfile(lineage_log_path) is True
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册