dataset_helper.py 4.4 KB
Newer Older
Z
z00478463 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset help for minddata dataset"""
16
from mindspore._checkparam import check_bool
Z
z00478463 已提交
17 18 19
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
    _to_full_shapes
20 21 22
from mindspore.train.parallel_utils import ParallelMode


Z
z00478463 已提交
23 24 25
class DatasetHelper:
    """
    Help function to use the Minddata dataset.
Z
z00478463 已提交
26

Z
z00478463 已提交
27
    According to different context, change the iter of dataset, to use the same for loop in different context.
Z
z00478463 已提交
28

Z
z00478463 已提交
29 30
    Note:
        The iter of DatasetHelper will give one epoch data.
Z
z00478463 已提交
31

Z
z00478463 已提交
32 33 34 35
    Args:
        dataset (DataSet): The dataset.
        dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host.
            Default: True.
Z
z00478463 已提交
36

Z
z00478463 已提交
37 38 39 40 41
    Examples:
        >>> dataset_helper = DatasetHelper(dataset)
        >>> for inputs in dataset_helper:
        >>>     outputs = network(*inputs)
    """
42

Z
z00478463 已提交
43
    def __init__(self, dataset, dataset_sink_mode=True, iter_first_order=0):
Z
z00478463 已提交
44
        check_bool(dataset_sink_mode)
Z
z00478463 已提交
45
        self.iter = _DatasetIterMSLoopSink(dataset, iter_first_order)
46

Z
z00478463 已提交
47 48
    def __iter__(self):
        return self.iter.__iter__()
49

Z
z00478463 已提交
50 51 52 53
    # A temp solution for loop sink. Delete later
    def types_shapes(self):
        """Get the types and shapes from dataset on current config."""
        return self.iter.types_shapes()
54

Z
z00478463 已提交
55 56 57
    def loop_size(self):
        """Get loop_size for every iteration."""
        return self.iter.loop_size
58 59


Z
z00478463 已提交
60 61
class _DatasetIter:
    """Base iter for dataset help"""
62

Z
z00478463 已提交
63 64 65 66 67 68 69 70
    def __init__(self, dataset):
        self.loop_size = 1
        if not hasattr(dataset, '__ME_INITED__'):
            if not hasattr(dataset, '__loop_size__'):
                self.loop_size = dataset.get_dataset_size()
            else:
                self.loop_size = dataset.__loop_size__
            dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
71

Z
z00478463 已提交
72 73 74 75
        self.ind = 0
        self.dataset = dataset
        dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
        self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes
76

Z
z00478463 已提交
77 78 79
    def __iter__(self):
        self.ind = 0
        return self
80

Z
z00478463 已提交
81 82 83 84 85
    def __next__(self):
        if self.ind >= self.loop_count:
            raise StopIteration()
        self.ind += 1
        return self.op()
86

Z
z00478463 已提交
87 88
    def types_shapes(self):
        return self.dataset_types, self.dataset_shapes
89

Z
z00478463 已提交
90 91 92 93
    def get_loop_count(self, dataset):
        loop_count = 1
        if hasattr(dataset, '__loop_size__'):
            loop_size = dataset.__loop_size__
Z
z00478463 已提交
94 95 96
            if dataset.get_dataset_size() % loop_size != 0:
                raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
                                 f'loop_size {loop_size} are not matched.')
97
            loop_count = int(dataset.get_dataset_size() / loop_size)
Z
z00478463 已提交
98
        return loop_count
99 100


Z
z00478463 已提交
101
class _DatasetIterMSLoopSink(_DatasetIter):
Z
z00478463 已提交
102
    """Iter for context (device_target=Ascend)"""
103

Z
z00478463 已提交
104
    def __init__(self, dataset, iter_first_order):
Z
z00478463 已提交
105
        super(_DatasetIterMSLoopSink, self).__init__(dataset)
Z
z00478463 已提交
106
        loop_size = dataset.__loop_size__ + iter_first_order
107
        self.loop_count = int(dataset.get_dataset_size() / loop_size) * 2
Z
z00478463 已提交
108 109 110 111 112 113
        # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to
        # compile, and slice tensor to run. The batch dimension of tensors for compile is device_number
        # times the batch dimension of tensors for run. Now only support LoopSink.
        if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
            device_num = _get_device_num()
            self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
114

Z
z00478463 已提交
115 116
        def op():
            return tuple()
117

Z
z00478463 已提交
118
        self.op = op