提交 a41a5bcb 编写于 作者: G gaotingquan 提交者: Wei Shengyu

debug

上级 ab29eaa8
...@@ -106,14 +106,14 @@ def build_dataloader(config, *mode, seed=None): ...@@ -106,14 +106,14 @@ def build_dataloader(config, *mode, seed=None):
# build dataset # build dataset
if use_dali: if use_dali:
from ppcls.data.dataloader.dali import dali_dataloader from ppcls.data.dataloader.dali import dali_dataloader
return dali_dataloader( data_loader = dali_dataloader(
dataloader_config, dataloader_config,
mode[-1], mode[-1],
paddle.device.get_device(), paddle.device.get_device(),
num_threads=num_workers, num_threads=num_workers,
seed=seed, seed=seed,
enable_fuse=True) enable_fuse=True)
else:
config_dataset = dataloader_config['dataset'] config_dataset = dataloader_config['dataset']
config_dataset = copy.deepcopy(config_dataset) config_dataset = copy.deepcopy(config_dataset)
dataset_name = config_dataset.pop('name') dataset_name = config_dataset.pop('name')
...@@ -135,12 +135,14 @@ def build_dataloader(config, *mode, seed=None): ...@@ -135,12 +135,14 @@ def build_dataloader(config, *mode, seed=None):
shuffle = config_sampler["shuffle"] shuffle = config_sampler["shuffle"]
else: else:
sampler_name = config_sampler.pop("name") sampler_name = config_sampler.pop("name")
sampler_argspec = inspect.getargspec(eval(sampler_name).__init__).args sampler_argspec = inspect.getargspec(eval(sampler_name)
.__init__).args
if "total_epochs" in sampler_argspec: if "total_epochs" in sampler_argspec:
config_sampler.update({"total_epochs": epochs}) config_sampler.update({"total_epochs": epochs})
batch_sampler = eval(sampler_name)(dataset, **config_sampler) batch_sampler = eval(sampler_name)(dataset, **config_sampler)
logger.debug("build batch_sampler({}) success...".format(batch_sampler)) logger.debug("build batch_sampler({}) success...".format(
batch_sampler))
# build batch operator # build batch operator
def mix_collate_fn(batch): def mix_collate_fn(batch):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import copy import copy
from collections import OrderedDict from collections import OrderedDict
from ..utils import logger
from .avg_metrics import AvgMetrics from .avg_metrics import AvgMetrics
from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk
from .metrics import DistillationTopkAcc from .metrics import DistillationTopkAcc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册