提交 a41a5bcb 编写于 作者: G gaotingquan 提交者: Wei Shengyu

debug

上级 ab29eaa8
...@@ -106,89 +106,91 @@ def build_dataloader(config, *mode, seed=None): ...@@ -106,89 +106,91 @@ def build_dataloader(config, *mode, seed=None):
# build dataset # build dataset
if use_dali: if use_dali:
from ppcls.data.dataloader.dali import dali_dataloader from ppcls.data.dataloader.dali import dali_dataloader
return dali_dataloader( data_loader = dali_dataloader(
dataloader_config, dataloader_config,
mode[-1], mode[-1],
paddle.device.get_device(), paddle.device.get_device(),
num_threads=num_workers, num_threads=num_workers,
seed=seed, seed=seed,
enable_fuse=True) enable_fuse=True)
config_dataset = dataloader_config['dataset']
config_dataset = copy.deepcopy(config_dataset)
dataset_name = config_dataset.pop('name')
if 'batch_transform_ops' in config_dataset:
batch_transform = config_dataset['batch_transform_ops']
else:
batch_transform = None
dataset = eval(dataset_name)(**config_dataset)
logger.debug("build dataset({}) success...".format(dataset))
# build sampler
config_sampler = dataloader_config['sampler']
if config_sampler and "name" not in config_sampler:
batch_sampler = None
batch_size = config_sampler["batch_size"]
drop_last = config_sampler["drop_last"]
shuffle = config_sampler["shuffle"]
else:
sampler_name = config_sampler.pop("name")
sampler_argspec = inspect.getargspec(eval(sampler_name).__init__).args
if "total_epochs" in sampler_argspec:
config_sampler.update({"total_epochs": epochs})
batch_sampler = eval(sampler_name)(dataset, **config_sampler)
logger.debug("build batch_sampler({}) success...".format(batch_sampler))
# build batch operator
def mix_collate_fn(batch):
batch = transform(batch, batch_ops)
# batch each field
slots = []
for items in batch:
for i, item in enumerate(items):
if len(slots) < len(items):
slots.append([item])
else:
slots[i].append(item)
return [np.stack(slot, axis=0) for slot in slots]
if isinstance(batch_transform, list):
batch_ops = create_operators(batch_transform, class_num)
batch_collate_fn = mix_collate_fn
else:
batch_collate_fn = None
init_fn = partial(
worker_init_fn,
num_workers=num_workers,
rank=dist.get_rank(),
seed=seed) if seed is not None else None
if batch_sampler is None:
data_loader = DataLoader(
dataset=dataset,
places=paddle.device.get_device(),
num_workers=num_workers,
return_list=True,
use_shared_memory=use_shared_memory,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
collate_fn=batch_collate_fn,
worker_init_fn=init_fn)
else: else:
data_loader = DataLoader( config_dataset = dataloader_config['dataset']
dataset=dataset, config_dataset = copy.deepcopy(config_dataset)
places=paddle.device.get_device(), dataset_name = config_dataset.pop('name')
if 'batch_transform_ops' in config_dataset:
batch_transform = config_dataset['batch_transform_ops']
else:
batch_transform = None
dataset = eval(dataset_name)(**config_dataset)
logger.debug("build dataset({}) success...".format(dataset))
# build sampler
config_sampler = dataloader_config['sampler']
if config_sampler and "name" not in config_sampler:
batch_sampler = None
batch_size = config_sampler["batch_size"]
drop_last = config_sampler["drop_last"]
shuffle = config_sampler["shuffle"]
else:
sampler_name = config_sampler.pop("name")
sampler_argspec = inspect.getargspec(eval(sampler_name)
.__init__).args
if "total_epochs" in sampler_argspec:
config_sampler.update({"total_epochs": epochs})
batch_sampler = eval(sampler_name)(dataset, **config_sampler)
logger.debug("build batch_sampler({}) success...".format(
batch_sampler))
# build batch operator
def mix_collate_fn(batch):
batch = transform(batch, batch_ops)
# batch each field
slots = []
for items in batch:
for i, item in enumerate(items):
if len(slots) < len(items):
slots.append([item])
else:
slots[i].append(item)
return [np.stack(slot, axis=0) for slot in slots]
if isinstance(batch_transform, list):
batch_ops = create_operators(batch_transform, class_num)
batch_collate_fn = mix_collate_fn
else:
batch_collate_fn = None
init_fn = partial(
worker_init_fn,
num_workers=num_workers, num_workers=num_workers,
return_list=True, rank=dist.get_rank(),
use_shared_memory=use_shared_memory, seed=seed) if seed is not None else None
batch_sampler=batch_sampler,
collate_fn=batch_collate_fn, if batch_sampler is None:
worker_init_fn=init_fn) data_loader = DataLoader(
dataset=dataset,
places=paddle.device.get_device(),
num_workers=num_workers,
return_list=True,
use_shared_memory=use_shared_memory,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
collate_fn=batch_collate_fn,
worker_init_fn=init_fn)
else:
data_loader = DataLoader(
dataset=dataset,
places=paddle.device.get_device(),
num_workers=num_workers,
return_list=True,
use_shared_memory=use_shared_memory,
batch_sampler=batch_sampler,
collate_fn=batch_collate_fn,
worker_init_fn=init_fn)
total_samples = len( total_samples = len(
data_loader.dataset) if not use_dali else data_loader.size data_loader.dataset) if not use_dali else data_loader.size
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import copy import copy
from collections import OrderedDict from collections import OrderedDict
from ..utils import logger
from .avg_metrics import AvgMetrics from .avg_metrics import AvgMetrics
from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk
from .metrics import DistillationTopkAcc from .metrics import DistillationTopkAcc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册