dataset.py 14.9 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
from . import core
D
dongdaxiang 已提交
18
__all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
D
dongdaxiang 已提交
19 20 21


class DatasetFactory(object):
22 23
    """
    DatasetFactory is a factory which create dataset by its name,
H
hutuxian 已提交
24
    you can create "QueueDataset" or "InMemoryDataset", or "FileInstantDataset",
25 26 27
    the default is "QueueDataset".

    Example:
28 29 30 31 32
        .. code-block:: python

          import paddle.fluid as fluid
          dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")

33
    """
D
dongdaxiang 已提交
34

D
dongdaxiang 已提交
35
    def __init__(self):
36
        """ Init. """
D
dongdaxiang 已提交
37 38
        pass

39
    def create_dataset(self, datafeed_class="QueueDataset"):
40
        """
H
hutuxian 已提交
41
        Create "QueueDataset" or "InMemoryDataset", or "FileInstantDataset",
42
        the default is "QueueDataset".
D
dongdaxiang 已提交
43

44 45 46 47
        Args:
            datafeed_class(str): datafeed class name, QueueDataset or InMemoryDataset.
                                 Default is QueueDataset.

D
dongdaxiang 已提交
48
        Examples:
49 50 51 52 53
            .. code-block:: python

              import paddle.fluid as fluid
              dataset = fluid.DatasetFactory().create_dataset()

54
        """
D
dongdaxiang 已提交
55 56
        try:
            dataset = globals()[datafeed_class]()
57
            return dataset
D
dongdaxiang 已提交
58 59 60 61 62 63
        except:
            raise ValueError("datafeed class %s does not exist" %
                             datafeed_class)


class DatasetBase(object):
64
    """ Base dataset class. """
D
dongdaxiang 已提交
65

D
dongdaxiang 已提交
66
    def __init__(self):
67
        """ Init. """
D
dongdaxiang 已提交
68 69 70 71
        # define class name here
        # to decide whether we need create in memory instance
        self.proto_desc = data_feed_pb2.DataFeedDesc()
        self.proto_desc.pipe_command = "cat"
X
xujiaqi01 已提交
72
        self.dataset = core.Dataset("MultiSlotDataset")
73
        self.thread_num = 0
D
dongdaxiang 已提交
74 75 76 77 78 79

    def set_pipe_command(self, pipe_command):
        """
        Set pipe command of current dataset
        A pipe command is a UNIX pipeline command that can be used only

80 81 82 83 84 85
        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              dataset = fluid.DatasetFactory().create_dataset()
              dataset.set_pipe_command("python my_script.py")
86 87

        Args:
88
            pipe_command(str): pipe command
89

D
dongdaxiang 已提交
90 91 92 93 94 95 96
        """
        self.proto_desc.pipe_command = pipe_command

    def set_batch_size(self, batch_size):
        """
        Set batch size. Will be effective during training

97 98 99 100 101 102
        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              dataset = fluid.DatasetFactory().create_dataset()
              dataset.set_batch_size(128)
D
dongdaxiang 已提交
103 104

        Args:
105
            batch_size(int): batch size
D
dongdaxiang 已提交
106 107 108 109

        """
        self.proto_desc.batch_size = batch_size

110
    def set_thread(self, thread_num):
111 112 113
        """
        Set thread num, it is the num of readers.

114 115 116 117 118 119
        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              dataset = fluid.DatasetFactory().create_dataset()
               dataset.set_thread(12)
120 121

        Args:
122
            thread_num(int): thread num
123
        """
124
        self.dataset.set_thread_num(thread_num)
125
        self.thread_num = thread_num
126 127

    def set_filelist(self, filelist):
128 129 130
        """
        Set file list in current worker.

131 132 133 134 135 136
        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              dataset = fluid.DatasetFactory().create_dataset()
              dataset.set_filelist(['a.txt', 'b.txt'])
137 138

        Args:
139
            filelist(list): file list
140
        """
141 142
        self.dataset.set_filelist(filelist)

D
dongdaxiang 已提交
143
    def set_use_var(self, var_list):
144 145 146
        """
        Set Variables which you will use.

147 148 149 150 151 152
        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              dataset = fluid.DatasetFactory().create_dataset()
              dataset.set_use_var([data, label])
153 154

        Args:
155
            var_list(list): variable list
156
        """
157
        multi_slot = self.proto_desc.multi_slot_desc
D
dongdaxiang 已提交
158
        for var in var_list:
159
            slot_var = multi_slot.slots.add()
D
dongdaxiang 已提交
160 161 162 163
            slot_var.is_used = True
            slot_var.name = var.name
            if var.lod_level == 0:
                slot_var.is_dense = True
164
                slot_var.shape.extend(var.shape)
165
            if var.dtype == core.VarDesc.VarType.FP32:
D
dongdaxiang 已提交
166
                slot_var.type = "float"
167
            elif var.dtype == core.VarDesc.VarType.INT64:
D
dongdaxiang 已提交
168 169 170 171 172 173
                slot_var.type = "uint64"
            else:
                raise ValueError(
                    "Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
                )

174
    def set_hdfs_config(self, fs_name, fs_ugi):
175 176 177
        """
        Set hdfs config: fs name ad ugi

178 179 180 181 182 183
        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              dataset = fluid.DatasetFactory().create_dataset()
              dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")
184 185

        Args:
186 187
            fs_name(str): fs name
            fs_ugi(str): fs ugi
188
        """
189 190
        self.dataset.set_hdfs_config(fs_name, fs_ugi)

191
    def _prepare_to_run(self):
192 193 194 195
        """
        Set data_feed_desc before load or shuffle,
        user no need to call this function.
        """
196 197
        self.dataset.set_data_feed_desc(self.desc())

D
dongdaxiang 已提交
198 199 200 201
    def desc(self):
        """
        Returns a protobuf message for this DataFeedDesc

202 203 204 205 206 207
        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              dataset = fluid.DatasetFactory().create_dataset()
              print(dataset.desc())
D
dongdaxiang 已提交
208 209 210 211 212 213 214 215

        Returns:
            A string message
        """
        return text_format.MessageToString(self.proto_desc)


class InMemoryDataset(DatasetBase):
216 217
    """
    InMemoryDataset, it will load data into memory
D
dongdaxiang 已提交
218 219
    and shuffle data before training.
    This class should be created by DatasetFactory
220 221

    Example:
222
        dataset = paddle.fluid.DatasetFactory().create_dataset("InMemoryDataset")
223
    """
D
dongdaxiang 已提交
224

D
dongdaxiang 已提交
225
    def __init__(self):
226
        """ Init. """
227 228 229 230
        super(InMemoryDataset, self).__init__()
        self.proto_desc.name = "MultiSlotInMemoryDataFeed"

    def load_into_memory(self):
231 232 233
        """
        Load data into memory

234 235 236 237 238 239 240 241
        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
              filelist = ["a.txt", "b.txt"]
              dataset.set_filelist(filelist)
              dataset.load_into_memory()
242
        """
243
        self._prepare_to_run()
244
        self.dataset.load_into_memory()
D
dongdaxiang 已提交
245 246

    def local_shuffle(self):
247 248 249
        """
        Local shuffle

250 251 252 253 254 255 256 257 258
        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
              filelist = ["a.txt", "b.txt"]
              dataset.set_filelist(filelist)
              dataset.load_into_memory()
              dataset.local_shuffle()
259
        """
260
        self.dataset.local_shuffle()
D
dongdaxiang 已提交
261

262
    def global_shuffle(self, fleet=None):
263 264
        """
        Global shuffle.
265 266 267
        Global shuffle can be used only in distributed mode. i.e. multiple
        processes on single machine or multiple machines training together.
        If you run in distributed mode, you should pass fleet instead of None.
268

269
        Examples:
270 271 272 273 274 275 276 277 278
            .. code-block:: python

              import paddle.fluid as fluid
              from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
              dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
              filelist = ["a.txt", "b.txt"]
              dataset.set_filelist(filelist)
              dataset.load_into_memory()
              dataset.global_shuffle(fleet)
279 280

        Args:
281 282
            fleet(Fleet): fleet singleton. Default None.

283
        """
284
        trainer_num = 1
X
xjqbest 已提交
285
        fleet_send_batch_size = 80000
286
        if fleet is not None:
287
            fleet._role_maker._barrier_worker()
288
            trainer_num = fleet.worker_num()
289
        self.dataset.register_client2client_msg_handler()
290
        self.dataset.set_trainer_num(trainer_num)
X
xjqbest 已提交
291
        self.dataset.set_fleet_send_batch_size(fleet_send_batch_size)
292
        if fleet is not None:
293
            fleet._role_maker._barrier_worker()
X
xujiaqi01 已提交
294
        self.dataset.global_shuffle()
295
        if fleet is not None:
296
            fleet._role_maker._barrier_worker()
D
dongdaxiang 已提交
297

298 299 300 301
    def release_memory(self):
        """
        Release InMemoryDataset memory data, when data will not be used again.

302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
              dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
              filelist = ["a.txt", "b.txt"]
              dataset.set_filelist(filelist)
              dataset.load_into_memory()
              dataset.global_shuffle(fleet)
              exe = fluid.Executor(fluid.CPUPlace())
              exe.run(fluid.default_startup_program())
              exe.train_from_dataset(fluid.default_main_program(), dataset)
              dataset.release_memory()

317 318
        """
        self.dataset.release_memory()
D
dongdaxiang 已提交
319

320 321 322 323 324 325 326 327 328 329 330 331 332 333
    def get_memory_data_size(self, fleet=None):
        """
        Get memory data size, user can call this function to know the num
        of ins in all workers after load into memory.

        Note:
            This function may cause bad performance, because it has barrier

        Args:
            fleet(Fleet): Fleet Object.

        Returns:
            The size of memory data.

334 335 336 337 338 339 340 341 342 343
        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
              dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
              filelist = ["a.txt", "b.txt"]
              dataset.set_filelist(filelist)
              dataset.load_into_memory()
              print dataset.get_memory_data_size(fleet)
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370

        """
        import numpy as np
        local_data_size = self.dataset.get_memory_data_size()
        local_data_size = np.array([local_data_size])
        if fleet is not None:
            global_data_size = local_data_size * 0
            fleet._role_maker._node_type_comm.Allreduce(local_data_size,
                                                        global_data_size)
            return global_data_size[0]
        return local_data_size[0]

    def get_shuffle_data_size(self, fleet=None):
        """
        Get shuffle data size, user can call this function to know the num
        of ins in all workers after local/global shuffle.

        Note:
            This function may cause bad performance to local shuffle,
            because it has barrier. It does not affect global shuffle.

        Args:
            fleet(Fleet): Fleet Object.

        Returns:
            The size of shuffle data.

371 372 373 374 375 376 377 378 379 380 381
        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
              dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
              filelist = ["a.txt", "b.txt"]
              dataset.set_filelist(filelist)
              dataset.load_into_memory()
              dataset.global_shuffle(fleet)
              print dataset.get_shuffle_data_size(fleet)
382 383 384 385 386 387 388 389 390 391 392 393

        """
        import numpy as np
        local_data_size = self.dataset.get_shuffle_data_size()
        local_data_size = np.array([local_data_size])
        if fleet is not None:
            global_data_size = local_data_size * 0
            fleet._role_maker._node_type_comm.Allreduce(local_data_size,
                                                        global_data_size)
            return global_data_size[0]
        return local_data_size[0]

X
xjqbest 已提交
394

D
dongdaxiang 已提交
395
class QueueDataset(DatasetBase):
396 397 398
    """
    QueueDataset, it will process data streamly.

399 400 401 402 403 404
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          dataset = fluid.DatasetFactory().create_dataset("QueueDataset")

405
    """
D
dongdaxiang 已提交
406

D
dongdaxiang 已提交
407
    def __init__(self):
408
        """
D
dongdaxiang 已提交
409 410
        Initialize QueueDataset
        This class should be created by DatasetFactory
411
        """
412
        super(QueueDataset, self).__init__()
D
dongdaxiang 已提交
413
        self.proto_desc.name = "MultiSlotDataFeed"
X
xujiaqi01 已提交
414 415

    def local_shuffle(self):
416
        """
417
        Local shuffle data.
D
dongdaxiang 已提交
418

D
dongdaxiang 已提交
419 420
        Local shuffle is not supported in QueueDataset
        NotImplementedError will be raised
421 422 423 424 425 426 427 428

        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
              dataset.local_shuffle()

429
        """
D
dongdaxiang 已提交
430 431 432
        raise NotImplementedError(
            "QueueDataset does not support local shuffle, "
            "please use InMemoryDataset for local_shuffle")
X
xujiaqi01 已提交
433

434
    def global_shuffle(self, fleet=None):
435
        """
436 437
        Global shuffle data.

D
dongdaxiang 已提交
438 439
        Global shuffle is not supported in QueueDataset
        NotImplementedError will be raised
440 441 442 443 444 445 446 447 448

        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
              dataset = fluid.DatasetFactory().create_dataset("QueueDataset")
              dataset.global_shuffle(fleet)

449
        """
D
dongdaxiang 已提交
450 451 452
        raise NotImplementedError(
            "QueueDataset does not support global shuffle, "
            "please use InMemoryDataset for global_shuffle")
H
hutuxian 已提交
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485


class FileInstantDataset(DatasetBase):
    """
    FileInstantDataset, it will process data streamly.
    Example:
        import paddle.fluid as fluid
        dataset = fluid.DatasetFactory.create_dataset("FileInstantDataset")
    """

    def __init__(self):
        """
        Init
        """
        super(FileInstantDataset, self).__init__()
        self.proto_desc.name = "MultiSlotFileInstantDataFeed"

    def local_shuffle(self):
        """
        Local shuffle
        FileInstantDataset does not support local shuffle
        """
        raise NotImplementedError(
            "FileInstantDataset does not support local shuffle, "
            "please use InMemoryDataset for local_shuffle")

    def global_shuffle(self, fleet=None):
        """
        Global shuffle
        """
        raise NotImplementedError(
            "FileInstantDataset does not support global shuffle, "
            "please use InMemoryDataset for global_shuffle")