提交 336fca14 编写于 作者: O ougongchang

Fix collecting bert network name faild in MindInsight lineage.

1. collect the origin network in model, and set it to cb_params
2. collect the origin network name in SummaryCollector
3. Update the SummaryCollector API Doc
上级 92517266
......@@ -73,7 +73,8 @@ class SummaryCollector(Callback):
summary_dir (str): The collected data will be persisted to this directory.
If the directory does not exist, it will be created automatically.
collect_freq (int): Set the frequency of data collection, it should be greater then zero,
and the unit is `step`. Default: 10. The first step will be recorded at any time.
and the unit is `step`. Default: 10. If a frequency is set, we will collect data
at (current steps % freq) == 0, and the first step will be collected at any time.
It is important to note that if the data sink mode is used, the unit will become the `epoch`.
It is not recommended to collect data too frequently, which can affect performance.
collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None.
......@@ -593,7 +594,7 @@ class SummaryCollector(Callback):
else:
train_lineage[LineageMetadata.learning_rate] = None
train_lineage[LineageMetadata.optimizer] = type(optimizer).__name__ if optimizer else None
train_lineage[LineageMetadata.train_network] = self._get_backbone(cb_params.train_network)
train_lineage[LineageMetadata.train_network] = type(cb_params.network).__name__
loss_fn = self._get_loss_fn(cb_params)
train_lineage[LineageMetadata.loss_function] = type(loss_fn).__name__ if loss_fn else None
......@@ -750,30 +751,6 @@ class SummaryCollector(Callback):
return ckpt_file_path
@staticmethod
def _get_backbone(network):
"""
Get the name of backbone network.
Args:
network (Cell): The train network.
Returns:
Union[str, None], If parse success, will return the name of the backbone network, else return None.
"""
backbone_name = None
backbone_key = '_backbone'
for _, cell in network.cells_and_names():
if hasattr(cell, backbone_key):
backbone_network = getattr(cell, backbone_key)
backbone_name = type(backbone_network).__name__
if backbone_name is None and network is not None:
backbone_name = type(network).__name__
return backbone_name
@staticmethod
def _get_loss_fn(cb_params):
"""
......
......@@ -355,6 +355,7 @@ class Model:
cb_params.train_dataset = train_dataset
cb_params.list_callback = self._transform_callbacks(callbacks)
cb_params.train_dataset_element = None
cb_params.network = self._network
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
epoch = 1
......@@ -660,6 +661,7 @@ class Model:
cb_params.mode = "eval"
cb_params.cur_step_num = 0
cb_params.list_callback = self._transform_callbacks(callbacks)
cb_params.network = self._network
self._eval_network.set_train(mode=False)
self._eval_network.phase = 'eval'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册