diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index 17766f14049b46fad84252864da0f9d1755e87b8..fe5580bd25e2d13dc11c43ba8dc776cd18219a7a 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -182,7 +182,7 @@ class SummaryCollector(Callback): self._custom_lineage_data = custom_lineage_data self._temp_optimizer = None - self._has_saved_train_network = False + self._has_saved_graph = False self._has_saved_custom_data = False self._is_parse_loss_success = True self._first_step = True @@ -287,32 +287,30 @@ class SummaryCollector(Callback): 'but got `{cb_params.mode}` mode.') self._record.set_mode(cb_params.mode) - if cb_params.mode == ModeEnum.TRAIN.value: - # Note: if model.init is not executed then the computed graph will not be obtained here - # The purpose of recording the graph here was to collect_freq if it was set to a large size, - # but also want to see the graph as soon after compilation. - self._collect_graphs(cb_params) - self._collect_dataset_graph(cb_params) + if cb_params.mode == ModeEnum.TRAIN.value: if self._collect_tensor_freq is None: default_tensor_summary_limit = 20 total_step = cb_params.epoch_num * cb_params.batch_num self._collect_tensor_freq = max(self._collect_freq, total_step // default_tensor_summary_limit) - if self._custom_lineage_data and not self._has_saved_custom_data: - packaged_custom_data = self._package_custom_lineage_data(self._custom_lineage_data) - self._record.add_value('custom_lineage_data', 'custom_lineage_data', packaged_custom_data) - self._has_saved_custom_data = True - - # There's nothing special about setting step to 0 here, just to satisfy the interface call - self._record.record(step=0) - def step_end(self, run_context): cb_params = run_context.original_args() if cb_params.mode != ModeEnum.TRAIN.value: return - if not self._has_saved_train_network: + + if not self._has_saved_graph: self._collect_graphs(cb_params) + self._collect_dataset_graph(cb_params) + self._has_saved_graph = True + self._record.record(cb_params.cur_step_num) + + if self._custom_lineage_data and not self._has_saved_custom_data: + packaged_custom_data = self._package_custom_lineage_data(self._custom_lineage_data) + self._record.add_value('custom_lineage_data', 'custom_lineage_data', packaged_custom_data) + self._has_saved_custom_data = True + self._record.record(cb_params.cur_step_num) + if self._first_step: # Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario self._dataset_sink_mode = cb_params.cur_step_num == cb_params.batch_num @@ -327,14 +325,12 @@ class SummaryCollector(Callback): elif current % self._collect_freq == 0: self._collect_at_step_end(cb_params, lambda plugin: plugin != PluginEnum.TENSOR.value) - def _collect_at_step_end(self, cb_params, plugin_filter): self._collect_input_data(cb_params) self._collect_metric(cb_params) self._collect_histogram(cb_params) self._record.record(cb_params.cur_step_num, plugin_filter=plugin_filter) - def end(self, run_context): cb_params = run_context.original_args() if cb_params.mode == ModeEnum.TRAIN.value: @@ -428,7 +424,6 @@ class SummaryCollector(Callback): if graph_proto is None: return - self._has_saved_train_network = True self._record.add_value(PluginEnum.GRAPH.value, 'train_network/auto', graph_proto) def _collect_metric(self, cb_params):