builder.py 1018 字节
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
import paddle
import numbers
import numpy as np
from paddle.imperative import ParallelEnv

from paddle.incubate.hapi.distributed import DistributedBatchSampler
from ..utils.registry import Registry


DATASETS = Registry("DATASETS")


def build_dataloader(cfg, is_train=True):
    dataset = DATASETS.get(cfg.name)(cfg)
    
    batch_size = cfg.get('batch_size', 1)

    # dataloader = DictDataLoader(dataset, batch_size, is_train)

    place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) \
                    if ParallelEnv().nranks > 1 else paddle.fluid.CUDAPlace(0)

    sampler = DistributedBatchSampler(
                dataset,
                batch_size=batch_size,
                shuffle=True if is_train else False,
                drop_last=True if is_train else False)

    dataloader = paddle.io.DataLoader(dataset,
                                      batch_sampler=sampler,
                                      places=place,
                                      num_workers=0)

    return dataloader