提交 0934281a 编写于 作者: O ougongchang

Decide whether to collect data by dataset sink mode and current step in SummaryCollector.

Before, we only decide whether to collect data by current step,
it will not work well in dataset sink mode, so we check to see
if it's a dataset sink mode, and decide whether to collect data.
上级 1cadea12
......@@ -166,8 +166,11 @@ class SummaryCollector(Callback):
self._has_saved_custom_data = False
self._is_parse_loss_success = True
self._first_step = True
self._dataset_sink_mode = True
def __enter__(self):
self._first_step = True
self._dataset_sink_mode = True
self._record = SummaryRecord(log_dir=self._summary_dir)
return self
......@@ -279,15 +282,15 @@ class SummaryCollector(Callback):
def step_end(self, run_context):
cb_params = run_context.original_args()
if self._first_step:
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num)
if cb_params.mode == ModeEnum.TRAIN.value:
# Make sure the first step data is recorded
if not self._first_step and cb_params.cur_step_num % self._collect_freq:
if not self._is_collect_this_step(cb_params):
return
self._first_step = False
if not self._has_saved_train_network:
self._collect_graphs(cb_params)
......@@ -295,6 +298,7 @@ class SummaryCollector(Callback):
self._collect_metric(cb_params)
self._collect_histogram(cb_params)
self._first_step = False
self._record.record(cb_params.cur_step_num)
def end(self, run_context):
......@@ -320,6 +324,18 @@ class SummaryCollector(Callback):
raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list,"
f"but expected only one {self.__class__.__name__} instance.")
def _is_collect_this_step(self, cb_params):
"""Decide whether to collect data for the current step."""
# Make sure the first step data is recorded
if not self._first_step:
if self._dataset_sink_mode:
if cb_params.cur_epoch_num % self._collect_freq:
return False
else:
if cb_params.cur_step_num % self._collect_freq:
return False
return True
@staticmethod
def _package_custom_lineage_data(custom_lineage_data):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册