未验证 提交 0e6468c7 编写于 作者: K Kaipeng Deng 提交者: GitHub

refine trainer (#2412)

* refine trainer
上级 2be546c4
......@@ -184,6 +184,7 @@ class BaseDataLoader(object):
batch_sampler=None,
return_list=False):
self.dataset = dataset
self.dataset.check_or_download_dataset()
self.dataset.parse_dataset()
# get data
self.dataset.set_transform(self._sample_transforms)
......
......@@ -16,7 +16,9 @@ from . import coco
# TODO add voc and widerface dataset
from . import voc
#from . import widerface
from . import category
from .coco import *
from .voc import *
#from .widerface import *
from .category import *
......@@ -49,10 +49,10 @@ class COCODataSet(DetDataset):
records = []
ct = 0
catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
cname2cid = dict({
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
self.cname2cid = dict({
coco.loadCats(catid)[0]['name']: clsid
for catid, clsid in catid2clsid.items()
for catid, clsid in self.catid2clsid.items()
})
if 'annotations' not in coco.dataset:
......@@ -119,7 +119,7 @@ class COCODataSet(DetDataset):
has_segmentation = False
for i, box in enumerate(bboxes):
catid = box['category_id']
gt_class[i][0] = catid2clsid[catid]
gt_class[i][0] = self.catid2clsid[catid]
gt_bbox[i, :] = box['clean_bbox']
is_crowd[i][0] = box['iscrowd']
# check RLE format
......@@ -163,4 +163,4 @@ class COCODataSet(DetDataset):
break
assert len(records) > 0, 'not found any coco record in %s' % (anno_path)
logger.debug('{} samples in file {}'.format(ct, anno_path))
self.roidbs, self.cname2cid = records, cname2cid
self.roidbs = records
......@@ -67,6 +67,10 @@ class DetDataset(Dataset):
return self.transform(roidb)
def check_or_download_dataset(self):
self.dataset_dir = get_dataset_path(self.dataset_dir, self.anno_path,
self.image_dir)
def set_kwargs(self, **kwargs):
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
......
......@@ -20,7 +20,7 @@ import os
import yaml
from collections import OrderedDict
from ppdet.metrics import get_categories
from ppdet.data.source.category import get_categories
from ppdet.utils.logger import setup_logger
logger = setup_logger('ppdet.engine')
......
......@@ -31,7 +31,8 @@ from paddle.static import InputSpec
from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.utils.visualizer import visualize_results
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_categories, get_infer_results
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results
from ppdet.data.source.category import get_categories
import ppdet.utils.stats as stats
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter
......@@ -116,8 +117,8 @@ class Trainer(object):
self._callbacks = []
self._compose_callback = None
def _init_metrics(self):
if self.mode == 'test':
def _init_metrics(self, validate=False):
if self.mode == 'test' or (self.mode == 'train' and not validate):
self._metrics = []
return
classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
......@@ -126,9 +127,12 @@ class Trainer(object):
bias = self.cfg['bias'] if 'bias' in self.cfg else 0
output_eval = self.cfg['output_eval'] \
if 'output_eval' in self.cfg else None
clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
if self.mode == 'eval' else None
self._metrics = [
COCOMetric(
anno_file=self.dataset.get_anno(),
clsid2catid=clsid2catid,
classwise=classwise,
output_eval=output_eval,
bias=bias)
......@@ -186,6 +190,11 @@ class Trainer(object):
def train(self, validate=False):
assert self.mode == 'train', "Model not in 'train' mode"
# if validation in training is enabled, metrics should be re-init
if validate:
self._init_metrics(validate=validate)
self._reset_metrics()
model = self.model
if self.cfg.fleet:
model = fleet.distributed_model(model)
......
......@@ -15,8 +15,4 @@
from . import metrics
from .metrics import *
from . import category
from .category import *
__all__ = metrics.__all__ \
+ category.__all__
__all__ = metrics.__all__
......@@ -22,10 +22,10 @@ import json
import paddle
import numpy as np
from .category import get_categories
from .map_utils import prune_zero_padding, DetectionMAP
from .coco_utils import get_infer_results, cocoapi_eval
from .widerface_utils import face_eval_run
from ppdet.data.source.category import get_categories
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
......@@ -62,7 +62,9 @@ class COCOMetric(Metric):
assert os.path.isfile(anno_file), \
"anno_file {} not a file".format(anno_file)
self.anno_file = anno_file
self.clsid2catid, self.catid2name = get_categories('COCO', anno_file)
self.clsid2catid = kwargs.get('clsid2catid', None)
if self.clsid2catid is None:
self.clsid2catid, _ = get_categories('COCO', anno_file)
self.classwise = kwargs.get('classwise', False)
self.output_eval = kwargs.get('output_eval', None)
# TODO: bias should be unified
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册