From 7d01ea47931494c323b56c7e2ecd09cd088e7672 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Tue, 10 May 2022 17:23:24 +0800 Subject: [PATCH] Fix yaml constructor (#5915) * fix yaml constructor in eval & export * fix typo --- ppdet/data/source/__init__.py | 1 + ppdet/engine/trainer.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/ppdet/data/source/__init__.py b/ppdet/data/source/__init__.py index 3854d3d25..e3abb16b6 100644 --- a/ppdet/data/source/__init__.py +++ b/ppdet/data/source/__init__.py @@ -27,3 +27,4 @@ from .category import * from .keypoint_coco import * from .mot import * from .sniper_coco import SniperCOCODataSet +from .dataset import ImageFolder diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index b7737b1e1..e124855f9 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -67,11 +67,13 @@ class Trainer(object): self.is_loaded_weights = False # build data loader + capital_mode = self.mode.capitalize() if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']: - self.dataset = create('{}MOTDataset'.format(self.mode.capitalize( - )))() + self.dataset = self.cfg['{}MOTDataset'.format( + capital_mode)] = create('{}MOTDataset'.format(capital_mode))() else: - self.dataset = create('{}Dataset'.format(self.mode.capitalize()))() + self.dataset = self.cfg['{}Dataset'.format(capital_mode)] = create( + '{}Dataset'.format(capital_mode))() if cfg.architecture == 'DeepSORT' and self.mode == 'train': logger.error('DeepSORT has no need of training on mot dataset.') @@ -82,7 +84,7 @@ class Trainer(object): self.dataset.set_images(images) if self.mode == 'train': - self.loader = create('{}Reader'.format(self.mode.capitalize()))( + self.loader = create('{}Reader'.format(capital_mode))( self.dataset, cfg.worker_num) if cfg.architecture == 'JDE' and self.mode == 'train': @@ -371,6 +373,8 @@ class Trainer(object): def train(self, validate=False): assert self.mode == 'train', "Model not in 'train' mode" Init_mark = False + if validate: + self.cfg.EvalDataset = create("EvalDataset")() sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and self.cfg.use_gpu and self._nranks > 1) -- GitLab