diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index 6d5ec45d5bfff418580bcc084994cb995d09fba7..cff03ca398f2be708ded48d943d902a52932b990 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -166,8 +166,11 @@ class SummaryCollector(Callback): self._has_saved_custom_data = False self._is_parse_loss_success = True self._first_step = True + self._dataset_sink_mode = True def __enter__(self): + self._first_step = True + self._dataset_sink_mode = True self._record = SummaryRecord(log_dir=self._summary_dir) return self @@ -279,15 +282,15 @@ class SummaryCollector(Callback): def step_end(self, run_context): cb_params = run_context.original_args() + if self._first_step: + # Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario + self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num) if cb_params.mode == ModeEnum.TRAIN.value: - # Make sure the first step data is recorded - if not self._first_step and cb_params.cur_step_num % self._collect_freq: + if not self._is_collect_this_step(cb_params): return - self._first_step = False - if not self._has_saved_train_network: self._collect_graphs(cb_params) @@ -295,6 +298,7 @@ class SummaryCollector(Callback): self._collect_metric(cb_params) self._collect_histogram(cb_params) + self._first_step = False self._record.record(cb_params.cur_step_num) def end(self, run_context): @@ -320,6 +324,18 @@ class SummaryCollector(Callback): raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list," f"but expected only one {self.__class__.__name__} instance.") + def _is_collect_this_step(self, cb_params): + """Decide whether to collect data for the current step.""" + # Make sure the first step data is recorded + if not self._first_step: + if self._dataset_sink_mode: + if cb_params.cur_epoch_num % self._collect_freq: + return False + else: + if cb_params.cur_step_num % self._collect_freq: + return False + return True + @staticmethod def _package_custom_lineage_data(custom_lineage_data): """