BatchSampler_cn.rst 1.9 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
.. _cn_api_fluid_io_BatchSampler:

BatchSampler
-------------------------------

.. py:class:: paddle.fluid.io.BatchSampler(dataset=None, indices=None, shuffle=False, batch_size=1, drop_last=False)

``fluid.io.DataLoader`` 使用的批次索引采样器,其可以迭代返回mini-batch的索引列表(长度为 ``batch_size`` ,内容为样本索引)。
``fluid.io.DataLoader`` 的 ``batch_sampler`` 参数必须为 ``BatchSampler`` 及其子类实例。 ``BatchSampler`` 子类须实现如下两个方法:

``__iter__`` : 迭代返回mini-batch索引列表。

``__len__`` : 每个epoch中的mini-batch个数。

参数:
    - **dataset** (Dataset) - ``fluid.io.Dataset`` 实例或者实现了 ``__len__`` 接口的python对象,用于生成 ``dataset`` 长度范围的索引。默认值为None。
    - **indices** (list|tuple) - 用于迭代的下标,``dataset`` 的替代参数, ``dataset`` 和 ``indices`` 必须设置其中之一。默认值为None。
    - **shuffle** (bool) - 迭代返回索引之前是否对索引打乱顺序。默认值为False。
    - **batch_size** (int) - 每mini-batch中的索引下标个数。默认值为1。
    - **drop_last** (int) - 是否丢弃因数据集样本数不能被 ``batch_size`` 整除而产生的最后一个不完成的mini-batch索引。默认值为False。

返回:迭代索引列表的迭代器

返回类型: BatchSampler

**代码示例**

.. code-block:: python

    from paddle.fluid.io import BatchSampler, MNIST

    # init with indices
    bs = BatchSampler(indices=list(range(1000)),
                      shuffle=True,
                      batch_size=8,
                      drop_last=True)

    for batch_indices in bs:
        print(batch_indices)

    # init with dataset
    bs = BatchSampler(dataset=MNIST(mode='test')),
                      shuffle=False,
                      batch_size=16,
                      drop_last=False)

    for batch_indices in bs:
        print(batch_indices)