未验证 提交 7d01ea47 编写于 作者: W wangguanzhong 提交者: GitHub

Fix yaml constructor (#5915)

* fix yaml constructor in eval & export

* fix typo
上级 81f84e23
...@@ -27,3 +27,4 @@ from .category import * ...@@ -27,3 +27,4 @@ from .category import *
from .keypoint_coco import * from .keypoint_coco import *
from .mot import * from .mot import *
from .sniper_coco import SniperCOCODataSet from .sniper_coco import SniperCOCODataSet
from .dataset import ImageFolder
...@@ -67,11 +67,13 @@ class Trainer(object): ...@@ -67,11 +67,13 @@ class Trainer(object):
self.is_loaded_weights = False self.is_loaded_weights = False
# build data loader # build data loader
capital_mode = self.mode.capitalize()
if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']: if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
self.dataset = create('{}MOTDataset'.format(self.mode.capitalize( self.dataset = self.cfg['{}MOTDataset'.format(
)))() capital_mode)] = create('{}MOTDataset'.format(capital_mode))()
else: else:
self.dataset = create('{}Dataset'.format(self.mode.capitalize()))() self.dataset = self.cfg['{}Dataset'.format(capital_mode)] = create(
'{}Dataset'.format(capital_mode))()
if cfg.architecture == 'DeepSORT' and self.mode == 'train': if cfg.architecture == 'DeepSORT' and self.mode == 'train':
logger.error('DeepSORT has no need of training on mot dataset.') logger.error('DeepSORT has no need of training on mot dataset.')
...@@ -82,7 +84,7 @@ class Trainer(object): ...@@ -82,7 +84,7 @@ class Trainer(object):
self.dataset.set_images(images) self.dataset.set_images(images)
if self.mode == 'train': if self.mode == 'train':
self.loader = create('{}Reader'.format(self.mode.capitalize()))( self.loader = create('{}Reader'.format(capital_mode))(
self.dataset, cfg.worker_num) self.dataset, cfg.worker_num)
if cfg.architecture == 'JDE' and self.mode == 'train': if cfg.architecture == 'JDE' and self.mode == 'train':
...@@ -371,6 +373,8 @@ class Trainer(object): ...@@ -371,6 +373,8 @@ class Trainer(object):
def train(self, validate=False): def train(self, validate=False):
assert self.mode == 'train', "Model not in 'train' mode" assert self.mode == 'train', "Model not in 'train' mode"
Init_mark = False Init_mark = False
if validate:
self.cfg.EvalDataset = create("EvalDataset")()
sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
self.cfg.use_gpu and self._nranks > 1) self.cfg.use_gpu and self._nranks > 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册