提交 ad186e79 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5599 Collect input data when `dataset_sink_mode` set on GPU

Merge pull request !5599 from LiHongzhang/dataset_sink_mode
...@@ -414,11 +414,11 @@ class SummaryCollector(Callback): ...@@ -414,11 +414,11 @@ class SummaryCollector(Callback):
logger.info("The 'train_dataset_element' in cb_params is None, maybe there is dataset sink mode.") logger.info("The 'train_dataset_element' in cb_params is None, maybe there is dataset sink mode.")
return return
if isinstance(input_data, (list, tuple)): if isinstance(input_data, (list, tuple)) and input_data:
input_data = input_data[0] input_data = input_data[0]
try: try:
self._record.add_value(PluginEnum.IMAGE.value, 'input_data/auto', input_data) self._record.add_value(PluginEnum.IMAGE.value, 'input_data/auto', input_data)
except ValueError: except (TypeError, ValueError):
logger.warning('The input data of network are not image, so will not collect by SummaryCollector.') logger.warning('The input data of network are not image, so will not collect by SummaryCollector.')
self._collect_specified_data['collect_input_data'] = False self._collect_specified_data['collect_input_data'] = False
return return
......
...@@ -448,6 +448,7 @@ class Model: ...@@ -448,6 +448,7 @@ class Model:
for inputs in dataset_helper: for inputs in dataset_helper:
if _need_to_full() and context.get_context("device_target") == "GPU": if _need_to_full() and context.get_context("device_target") == "GPU":
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
cb_params.train_dataset_element = inputs
list_callback.step_begin(run_context) list_callback.step_begin(run_context)
outputs = self._train_network(*inputs) outputs = self._train_network(*inputs)
cb_params.cur_step_num += dataset_helper.sink_size() cb_params.cur_step_num += dataset_helper.sink_size()
...@@ -499,7 +500,6 @@ class Model: ...@@ -499,7 +500,6 @@ class Model:
raise ValueError("when loss_fn is not None, train_dataset should" raise ValueError("when loss_fn is not None, train_dataset should"
"return two elements, but got {}".format(len_element)) "return two elements, but got {}".format(len_element))
cb_params.cur_step_num += 1 cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
overflow = False overflow = False
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
...@@ -507,6 +507,7 @@ class Model: ...@@ -507,6 +507,7 @@ class Model:
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),) next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
cb_params.train_dataset_element = next_element cb_params.train_dataset_element = next_element
list_callback.step_begin(run_context)
outputs = self._train_network(*next_element) outputs = self._train_network(*next_element)
cb_params.net_outputs = outputs cb_params.net_outputs = outputs
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
......
...@@ -482,6 +482,7 @@ class Model: ...@@ -482,6 +482,7 @@ class Model:
for inputs in dataset_helper: for inputs in dataset_helper:
if _need_to_full(): if _need_to_full():
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
cb_params.train_dataset_element = inputs
list_callback.step_begin(run_context) list_callback.step_begin(run_context)
if switch_branch_one: if switch_branch_one:
cb_params.cur_step_num += dataset_helper.sink_size() cb_params.cur_step_num += dataset_helper.sink_size()
...@@ -546,7 +547,6 @@ class Model: ...@@ -546,7 +547,6 @@ class Model:
raise ValueError("when loss_fn is not None, train_dataset should" raise ValueError("when loss_fn is not None, train_dataset should"
"return two elements, but got {}".format(len_element)) "return two elements, but got {}".format(len_element))
cb_params.cur_step_num += 1 cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
overflow = False overflow = False
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
...@@ -554,6 +554,7 @@ class Model: ...@@ -554,6 +554,7 @@ class Model:
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),) next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
cb_params.train_dataset_element = next_element cb_params.train_dataset_element = next_element
list_callback.step_begin(run_context)
outputs = self._train_network(*next_element) outputs = self._train_network(*next_element)
cb_params.net_outputs = outputs cb_params.net_outputs = outputs
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
......
...@@ -454,7 +454,6 @@ class Model: ...@@ -454,7 +454,6 @@ class Model:
# for data sink dataset_helper only iter once, other wise iter epoch_size times. # for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper: for inputs in dataset_helper:
list_callback.step_begin(run_context)
if switch_branch_one: if switch_branch_one:
cb_params.cur_step_num += loop_size cb_params.cur_step_num += loop_size
self._train_network.add_flags_recursive(thor=True) self._train_network.add_flags_recursive(thor=True)
...@@ -467,6 +466,8 @@ class Model: ...@@ -467,6 +466,8 @@ class Model:
_exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset') _exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
self._has_do_dataset_init = True self._has_do_dataset_init = True
switch_branch_one = not switch_branch_one switch_branch_one = not switch_branch_one
cb_params.train_dataset_element = inputs
list_callback.step_begin(run_context)
outputs = self._train_network(*inputs) outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs cb_params.net_outputs = outputs
list_callback.step_end(run_context) list_callback.step_end(run_context)
...@@ -514,13 +515,14 @@ class Model: ...@@ -514,13 +515,14 @@ class Model:
raise ValueError("when loss_fn is not None, train_dataset should" raise ValueError("when loss_fn is not None, train_dataset should"
"return two elements, but got {}".format(len_element)) "return two elements, but got {}".format(len_element))
cb_params.cur_step_num += 1 cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
overflow = False overflow = False
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
scaling_sens = self._get_scaling_sens() scaling_sens = self._get_scaling_sens()
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),) next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
cb_params.train_dataset_element = next_element
list_callback.step_begin(run_context)
outputs = self._train_network(*next_element) outputs = self._train_network(*next_element)
cb_params.net_outputs = outputs cb_params.net_outputs = outputs
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
......
...@@ -242,13 +242,14 @@ class TestSummaryCollector: ...@@ -242,13 +242,14 @@ class TestSummaryCollector:
SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))._check_callbacks(cb_params) SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))._check_callbacks(cb_params)
assert f"more than one SummaryCollector instance in callback list" in str(exc.value) assert f"more than one SummaryCollector instance in callback list" in str(exc.value)
def test_collect_input_data_with_train_dataset_element_none(self): def test_collect_input_data_with_train_dataset_element_invalid(self):
"""Test the param 'train_dataset_element' in cb_params is none.""" """Test the param 'train_dataset_element' in cb_params is invalid."""
cb_params = _InternalCallbackParam() cb_params = _InternalCallbackParam()
cb_params.train_dataset_element = None for invalid in (), [], None, [None]:
summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir))) cb_params.train_dataset_element = invalid
summary_collector._collect_input_data(cb_params) with SummaryCollector(tempfile.mkdtemp(dir=self.base_summary_dir)) as summary_collector:
assert not summary_collector._collect_specified_data['collect_input_data'] summary_collector._collect_input_data(cb_params)
assert not summary_collector._collect_specified_data['collect_input_data']
@mock.patch.object(SummaryRecord, 'add_value') @mock.patch.object(SummaryRecord, 'add_value')
def test_collect_input_data_success(self, mock_add_value): def test_collect_input_data_success(self, mock_add_value):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册