提交 10003c2e 编写于 作者: K kouzhenzhong

lineagemgr: user defined error not affect other infos record when not raise...

lineagemgr: user defined error not affect other infos record when not raise exception in lineage callback
上级 42cb780f
......@@ -327,7 +327,9 @@ def _package_user_defined_info(user_defined_dict, user_defined_message):
"""
for key, value in user_defined_dict.items():
if not isinstance(key, str):
raise LineageParamTypeError("The key must be str.")
error_msg = f"Invalid key type in user defined info. The {key}'s type" \
f"'{type(key).__name__}' is not supported. It should be str."
log.error(error_msg)
if isinstance(value, int):
attr_name = "map_int32"
......@@ -336,13 +338,12 @@ def _package_user_defined_info(user_defined_dict, user_defined_message):
elif isinstance(value, str):
attr_name = "map_str"
else:
error_msg = "Value type {} is not supported in user defined event package." \
"Only str, int and float are permitted now.".format(type(value))
log.error(error_msg)
raise LineageParamTypeError(error_msg)
attr_name = "attr_name"
add_user_defined_info = user_defined_message.user_info.add()
try:
getattr(add_user_defined_info, attr_name)[key] = value
except ValueError:
raise LineageParamValueError("Value is out of range or not be supported yet.")
except AttributeError:
error_msg = f"Invalid value type in user defined info. The {value}'s type" \
f"'{type(value).__name__}' is not supported. It should be float, int or str."
log.error(error_msg)
......@@ -88,6 +88,25 @@ class TestModelLineage(TestCase):
lineage_log_path = self.summary_record.full_file_name + '_lineage'
assert os.path.isfile(lineage_log_path) is True
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_train_begin_with_user_defined_info(self):
"""Test TrainLineage with nested user defined info."""
user_defined_info = {"info": {"version": "v1"}}
train_callback = TrainLineage(
self.summary_record,
False,
user_defined_info
)
train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12
lineage_log_path = self.summary_record.full_file_name + '_lineage'
assert os.path.isfile(lineage_log_path) is True
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册