未验证 提交 0e6468c7 编写于 作者: K Kaipeng Deng 提交者: GitHub

refine trainer (#2412)

* refine trainer
上级 2be546c4
...@@ -184,6 +184,7 @@ class BaseDataLoader(object): ...@@ -184,6 +184,7 @@ class BaseDataLoader(object):
batch_sampler=None, batch_sampler=None,
return_list=False): return_list=False):
self.dataset = dataset self.dataset = dataset
self.dataset.check_or_download_dataset()
self.dataset.parse_dataset() self.dataset.parse_dataset()
# get data # get data
self.dataset.set_transform(self._sample_transforms) self.dataset.set_transform(self._sample_transforms)
......
...@@ -16,7 +16,9 @@ from . import coco ...@@ -16,7 +16,9 @@ from . import coco
# TODO add voc and widerface dataset # TODO add voc and widerface dataset
from . import voc from . import voc
#from . import widerface #from . import widerface
from . import category
from .coco import * from .coco import *
from .voc import * from .voc import *
#from .widerface import * #from .widerface import *
from .category import *
...@@ -49,10 +49,10 @@ class COCODataSet(DetDataset): ...@@ -49,10 +49,10 @@ class COCODataSet(DetDataset):
records = [] records = []
ct = 0 ct = 0
catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)}) self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
cname2cid = dict({ self.cname2cid = dict({
coco.loadCats(catid)[0]['name']: clsid coco.loadCats(catid)[0]['name']: clsid
for catid, clsid in catid2clsid.items() for catid, clsid in self.catid2clsid.items()
}) })
if 'annotations' not in coco.dataset: if 'annotations' not in coco.dataset:
...@@ -119,7 +119,7 @@ class COCODataSet(DetDataset): ...@@ -119,7 +119,7 @@ class COCODataSet(DetDataset):
has_segmentation = False has_segmentation = False
for i, box in enumerate(bboxes): for i, box in enumerate(bboxes):
catid = box['category_id'] catid = box['category_id']
gt_class[i][0] = catid2clsid[catid] gt_class[i][0] = self.catid2clsid[catid]
gt_bbox[i, :] = box['clean_bbox'] gt_bbox[i, :] = box['clean_bbox']
is_crowd[i][0] = box['iscrowd'] is_crowd[i][0] = box['iscrowd']
# check RLE format # check RLE format
...@@ -163,4 +163,4 @@ class COCODataSet(DetDataset): ...@@ -163,4 +163,4 @@ class COCODataSet(DetDataset):
break break
assert len(records) > 0, 'not found any coco record in %s' % (anno_path) assert len(records) > 0, 'not found any coco record in %s' % (anno_path)
logger.debug('{} samples in file {}'.format(ct, anno_path)) logger.debug('{} samples in file {}'.format(ct, anno_path))
self.roidbs, self.cname2cid = records, cname2cid self.roidbs = records
...@@ -67,6 +67,10 @@ class DetDataset(Dataset): ...@@ -67,6 +67,10 @@ class DetDataset(Dataset):
return self.transform(roidb) return self.transform(roidb)
def check_or_download_dataset(self):
self.dataset_dir = get_dataset_path(self.dataset_dir, self.anno_path,
self.image_dir)
def set_kwargs(self, **kwargs): def set_kwargs(self, **kwargs):
self.mixup_epoch = kwargs.get('mixup_epoch', -1) self.mixup_epoch = kwargs.get('mixup_epoch', -1)
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1) self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
......
...@@ -20,7 +20,7 @@ import os ...@@ -20,7 +20,7 @@ import os
import yaml import yaml
from collections import OrderedDict from collections import OrderedDict
from ppdet.metrics import get_categories from ppdet.data.source.category import get_categories
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger('ppdet.engine') logger = setup_logger('ppdet.engine')
......
...@@ -31,7 +31,8 @@ from paddle.static import InputSpec ...@@ -31,7 +31,8 @@ from paddle.static import InputSpec
from ppdet.core.workspace import create from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.utils.visualizer import visualize_results from ppdet.utils.visualizer import visualize_results
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_categories, get_infer_results from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results
from ppdet.data.source.category import get_categories
import ppdet.utils.stats as stats import ppdet.utils.stats as stats
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter
...@@ -116,8 +117,8 @@ class Trainer(object): ...@@ -116,8 +117,8 @@ class Trainer(object):
self._callbacks = [] self._callbacks = []
self._compose_callback = None self._compose_callback = None
def _init_metrics(self): def _init_metrics(self, validate=False):
if self.mode == 'test': if self.mode == 'test' or (self.mode == 'train' and not validate):
self._metrics = [] self._metrics = []
return return
classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
...@@ -126,9 +127,12 @@ class Trainer(object): ...@@ -126,9 +127,12 @@ class Trainer(object):
bias = self.cfg['bias'] if 'bias' in self.cfg else 0 bias = self.cfg['bias'] if 'bias' in self.cfg else 0
output_eval = self.cfg['output_eval'] \ output_eval = self.cfg['output_eval'] \
if 'output_eval' in self.cfg else None if 'output_eval' in self.cfg else None
clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
if self.mode == 'eval' else None
self._metrics = [ self._metrics = [
COCOMetric( COCOMetric(
anno_file=self.dataset.get_anno(), anno_file=self.dataset.get_anno(),
clsid2catid=clsid2catid,
classwise=classwise, classwise=classwise,
output_eval=output_eval, output_eval=output_eval,
bias=bias) bias=bias)
...@@ -186,6 +190,11 @@ class Trainer(object): ...@@ -186,6 +190,11 @@ class Trainer(object):
def train(self, validate=False): def train(self, validate=False):
assert self.mode == 'train', "Model not in 'train' mode" assert self.mode == 'train', "Model not in 'train' mode"
# if validation in training is enabled, metrics should be re-init
if validate:
self._init_metrics(validate=validate)
self._reset_metrics()
model = self.model model = self.model
if self.cfg.fleet: if self.cfg.fleet:
model = fleet.distributed_model(model) model = fleet.distributed_model(model)
......
...@@ -15,8 +15,4 @@ ...@@ -15,8 +15,4 @@
from . import metrics from . import metrics
from .metrics import * from .metrics import *
from . import category __all__ = metrics.__all__
from .category import *
__all__ = metrics.__all__ \
+ category.__all__
...@@ -22,10 +22,10 @@ import json ...@@ -22,10 +22,10 @@ import json
import paddle import paddle
import numpy as np import numpy as np
from .category import get_categories
from .map_utils import prune_zero_padding, DetectionMAP from .map_utils import prune_zero_padding, DetectionMAP
from .coco_utils import get_infer_results, cocoapi_eval from .coco_utils import get_infer_results, cocoapi_eval
from .widerface_utils import face_eval_run from .widerface_utils import face_eval_run
from ppdet.data.source.category import get_categories
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
...@@ -62,7 +62,9 @@ class COCOMetric(Metric): ...@@ -62,7 +62,9 @@ class COCOMetric(Metric):
assert os.path.isfile(anno_file), \ assert os.path.isfile(anno_file), \
"anno_file {} not a file".format(anno_file) "anno_file {} not a file".format(anno_file)
self.anno_file = anno_file self.anno_file = anno_file
self.clsid2catid, self.catid2name = get_categories('COCO', anno_file) self.clsid2catid = kwargs.get('clsid2catid', None)
if self.clsid2catid is None:
self.clsid2catid, _ = get_categories('COCO', anno_file)
self.classwise = kwargs.get('classwise', False) self.classwise = kwargs.get('classwise', False)
self.output_eval = kwargs.get('output_eval', None) self.output_eval = kwargs.get('output_eval', None)
# TODO: bias should be unified # TODO: bias should be unified
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册