diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index e4232294a1d5a347ea742148fd1f98a4cccac5cf..bcaf272c419fdfd0991aa7c6f4e488c0d7809f7c 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -106,89 +106,91 @@ def build_dataloader(config, *mode, seed=None): # build dataset if use_dali: from ppcls.data.dataloader.dali import dali_dataloader - return dali_dataloader( + data_loader = dali_dataloader( dataloader_config, mode[-1], paddle.device.get_device(), num_threads=num_workers, seed=seed, enable_fuse=True) - - config_dataset = dataloader_config['dataset'] - config_dataset = copy.deepcopy(config_dataset) - dataset_name = config_dataset.pop('name') - if 'batch_transform_ops' in config_dataset: - batch_transform = config_dataset['batch_transform_ops'] - else: - batch_transform = None - - dataset = eval(dataset_name)(**config_dataset) - - logger.debug("build dataset({}) success...".format(dataset)) - - # build sampler - config_sampler = dataloader_config['sampler'] - if config_sampler and "name" not in config_sampler: - batch_sampler = None - batch_size = config_sampler["batch_size"] - drop_last = config_sampler["drop_last"] - shuffle = config_sampler["shuffle"] - else: - sampler_name = config_sampler.pop("name") - sampler_argspec = inspect.getargspec(eval(sampler_name).__init__).args - if "total_epochs" in sampler_argspec: - config_sampler.update({"total_epochs": epochs}) - batch_sampler = eval(sampler_name)(dataset, **config_sampler) - - logger.debug("build batch_sampler({}) success...".format(batch_sampler)) - - # build batch operator - def mix_collate_fn(batch): - batch = transform(batch, batch_ops) - # batch each field - slots = [] - for items in batch: - for i, item in enumerate(items): - if len(slots) < len(items): - slots.append([item]) - else: - slots[i].append(item) - return [np.stack(slot, axis=0) for slot in slots] - - if isinstance(batch_transform, list): - batch_ops = create_operators(batch_transform, class_num) - batch_collate_fn = mix_collate_fn - else: - batch_collate_fn = None - - init_fn = partial( - worker_init_fn, - num_workers=num_workers, - rank=dist.get_rank(), - seed=seed) if seed is not None else None - - if batch_sampler is None: - data_loader = DataLoader( - dataset=dataset, - places=paddle.device.get_device(), - num_workers=num_workers, - return_list=True, - use_shared_memory=use_shared_memory, - batch_size=batch_size, - shuffle=shuffle, - drop_last=drop_last, - collate_fn=batch_collate_fn, - worker_init_fn=init_fn) else: - data_loader = DataLoader( - dataset=dataset, - places=paddle.device.get_device(), + config_dataset = dataloader_config['dataset'] + config_dataset = copy.deepcopy(config_dataset) + dataset_name = config_dataset.pop('name') + if 'batch_transform_ops' in config_dataset: + batch_transform = config_dataset['batch_transform_ops'] + else: + batch_transform = None + + dataset = eval(dataset_name)(**config_dataset) + + logger.debug("build dataset({}) success...".format(dataset)) + + # build sampler + config_sampler = dataloader_config['sampler'] + if config_sampler and "name" not in config_sampler: + batch_sampler = None + batch_size = config_sampler["batch_size"] + drop_last = config_sampler["drop_last"] + shuffle = config_sampler["shuffle"] + else: + sampler_name = config_sampler.pop("name") + sampler_argspec = inspect.getargspec(eval(sampler_name) + .__init__).args + if "total_epochs" in sampler_argspec: + config_sampler.update({"total_epochs": epochs}) + batch_sampler = eval(sampler_name)(dataset, **config_sampler) + + logger.debug("build batch_sampler({}) success...".format( + batch_sampler)) + + # build batch operator + def mix_collate_fn(batch): + batch = transform(batch, batch_ops) + # batch each field + slots = [] + for items in batch: + for i, item in enumerate(items): + if len(slots) < len(items): + slots.append([item]) + else: + slots[i].append(item) + return [np.stack(slot, axis=0) for slot in slots] + + if isinstance(batch_transform, list): + batch_ops = create_operators(batch_transform, class_num) + batch_collate_fn = mix_collate_fn + else: + batch_collate_fn = None + + init_fn = partial( + worker_init_fn, num_workers=num_workers, - return_list=True, - use_shared_memory=use_shared_memory, - batch_sampler=batch_sampler, - collate_fn=batch_collate_fn, - worker_init_fn=init_fn) + rank=dist.get_rank(), + seed=seed) if seed is not None else None + + if batch_sampler is None: + data_loader = DataLoader( + dataset=dataset, + places=paddle.device.get_device(), + num_workers=num_workers, + return_list=True, + use_shared_memory=use_shared_memory, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + collate_fn=batch_collate_fn, + worker_init_fn=init_fn) + else: + data_loader = DataLoader( + dataset=dataset, + places=paddle.device.get_device(), + num_workers=num_workers, + return_list=True, + use_shared_memory=use_shared_memory, + batch_sampler=batch_sampler, + collate_fn=batch_collate_fn, + worker_init_fn=init_fn) total_samples = len( data_loader.dataset) if not use_dali else data_loader.size diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py index 48b37153ac0040ee5b09ec2ed3e82cf5024c7e77..9640bb9cb8d71d4d572df150575ee135d5987a09 100644 --- a/ppcls/metric/__init__.py +++ b/ppcls/metric/__init__.py @@ -15,6 +15,7 @@ import copy from collections import OrderedDict +from ..utils import logger from .avg_metrics import AvgMetrics from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk from .metrics import DistillationTopkAcc