提交 76220c0f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3368 Restore the code to collect the graph network

Merge pull request !3368 from LiHongzhang/oh_graph
...@@ -182,6 +182,7 @@ class SummaryCollector(Callback): ...@@ -182,6 +182,7 @@ class SummaryCollector(Callback):
self._custom_lineage_data = custom_lineage_data self._custom_lineage_data = custom_lineage_data
self._temp_optimizer = None self._temp_optimizer = None
self._has_saved_train_network = False
self._has_saved_custom_data = False self._has_saved_custom_data = False
self._is_parse_loss_success = True self._is_parse_loss_success = True
self._first_step = True self._first_step = True
...@@ -215,7 +216,7 @@ class SummaryCollector(Callback): ...@@ -215,7 +216,7 @@ class SummaryCollector(Callback):
@staticmethod @staticmethod
def _check_positive(name, value, allow_none=False): def _check_positive(name, value, allow_none=False):
"""Check if the value to be int type and positive.""" """Check if the value to be int type and positive."""
if allow_none: if allow_none and value is None:
return return
check_value_type(name, value, int) check_value_type(name, value, int)
if value <= 0: if value <= 0:
...@@ -294,8 +295,9 @@ class SummaryCollector(Callback): ...@@ -294,8 +295,9 @@ class SummaryCollector(Callback):
self._collect_dataset_graph(cb_params) self._collect_dataset_graph(cb_params)
if self._collect_tensor_freq is None: if self._collect_tensor_freq is None:
default_tensor_summary_limit = 50
total_step = cb_params.epoch_num * cb_params.batch_num total_step = cb_params.epoch_num * cb_params.batch_num
self._collect_tensor_freq = max(self._collect_freq, total_step // 50) self._collect_tensor_freq = max(self._collect_freq, total_step // default_tensor_summary_limit)
if self._custom_lineage_data and not self._has_saved_custom_data: if self._custom_lineage_data and not self._has_saved_custom_data:
packaged_custom_data = self._package_custom_lineage_data(self._custom_lineage_data) packaged_custom_data = self._package_custom_lineage_data(self._custom_lineage_data)
...@@ -309,6 +311,8 @@ class SummaryCollector(Callback): ...@@ -309,6 +311,8 @@ class SummaryCollector(Callback):
cb_params = run_context.original_args() cb_params = run_context.original_args()
if cb_params.mode != ModeEnum.TRAIN.value: if cb_params.mode != ModeEnum.TRAIN.value:
return return
if not self._has_saved_train_network:
self._collect_graphs(cb_params)
if self._first_step: if self._first_step:
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario # Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
self._dataset_sink_mode = cb_params.cur_step_num == cb_params.batch_num self._dataset_sink_mode = cb_params.cur_step_num == cb_params.batch_num
...@@ -424,6 +428,7 @@ class SummaryCollector(Callback): ...@@ -424,6 +428,7 @@ class SummaryCollector(Callback):
if graph_proto is None: if graph_proto is None:
return return
self._has_saved_train_network = True
self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto) self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto)
def _collect_metric(self, cb_params): def _collect_metric(self, cb_params):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册