dataset.py 12.1 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
from . import core
D
dongdaxiang 已提交
18
__all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
D
dongdaxiang 已提交
19 20 21


class DatasetFactory(object):
22 23 24 25 26 27 28 29
    """
    DatasetFactory is a factory which create dataset by its name,
    you can create "QueueDataset" or "InMemoryDataset",
    the default is "QueueDataset".

    Example:
        dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
    """
D
dongdaxiang 已提交
30

D
dongdaxiang 已提交
31
    def __init__(self):
32
        """ Init. """
D
dongdaxiang 已提交
33 34
        pass

35
    def create_dataset(self, datafeed_class="QueueDataset"):
36 37 38
        """
        Create "QueueDataset" or "InMemoryDataset",
        the default is "QueueDataset".
D
dongdaxiang 已提交
39

40 41 42 43
        Args:
            datafeed_class(str): datafeed class name, QueueDataset or InMemoryDataset.
                                 Default is QueueDataset.

D
dongdaxiang 已提交
44 45 46
        Examples:
            import paddle.fluid as fluid
            dataset = fluid.DatasetFactory().create_dataset()
47
        """
D
dongdaxiang 已提交
48 49
        try:
            dataset = globals()[datafeed_class]()
50
            return dataset
D
dongdaxiang 已提交
51 52 53 54 55 56
        except:
            raise ValueError("datafeed class %s does not exist" %
                             datafeed_class)


class DatasetBase(object):
57
    """ Base dataset class. """
D
dongdaxiang 已提交
58

D
dongdaxiang 已提交
59
    def __init__(self):
60
        """ Init. """
D
dongdaxiang 已提交
61 62 63 64
        # define class name here
        # to decide whether we need create in memory instance
        self.proto_desc = data_feed_pb2.DataFeedDesc()
        self.proto_desc.pipe_command = "cat"
X
xujiaqi01 已提交
65
        self.dataset = core.Dataset("MultiSlotDataset")
66
        self.thread_num = 0
D
dongdaxiang 已提交
67 68 69 70 71 72

    def set_pipe_command(self, pipe_command):
        """
        Set pipe command of current dataset
        A pipe command is a UNIX pipeline command that can be used only

73 74 75 76
        Example:
            >>> dataset.set_pipe_command("python my_script.py")

        Args:
77
            pipe_command(str): pipe command
78

D
dongdaxiang 已提交
79 80 81 82 83 84 85 86
        """
        self.proto_desc.pipe_command = pipe_command

    def set_batch_size(self, batch_size):
        """
        Set batch size. Will be effective during training

        Example:
87
            >>> dataset.set_batch_size(128)
D
dongdaxiang 已提交
88 89

        Args:
90
            batch_size(int): batch size
D
dongdaxiang 已提交
91 92 93 94

        """
        self.proto_desc.batch_size = batch_size

95
    def set_thread(self, thread_num):
96 97 98 99 100 101 102
        """
        Set thread num, it is the num of readers.

        Example:
            >>> dataset.set_thread(12)

        Args:
103
            thread_num(int): thread num
104
        """
105
        self.dataset.set_thread_num(thread_num)
106
        self.thread_num = thread_num
107 108

    def set_filelist(self, filelist):
109 110 111 112 113 114 115
        """
        Set file list in current worker.

        Example:
            >>> dataset.set_filelist(['a.txt', 'b.txt'])

        Args:
116
            filelist(list): file list
117
        """
118 119
        self.dataset.set_filelist(filelist)

D
dongdaxiang 已提交
120
    def set_use_var(self, var_list):
121 122 123 124 125 126 127
        """
        Set Variables which you will use.

        Example:
            >>> dataset.set_use_var([data, label])

        Args:
128
            var_list(list): variable list
129
        """
130
        multi_slot = self.proto_desc.multi_slot_desc
D
dongdaxiang 已提交
131
        for var in var_list:
132
            slot_var = multi_slot.slots.add()
D
dongdaxiang 已提交
133 134 135 136
            slot_var.is_used = True
            slot_var.name = var.name
            if var.lod_level == 0:
                slot_var.is_dense = True
137
                slot_var.shape.extend(var.shape)
138
            if var.dtype == core.VarDesc.VarType.FP32:
D
dongdaxiang 已提交
139
                slot_var.type = "float"
140
            elif var.dtype == core.VarDesc.VarType.INT64:
D
dongdaxiang 已提交
141 142 143 144 145 146
                slot_var.type = "uint64"
            else:
                raise ValueError(
                    "Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
                )

147
    def set_hdfs_config(self, fs_name, fs_ugi):
148 149 150 151 152 153 154
        """
        Set hdfs config: fs name ad ugi

        Example:
            >>> dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")

        Args:
155 156
            fs_name(str): fs name
            fs_ugi(str): fs ugi
157
        """
158 159
        self.dataset.set_hdfs_config(fs_name, fs_ugi)

160
    def _prepare_to_run(self):
161 162 163 164
        """
        Set data_feed_desc before load or shuffle,
        user no need to call this function.
        """
165 166
        self.dataset.set_data_feed_desc(self.desc())

D
dongdaxiang 已提交
167 168 169 170 171
    def desc(self):
        """
        Returns a protobuf message for this DataFeedDesc

        Example:
172
            >>> print(dataset.desc())
D
dongdaxiang 已提交
173 174 175 176 177 178 179 180

        Returns:
            A string message
        """
        return text_format.MessageToString(self.proto_desc)


class InMemoryDataset(DatasetBase):
181 182
    """
    InMemoryDataset, it will load data into memory
D
dongdaxiang 已提交
183 184
    and shuffle data before training.
    This class should be created by DatasetFactory
185 186 187 188

    Example:
        dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
    """
D
dongdaxiang 已提交
189

D
dongdaxiang 已提交
190
    def __init__(self):
191
        """ Init. """
192 193 194 195
        super(InMemoryDataset, self).__init__()
        self.proto_desc.name = "MultiSlotInMemoryDataFeed"

    def load_into_memory(self):
196 197 198 199
        """
        Load data into memory

        Example:
D
dongdaxiang 已提交
200 201 202 203
            >>> import paddle.fluid as fluid
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
204 205
            >>> dataset.load_into_memory()
        """
206
        self._prepare_to_run()
207
        self.dataset.load_into_memory()
D
dongdaxiang 已提交
208 209

    def local_shuffle(self):
210 211 212 213
        """
        Local shuffle

        Example:
D
dongdaxiang 已提交
214 215 216 217
            >>> import paddle.fluid as fluid
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
218
            >>> dataset.load_into_memory()
219 220
            >>> dataset.local_shuffle()
        """
221
        self.dataset.local_shuffle()
D
dongdaxiang 已提交
222

223
    def global_shuffle(self, fleet=None):
224 225
        """
        Global shuffle.
226 227 228
        Global shuffle can be used only in distributed mode. i.e. multiple
        processes on single machine or multiple machines training together.
        If you run in distributed mode, you should pass fleet instead of None.
229

230
        Examples:
D
dongdaxiang 已提交
231
            >>> import paddle.fluid as fluid
232
            >>> from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
D
dongdaxiang 已提交
233 234 235
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
236
            >>> dataset.load_into_memory()
237 238 239
            >>> dataset.global_shuffle(fleet)

        Args:
240 241
            fleet(Fleet): fleet singleton. Default None.

242
        """
243
        trainer_num = 1
X
xjqbest 已提交
244
        fleet_send_batch_size = 80000
245
        if fleet is not None:
246
            fleet._role_maker._barrier_worker()
247
            trainer_num = fleet.worker_num()
248
        self.dataset.register_client2client_msg_handler()
249
        self.dataset.set_trainer_num(trainer_num)
X
xjqbest 已提交
250
        self.dataset.set_fleet_send_batch_size(fleet_send_batch_size)
251
        if fleet is not None:
252
            fleet._role_maker._barrier_worker()
X
xujiaqi01 已提交
253
        self.dataset.global_shuffle()
254
        if fleet is not None:
255
            fleet._role_maker._barrier_worker()
D
dongdaxiang 已提交
256

257 258 259 260 261 262
    def release_memory(self):
        """
        Release InMemoryDataset memory data, when data will not be used again.

        Example:
            >>> import paddle.fluid as fluid
263
            >>> from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
264 265 266 267 268 269 270 271 272 273 274
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
            >>> dataset.load_into_memory()
            >>> dataset.global_shuffle(fleet)
            >>> exe = fluid.Executor(fluid.CPUPlace())
            >>> exe.run(fluid.default_startup_program())
            >>> exe.train_from_dataset(fluid.default_main_program(), dataset)
            >>> dataset.release_memory()
        """
        self.dataset.release_memory()
D
dongdaxiang 已提交
275

276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
    def get_memory_data_size(self, fleet=None):
        """
        Get memory data size, user can call this function to know the num
        of ins in all workers after load into memory.

        Note:
            This function may cause bad performance, because it has barrier

        Args:
            fleet(Fleet): Fleet Object.

        Returns:
            The size of memory data.

        Example:
            >>> import paddle.fluid as fluid
            >>> from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
            >>> dataset.load_into_memory()
            >>> print dataset.get_memory_data_size(fleet)

        """
        import numpy as np
        local_data_size = self.dataset.get_memory_data_size()
        local_data_size = np.array([local_data_size])
        if fleet is not None:
            global_data_size = local_data_size * 0
            fleet._role_maker._node_type_comm.Allreduce(local_data_size,
                                                        global_data_size)
            return global_data_size[0]
        return local_data_size[0]

    def get_shuffle_data_size(self, fleet=None):
        """
        Get shuffle data size, user can call this function to know the num
        of ins in all workers after local/global shuffle.

        Note:
            This function may cause bad performance to local shuffle,
            because it has barrier. It does not affect global shuffle.

        Args:
            fleet(Fleet): Fleet Object.

        Returns:
            The size of shuffle data.

        Example:
            >>> import paddle.fluid as fluid
            >>> from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
            >>> dataset = fluid.DatasetFactory.create_dataset("InMemoryDataset")
            >>> filelist = ["a.txt", "b.txt"]
            >>> dataset.set_filelist(filelist)
            >>> dataset.load_into_memory()
            >>> dataset.global_shuffle(fleet)
            >>> print dataset.get_shuffle_data_size(fleet)

        """
        import numpy as np
        local_data_size = self.dataset.get_shuffle_data_size()
        local_data_size = np.array([local_data_size])
        if fleet is not None:
            global_data_size = local_data_size * 0
            fleet._role_maker._node_type_comm.Allreduce(local_data_size,
                                                        global_data_size)
            return global_data_size[0]
        return local_data_size[0]

X
xjqbest 已提交
346

D
dongdaxiang 已提交
347
class QueueDataset(DatasetBase):
348 349 350 351
    """
    QueueDataset, it will process data streamly.

    Example:
D
dongdaxiang 已提交
352 353
        import paddle.fluid as fluid
        dataset = fluid.DatasetFactory.create_dataset("QueueDataset")
354
    """
D
dongdaxiang 已提交
355

D
dongdaxiang 已提交
356
    def __init__(self):
357
        """
D
dongdaxiang 已提交
358 359
        Initialize QueueDataset
        This class should be created by DatasetFactory
360
        """
361
        super(QueueDataset, self).__init__()
D
dongdaxiang 已提交
362
        self.proto_desc.name = "MultiSlotDataFeed"
X
xujiaqi01 已提交
363 364

    def local_shuffle(self):
365 366
        """
        Local shuffle
D
dongdaxiang 已提交
367

D
dongdaxiang 已提交
368 369
        Local shuffle is not supported in QueueDataset
        NotImplementedError will be raised
370
        """
D
dongdaxiang 已提交
371 372 373
        raise NotImplementedError(
            "QueueDataset does not support local shuffle, "
            "please use InMemoryDataset for local_shuffle")
X
xujiaqi01 已提交
374

375
    def global_shuffle(self, fleet=None):
376
        """
D
dongdaxiang 已提交
377 378
        Global shuffle is not supported in QueueDataset
        NotImplementedError will be raised
379
        """
D
dongdaxiang 已提交
380 381 382
        raise NotImplementedError(
            "QueueDataset does not support global shuffle, "
            "please use InMemoryDataset for global_shuffle")