data_feeder.py 24.8 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

17
from . import core
18
import numpy as np
C
chengduoZH 已提交
19
import os
20 21
import six
from six.moves import zip, range, xrange
Y
yuyang18 已提交
22
import multiprocessing
23
import warnings
24
import struct
Y
Yu Yang 已提交
25

J
Jiabin Yang 已提交
26
from .framework import Variable, default_main_program, _current_expected_place, _non_static_mode, _in_eager_without_dygraph_check
C
chengduo 已提交
27
from .framework import _cpu_num, _cuda_ids
28

Y
Yu Yang 已提交
29 30
__all__ = ['DataFeeder']

L
Leo Chen 已提交
31 32 33
_PADDLE_DTYPE_2_NUMPY_DTYPE = {
    core.VarDesc.VarType.BOOL: 'bool',
    core.VarDesc.VarType.FP16: 'float16',
34
    core.VarDesc.VarType.BF16: 'uint16',
L
Leo Chen 已提交
35 36 37 38 39 40 41 42 43 44 45
    core.VarDesc.VarType.FP32: 'float32',
    core.VarDesc.VarType.FP64: 'float64',
    core.VarDesc.VarType.INT8: 'int8',
    core.VarDesc.VarType.INT16: 'int16',
    core.VarDesc.VarType.INT32: 'int32',
    core.VarDesc.VarType.INT64: 'int64',
    core.VarDesc.VarType.UINT8: 'uint8',
    core.VarDesc.VarType.COMPLEX64: 'complex64',
    core.VarDesc.VarType.COMPLEX128: 'complex128',
}

Y
Yu Yang 已提交
46

47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
def copy_bits_from_float_to_uint16(f):
    return struct.unpack('<I', struct.pack('<f', f))[0] >> 16


def convert_float_to_uint16(data, data_format="NCHW"):
    if data.size == 0:
        return data.view(np.uint16)

    if data_format == "NHWC":
        data = np.transpose(data, [0, 3, 1, 2])

    new_data = []
    for x in np.nditer(data):
        new_data.append(np.uint16(copy_bits_from_float_to_uint16(x)))
    new_data = np.reshape(new_data, data.shape).view(np.uint16)

    if data_format == "NHWC":
        new_data = np.transpose(new_output, [0, 2, 3, 1])
    return new_data


S
sneaxiy 已提交
68
def convert_dtype(dtype):
P
pkpk 已提交
69
    if isinstance(dtype, core.VarDesc.VarType):
L
Leo Chen 已提交
70 71
        if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE:
            return _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype]
72 73
    elif isinstance(dtype, type):
        if dtype in [
74
                bool, np.float16, np.uint16, np.float32, np.float64, np.int8,
75 76
                np.int16, np.int32, np.int64, np.uint8, np.complex64,
                np.complex128
77 78
        ]:
            return dtype.__name__
P
pkpk 已提交
79 80
    else:
        if dtype in [
81 82 83 84 85
                'bool', 'float16', 'uint16', 'float32', 'float64', 'int8',
                'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128',
                u'bool', u'float16', u'uint16', u'float32', u'float64', u'int8',
                u'int16', u'int32', u'int64', u'uint8', u'complex64',
                u'complex128'
P
pkpk 已提交
86 87
        ]:
            # this code is a little bit dangerous, since error could happen
88
            # when casting no-ascii code to str in python2.
P
pkpk 已提交
89 90 91 92
            # but since the set itself is limited, so currently, it is good.
            # however, jointly supporting python2 and python3, (as well as python4 maybe)
            # may still be a long-lasting problem.
            return str(dtype)
93 94 95
        # NOTE(zhangbo): Now numpy does not support bfloat, so use numpy.uint16 to represent paddle.bfloat16, there binaries are consistent.
        # If cast ndarray to uint16 and trans to tensor, should not ndarray.astype('uint16') directly
        # should use function 'convert_float_to_uint16' above, otherwise bits is wrong
96 97
        if dtype in ['bfloat16']:
            return 'uint16'
P
pkpk 已提交
98

99
    raise TypeError(
100
        "dtype must be any of [bool, float16, uint16, float32, float64, int8, int16, "
101
        "int32, int64, uint8, complex64, complex128], but received %s" % dtype)
S
sneaxiy 已提交
102 103


104 105 106 107 108
def check_variable_and_dtype(input,
                             input_name,
                             expected_dtype,
                             op_name,
                             extra_message=''):
109
    check_type(input, input_name, Variable, op_name, extra_message)
110 111 112 113
    check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)


def check_type(input, input_name, expected_type, op_name, extra_message=''):
114 115 116 117 118 119 120
    # NOTE [ Why skip dynamic graph check ]:
    # 1. If the input type / dtype of a layer is wrong, it will be reported
    # directly on that line. User can easily print the relevant information
    # on which line. It is easier to debug, so there is no need to check
    # in dynamic graph mode.
    # 2. Performance considerations. Because these checks are executed at
    # each step in dynamic graph mode, it will bring a heavy performance burden.
J
Jiabin Yang 已提交
121
    if _non_static_mode():
122
        return
123 124 125 126

    # NOTE: `in_declarative_mode` is used to determined whether this op is called under
    # @declarative in transformation from dygrah to static layer. We add VarBase in
    # expected_type to skip checking because varBase may be created and used in unusual way.
127
    from .dygraph.base import in_declarative_mode
128 129 130 131 132
    # Need a better design to be fix this.
    if in_declarative_mode():
        if not isinstance(expected_type, tuple):
            expected_type = (expected_type, )
        expected_type += (core.VarBase, )
J
Jiabin Yang 已提交
133
        if _in_eager_without_dygraph_check():
134
            expected_type += (core.eager.Tensor, )
135 136 137 138 139
    elif isinstance(input, core.VarBase):
        raise TypeError(
            "Please use `with fluid.dygraph.guard()` as context or `fluid.enable_dygraph()` to switch to imperative mode firstly. "
            "Because received '{}' in {} is a imperative Variable.".format(
                input_name, op_name))
140
    elif hasattr(core, "eager"):
141
        if isinstance(input, core.eager.Tensor):
142 143 144 145
            raise TypeError(
                "Please use `with fluid.dygraph.guard()` as context or `fluid.enable_dygraph()` to switch to imperative mode firstly. "
                "Because received '{}' in {} is a imperative Variable.".format(
                    input_name, op_name))
146 147 148 149 150 151 152 153 154 155 156
    if not isinstance(input, expected_type):
        raise TypeError(
            "The type of '%s' in %s must be %s, but received %s. %s" %
            (input_name, op_name, expected_type, type(input), extra_message))


def check_dtype(input_dtype,
                input_name,
                expected_dtype,
                op_name,
                extra_message=''):
157
    # See NOTE [ Why skip dynamic graph check ]
J
Jiabin Yang 已提交
158
    if _non_static_mode():
159
        return
160 161 162 163
    if convert_dtype(input_dtype) in ['float16']:
        warnings.warn(
            "The data type of '%s' in %s only support float16 in GPU now. %s" %
            (input_name, op_name, extra_message))
164 165 166 167 168 169
    if convert_dtype(input_dtype) in ['uint16'] and op_name not in [
            'reshape', 'lookup_table', 'scale'
    ]:
        warnings.warn(
            "The data type of '%s' in %s only support bfloat16 in OneDNN now. %s"
            % (input_name, op_name, extra_message))
170 171 172 173 174 175 176
    if convert_dtype(input_dtype) not in expected_dtype:
        raise TypeError(
            "The data type of '%s' in %s must be %s, but received %s. %s" %
            (input_name, op_name, expected_dtype, convert_dtype(input_dtype),
             extra_message))


177 178 179 180 181 182
def check_shape(shape,
                op_name,
                expected_shape_type=(list, tuple, Variable),
                expected_element_type=(int, Variable),
                expected_tensor_dtype=('int32', 'int64')):
    # See NOTE [ Why skip dynamic graph check ]
J
Jiabin Yang 已提交
183
    if _non_static_mode():
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
        return
    check_type(shape, 'shape', expected_shape_type, op_name)
    if expected_element_type is not None and not isinstance(shape, Variable):
        for item in shape:
            check_type(item, 'element of shape', expected_element_type, op_name)
            if expected_tensor_dtype is not None and isinstance(item, Variable):
                check_dtype(
                    item.dtype, 'element of shape', expected_tensor_dtype,
                    op_name,
                    'If element of shape is Tensor, its data type should be {}'.
                    format(', '.join(expected_tensor_dtype)))
    if expected_tensor_dtype is not None and isinstance(shape, Variable):
        check_dtype(shape.dtype, 'shape', expected_tensor_dtype, op_name)


Y
Yu Yang 已提交
199
class DataToLoDTensorConverter(object):
200

Y
Yu Yang 已提交
201 202 203 204
    def __init__(self, place, lod_level, shape, dtype):
        self.place = place
        self.lod_level = lod_level
        self.shape = shape
205 206 207 208 209 210 211
        negtive_count = 0
        for s in self.shape:
            if s < 0:
                negtive_count += 1
            if negtive_count > 1:
                self.shape = None
                break
S
sneaxiy 已提交
212 213
        self.dtype = convert_dtype(dtype)
        self._reset()
Y
Yu Yang 已提交
214

S
sneaxiy 已提交
215
    def _reset(self):
Y
Yu Yang 已提交
216
        self.data = []
S
sneaxiy 已提交
217
        self.lod = [[] for _ in six.moves.range(self.lod_level)]
Y
Yu Yang 已提交
218 219 220 221 222 223 224 225

    def feed(self, data):
        self._feed_impl_(data, self.lod, self.lod_level)

    def _feed_impl_(self, data, lod, lod_level):
        if lod_level == 0:
            self.data.append(data)
        else:
226
            lod[0].append(len(data))
Y
Yu Yang 已提交
227
            for each_data in data:
K
Kexin Zhao 已提交
228
                self._feed_impl_(each_data, lod[1:], lod_level - 1)
Y
Yu Yang 已提交
229

S
sneaxiy 已提交
230
    def _check_shape(self, shape):
S
sneaxiy 已提交
231 232 233
        for s1, s2 in zip(self.shape, shape):
            if s1 != s2 and s1 >= 0 and s2 >= 0:
                raise ValueError(
234 235
                    "Shape not match. What is defined in data layer is {}, but receive {}"
                    .format(self.shape, shape))
S
sneaxiy 已提交
236

Y
Yu Yang 已提交
237
    def done(self):
238
        arr = np.array(self.data, dtype=self.dtype)
S
sneaxiy 已提交
239 240
        if self.shape:
            if len(arr.shape) != len(self.shape):
S
sneaxiy 已提交
241 242 243 244 245 246
                try:
                    arr = arr.reshape(self.shape)
                except ValueError:
                    raise ValueError(
                        "Reshape error. What is defined in data layer is {}, but receive {}"
                        .format(self.shape, arr.shape))
Y
Yu Yang 已提交
247 248 249
        t = core.LoDTensor()
        t.set(arr, self.place)
        if self.lod_level > 0:
250
            t.set_recursive_sequence_lengths(self.lod)
S
sneaxiy 已提交
251
        self._reset()
Y
Yu Yang 已提交
252 253 254
        return t


S
sneaxiy 已提交
255
class BatchedTensorProvider(object):
256

S
sneaxiy 已提交
257 258 259 260 261 262 263 264 265 266
    def __init__(self, feed_list, place, batch_size, generator, drop_last):
        self.place = place
        self.batch_size = batch_size
        self.generator = generator
        self.converters = []
        self.drop_last = drop_last

        for var in feed_list:
            assert var.lod_level == 0, "lod_level must be 0"
            self.converters.append(
267 268 269 270
                DataToLoDTensorConverter(place=self.place,
                                         lod_level=0,
                                         shape=var.shape,
                                         dtype=var.dtype))
S
sneaxiy 已提交
271 272 273 274 275 276 277

    def _done(self):
        return [c.done() for c in self.converters]

    def __call__(self):
        idx = 0
        for each_sample in self.generator():
278 279
            for each_slot, each_converter in six.moves.zip(
                    each_sample, self.converters):
S
sneaxiy 已提交
280 281 282 283 284 285 286 287 288 289 290 291 292
                each_converter.data.append(each_slot)

            idx += 1
            if idx == self.batch_size:
                idx = 0
                yield self._done()

        if not self.drop_last and idx > 0:
            yield self._done()
        else:
            [c._reset() for c in self.converters]


Y
Yu Yang 已提交
293
class DataFeeder(object):
C
chengduoZH 已提交
294
    """
295 296
    :api_attr: Static Graph
    
C
chengduoZH 已提交
297
    DataFeeder converts the data that returned by a reader into a data
298 299 300 301 302 303 304 305 306 307 308 309 310 311
    structure that can feed into Executor. The reader is usually a 
    python generator that returns a list of mini-batch data entries. 

    Parameters:
        feed_list (list): Variables or names of Variables that need
            to feed.
        place (:ref:`api_fluid_CPUPlace` | :ref:`api_fluid_CUDAPlace` ): 
            place indicates the device (CPU | GPU) the data will be fed into, if 
            you want to feed data into GPU, please using :code:`fluid.CUDAPlace(i)` 
            (:code:`i` represents the GPU id), or if you want to feed data into CPU, 
            please using :code:`fluid.CPUPlace()`.
        program (:ref:`api_fluid_Program` , optional): The Program that will 
            feed data into, if program is None, it will use default_main_program(). 
            Default None.
C
chengduoZH 已提交
312 313

    Raises:
314
        :code:`ValueError` - If some Variables are not in this Program.
C
chengduoZH 已提交
315

316
    Example:
317 318 319 320 321 322
        ..  code-block:: python

            import numpy as np
            import paddle
            import paddle.fluid as fluid
            
C
chengduoZH 已提交
323
            place = fluid.CPUPlace()
324
            def reader():
325 326
                for _ in range(4):
                    yield np.random.random([4]).astype('float32'), np.random.random([3]).astype('float32'),
327 328 329 330 331
            
            main_program = fluid.Program()
            startup_program = fluid.Program()
            
            with fluid.program_guard(main_program, startup_program):
332 333
                data_1 = fluid.data(name='data_1', shape=[None, 2, 2], dtype='float32')
                data_2 = fluid.data(name='data_2', shape=[None, 1, 3], dtype='float32')
334 335 336
                out = fluid.layers.fc(input=[data_1, data_2], size=2)
                # ...
            feeder = fluid.DataFeeder([data_1, data_2], place)
337
            
338 339
            exe = fluid.Executor(place)
            exe.run(startup_program)
340 341 342 343 344 345 346 347 348 349
            
            feed_data = feeder.feed(reader())
            
            # print feed_data to view feed results
            # print(feed_data['data_1'])
            # print(feed_data['data_2'])
            
            outs = exe.run(program=main_program,
                            feed=feed_data,
                            fetch_list=[out])
350
            print(outs)
351

C
chengduoZH 已提交
352 353
    """

F
fengjiayi 已提交
354
    def __init__(self, feed_list, place, program=None):
Y
Yu Yang 已提交
355 356 357 358
        self.feed_dtypes = []
        self.feed_names = []
        self.feed_shapes = []
        self.feed_lod_level = []
F
fengjiayi 已提交
359 360
        if program is None:
            program = default_main_program()
Y
Yu Yang 已提交
361
        for each_var in feed_list:
362
            if isinstance(each_var, six.string_types):
F
fengjiayi 已提交
363
                each_var = program.block(0).var(each_var)
Y
Yu Yang 已提交
364 365 366 367 368
            if not isinstance(each_var, Variable):
                raise TypeError("Feed list should contain a list of variable")
            self.feed_dtypes.append(each_var.dtype)
            self.feed_names.append(each_var.name)
            self.feed_lod_level.append(each_var.lod_level)
S
sneaxiy 已提交
369
            self.feed_shapes.append(each_var.shape)
Y
Yu Yang 已提交
370 371 372 373

        self.place = place

    def feed(self, iterable):
C
chengduoZH 已提交
374
        """
375 376
        According to :code:`feed_list` of :code:`DataFeeder` and :code:`iterable` , converts 
        the input into a data structure that can feed into Executor.
C
chengduoZH 已提交
377

378 379
        Parameters:
            iterable (generator): user defined python generator to read the raw input data
C
chengduoZH 已提交
380

381 382
        Returns: 
            :code:`dict`: a :code:`dict` that contains (variable name - converted tensor) pairs
383

384
        Example:
385 386
            ..  code-block:: python

387 388 389 390 391 392
                # In this example, reader - generator will return a list of ndarray of 3 elements
                # feed API will convert each ndarray input into a tensor
                # the return result is a dict with keys: data_1, data_2, data_3
                # result['data_1']  a LoD-Tensor with shape of  [5, 2, 1, 3]. 5 is batch size, and [2, 1, 3] is the real shape of data_1.
                # result['data_2'], result['data_3'] are similar.
                import numpy as np
393 394 395
                import paddle.fluid as fluid
                
                def reader(limit=5):
396 397
                    for i in range(1, limit + 1):
                        yield np.ones([6]).astype('float32') * i , np.ones([1]).astype('int64') * i, np.random.random([9]).astype('float32')
398
                
399 400 401
                data_1 = fluid.data(name='data_1', shape=[None, 2, 1, 3])
                data_2 = fluid.data(name='data_2', shape=[None, 1], dtype='int64')
                data_3 = fluid.data(name='data_3', shape=[None, 3, 3], dtype='float32')
402 403
                feeder = fluid.DataFeeder(['data_1','data_2', 'data_3'], fluid.CPUPlace())
                
404 405 406 407
                
                result = feeder.feed(reader())
                print(result['data_1'])
                print(result['data_2'])
408
                print(result['data_3'])
409

C
chengduoZH 已提交
410
        """
Y
Yu Yang 已提交
411
        converter = []
412 413 414
        for lod_level, shape, dtype in six.moves.zip(self.feed_lod_level,
                                                     self.feed_shapes,
                                                     self.feed_dtypes):
Y
Yu Yang 已提交
415
            converter.append(
416 417 418 419
                DataToLoDTensorConverter(place=self.place,
                                         lod_level=lod_level,
                                         shape=shape,
                                         dtype=dtype))
Y
Yu Yang 已提交
420 421

        for each_sample in iterable:
422
            assert len(each_sample) == len(converter), (
423 424
                "The number of fields in data (%d) does not match " +
                "len(feed_list) (%d)") % (len(each_sample), len(converter))
425 426
            for each_converter, each_slot in six.moves.zip(
                    converter, each_sample):
Y
Yu Yang 已提交
427 428
                each_converter.feed(each_slot)
        ret_dict = {}
429 430
        for each_name, each_converter in six.moves.zip(self.feed_names,
                                                       converter):
Y
Yu Yang 已提交
431 432
            ret_dict[each_name] = each_converter.done()
        return ret_dict
Y
yuyang18 已提交
433 434

    def feed_parallel(self, iterable, num_places=None):
C
chengduoZH 已提交
435
        """
436 437
        Similar with feed function, feed_parallel is used with multiple devices (CPU|GPU).
        Here :code:`iterable` is a list of python generators. The data return by each 
T
tianshuo78520a 已提交
438
        generator in the list will be fed into a separate device.        
C
chengduoZH 已提交
439

440
        Parameters:
T
tianshuo78520a 已提交
441
            iterable (list|tuple): list of user-defined python generators. The element 
442 443 444
                number should match the :code:`num_places`.
            num_places (int, optional): the number of devices. If not provided (None), 
                all available devices on the machine will be used. Default None.
C
chengduoZH 已提交
445

446 447 448
        Returns: 
            :code:`generator`: a :code:`generator` that generate dict which contains (variable name - converted tensor) pairs, 
            the total number of dicts will be generated matches with the :code:`num_places`
C
chengduoZH 已提交
449

450 451
        .. note::        
            The number of devices - :code:`num_places` should equal to the generator (element of :code:`iterable` ) number
452

453
        Example:
454 455
            ..  code-block:: python

456
                import numpy as np
457
                import paddle.fluid as fluid
458

459 460 461 462 463
                def generate_reader(batch_size, base=0, factor=1):
                    def _reader():
                        for i in range(batch_size):
                            yield np.ones([4]) * factor + base, np.ones([4]) * factor + base + 5
                    return _reader()
464 465 466 467

                x = fluid.data(name='x', shape=[None, 2, 2])
                y = fluid.data(name='y', shape=[None, 2, 2], dtype='float32')

468
                z = fluid.layers.elementwise_add(x, y)
469

470
                feeder = fluid.DataFeeder(['x','y'], fluid.CPUPlace())
471
                place_num = 2
472 473 474 475 476
                places = [fluid.CPUPlace() for x in range(place_num)]
                data = []
                exe = fluid.Executor(fluid.CPUPlace())
                exe.run(fluid.default_startup_program())
                program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(places=places)
477

T
tianshuo78520a 已提交
478
                # print sample feed_parallel r result
479 480 481
                # for item in list(feeder.feed_parallel([generate_reader(5, 0, 1), generate_reader(3, 10, 2)], 2)):
                #     print(item['x'])
                #     print(item['y'])
482

483 484 485
                reader_list = [generate_reader(5, 0, 1), generate_reader(3, 10, 2)]
                res = exe.run(program=program, feed=list(feeder.feed_parallel(reader_list, 2)), fetch_list=[z])
                print(res)
486

C
chengduoZH 已提交
487
        """
Y
yuyang18 已提交
488 489
        if isinstance(self.place, core.CUDAPlace):
            places = [
490
                core.CUDAPlace(i) for i in six.moves.xrange(
491
                    self._get_number_of_places_(num_places))
Y
yuyang18 已提交
492 493 494
            ]
        else:
            places = [
495
                core.CPUPlace() for _ in six.moves.xrange(
496
                    self._get_number_of_places_(num_places))
Y
yuyang18 已提交
497 498 499 500 501 502 503 504 505
            ]

        if len(iterable) != len(places):
            raise ValueError("feed_parallel takes multiple mini-batches. Each "
                             "mini-batch will be feed on each device. The "
                             "number of devices and number of mini-batches "
                             "must be same.")

        place = self.place
506
        for p, batch in six.moves.zip(places, iterable):
Y
yuyang18 已提交
507 508 509 510 511 512 513 514
            self.place = p
            yield self.feed(batch)
        self.place = place

    def _get_number_of_places_(self, num_places):
        if num_places is not None:
            return int(num_places)
        elif isinstance(self.place, core.CUDAPlace):
C
chengduo 已提交
515
            return len(_cuda_ids())
Y
yuyang18 已提交
516
        else:
C
chengduo 已提交
517
            return _cpu_num()
Y
yuyang18 已提交
518 519 520 521 522 523

    def decorate_reader(self,
                        reader,
                        multi_devices,
                        num_places=None,
                        drop_last=True):
C
chengduoZH 已提交
524
        """
525 526 527 528 529
        Decorate the reader (generator) to fit multiple devices. The reader generate
        multiple mini-batches. Each mini-batch will be fed into a single device.

        Parameters:
            reader(generator): a user defined python generator used to get :code:`mini-batch` of data.
T
tianshuo78520a 已提交
530
                A :code:`mini-batch` can be regarded as a python generator that returns batches of input 
531 532 533 534 535 536 537 538 539 540 541
                entities, just like the below :code:`_mini_batch` in the code example.                      
            multi_devices(bool): indicate whether to use multiple devices or not.
            num_places(int, optional): if :code:`multi_devices` is True, you can specify the number
                of devices(CPU|GPU) to use, if multi_devices is None, the function will use all the
                devices of the current machine. Default None.
            drop_last(bool, optional): whether to drop the last round of data if it is not enough to 
                feed all devices. Default True.

        Returns: 
            :code:`generator`: a new :code:`generator` which return converted dicts that can be fed into Executor
            
C
chengduoZH 已提交
542
        Raises:
543
            :code:`ValueError`: If drop_last is False and the data cannot fit devices perfectly.
544

545
        Example:
546 547
            ..  code-block:: python

548
                import numpy as np
549 550
                import paddle
                import paddle.fluid as fluid
551
                import paddle.fluid.compiler as compiler
552
                
553 554 555 556
                def reader():
                    def _mini_batch(batch_size):
                        for i in range(batch_size):
                            yield np.random.random([16]).astype('float32'), np.random.randint(10, size=[1])
557

558 559
                    for _ in range(10):
                        yield _mini_batch(np.random.randint(1, 10))
560
                
561 562
                place_num = 3
                places = [fluid.CPUPlace() for _ in range(place_num)]
563
                
564
                # a simple network sample
565 566
                data = fluid.data(name='data', shape=[None, 4, 4], dtype='float32')
                label = fluid.data(name='label', shape=[None, 1], dtype='int64')
567 568
                hidden = fluid.layers.fc(input=data, size=10)
                
569 570
                feeder = fluid.DataFeeder(place=places[0], feed_list=[data, label])
                reader = feeder.decorate_reader(reader, multi_devices=True, num_places=3, drop_last=True)
571
                
572
                exe = fluid.Executor(places[0])
573
                exe.run(fluid.default_startup_program())
574
                compiled_prog = compiler.CompiledProgram(
575 576
                         fluid.default_main_program()).with_data_parallel(places=places)
                
577
                for i,data in enumerate(reader()):
578 579
                    # print data if you like
                    # print(i, data)
580
                    ret = exe.run(compiled_prog, feed=data, fetch_list=[hidden])
581 582
                    print(ret)

C
chengduoZH 已提交
583 584
        """

Y
yuyang18 已提交
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603
        def __reader_creator__():
            if not multi_devices:
                for item in reader():
                    yield self.feed(item)
            else:
                num = self._get_number_of_places_(num_places)
                item = []
                for batch in reader():
                    item.append(batch)
                    if len(item) == num:
                        yield list(self.feed_parallel(item, num))
                        item = []
                if not drop_last and len(item) != 0:
                    raise ValueError(
                        "The data batch which cannot fit for devices will be "
                        "dropped is not implementation. Other strategies are "
                        "not implemented")

        return __reader_creator__