dataset.py 8.5 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
from . import core
D
dongdaxiang 已提交
18
__all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
D
dongdaxiang 已提交
19 20 21


class DatasetFactory(object):
22 23 24 25 26 27 28 29
    """
    DatasetFactory is a factory which create dataset by its name,
    you can create "QueueDataset" or "InMemoryDataset",
    the default is "QueueDataset".

    Example:
        dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
    """
D
dongdaxiang 已提交
30

D
dongdaxiang 已提交
31
    def __init__(self):
32 33 34
        """
        Init
        """
D
dongdaxiang 已提交
35 36
        pass

37
    def create_dataset(self, datafeed_class="QueueDataset"):
38 39 40
        """
        Create "QueueDataset" or "InMemoryDataset",
        the default is "QueueDataset".
D
dongdaxiang 已提交
41 42 43 44

        Examples:
            import paddle.fluid as fluid
            dataset = fluid.DatasetFactory().create_dataset()
45
        """
D
dongdaxiang 已提交
46 47
        try:
            dataset = globals()[datafeed_class]()
48
            return dataset
D
dongdaxiang 已提交
49 50 51 52 53 54
        except:
            raise ValueError("datafeed class %s does not exist" %
                             datafeed_class)


class DatasetBase(object):
55 56 57
    """
    Base dataset class
    """
D
dongdaxiang 已提交
58

D
dongdaxiang 已提交
59
    def __init__(self):
60 61 62
        """
        Init
        """
D
dongdaxiang 已提交
63 64 65 66
        # define class name here
        # to decide whether we need create in memory instance
        self.proto_desc = data_feed_pb2.DataFeedDesc()
        self.proto_desc.pipe_command = "cat"
X
xujiaqi01 已提交
67
        self.dataset = core.Dataset("MultiSlotDataset")
68
        self.thread_num = 0
D
dongdaxiang 已提交
69 70 71 72 73 74

    def set_pipe_command(self, pipe_command):
        """
        Set pipe command of current dataset
        A pipe command is a UNIX pipeline command that can be used only

75 76 77 78 79 80
        Example:
            >>> dataset.set_pipe_command("python my_script.py")

        Args:
            pipe_command: pipe command

D
dongdaxiang 已提交
81 82 83 84 85 86 87 88
        """
        self.proto_desc.pipe_command = pipe_command

    def set_batch_size(self, batch_size):
        """
        Set batch size. Will be effective during training

        Example:
89
            >>> dataset.set_batch_size(128)
D
dongdaxiang 已提交
90 91 92 93 94 95 96

        Args:
            batch_size: batch size

        """
        self.proto_desc.batch_size = batch_size

97
    def set_thread(self, thread_num):
98 99 100 101 102 103 104 105 106
        """
        Set thread num, it is the num of readers.

        Example:
            >>> dataset.set_thread(12)

        Args:
            thread_num: thread num
        """
107
        self.dataset.set_thread_num(thread_num)
108
        self.thread_num = thread_num
109 110

    def set_filelist(self, filelist):
111 112 113 114 115 116 117 118 119
        """
        Set file list in current worker.

        Example:
            >>> dataset.set_filelist(['a.txt', 'b.txt'])

        Args:
            filelist: file list
        """
120 121
        self.dataset.set_filelist(filelist)

D
dongdaxiang 已提交
122
    def set_use_var(self, var_list):
123 124 125 126 127 128 129 130 131
        """
        Set Variables which you will use.

        Example:
            >>> dataset.set_use_var([data, label])

        Args:
            var_list: variable list
        """
132
        multi_slot = self.proto_desc.multi_slot_desc
D
dongdaxiang 已提交
133
        for var in var_list:
134
            slot_var = multi_slot.slots.add()
D
dongdaxiang 已提交
135 136 137 138
            slot_var.is_used = True
            slot_var.name = var.name
            if var.lod_level == 0:
                slot_var.is_dense = True
139
                slot_var.shape.extend(var.shape)
140
            if var.dtype == core.VarDesc.VarType.FP32:
D
dongdaxiang 已提交
141
                slot_var.type = "float"
142
            elif var.dtype == core.VarDesc.VarType.INT64:
D
dongdaxiang 已提交
143 144 145 146 147 148
                slot_var.type = "uint64"
            else:
                raise ValueError(
                    "Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
                )

149
    def set_hdfs_config(self, fs_name, fs_ugi):
150 151 152 153 154 155 156 157 158 159
        """
        Set hdfs config: fs name ad ugi

        Example:
            >>> dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")

        Args:
            fs_name: fs name
            fs_ugi: fs ugi
        """
160 161
        self.dataset.set_hdfs_config(fs_name, fs_ugi)

162
    def _prepare_to_run(self):
163 164 165 166
        """
        Set data_feed_desc before load or shuffle,
        user no need to call this function.
        """
167 168
        self.dataset.set_data_feed_desc(self.desc())

D
dongdaxiang 已提交
169 170 171 172 173
    def desc(self):
        """
        Returns a protobuf message for this DataFeedDesc

        Example:
174
            >>> print(dataset.desc())
D
dongdaxiang 已提交
175 176 177 178 179 180 181 182

        Returns:
            A string message
        """
        return text_format.MessageToString(self.proto_desc)


class InMemoryDataset(DatasetBase):
183 184
    """
    InMemoryDataset, it will load data into memory
D
dongdaxiang 已提交
185 186
    and shuffle data before training.
    This class should be created by DatasetFactory
187 188 189 190

    Example:
        dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
    """
D
dongdaxiang 已提交
191

D
dongdaxiang 已提交
192
    def __init__(self):
193 194 195
        """
        Init
        """
196 197 198 199
        super(InMemoryDataset, self).__init__()
        self.proto_desc.name = "MultiSlotInMemoryDataFeed"

    def load_into_memory(self):
200 201 202 203
        """
        Load data into memory

        Example:
D
dongdaxiang 已提交
204 205 206 207
            >>> import paddle.fluid as fluid
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
208 209
            >>> dataset.load_into_memory()
        """
210
        self._prepare_to_run()
211
        self.dataset.load_into_memory()
D
dongdaxiang 已提交
212 213

    def local_shuffle(self):
214 215 216 217
        """
        Local shuffle

        Example:
D
dongdaxiang 已提交
218 219 220 221
            >>> import paddle.fluid as fluid
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
222 223
            >>> dataset.local_shuffle()
        """
224
        self.dataset.local_shuffle()
D
dongdaxiang 已提交
225

226
    def global_shuffle(self, fleet=None):
227 228
        """
        Global shuffle.
229 230 231
        Global shuffle can be used only in distributed mode. i.e. multiple
        processes on single machine or multiple machines training together.
        If you run in distributed mode, you should pass fleet instead of None.
232

233
        Examples:
D
dongdaxiang 已提交
234 235 236 237 238
            >>> import paddle.fluid as fluid
            >>> import paddle.fluid.incubate.fleet.parameter_server as fleet
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
239 240 241 242 243
            >>> dataset.global_shuffle(fleet)

        Args:
            fleet: fleet singleton. Default None.
        """
244
        trainer_num = 1
X
xjqbest 已提交
245
        fleet_send_batch_size = 80000
246
        if fleet is not None:
X
xjqbest 已提交
247
            fleet.fleet_instance.role_maker_._barrier_worker()
248
            trainer_num = fleet.worker_num()
249
        self.dataset.register_client2client_msg_handler()
250
        self.dataset.set_trainer_num(trainer_num)
X
xjqbest 已提交
251
        self.dataset.set_fleet_send_batch_size(fleet_send_batch_size)
252
        if fleet is not None:
X
xjqbest 已提交
253
            fleet.fleet_instance.role_maker_._barrier_worker()
X
xujiaqi01 已提交
254
        self.dataset.global_shuffle()
255
        if fleet is not None:
X
xjqbest 已提交
256
            fleet.fleet_instance.role_maker_._barrier_worker()
D
dongdaxiang 已提交
257 258 259


class QueueDataset(DatasetBase):
260 261 262 263
    """
    QueueDataset, it will process data streamly.

    Example:
D
dongdaxiang 已提交
264 265
        import paddle.fluid as fluid
        dataset = fluid.DatasetFactory.create_dataset("QueueDataset")
266
    """
D
dongdaxiang 已提交
267

D
dongdaxiang 已提交
268
    def __init__(self):
269
        """
D
dongdaxiang 已提交
270 271
        Initialize QueueDataset
        This class should be created by DatasetFactory
272
        """
273
        super(QueueDataset, self).__init__()
D
dongdaxiang 已提交
274
        self.proto_desc.name = "MultiSlotDataFeed"
X
xujiaqi01 已提交
275 276

    def local_shuffle(self):
277 278
        """
        Local shuffle
D
dongdaxiang 已提交
279

D
dongdaxiang 已提交
280 281
        Local shuffle is not supported in QueueDataset
        NotImplementedError will be raised
282
        """
D
dongdaxiang 已提交
283 284 285
        raise NotImplementedError(
            "QueueDataset does not support local shuffle, "
            "please use InMemoryDataset for local_shuffle")
X
xujiaqi01 已提交
286

287
    def global_shuffle(self, fleet=None):
288
        """
D
dongdaxiang 已提交
289 290
        Global shuffle is not supported in QueueDataset
        NotImplementedError will be raised
291
        """
D
dongdaxiang 已提交
292 293 294
        raise NotImplementedError(
            "QueueDataset does not support global shuffle, "
            "please use InMemoryDataset for global_shuffle")