提交 91c856e5 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2334 remove dataset send from data exec for r0.3

Merge pull request !2334 from wangnan39/do_not_send_data_duriing_model_init
......@@ -15,6 +15,7 @@
"""Dataset help for minddata dataset"""
from mindspore._checkparam import check_bool
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
from mindspore.train.dataset_helper import _send_data
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
_to_full_shapes
from mindspore.train.parallel_utils import ParallelMode
......@@ -67,7 +68,13 @@ class _DatasetIter:
self.loop_size = dataset.get_dataset_size()
else:
self.loop_size = dataset.__loop_size__
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'):
_send_data(dataset)
else:
_send_data(dataset)
self.ind = 0
self.dataset = dataset
......
......@@ -16,11 +16,10 @@
import os
import numpy as np
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype
from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
from mindspore.common import dtype as mstype
from mindspore import log as logger
from mindspore.common.api import _executor
from mindspore.common.dtype import pytype_to_dtype
def _convert_type(types):
......@@ -63,9 +62,6 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
dataset_shapes,
input_indexs,
phase=phase)
# engine dataset to write data to tdt queue
exec_dataset.send()
return exec_dataset
......
......@@ -24,6 +24,14 @@ from ..nn.wrap import GetNextSingleOp
from ..parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode
def _send_data(dataset):
"""Engine dataset to write data to tdt queue."""
if not hasattr(dataset, '__has_sent__'):
exec_dataset = dataset.__TRANSFER_DATASET__
exec_dataset.send()
dataset.__has_sent__ = True
class DatasetHelper:
"""
Help function to use the Minddata dataset.
......@@ -82,7 +90,13 @@ class _DatasetIter:
self.loop_size = dataset.get_dataset_size()
else:
self.loop_size = dataset.__loop_size__
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'):
_send_data(dataset)
else:
_send_data(dataset)
self.ind = 0
self.dataset = dataset
......
......@@ -278,7 +278,7 @@ class Model:
if self._parameter_broadcast:
self._train_network.set_broadcast_flag()
train_dataset.__no_send__ = True
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
is_train=True,
phase='train',
......@@ -295,6 +295,7 @@ class Model:
self._eval_network.set_train(False)
self._eval_network.phase = 'eval'
valid_dataset.__no_send__ = True
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
is_train=False,
phase='eval',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册