提交 336fca14 编写于 作者: O ougongchang

Fix collecting bert network name faild in MindInsight lineage.

1. collect the origin network in model, and set it to cb_params
2. collect the origin network name in SummaryCollector
3. Update the SummaryCollector API Doc
上级 92517266
...@@ -73,7 +73,8 @@ class SummaryCollector(Callback): ...@@ -73,7 +73,8 @@ class SummaryCollector(Callback):
summary_dir (str): The collected data will be persisted to this directory. summary_dir (str): The collected data will be persisted to this directory.
If the directory does not exist, it will be created automatically. If the directory does not exist, it will be created automatically.
collect_freq (int): Set the frequency of data collection, it should be greater then zero, collect_freq (int): Set the frequency of data collection, it should be greater then zero,
and the unit is `step`. Default: 10. The first step will be recorded at any time. and the unit is `step`. Default: 10. If a frequency is set, we will collect data
at (current steps % freq) == 0, and the first step will be collected at any time.
It is important to note that if the data sink mode is used, the unit will become the `epoch`. It is important to note that if the data sink mode is used, the unit will become the `epoch`.
It is not recommended to collect data too frequently, which can affect performance. It is not recommended to collect data too frequently, which can affect performance.
collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None. collect_specified_data (Union[None, dict]): Perform custom operations on the collected data. Default: None.
...@@ -593,7 +594,7 @@ class SummaryCollector(Callback): ...@@ -593,7 +594,7 @@ class SummaryCollector(Callback):
else: else:
train_lineage[LineageMetadata.learning_rate] = None train_lineage[LineageMetadata.learning_rate] = None
train_lineage[LineageMetadata.optimizer] = type(optimizer).__name__ if optimizer else None train_lineage[LineageMetadata.optimizer] = type(optimizer).__name__ if optimizer else None
train_lineage[LineageMetadata.train_network] = self._get_backbone(cb_params.train_network) train_lineage[LineageMetadata.train_network] = type(cb_params.network).__name__
loss_fn = self._get_loss_fn(cb_params) loss_fn = self._get_loss_fn(cb_params)
train_lineage[LineageMetadata.loss_function] = type(loss_fn).__name__ if loss_fn else None train_lineage[LineageMetadata.loss_function] = type(loss_fn).__name__ if loss_fn else None
...@@ -750,30 +751,6 @@ class SummaryCollector(Callback): ...@@ -750,30 +751,6 @@ class SummaryCollector(Callback):
return ckpt_file_path return ckpt_file_path
@staticmethod
def _get_backbone(network):
"""
Get the name of backbone network.
Args:
network (Cell): The train network.
Returns:
Union[str, None], If parse success, will return the name of the backbone network, else return None.
"""
backbone_name = None
backbone_key = '_backbone'
for _, cell in network.cells_and_names():
if hasattr(cell, backbone_key):
backbone_network = getattr(cell, backbone_key)
backbone_name = type(backbone_network).__name__
if backbone_name is None and network is not None:
backbone_name = type(network).__name__
return backbone_name
@staticmethod @staticmethod
def _get_loss_fn(cb_params): def _get_loss_fn(cb_params):
""" """
......
...@@ -355,6 +355,7 @@ class Model: ...@@ -355,6 +355,7 @@ class Model:
cb_params.train_dataset = train_dataset cb_params.train_dataset = train_dataset
cb_params.list_callback = self._transform_callbacks(callbacks) cb_params.list_callback = self._transform_callbacks(callbacks)
cb_params.train_dataset_element = None cb_params.train_dataset_element = None
cb_params.network = self._network
ms_role = os.getenv("MS_ROLE") ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"): if ms_role in ("MS_PSERVER", "MS_SCHED"):
epoch = 1 epoch = 1
...@@ -660,6 +661,7 @@ class Model: ...@@ -660,6 +661,7 @@ class Model:
cb_params.mode = "eval" cb_params.mode = "eval"
cb_params.cur_step_num = 0 cb_params.cur_step_num = 0
cb_params.list_callback = self._transform_callbacks(callbacks) cb_params.list_callback = self._transform_callbacks(callbacks)
cb_params.network = self._network
self._eval_network.set_train(mode=False) self._eval_network.set_train(mode=False)
self._eval_network.phase = 'eval' self._eval_network.phase = 'eval'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册