__init__.py 7.0 KB
Newer Older
F
Felix 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
W
WuHaobo 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
dbg  
weishengyu 已提交
14

G
gaotingquan 已提交
15
import inspect
F
Felix 已提交
16
import copy
17
import random
F
Felix 已提交
18 19
import paddle
import numpy as np
20 21
import paddle.distributed as dist
from functools import partial
F
Felix 已提交
22 23 24
from paddle.io import DistributedBatchSampler, BatchSampler, DataLoader
from ppcls.utils import logger

W
dbg  
weishengyu 已提交
25
from ppcls.data import dataloader
F
Felix 已提交
26
# dataset
W
dbg  
weishengyu 已提交
27 28 29 30
from ppcls.data.dataloader.imagenet_dataset import ImageNetDataset
from ppcls.data.dataloader.multilabel_dataset import MultiLabelDataset
from ppcls.data.dataloader.common_dataset import create_operators
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
F
Felix 已提交
31
from ppcls.data.dataloader.logo_dataset import LogoDataset
B
Bin Lu 已提交
32
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
W
dbg  
weishengyu 已提交
33
from ppcls.data.dataloader.mix_dataset import MixDataset
S
sibo2rr 已提交
34
from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
35
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
D
dongshuilong 已提交
36
from ppcls.data.dataloader.face_dataset import FiveValidationDataset, AdaFaceDataset
37
from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset
Z
zh-hike 已提交
38
from ppcls.data.dataloader.cifar import Cifar10, Cifar100, CIFAR10SSL
F
Felix 已提交
39

F
Felix 已提交
40
# sampler
D
dongshuilong 已提交
41
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
W
dbg  
weishengyu 已提交
42
from ppcls.data.dataloader.pk_sampler import PKSampler
W
dbg  
weishengyu 已提交
43
from ppcls.data.dataloader.mix_sampler import MixSampler
44
from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler
C
cuicheng01 已提交
45
from ppcls.data import preprocess
W
dbg  
weishengyu 已提交
46
from ppcls.data.preprocess import transform
F
Felix 已提交
47

D
dongshuilong 已提交
48

G
gaotingquan 已提交
49
def create_operators(params, class_num=None):
littletomatodonkey's avatar
littletomatodonkey 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62
    """
    create operators based on the config

    Args:
        params(list): a dict list, used to create some operators
    """
    assert isinstance(params, list), ('operator config should be a list')
    ops = []
    for operator in params:
        assert isinstance(operator,
                          dict) and len(operator) == 1, "yaml format error"
        op_name = list(operator)[0]
        param = {} if operator[op_name] is None else operator[op_name]
G
gaotingquan 已提交
63 64 65 66
        op_func = getattr(preprocess, op_name)
        if "class_num" in inspect.getfullargspec(op_func).args:
            param.update({"class_num": class_num})
        op = op_func(**param)
littletomatodonkey's avatar
littletomatodonkey 已提交
67 68 69 70 71
        ops.append(op)

    return ops


72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
    """callback function on each worker subprocess after seeding and before data loading.

    Args:
        worker_id (int): Worker id in [0, num_workers - 1]
        num_workers (int): Number of subprocesses to use for data loading.
        rank (int): Rank of process in distributed environment. If in non-distributed environment, it is a constant number `0`.
        seed (int): Random seed
    """
    # The seed of each worker equals to
    # num_worker * rank + worker_id + user_seed
    worker_seed = num_workers * rank + worker_id + seed
    np.random.seed(worker_seed)
    random.seed(worker_seed)


W
Walter 已提交
88
def build_dataloader(config, mode, device, use_dali=False, seed=None):
89
    assert mode in [
D
dongshuilong 已提交
90 91 92
        'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
    ], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
    assert mode in config.keys(), "{} config not in yaml".format(mode)
F
Felix 已提交
93
    # build dataset
W
Walter 已提交
94 95
    if use_dali:
        from ppcls.data.dataloader.dali import dali_dataloader
96 97 98 99 100
        return dali_dataloader(
            config,
            mode,
            paddle.device.get_device(),
            num_threads=config[mode]['loader']["num_workers"],
H
HydrogenSulfate 已提交
101 102
            seed=seed,
            enable_fuse=True)
littletomatodonkey's avatar
littletomatodonkey 已提交
103

G
gaotingquan 已提交
104
    class_num = config.get("class_num", None)
105
    epochs = config.get("epochs", None)
F
Felix 已提交
106
    config_dataset = config[mode]['dataset']
107
    config_dataset = copy.deepcopy(config_dataset)
F
Felix 已提交
108
    dataset_name = config_dataset.pop('name')
C
cuicheng01 已提交
109
    if 'batch_transform_ops' in config_dataset:
F
Felix 已提交
110 111 112 113 114 115
        batch_transform = config_dataset.pop('batch_transform_ops')
    else:
        batch_transform = None

    dataset = eval(dataset_name)(**config_dataset)

L
littletomatodonkey 已提交
116
    logger.debug("build dataset({}) success...".format(dataset))
F
Felix 已提交
117 118 119

    # build sampler
    config_sampler = config[mode]['sampler']
D
dongshuilong 已提交
120
    if config_sampler and "name" not in config_sampler:
F
Felix 已提交
121 122 123 124 125 126
        batch_sampler = None
        batch_size = config_sampler["batch_size"]
        drop_last = config_sampler["drop_last"]
        shuffle = config_sampler["shuffle"]
    else:
        sampler_name = config_sampler.pop("name")
127 128 129
        sampler_argspec = inspect.getargspec(eval(sampler_name).__init__).args
        if "total_epochs" in sampler_argspec:
            config_sampler.update({"total_epochs": epochs})
F
Felix 已提交
130 131
        batch_sampler = eval(sampler_name)(dataset, **config_sampler)

L
littletomatodonkey 已提交
132
    logger.debug("build batch_sampler({}) success...".format(batch_sampler))
F
Felix 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147

    # build batch operator
    def mix_collate_fn(batch):
        batch = transform(batch, batch_ops)
        # batch each field
        slots = []
        for items in batch:
            for i, item in enumerate(items):
                if len(slots) < len(items):
                    slots.append([item])
                else:
                    slots[i].append(item)
        return [np.stack(slot, axis=0) for slot in slots]

    if isinstance(batch_transform, list):
G
gaotingquan 已提交
148
        batch_ops = create_operators(batch_transform, class_num)
F
Felix 已提交
149 150
        batch_collate_fn = mix_collate_fn
    else:
D
dongshuilong 已提交
151
        batch_collate_fn = None
F
Felix 已提交
152 153 154 155 156 157

    # build dataloader
    config_loader = config[mode]['loader']
    num_workers = config_loader["num_workers"]
    use_shared_memory = config_loader["use_shared_memory"]

158 159 160 161 162 163
    init_fn = partial(
        worker_init_fn,
        num_workers=num_workers,
        rank=dist.get_rank(),
        seed=seed) if seed is not None else None

F
Felix 已提交
164 165 166 167 168 169 170 171 172 173
    if batch_sampler is None:
        data_loader = DataLoader(
            dataset=dataset,
            places=device,
            num_workers=num_workers,
            return_list=True,
            use_shared_memory=use_shared_memory,
            batch_size=batch_size,
            shuffle=shuffle,
            drop_last=drop_last,
174 175
            collate_fn=batch_collate_fn,
            worker_init_fn=init_fn)
F
Felix 已提交
176 177 178 179 180 181 182 183
    else:
        data_loader = DataLoader(
            dataset=dataset,
            places=device,
            num_workers=num_workers,
            return_list=True,
            use_shared_memory=use_shared_memory,
            batch_sampler=batch_sampler,
184 185
            collate_fn=batch_collate_fn,
            worker_init_fn=init_fn)
F
Felix 已提交
186

L
littletomatodonkey 已提交
187
    logger.debug("build data_loader({}) success...".format(data_loader))
188
    return data_loader