提交 ad186e79 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5599 Collect input data when `dataset_sink_mode` set on GPU

Merge pull request !5599 from LiHongzhang/dataset_sink_mode
......@@ -414,11 +414,11 @@ class SummaryCollector(Callback):
logger.info("The 'train_dataset_element' in cb_params is None, maybe there is dataset sink mode.")
return
if isinstance(input_data, (list, tuple)):
if isinstance(input_data, (list, tuple)) and input_data:
input_data = input_data[0]
try:
self._record.add_value(PluginEnum.IMAGE.value, 'input_data/auto', input_data)
except ValueError:
except (TypeError, ValueError):
logger.warning('The input data of network are not image, so will not collect by SummaryCollector.')
self._collect_specified_data['collect_input_data'] = False
return
......
......@@ -448,6 +448,7 @@ class Model:
for inputs in dataset_helper:
if _need_to_full() and context.get_context("device_target") == "GPU":
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
cb_params.train_dataset_element = inputs
list_callback.step_begin(run_context)
outputs = self._train_network(*inputs)
cb_params.cur_step_num += dataset_helper.sink_size()
......@@ -499,7 +500,6 @@ class Model:
raise ValueError("when loss_fn is not None, train_dataset should"
"return two elements, but got {}".format(len_element))
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
overflow = False
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
......@@ -507,6 +507,7 @@ class Model:
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
cb_params.train_dataset_element = next_element
list_callback.step_begin(run_context)
outputs = self._train_network(*next_element)
cb_params.net_outputs = outputs
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
......
......@@ -482,6 +482,7 @@ class Model:
for inputs in dataset_helper:
if _need_to_full():
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
cb_params.train_dataset_element = inputs
list_callback.step_begin(run_context)
if switch_branch_one:
cb_params.cur_step_num += dataset_helper.sink_size()
......@@ -546,7 +547,6 @@ class Model:
raise ValueError("when loss_fn is not None, train_dataset should"
"return two elements, but got {}".format(len_element))
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
overflow = False
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
......@@ -554,6 +554,7 @@ class Model:
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
cb_params.train_dataset_element = next_element
list_callback.step_begin(run_context)
outputs = self._train_network(*next_element)
cb_params.net_outputs = outputs
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
......
......@@ -454,7 +454,6 @@ class Model:
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper:
list_callback.step_begin(run_context)
if switch_branch_one:
cb_params.cur_step_num += loop_size
self._train_network.add_flags_recursive(thor=True)
......@@ -467,6 +466,8 @@ class Model:
_exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
self._has_do_dataset_init = True
switch_branch_one = not switch_branch_one
cb_params.train_dataset_element = inputs
list_callback.step_begin(run_context)
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
......@@ -514,13 +515,14 @@ class Model:
raise ValueError("when loss_fn is not None, train_dataset should"
"return two elements, but got {}".format(len_element))
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
overflow = False
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
scaling_sens = self._get_scaling_sens()
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
cb_params.train_dataset_element = next_element
list_callback.step_begin(run_context)
outputs = self._train_network(*next_element)
cb_params.net_outputs = outputs
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
......
......@@ -242,11 +242,12 @@ class TestSummaryCollector:
SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))._check_callbacks(cb_params)
assert f"more than one SummaryCollector instance in callback list" in str(exc.value)
def test_collect_input_data_with_train_dataset_element_none(self):
"""Test the param 'train_dataset_element' in cb_params is none."""
def test_collect_input_data_with_train_dataset_element_invalid(self):
"""Test the param 'train_dataset_element' in cb_params is invalid."""
cb_params = _InternalCallbackParam()
cb_params.train_dataset_element = None
summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))
for invalid in (), [], None, [None]:
cb_params.train_dataset_element = invalid
with SummaryCollector(tempfile.mkdtemp(dir=self.base_summary_dir)) as summary_collector:
summary_collector._collect_input_data(cb_params)
assert not summary_collector._collect_specified_data['collect_input_data']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册