提交 20049bbe 编写于 作者: W wangnan39@huawei.com

send data after model init

上级 cb6211f2
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Dataset help for minddata dataset""" """Dataset help for minddata dataset"""
from mindspore._checkparam import check_bool from mindspore._checkparam import check_bool
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
from mindspore.train.dataset_helper import _send_data
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \ from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
_to_full_shapes _to_full_shapes
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
...@@ -67,7 +68,13 @@ class _DatasetIter: ...@@ -67,7 +68,13 @@ class _DatasetIter:
self.loop_size = dataset.get_dataset_size() self.loop_size = dataset.get_dataset_size()
else: else:
self.loop_size = dataset.__loop_size__ self.loop_size = dataset.__loop_size__
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'):
_send_data(dataset)
else:
_send_data(dataset)
self.ind = 0 self.ind = 0
self.dataset = dataset self.dataset = dataset
......
...@@ -16,11 +16,10 @@ ...@@ -16,11 +16,10 @@
import os import os
import numpy as np import numpy as np
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore import log as logger from mindspore import log as logger
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.common.dtype import pytype_to_dtype
def _convert_type(types): def _convert_type(types):
...@@ -63,9 +62,6 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): ...@@ -63,9 +62,6 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
dataset_shapes, dataset_shapes,
input_indexs, input_indexs,
phase=phase) phase=phase)
# engine dataset to write data to tdt queue
exec_dataset.send()
return exec_dataset return exec_dataset
......
...@@ -24,6 +24,14 @@ from ..nn.wrap import GetNextSingleOp ...@@ -24,6 +24,14 @@ from ..nn.wrap import GetNextSingleOp
from ..parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode from ..parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode
def _send_data(dataset):
"""Engine dataset to write data to tdt queue."""
if not hasattr(dataset, '__has_sent__'):
exec_dataset = dataset.__TRANSFER_DATASET__
exec_dataset.send()
dataset.__has_sent__ = True
class DatasetHelper: class DatasetHelper:
""" """
Help function to use the Minddata dataset. Help function to use the Minddata dataset.
...@@ -82,7 +90,13 @@ class _DatasetIter: ...@@ -82,7 +90,13 @@ class _DatasetIter:
self.loop_size = dataset.get_dataset_size() self.loop_size = dataset.get_dataset_size()
else: else:
self.loop_size = dataset.__loop_size__ self.loop_size = dataset.__loop_size__
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'):
_send_data(dataset)
else:
_send_data(dataset)
self.ind = 0 self.ind = 0
self.dataset = dataset self.dataset = dataset
......
...@@ -278,7 +278,7 @@ class Model: ...@@ -278,7 +278,7 @@ class Model:
if self._parameter_broadcast: if self._parameter_broadcast:
self._train_network.set_broadcast_flag() self._train_network.set_broadcast_flag()
train_dataset.__no_send__ = True
train_dataset_helper, train_network = self._exec_preprocess(self._train_network, train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
is_train=True, is_train=True,
phase='train', phase='train',
...@@ -295,6 +295,7 @@ class Model: ...@@ -295,6 +295,7 @@ class Model:
self._eval_network.set_train(False) self._eval_network.set_train(False)
self._eval_network.phase = 'eval' self._eval_network.phase = 'eval'
valid_dataset.__no_send__ = True
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network, valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
is_train=False, is_train=False,
phase='eval', phase='eval',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册